Skip to content

Instantly share code, notes, and snippets.

@createthis
Created November 12, 2025 13:14
Show Gist options
  • Select an option

  • Save createthis/8f271607722880b8df608ba7ef21d23a to your computer and use it in GitHub Desktop.

Select an option

Save createthis/8f271607722880b8df608ba7ef21d23a to your computer and use it in GitHub Desktop.
bench_indexer_tilelang.py
#!/usr/bin/env python3
import argparse
import torch
# Prefer local examples path resolution if running from repo root
try:
from examples.deepseek_v32.utils import per_custom_dims_cast_to_fp8 as _to_fp8
def to_fp8(x):
# Cast along last dim to FP8 E4M3 to match kernel expectations
# Handle both (x, dims, use_ue8m0) and (x, dims) signatures and return the scaled tensor only.
try:
x_scaled, _ = _to_fp8(x, dims=(-1,), use_ue8m0=False)
return x_scaled
except TypeError:
out = _to_fp8(x, dims=(-1,))
return out[0] if isinstance(out, tuple) else out
except Exception:
def to_fp8(x):
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("torch.float8_e4m3fn not available; install a CUDA-enabled PyTorch.")
return x.to(torch.float8_e4m3fn)
# TileLang example kernels for lightning indexer
from examples.deepseek_v32.fp8_lighting_indexer import (
mqa_attn_return_logits,
mqa_attn_return_logits_interface,
)
def bench_tl_indexer_wrapper(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
# Convert to FP8 E4M3 to match kernel signature
q_fp8 = to_fp8(q)
kv_fp8 = to_fp8(kv)
# Precompute kv_scales similar to reference: sqrt(mean(k^2)) along dim=-1
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
# Warmup
for _ in range(warmup):
_ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
_ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] WRAPPER S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
def bench_tl_indexer_impl(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Compile kernel once
kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
# Convert to FP8 E4M3 to match kernel signature
q_fp8 = to_fp8(q)
kv_fp8 = to_fp8(kv)
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
logits = torch.empty(seq_len, seq_len_kv, device=device, dtype=torch.float32)
# Warmup
for _ in range(warmup):
kernel(
q_fp8.view(seq_len * heads, index_dim),
kv_fp8,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
kernel(
q_fp8.view(seq_len * heads, index_dim),
kv_fp8,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] IMPL S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
def parse_int_list(s: str):
vals = []
for part in s.split(','):
part = part.strip()
if not part:
continue
vals.append(int(part))
return vals
def main():
parser = argparse.ArgumentParser(description="Benchmark TileLang lightning indexer (DeepSeek V3.2)")
parser.add_argument("--seq-lens", type=parse_int_list, default="4096,16384,163840",
help="Comma-separated sequence lengths S (default: 4096,16384,163840)")
parser.add_argument("--kv-lens", type=parse_int_list, default=None,
help="Comma-separated KV lengths SKV; if omitted, uses seq-lens")
parser.add_argument("--heads", type=int, default=4, help="Indexer heads H (default: 4)")
parser.add_argument("--dim", type=int, default=64, help="Indexer dimension D (default: 64)")
parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)")
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)")
parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both",
help="Which path to benchmark: wrapper (interface), impl (kernel), or both (default)")
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
dev = torch.cuda.get_device_name(0)
print(f"CUDA device: {dev}")
seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_int_list(args.seq_lens)
kv_lens = None
if args.kv_lens is None:
kv_lens = seq_lens
else:
kv_lens = args.kv_lens if isinstance(args.kv_lens, list) else parse_int_list(args.kv_lens)
if len(kv_lens) != len(seq_lens):
raise ValueError("--kv-lens must have the same number of elements as --seq-lens")
for S, SKV in zip(seq_lens, kv_lens):
if args.mode in ("both", "wrapper"):
bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if args.mode in ("both", "impl"):
bench_tl_indexer_impl(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment