Created
November 12, 2025 13:17
-
-
Save createthis/f6700a74130591fe7fe8f1a4b05eea48 to your computer and use it in GitHub Desktop.
bench_topk_tilelang.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| import time | |
| import torch | |
| # TileLang example kernels | |
| from examples.deepseek_v32.topk_selector import tl_topk, tl_topk_impl | |
| def bench_tl_topk(seq_len: int, topk: int = 256, batch: int = 1, iters: int = 50, warmup: int = 5): | |
| torch.cuda.synchronize() | |
| x = torch.randn(batch, seq_len, device="cuda", dtype=torch.float32) | |
| starts = torch.zeros(batch, dtype=torch.int32, device="cuda") | |
| ends = torch.full((batch,), seq_len, dtype=torch.int32, device="cuda") | |
| # warmup | |
| for _ in range(warmup): | |
| _ = tl_topk(x, starts, ends, topk) | |
| torch.cuda.synchronize() | |
| # timed | |
| times = [] | |
| for _ in range(iters): | |
| t0 = torch.cuda.Event(enable_timing=True) | |
| t1 = torch.cuda.Event(enable_timing=True) | |
| t0.record() | |
| _ = tl_topk(x, starts, ends, topk) | |
| t1.record() | |
| t1.synchronize() | |
| times.append(t0.elapsed_time(t1)) # ms | |
| avg_ms = sum(times) / len(times) | |
| print(f"[TILELANG_TOPK] seq_len={seq_len} topk={topk} batch={batch} avg_ms={avg_ms:.3f} over {iters}") | |
| def bench_tl_topk_impl(seq_len: int, topk: int = 256, batch: int = 1, iters: int = 50, warmup: int = 5): | |
| # compile kernel once | |
| kernel = tl_topk_impl(topk) | |
| x = torch.randn(batch, seq_len, device="cuda", dtype=torch.float32) | |
| starts = torch.zeros(batch, dtype=torch.int32, device="cuda") | |
| ends = torch.full((batch,), seq_len, dtype=torch.int32, device="cuda") | |
| out = torch.empty(batch, topk, dtype=torch.int32, device="cuda") | |
| # warmup | |
| for _ in range(warmup): | |
| kernel(x, out, starts, ends) | |
| torch.cuda.synchronize() | |
| # timed | |
| times = [] | |
| for _ in range(iters): | |
| t0 = torch.cuda.Event(enable_timing=True) | |
| t1 = torch.cuda.Event(enable_timing=True) | |
| t0.record() | |
| kernel(x, out, starts, ends) | |
| t1.record() | |
| t1.synchronize() | |
| times.append(t0.elapsed_time(t1)) # ms | |
| avg_ms = sum(times) / len(times) | |
| print(f"[TILELANG_TOPK_IMPL] seq_len={seq_len} topk={topk} batch={batch} avg_ms={avg_ms:.3f} over {iters}") | |
| def parse_seq_lens(s: str): | |
| vals = [] | |
| for part in s.split(','): | |
| part = part.strip() | |
| if not part: | |
| continue | |
| vals.append(int(part)) | |
| return vals | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Benchmark TileLang top-k selector without full model") | |
| parser.add_argument("--seq-lens", type=parse_seq_lens, default="4096,16384,163840", | |
| help="Comma-separated sequence lengths to test (default: 4096,16384,163840)") | |
| parser.add_argument("--topk", type=int, default=256, help="Top-k to select (default: 256)") | |
| parser.add_argument("--batch", type=int, default=1, help="Batch size (T) (default: 1)") | |
| parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)") | |
| parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)") | |
| parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both", | |
| help="Which path to benchmark: tl_topk wrapper, tl_topk_impl kernel, or both (default)") | |
| args = parser.parse_args() | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for this benchmark.") | |
| dev = torch.cuda.get_device_name(0) | |
| print(f"CUDA device: {dev}") | |
| seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_seq_lens(args.seq_lens) | |
| for kv in seq_lens: | |
| if args.mode in ("both", "wrapper"): | |
| bench_tl_topk(kv, topk=args.topk, batch=args.batch, iters=args.iters, warmup=args.warmup) | |
| if args.mode in ("both", "impl"): | |
| bench_tl_topk_impl(kv, topk=args.topk, batch=args.batch, iters=args.iters, warmup=args.warmup) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment