Created
August 22, 2025 22:16
-
-
Save ivanfioravanti/78213074c3ae4d7729bd480e2aafbeec to your computer and use it in GitHub Desktop.
mlx-bench
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| MLX benchmark script that replicates llama-bench behavior exactly. | |
| Uses random tokens for both prompt and generation, no sampling. | |
| """ | |
| import mlx.core as mx | |
| import mlx_lm | |
| from mlx_lm.models.cache import make_prompt_cache | |
| import time | |
| import numpy as np | |
| import argparse | |
| import sys | |
| def test_prompt(model, tokenizer, n_prompt, n_batch=2048): | |
| """ | |
| Test prompt processing with random tokens, exactly like llama-bench. | |
| """ | |
| # Get vocab size | |
| try: | |
| vocab_size = tokenizer.vocab_size | |
| except AttributeError: | |
| try: | |
| vocab_size = tokenizer._tokenizer.vocab_size | |
| except AttributeError: | |
| vocab_size = 32000 # Common default | |
| # Generate random token IDs (not text!) | |
| token_ids = np.random.randint(0, vocab_size, n_prompt).tolist() | |
| # Process tokens through model (batched like llama-bench) | |
| n_processed = 0 | |
| cache = make_prompt_cache(model) | |
| start = time.perf_counter() | |
| while n_processed < n_prompt: | |
| batch_size = min(n_prompt - n_processed, n_batch) | |
| batch_tokens = token_ids[n_processed:n_processed + batch_size] | |
| # Convert to MLX array | |
| tokens = mx.array([batch_tokens]) | |
| # Forward pass only - no sampling | |
| logits = model(tokens, cache=cache) | |
| mx.eval(logits) # Force evaluation | |
| n_processed += batch_size | |
| end = time.perf_counter() | |
| return end - start | |
| def test_gen(model, tokenizer, n_gen): | |
| """ | |
| Test generation with random tokens, exactly like llama-bench. | |
| No sampling - just forward passes with random tokens. | |
| """ | |
| # Get vocab size | |
| try: | |
| vocab_size = tokenizer.vocab_size | |
| except AttributeError: | |
| try: | |
| vocab_size = tokenizer._tokenizer.vocab_size | |
| except AttributeError: | |
| vocab_size = 32000 # Common default | |
| # Start with a random token | |
| token_id = np.random.randint(0, vocab_size) | |
| cache = make_prompt_cache(model) | |
| start = time.perf_counter() | |
| for i in range(n_gen): | |
| # Convert to MLX array | |
| tokens = mx.array([[token_id]]) | |
| # Forward pass only - no sampling | |
| logits = model(tokens, cache=cache) | |
| mx.eval(logits) # Force evaluation | |
| # Next token is random (not sampled from logits!) | |
| token_id = np.random.randint(0, vocab_size) | |
| end = time.perf_counter() | |
| return end - start | |
| def run_benchmark(model_path, n_prompt=512, n_gen=128, n_reps=5, warmup=True): | |
| """ | |
| Run benchmark exactly like llama-bench. | |
| Args: | |
| model_path: Path to the MLX model | |
| n_prompt: Number of prompt tokens (default: 512) | |
| n_gen: Number of generation tokens (default: 128) | |
| n_reps: Number of repetitions (default: 5) | |
| warmup: Whether to do warmup runs (default: True) | |
| """ | |
| print(f"MLX-bench (llama-bench compatible mode)") | |
| print(f"=======================================") | |
| print(f"Loading model from: {model_path}") | |
| model, tokenizer = mlx_lm.load(model_path) | |
| print(f"Test configuration:") | |
| print(f" Prompt tokens: {n_prompt} (random)") | |
| print(f" Generation tokens: {n_gen} (random)") | |
| print(f" Repetitions: {n_reps}") | |
| print(f" Warmup: {warmup}") | |
| print() | |
| # Warmup runs (like llama-bench does) | |
| if warmup: | |
| print("Running warmup...") | |
| if n_prompt > 0: | |
| print(" Warmup prompt processing...") | |
| _ = test_prompt(model, tokenizer, min(32, n_prompt)) | |
| if n_gen > 0: | |
| print(" Warmup generation...") | |
| _ = test_gen(model, tokenizer, 1) | |
| print("Warmup complete.\n") | |
| # Benchmark runs | |
| print("Running benchmark...") | |
| prompt_times = [] | |
| gen_times = [] | |
| total_times = [] | |
| for i in range(n_reps): | |
| # Clear any caches | |
| mx.eval(model.parameters()) | |
| prompt_time = 0 | |
| gen_time = 0 | |
| # Test prompt processing | |
| if n_prompt > 0: | |
| prompt_time = test_prompt(model, tokenizer, n_prompt) | |
| # Test generation | |
| if n_gen > 0: | |
| gen_time = test_gen(model, tokenizer, n_gen) | |
| total_time = prompt_time + gen_time | |
| prompt_times.append(prompt_time) | |
| gen_times.append(gen_time) | |
| total_times.append(total_time) | |
| # Calculate tokens per second | |
| total_tokens = n_prompt + n_gen | |
| tps = total_tokens / total_time if total_time > 0 else 0 | |
| print(f"Run {i+1}/{n_reps}:") | |
| if n_prompt > 0: | |
| pp_tps = n_prompt / prompt_time if prompt_time > 0 else 0 | |
| print(f" Prompt: {prompt_time:.3f}s ({pp_tps:.2f} tokens/sec)") | |
| if n_gen > 0: | |
| tg_tps = n_gen / gen_time if gen_time > 0 else 0 | |
| print(f" Generation: {gen_time:.3f}s ({tg_tps:.2f} tokens/sec)") | |
| print(f" Total: {total_time:.3f}s ({tps:.2f} tokens/sec)") | |
| # Calculate statistics | |
| total_tokens = n_prompt + n_gen | |
| # Overall statistics | |
| avg_total = np.mean(total_times) | |
| std_total = np.std(total_times) | |
| avg_tps = total_tokens / avg_total if avg_total > 0 else 0 | |
| tps_values = [total_tokens / t for t in total_times if t > 0] | |
| std_tps = np.std(tps_values) if tps_values else 0 | |
| print("\n" + "="*50) | |
| print("Benchmark Results (llama-bench compatible):") | |
| print("="*50) | |
| print(f"Model: {model_path}") | |
| # Test description (like llama-bench output) | |
| test_desc = [] | |
| if n_prompt > 0: | |
| test_desc.append(f"pp{n_prompt}") | |
| if n_gen > 0: | |
| test_desc.append(f"tg{n_gen}") | |
| print(f"Test: {'+'.join(test_desc)}") | |
| print(f"\nPrompt processing ({n_prompt} tokens):") | |
| if n_prompt > 0: | |
| avg_pp = np.mean(prompt_times) | |
| std_pp = np.std(prompt_times) | |
| avg_pp_tps = n_prompt / avg_pp if avg_pp > 0 else 0 | |
| print(f" Time: {avg_pp:.3f} ± {std_pp:.3f}s") | |
| print(f" Speed: {avg_pp_tps:.2f} tokens/sec") | |
| print(f"\nGeneration ({n_gen} tokens):") | |
| if n_gen > 0: | |
| avg_tg = np.mean(gen_times) | |
| std_tg = np.std(gen_times) | |
| avg_tg_tps = n_gen / avg_tg if avg_tg > 0 else 0 | |
| print(f" Time: {avg_tg:.3f} ± {std_tg:.3f}s") | |
| print(f" Speed: {avg_tg_tps:.2f} tokens/sec") | |
| print(f"\nOverall ({total_tokens} tokens):") | |
| print(f" Time: {avg_total:.3f} ± {std_total:.3f}s") | |
| print(f" Speed: {avg_tps:.2f} ± {std_tps:.2f} tokens/sec") | |
| return { | |
| 'prompt_times': prompt_times, | |
| 'gen_times': gen_times, | |
| 'total_times': total_times, | |
| 'avg_tps': avg_tps, | |
| 'std_tps': std_tps | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='MLX benchmark that exactly replicates llama-bench behavior' | |
| ) | |
| parser.add_argument('-m', '--model', type=str, required=True, help='Path to MLX model') | |
| parser.add_argument('-p', '--n-prompt', type=int, default=512, | |
| help='Number of prompt tokens (default: 512)') | |
| parser.add_argument('-n', '--n-gen', type=int, default=128, | |
| help='Number of generation tokens (default: 128)') | |
| parser.add_argument('-b', '--batch-size', type=int, default=2048, | |
| help='Batch size for prompt processing (default: 2048)') | |
| parser.add_argument('-r', '--reps', type=int, default=5, | |
| help='Number of repetitions (default: 5)') | |
| parser.add_argument('--no-warmup', action='store_true', | |
| help='Skip warmup run') | |
| args = parser.parse_args() | |
| try: | |
| results = run_benchmark( | |
| model_path=args.model, | |
| n_prompt=args.n_prompt, | |
| n_gen=args.n_gen, | |
| n_reps=args.reps, | |
| warmup=not args.no_warmup | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}", file=sys.stderr) | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment