Created
December 9, 2025 16:06
-
-
Save samwho/210cf220723c8b0a43a86ebb330bc417 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import argparse | |
| import time | |
| from dataclasses import dataclass | |
| import mlx.core as mx | |
| from mlx_lm import load | |
| from mlx_lm.generate import BatchGenerator | |
| from rich.console import Console | |
| from rich.table import Table | |
| @dataclass | |
| class Branch: | |
| prompt: list[int] | |
| answer: list[int] | |
| probability: float = 1 | |
| finish_reason: str | None = None | |
| def process_batch( | |
| model, | |
| tokenizer, | |
| branches: list[Branch], | |
| min_probability, | |
| topk, | |
| ): | |
| """ | |
| Like mlx_lm.batch_generate, but returns (token, logprobs) for each step. | |
| """ | |
| gen = BatchGenerator(model) | |
| uids = gen.insert([branch.prompt + branch.answer for branch in branches], max_tokens=1) | |
| branches_by_uid = {uid: branch for uid, branch in zip(uids, branches)} | |
| finished_branches: list[Branch] = [] | |
| new_branches: list[Branch] = [] | |
| while responses := gen.next(): | |
| for r in responses: | |
| branch = branches_by_uid[r.uid] | |
| probs = mx.softmax(r.logprobs, axis=-1) | |
| # mx.topk does not return indices for 1D inputs, so sort manually. | |
| k = min(topk, probs.shape[0]) | |
| top_indices = mx.argsort(probs)[-k:][::-1] | |
| top_probs = mx.take(probs, top_indices) | |
| top_indices = top_indices.astype(mx.int64).tolist() | |
| top_probs = mx.reshape(top_probs, (-1,)).tolist() | |
| for _, (token_id, prob) in enumerate(zip(top_indices, top_probs), start=1): | |
| new_branch = Branch( | |
| prompt=branch.prompt, | |
| answer=branch.answer + [token_id], | |
| probability=branch.probability * prob | |
| ) | |
| if new_branch.probability < min_probability: | |
| new_branch.finish_reason = "low_probability" | |
| finished_branches.append(new_branch) | |
| continue | |
| if token_id in tokenizer.eos_token_ids: | |
| new_branch.finish_reason = "eos_token" | |
| finished_branches.append(new_branch) | |
| continue | |
| if token_id in tokenizer._tokenizer.all_special_ids: | |
| new_branch.finish_reason = "special_token" | |
| finished_branches.append(new_branch) | |
| continue | |
| new_branches.append(new_branch) | |
| return new_branches, finished_branches | |
| def print_branches_table(tokenizer, branches: list[Branch]) -> None: | |
| console = Console() | |
| table = Table() | |
| table.add_column("Rank", justify="right") | |
| table.add_column("Probability") | |
| table.add_column("Finish Reason") | |
| table.add_column("Text", overflow="fold") | |
| for idx, branch in enumerate(branches, start=1): | |
| token_str = tokenizer.decode(branch.answer, skip_special_tokens=True) | |
| table.add_row( | |
| str(idx), | |
| f"{branch.probability * 100:.2f}%", | |
| branch.finish_reason or "-", | |
| token_str, | |
| ) | |
| console.print(table) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-p", "--prompt", default="What is 2+2?", help="Prompt to score") | |
| parser.add_argument("-m", "--model", default="mlx-community/Llama-3.2-1B-Instruct-4bit") | |
| parser.add_argument("--min-probability", type=float, default=0.0001) | |
| parser.add_argument("--topk", default=50) | |
| args = parser.parse_args() | |
| load_resp = load(args.model) | |
| model = load_resp[0] | |
| tokenizer = load_resp[1] | |
| input_ids = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": args.prompt}], | |
| add_generation_prompt=True, | |
| ) | |
| root = Branch(prompt=input_ids, answer=[]) | |
| branches, result = process_batch(model, tokenizer, [root], args.min_probability, args.topk) | |
| while len(branches): | |
| start = time.perf_counter() | |
| new_branches, finished_branches = process_batch(model, tokenizer, branches, args.min_probability, args.topk) | |
| branches = new_branches | |
| result.extend(finished_branches) | |
| elapsed = time.perf_counter() - start | |
| tps = (len(new_branches) + len(finished_branches)) / elapsed | |
| print(f"Queue: {len(branches)}, tps: {tps:.2f}") | |
| result.sort(key=lambda branch: branch.probability, reverse=True) | |
| print_branches_table(tokenizer, result[:20]) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment