Skip to content

Instantly share code, notes, and snippets.

@samwho
Created December 9, 2025 16:06
Show Gist options
  • Select an option

  • Save samwho/210cf220723c8b0a43a86ebb330bc417 to your computer and use it in GitHub Desktop.

Select an option

Save samwho/210cf220723c8b0a43a86ebb330bc417 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import argparse
import time
from dataclasses import dataclass
import mlx.core as mx
from mlx_lm import load
from mlx_lm.generate import BatchGenerator
from rich.console import Console
from rich.table import Table
@dataclass
class Branch:
prompt: list[int]
answer: list[int]
probability: float = 1
finish_reason: str | None = None
def process_batch(
model,
tokenizer,
branches: list[Branch],
min_probability,
topk,
):
"""
Like mlx_lm.batch_generate, but returns (token, logprobs) for each step.
"""
gen = BatchGenerator(model)
uids = gen.insert([branch.prompt + branch.answer for branch in branches], max_tokens=1)
branches_by_uid = {uid: branch for uid, branch in zip(uids, branches)}
finished_branches: list[Branch] = []
new_branches: list[Branch] = []
while responses := gen.next():
for r in responses:
branch = branches_by_uid[r.uid]
probs = mx.softmax(r.logprobs, axis=-1)
# mx.topk does not return indices for 1D inputs, so sort manually.
k = min(topk, probs.shape[0])
top_indices = mx.argsort(probs)[-k:][::-1]
top_probs = mx.take(probs, top_indices)
top_indices = top_indices.astype(mx.int64).tolist()
top_probs = mx.reshape(top_probs, (-1,)).tolist()
for _, (token_id, prob) in enumerate(zip(top_indices, top_probs), start=1):
new_branch = Branch(
prompt=branch.prompt,
answer=branch.answer + [token_id],
probability=branch.probability * prob
)
if new_branch.probability < min_probability:
new_branch.finish_reason = "low_probability"
finished_branches.append(new_branch)
continue
if token_id in tokenizer.eos_token_ids:
new_branch.finish_reason = "eos_token"
finished_branches.append(new_branch)
continue
if token_id in tokenizer._tokenizer.all_special_ids:
new_branch.finish_reason = "special_token"
finished_branches.append(new_branch)
continue
new_branches.append(new_branch)
return new_branches, finished_branches
def print_branches_table(tokenizer, branches: list[Branch]) -> None:
console = Console()
table = Table()
table.add_column("Rank", justify="right")
table.add_column("Probability")
table.add_column("Finish Reason")
table.add_column("Text", overflow="fold")
for idx, branch in enumerate(branches, start=1):
token_str = tokenizer.decode(branch.answer, skip_special_tokens=True)
table.add_row(
str(idx),
f"{branch.probability * 100:.2f}%",
branch.finish_reason or "-",
token_str,
)
console.print(table)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--prompt", default="What is 2+2?", help="Prompt to score")
parser.add_argument("-m", "--model", default="mlx-community/Llama-3.2-1B-Instruct-4bit")
parser.add_argument("--min-probability", type=float, default=0.0001)
parser.add_argument("--topk", default=50)
args = parser.parse_args()
load_resp = load(args.model)
model = load_resp[0]
tokenizer = load_resp[1]
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": args.prompt}],
add_generation_prompt=True,
)
root = Branch(prompt=input_ids, answer=[])
branches, result = process_batch(model, tokenizer, [root], args.min_probability, args.topk)
while len(branches):
start = time.perf_counter()
new_branches, finished_branches = process_batch(model, tokenizer, branches, args.min_probability, args.topk)
branches = new_branches
result.extend(finished_branches)
elapsed = time.perf_counter() - start
tps = (len(new_branches) + len(finished_branches)) / elapsed
print(f"Queue: {len(branches)}, tps: {tps:.2f}")
result.sort(key=lambda branch: branch.probability, reverse=True)
print_branches_table(tokenizer, result[:20])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment