Created
November 12, 2025 13:14
-
-
Save createthis/8f271607722880b8df608ba7ef21d23a to your computer and use it in GitHub Desktop.
bench_indexer_tilelang.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| import torch | |
| # Prefer local examples path resolution if running from repo root | |
| try: | |
| from examples.deepseek_v32.utils import per_custom_dims_cast_to_fp8 as _to_fp8 | |
| def to_fp8(x): | |
| # Cast along last dim to FP8 E4M3 to match kernel expectations | |
| # Handle both (x, dims, use_ue8m0) and (x, dims) signatures and return the scaled tensor only. | |
| try: | |
| x_scaled, _ = _to_fp8(x, dims=(-1,), use_ue8m0=False) | |
| return x_scaled | |
| except TypeError: | |
| out = _to_fp8(x, dims=(-1,)) | |
| return out[0] if isinstance(out, tuple) else out | |
| except Exception: | |
| def to_fp8(x): | |
| if not hasattr(torch, "float8_e4m3fn"): | |
| raise RuntimeError("torch.float8_e4m3fn not available; install a CUDA-enabled PyTorch.") | |
| return x.to(torch.float8_e4m3fn) | |
| # TileLang example kernels for lightning indexer | |
| from examples.deepseek_v32.fp8_lighting_indexer import ( | |
| mqa_attn_return_logits, | |
| mqa_attn_return_logits_interface, | |
| ) | |
| def bench_tl_indexer_wrapper(seq_len: int, | |
| seq_len_kv: int, | |
| heads: int = 4, | |
| index_dim: int = 64, | |
| iters: int = 50, | |
| warmup: int = 5): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for this benchmark.") | |
| device = torch.device("cuda") | |
| # Inputs | |
| q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32) | |
| kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32) | |
| # Convert to FP8 E4M3 to match kernel signature | |
| q_fp8 = to_fp8(q) | |
| kv_fp8 = to_fp8(kv) | |
| # Precompute kv_scales similar to reference: sqrt(mean(k^2)) along dim=-1 | |
| kv_scales = kv.pow(2).mean(dim=-1).sqrt() | |
| weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32) | |
| cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device) | |
| cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device) | |
| # Warmup | |
| for _ in range(warmup): | |
| _ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke) | |
| torch.cuda.synchronize() | |
| # Timed | |
| times = [] | |
| for _ in range(iters): | |
| t0 = torch.cuda.Event(enable_timing=True) | |
| t1 = torch.cuda.Event(enable_timing=True) | |
| t0.record() | |
| _ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke) | |
| t1.record() | |
| t1.synchronize() | |
| times.append(t0.elapsed_time(t1)) # ms | |
| avg_ms = sum(times) / len(times) if times else float("nan") | |
| print(f"[TILELANG_INDEXER] WRAPPER S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}") | |
| def bench_tl_indexer_impl(seq_len: int, | |
| seq_len_kv: int, | |
| heads: int = 4, | |
| index_dim: int = 64, | |
| iters: int = 50, | |
| warmup: int = 5): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for this benchmark.") | |
| device = torch.device("cuda") | |
| # Compile kernel once | |
| kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim) | |
| # Inputs | |
| q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32) | |
| kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32) | |
| # Convert to FP8 E4M3 to match kernel signature | |
| q_fp8 = to_fp8(q) | |
| kv_fp8 = to_fp8(kv) | |
| kv_scales = kv.pow(2).mean(dim=-1).sqrt() | |
| weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32) | |
| cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device) | |
| cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device) | |
| logits = torch.empty(seq_len, seq_len_kv, device=device, dtype=torch.float32) | |
| # Warmup | |
| for _ in range(warmup): | |
| kernel( | |
| q_fp8.view(seq_len * heads, index_dim), | |
| kv_fp8, | |
| kv_scales, | |
| logits, | |
| weights, | |
| cu_seqlen_ks, | |
| cu_seqlen_ke, | |
| ) | |
| torch.cuda.synchronize() | |
| # Timed | |
| times = [] | |
| for _ in range(iters): | |
| t0 = torch.cuda.Event(enable_timing=True) | |
| t1 = torch.cuda.Event(enable_timing=True) | |
| t0.record() | |
| kernel( | |
| q_fp8.view(seq_len * heads, index_dim), | |
| kv_fp8, | |
| kv_scales, | |
| logits, | |
| weights, | |
| cu_seqlen_ks, | |
| cu_seqlen_ke, | |
| ) | |
| t1.record() | |
| t1.synchronize() | |
| times.append(t0.elapsed_time(t1)) # ms | |
| avg_ms = sum(times) / len(times) if times else float("nan") | |
| print(f"[TILELANG_INDEXER] IMPL S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}") | |
| def parse_int_list(s: str): | |
| vals = [] | |
| for part in s.split(','): | |
| part = part.strip() | |
| if not part: | |
| continue | |
| vals.append(int(part)) | |
| return vals | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Benchmark TileLang lightning indexer (DeepSeek V3.2)") | |
| parser.add_argument("--seq-lens", type=parse_int_list, default="4096,16384,163840", | |
| help="Comma-separated sequence lengths S (default: 4096,16384,163840)") | |
| parser.add_argument("--kv-lens", type=parse_int_list, default=None, | |
| help="Comma-separated KV lengths SKV; if omitted, uses seq-lens") | |
| parser.add_argument("--heads", type=int, default=4, help="Indexer heads H (default: 4)") | |
| parser.add_argument("--dim", type=int, default=64, help="Indexer dimension D (default: 64)") | |
| parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)") | |
| parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)") | |
| parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both", | |
| help="Which path to benchmark: wrapper (interface), impl (kernel), or both (default)") | |
| args = parser.parse_args() | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for this benchmark.") | |
| dev = torch.cuda.get_device_name(0) | |
| print(f"CUDA device: {dev}") | |
| seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_int_list(args.seq_lens) | |
| kv_lens = None | |
| if args.kv_lens is None: | |
| kv_lens = seq_lens | |
| else: | |
| kv_lens = args.kv_lens if isinstance(args.kv_lens, list) else parse_int_list(args.kv_lens) | |
| if len(kv_lens) != len(seq_lens): | |
| raise ValueError("--kv-lens must have the same number of elements as --seq-lens") | |
| for S, SKV in zip(seq_lens, kv_lens): | |
| if args.mode in ("both", "wrapper"): | |
| bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup) | |
| if args.mode in ("both", "impl"): | |
| bench_tl_indexer_impl(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment