Skip to content

Instantly share code, notes, and snippets.

@riZZZhik
Created January 28, 2026 12:48
Show Gist options
  • Select an option

  • Save riZZZhik/46661532d835f9d1bc4002c518d892a7 to your computer and use it in GitHub Desktop.

Select an option

Save riZZZhik/46661532d835f9d1bc4002c518d892a7 to your computer and use it in GitHub Desktop.
bench_fa3_fp8_kvcache_gqa_upd.py
import torch
# from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from sgl_kernel.flash_attn import flash_attn_with_kvcache as flash_attn_func_v3
# from flash_attn.utils.benchmark import benchmark_forward
import torch.utils.benchmark as benchmark
def benchmark_forward(
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, "- Forward pass")
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
page_size = 1
tp_k_head_num = 1
tp_v_head_num = 1
head_dim = 128
cache_seqlens = torch.tensor([10240], device="cuda", dtype=torch.int32)
cu_seqlens_q = torch.tensor([0, 10240], device="cuda", dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, 10240], device="cuda", dtype=torch.int32)
max_seqlen_q = 10240
layer_scaling = 0.08838834764831845
window_size = (-1, -1)
layer_logit_cap = 0.0
num_splits = 0
use_cascade_attn = False
k_descale = None
v_descale = None
return_softmax_lse = use_cascade_attn
causal = False if use_cascade_attn else True
# total_tokens, page_size, tp_k_head_num, head_dim
page_size = 64
tp_kv_head_num = 1
head_dim = 128
num_pages = 10240 // page_size
cache_shape = (num_pages, page_size, tp_kv_head_num, head_dim)
page_table = torch.arange(num_pages, device="cuda", dtype=torch.int32).view(1, -1)
batch = 1
head = 8
headdim = 128
print(f"batch: {batch}, head: {head}, headdim: {headdim}, page_size: {page_size}")
print(f"FlashAttention3-BF16 Benchmark")
is_causal = causal
test_seq_lens = {10240}
print(f"is_causal: {is_causal}")
for seq_len in test_seq_lens:
flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
q = torch.randn(batch * seq_len, head, headdim, dtype=torch.bfloat16, device="cuda")
k = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda")
v = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda")
for i in range(5): flash_attn_func_v3(q, k, v,
softmax_scale=layer_scaling, causal=is_causal,
page_table=page_table, cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, window_size=window_size,
softcap=layer_logit_cap, k_descale=k_descale,
v_descale=v_descale, return_softmax_lse=use_cascade_attn,
num_splits=num_splits)
torch.cuda.synchronize()
_, time = benchmark_forward(flash_attn_func_v3, q, k, v,
softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits, repeats=100, verbose=False, desc='Triton')
print(f'{seq_len} flops:{flops/time.mean*1e-12}')
print(f"FlashAttention3-FP8 Benchmark, Cast is done within flash_attn")
for seq_len in test_seq_lens:
flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
q = torch.randn(batch * seq_len, head, headdim, dtype=torch.bfloat16, device="cuda")
k = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda")
v = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda")
for i in range(5): flash_attn_func_v3(q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits)
torch.cuda.synchronize()
_, time = benchmark_forward(flash_attn_func_v3, q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits, repeats=100, verbose=False, desc='Triton')
print(f'{seq_len} flops:{flops/time.mean*1e-12}')
print(f"FlashAttention3-FP8 Benchmark, Cast is done outside of flash_attn")
for seq_len in test_seq_lens:
flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
q = torch.randn(batch * seq_len, head, headdim, dtype=torch.bfloat16, device="cuda")
k = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda")
v = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda")
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
for i in range(5): flash_attn_func_v3(q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits)
torch.cuda.synchronize()
_, time = benchmark_forward(flash_attn_func_v3, q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits, repeats=100, verbose=False, desc='Triton')
print(f'{seq_len} flops:{flops/time.mean*1e-12}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment