Skip to content

Instantly share code, notes, and snippets.

@createthis
Created November 12, 2025 13:17
Show Gist options
  • Select an option

  • Save createthis/f6700a74130591fe7fe8f1a4b05eea48 to your computer and use it in GitHub Desktop.

Select an option

Save createthis/f6700a74130591fe7fe8f1a4b05eea48 to your computer and use it in GitHub Desktop.
bench_topk_tilelang.py
#!/usr/bin/env python3
import argparse
import time
import torch
# TileLang example kernels
from examples.deepseek_v32.topk_selector import tl_topk, tl_topk_impl
def bench_tl_topk(seq_len: int, topk: int = 256, batch: int = 1, iters: int = 50, warmup: int = 5):
torch.cuda.synchronize()
x = torch.randn(batch, seq_len, device="cuda", dtype=torch.float32)
starts = torch.zeros(batch, dtype=torch.int32, device="cuda")
ends = torch.full((batch,), seq_len, dtype=torch.int32, device="cuda")
# warmup
for _ in range(warmup):
_ = tl_topk(x, starts, ends, topk)
torch.cuda.synchronize()
# timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
_ = tl_topk(x, starts, ends, topk)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times)
print(f"[TILELANG_TOPK] seq_len={seq_len} topk={topk} batch={batch} avg_ms={avg_ms:.3f} over {iters}")
def bench_tl_topk_impl(seq_len: int, topk: int = 256, batch: int = 1, iters: int = 50, warmup: int = 5):
# compile kernel once
kernel = tl_topk_impl(topk)
x = torch.randn(batch, seq_len, device="cuda", dtype=torch.float32)
starts = torch.zeros(batch, dtype=torch.int32, device="cuda")
ends = torch.full((batch,), seq_len, dtype=torch.int32, device="cuda")
out = torch.empty(batch, topk, dtype=torch.int32, device="cuda")
# warmup
for _ in range(warmup):
kernel(x, out, starts, ends)
torch.cuda.synchronize()
# timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
kernel(x, out, starts, ends)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times)
print(f"[TILELANG_TOPK_IMPL] seq_len={seq_len} topk={topk} batch={batch} avg_ms={avg_ms:.3f} over {iters}")
def parse_seq_lens(s: str):
vals = []
for part in s.split(','):
part = part.strip()
if not part:
continue
vals.append(int(part))
return vals
def main():
parser = argparse.ArgumentParser(description="Benchmark TileLang top-k selector without full model")
parser.add_argument("--seq-lens", type=parse_seq_lens, default="4096,16384,163840",
help="Comma-separated sequence lengths to test (default: 4096,16384,163840)")
parser.add_argument("--topk", type=int, default=256, help="Top-k to select (default: 256)")
parser.add_argument("--batch", type=int, default=1, help="Batch size (T) (default: 1)")
parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)")
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)")
parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both",
help="Which path to benchmark: tl_topk wrapper, tl_topk_impl kernel, or both (default)")
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
dev = torch.cuda.get_device_name(0)
print(f"CUDA device: {dev}")
seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_seq_lens(args.seq_lens)
for kv in seq_lens:
if args.mode in ("both", "wrapper"):
bench_tl_topk(kv, topk=args.topk, batch=args.batch, iters=args.iters, warmup=args.warmup)
if args.mode in ("both", "impl"):
bench_tl_topk_impl(kv, topk=args.topk, batch=args.batch, iters=args.iters, warmup=args.warmup)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment