Created
January 28, 2026 12:48
-
-
Save riZZZhik/46661532d835f9d1bc4002c518d892a7 to your computer and use it in GitHub Desktop.
bench_fa3_fp8_kvcache_gqa_upd.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| # from flash_attn_interface import flash_attn_func as flash_attn_func_v3 | |
| from sgl_kernel.flash_attn import flash_attn_with_kvcache as flash_attn_func_v3 | |
| # from flash_attn.utils.benchmark import benchmark_forward | |
| import torch.utils.benchmark as benchmark | |
| def benchmark_forward( | |
| fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs | |
| ): | |
| """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" | |
| if verbose: | |
| print(desc, "- Forward pass") | |
| def amp_wrapper(*inputs, **kwinputs): | |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
| fn(*inputs, **kwinputs) | |
| t = benchmark.Timer( | |
| stmt="fn_amp(*inputs, **kwinputs)", | |
| globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, | |
| num_threads=torch.get_num_threads(), | |
| ) | |
| m = t.timeit(repeats) | |
| if verbose: | |
| print(m) | |
| return t, m | |
| page_size = 1 | |
| tp_k_head_num = 1 | |
| tp_v_head_num = 1 | |
| head_dim = 128 | |
| cache_seqlens = torch.tensor([10240], device="cuda", dtype=torch.int32) | |
| cu_seqlens_q = torch.tensor([0, 10240], device="cuda", dtype=torch.int32) | |
| cu_seqlens_k = torch.tensor([0, 10240], device="cuda", dtype=torch.int32) | |
| max_seqlen_q = 10240 | |
| layer_scaling = 0.08838834764831845 | |
| window_size = (-1, -1) | |
| layer_logit_cap = 0.0 | |
| num_splits = 0 | |
| use_cascade_attn = False | |
| k_descale = None | |
| v_descale = None | |
| return_softmax_lse = use_cascade_attn | |
| causal = False if use_cascade_attn else True | |
| # total_tokens, page_size, tp_k_head_num, head_dim | |
| page_size = 64 | |
| tp_kv_head_num = 1 | |
| head_dim = 128 | |
| num_pages = 10240 // page_size | |
| cache_shape = (num_pages, page_size, tp_kv_head_num, head_dim) | |
| page_table = torch.arange(num_pages, device="cuda", dtype=torch.int32).view(1, -1) | |
| batch = 1 | |
| head = 8 | |
| headdim = 128 | |
| print(f"batch: {batch}, head: {head}, headdim: {headdim}, page_size: {page_size}") | |
| print(f"FlashAttention3-BF16 Benchmark") | |
| is_causal = causal | |
| test_seq_lens = {10240} | |
| print(f"is_causal: {is_causal}") | |
| for seq_len in test_seq_lens: | |
| flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) | |
| q = torch.randn(batch * seq_len, head, headdim, dtype=torch.bfloat16, device="cuda") | |
| k = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda") | |
| v = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda") | |
| for i in range(5): flash_attn_func_v3(q, k, v, | |
| softmax_scale=layer_scaling, causal=is_causal, | |
| page_table=page_table, cache_seqlens=cache_seqlens, | |
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_q, window_size=window_size, | |
| softcap=layer_logit_cap, k_descale=k_descale, | |
| v_descale=v_descale, return_softmax_lse=use_cascade_attn, | |
| num_splits=num_splits) | |
| torch.cuda.synchronize() | |
| _, time = benchmark_forward(flash_attn_func_v3, q, k, v, | |
| softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits, repeats=100, verbose=False, desc='Triton') | |
| print(f'{seq_len} flops:{flops/time.mean*1e-12}') | |
| print(f"FlashAttention3-FP8 Benchmark, Cast is done within flash_attn") | |
| for seq_len in test_seq_lens: | |
| flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) | |
| q = torch.randn(batch * seq_len, head, headdim, dtype=torch.bfloat16, device="cuda") | |
| k = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda") | |
| v = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda") | |
| for i in range(5): flash_attn_func_v3(q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits) | |
| torch.cuda.synchronize() | |
| _, time = benchmark_forward(flash_attn_func_v3, q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits, repeats=100, verbose=False, desc='Triton') | |
| print(f'{seq_len} flops:{flops/time.mean*1e-12}') | |
| print(f"FlashAttention3-FP8 Benchmark, Cast is done outside of flash_attn") | |
| for seq_len in test_seq_lens: | |
| flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) | |
| q = torch.randn(batch * seq_len, head, headdim, dtype=torch.bfloat16, device="cuda") | |
| k = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda") | |
| v = torch.randn(cache_shape, dtype=torch.bfloat16, device="cuda") | |
| q = q.to(torch.float8_e4m3fn) | |
| k = k.to(torch.float8_e4m3fn) | |
| v = v.to(torch.float8_e4m3fn) | |
| for i in range(5): flash_attn_func_v3(q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits) | |
| torch.cuda.synchronize() | |
| _, time = benchmark_forward(flash_attn_func_v3, q, k, v, softmax_scale=layer_scaling, causal=is_causal, page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=max_seqlen_q, window_size=window_size, softcap=layer_logit_cap, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=num_splits, repeats=100, verbose=False, desc='Triton') | |
| print(f'{seq_len} flops:{flops/time.mean*1e-12}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment