Skip to content

Instantly share code, notes, and snippets.

@ivanfioravanti
Created August 22, 2025 22:16
Show Gist options
  • Select an option

  • Save ivanfioravanti/78213074c3ae4d7729bd480e2aafbeec to your computer and use it in GitHub Desktop.

Select an option

Save ivanfioravanti/78213074c3ae4d7729bd480e2aafbeec to your computer and use it in GitHub Desktop.
mlx-bench
#!/usr/bin/env python3
"""
MLX benchmark script that replicates llama-bench behavior exactly.
Uses random tokens for both prompt and generation, no sampling.
"""
import mlx.core as mx
import mlx_lm
from mlx_lm.models.cache import make_prompt_cache
import time
import numpy as np
import argparse
import sys
def test_prompt(model, tokenizer, n_prompt, n_batch=2048):
"""
Test prompt processing with random tokens, exactly like llama-bench.
"""
# Get vocab size
try:
vocab_size = tokenizer.vocab_size
except AttributeError:
try:
vocab_size = tokenizer._tokenizer.vocab_size
except AttributeError:
vocab_size = 32000 # Common default
# Generate random token IDs (not text!)
token_ids = np.random.randint(0, vocab_size, n_prompt).tolist()
# Process tokens through model (batched like llama-bench)
n_processed = 0
cache = make_prompt_cache(model)
start = time.perf_counter()
while n_processed < n_prompt:
batch_size = min(n_prompt - n_processed, n_batch)
batch_tokens = token_ids[n_processed:n_processed + batch_size]
# Convert to MLX array
tokens = mx.array([batch_tokens])
# Forward pass only - no sampling
logits = model(tokens, cache=cache)
mx.eval(logits) # Force evaluation
n_processed += batch_size
end = time.perf_counter()
return end - start
def test_gen(model, tokenizer, n_gen):
"""
Test generation with random tokens, exactly like llama-bench.
No sampling - just forward passes with random tokens.
"""
# Get vocab size
try:
vocab_size = tokenizer.vocab_size
except AttributeError:
try:
vocab_size = tokenizer._tokenizer.vocab_size
except AttributeError:
vocab_size = 32000 # Common default
# Start with a random token
token_id = np.random.randint(0, vocab_size)
cache = make_prompt_cache(model)
start = time.perf_counter()
for i in range(n_gen):
# Convert to MLX array
tokens = mx.array([[token_id]])
# Forward pass only - no sampling
logits = model(tokens, cache=cache)
mx.eval(logits) # Force evaluation
# Next token is random (not sampled from logits!)
token_id = np.random.randint(0, vocab_size)
end = time.perf_counter()
return end - start
def run_benchmark(model_path, n_prompt=512, n_gen=128, n_reps=5, warmup=True):
"""
Run benchmark exactly like llama-bench.
Args:
model_path: Path to the MLX model
n_prompt: Number of prompt tokens (default: 512)
n_gen: Number of generation tokens (default: 128)
n_reps: Number of repetitions (default: 5)
warmup: Whether to do warmup runs (default: True)
"""
print(f"MLX-bench (llama-bench compatible mode)")
print(f"=======================================")
print(f"Loading model from: {model_path}")
model, tokenizer = mlx_lm.load(model_path)
print(f"Test configuration:")
print(f" Prompt tokens: {n_prompt} (random)")
print(f" Generation tokens: {n_gen} (random)")
print(f" Repetitions: {n_reps}")
print(f" Warmup: {warmup}")
print()
# Warmup runs (like llama-bench does)
if warmup:
print("Running warmup...")
if n_prompt > 0:
print(" Warmup prompt processing...")
_ = test_prompt(model, tokenizer, min(32, n_prompt))
if n_gen > 0:
print(" Warmup generation...")
_ = test_gen(model, tokenizer, 1)
print("Warmup complete.\n")
# Benchmark runs
print("Running benchmark...")
prompt_times = []
gen_times = []
total_times = []
for i in range(n_reps):
# Clear any caches
mx.eval(model.parameters())
prompt_time = 0
gen_time = 0
# Test prompt processing
if n_prompt > 0:
prompt_time = test_prompt(model, tokenizer, n_prompt)
# Test generation
if n_gen > 0:
gen_time = test_gen(model, tokenizer, n_gen)
total_time = prompt_time + gen_time
prompt_times.append(prompt_time)
gen_times.append(gen_time)
total_times.append(total_time)
# Calculate tokens per second
total_tokens = n_prompt + n_gen
tps = total_tokens / total_time if total_time > 0 else 0
print(f"Run {i+1}/{n_reps}:")
if n_prompt > 0:
pp_tps = n_prompt / prompt_time if prompt_time > 0 else 0
print(f" Prompt: {prompt_time:.3f}s ({pp_tps:.2f} tokens/sec)")
if n_gen > 0:
tg_tps = n_gen / gen_time if gen_time > 0 else 0
print(f" Generation: {gen_time:.3f}s ({tg_tps:.2f} tokens/sec)")
print(f" Total: {total_time:.3f}s ({tps:.2f} tokens/sec)")
# Calculate statistics
total_tokens = n_prompt + n_gen
# Overall statistics
avg_total = np.mean(total_times)
std_total = np.std(total_times)
avg_tps = total_tokens / avg_total if avg_total > 0 else 0
tps_values = [total_tokens / t for t in total_times if t > 0]
std_tps = np.std(tps_values) if tps_values else 0
print("\n" + "="*50)
print("Benchmark Results (llama-bench compatible):")
print("="*50)
print(f"Model: {model_path}")
# Test description (like llama-bench output)
test_desc = []
if n_prompt > 0:
test_desc.append(f"pp{n_prompt}")
if n_gen > 0:
test_desc.append(f"tg{n_gen}")
print(f"Test: {'+'.join(test_desc)}")
print(f"\nPrompt processing ({n_prompt} tokens):")
if n_prompt > 0:
avg_pp = np.mean(prompt_times)
std_pp = np.std(prompt_times)
avg_pp_tps = n_prompt / avg_pp if avg_pp > 0 else 0
print(f" Time: {avg_pp:.3f} ± {std_pp:.3f}s")
print(f" Speed: {avg_pp_tps:.2f} tokens/sec")
print(f"\nGeneration ({n_gen} tokens):")
if n_gen > 0:
avg_tg = np.mean(gen_times)
std_tg = np.std(gen_times)
avg_tg_tps = n_gen / avg_tg if avg_tg > 0 else 0
print(f" Time: {avg_tg:.3f} ± {std_tg:.3f}s")
print(f" Speed: {avg_tg_tps:.2f} tokens/sec")
print(f"\nOverall ({total_tokens} tokens):")
print(f" Time: {avg_total:.3f} ± {std_total:.3f}s")
print(f" Speed: {avg_tps:.2f} ± {std_tps:.2f} tokens/sec")
return {
'prompt_times': prompt_times,
'gen_times': gen_times,
'total_times': total_times,
'avg_tps': avg_tps,
'std_tps': std_tps
}
def main():
parser = argparse.ArgumentParser(
description='MLX benchmark that exactly replicates llama-bench behavior'
)
parser.add_argument('-m', '--model', type=str, required=True, help='Path to MLX model')
parser.add_argument('-p', '--n-prompt', type=int, default=512,
help='Number of prompt tokens (default: 512)')
parser.add_argument('-n', '--n-gen', type=int, default=128,
help='Number of generation tokens (default: 128)')
parser.add_argument('-b', '--batch-size', type=int, default=2048,
help='Batch size for prompt processing (default: 2048)')
parser.add_argument('-r', '--reps', type=int, default=5,
help='Number of repetitions (default: 5)')
parser.add_argument('--no-warmup', action='store_true',
help='Skip warmup run')
args = parser.parse_args()
try:
results = run_benchmark(
model_path=args.model,
n_prompt=args.n_prompt,
n_gen=args.n_gen,
n_reps=args.reps,
warmup=not args.no_warmup
)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment