Skip to content

Instantly share code, notes, and snippets.

@SeunghyunSEO
Last active October 17, 2025 17:11
Show Gist options
  • Select an option

  • Save SeunghyunSEO/ce9a96b83efd928edb1ddc5bbcd05f6a to your computer and use it in GitHub Desktop.

Select an option

Save SeunghyunSEO/ce9a96b83efd928edb1ddc5bbcd05f6a to your computer and use it in GitHub Desktop.
flops_utils.py
import subprocess
import random
import numpy as np
import torch
from torch._utils import _get_available_device_type, _get_device_module
from loguru import logger
# DEBUG = False
DEBUG = True
# ------------------------------------------------------------------------------------------------
# reference flops per token for language model
# https://github.com/pytorch/torchtitan/blob/327a99cc2371964a1160ff4aec15e37806993139/torchtitan/models/llama3/model/args.py#L63
# https://github.com/pytorch/torchtitan/blob/e65ef30dec74e8e592191a1487e92256d27ecb2b/torchtitan/components/metrics.py#L366-L378
def num_floating_point_operations_simple(
model: torch.nn.Module,
l: int, # num layers
h: int, # num head
q: int, # head dim
batch_size: int,
seq_len: int,
):
assert model is not None
nparams = sum(p.numel() for p in model.parameters())
if hasattr(model, 'model') and hasattr(model.model, 'blocks'):
residual_blocks_nparams = sum(p.numel() for p in model.model.blocks.parameters())
else:
raise NotImplementedError
# nparams_embedding = sum(
# sum(p.numel() for p in m.parameters())
# for m in model.children()
# if isinstance(m, torch.nn.Embedding)
# )
# residual_blocks_nparams = (nparams - nparams_embedding)
kaplan = 6 * residual_blocks_nparams # non-embedding count is more common in llm (sparse case)
sdpa = 12 * l * h * q * seq_len
total_flops_per_token = kaplan + sdpa # rough kaplan estimation + sdpa
estimated_flops = batch_size * seq_len * total_flops_per_token
if DEBUG:
logger.info(f'''
batch_size: {batch_size}
seq_len: {seq_len}
model size (N): {nparams/1e9:.4f}B (in billion)
params in residual blocks: {residual_blocks_nparams/1e9:.4f}B (in billion)
6*N kaplan estimation: {kaplan/1e12:.4f}TFLOPs
12*L*H*Q*S sdpa estimation: {sdpa/1e12:.4f}TFLOPs
flops per token (6*N + 12*L*H*Q*S): {total_flops_per_token/1e12:.6f}TFLOPs
flops per batch: {estimated_flops/1e12:.2f}TFLOPs
''')
return estimated_flops
# ------------------------------------------------------------------------------------------------
# using generalized flop counter
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_inputs(vocab_size, batch_size, seq_len, device, return_tuple=False):
set_seed()
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
y = x[:, 1:]
if return_tuple:
return (x, y)
else:
return {
'x': x,
'y': y,
}
# https://github.com/pytorch/pytorch/pull/95751
def num_floating_point_operations_torch_native(
model,
vocab_size,
batch_size,
seq_len,
):
from torch.utils.flop_counter import FlopCounterMode
flop_counter = FlopCounterMode(display=True)
device = next(model.parameters()).device
inputs = get_inputs(vocab_size, batch_size, seq_len, device)
with flop_counter:
out = model(**inputs)
out[0].backward() # backward should be included too
# estimated_flops = sum(flop_counter.get_flop_counts()['Global'].values()) # https://x.com/cHHillee/status/1649209467811299329
estimated_flops = flop_counter.get_total_flops()
if DEBUG:
logger.info(f'''
batch_size: {batch_size}
seq_len: {seq_len}
flops per batch: {estimated_flops/1e12:.2f} TFLOPs
''')
return estimated_flops
# https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md
def num_floating_point_operations_fvcore(
model,
vocab_size,
batch_size,
seq_len,
):
from fvcore.nn import FlopCountAnalysis
device = next(model.parameters()).device
inputs = get_inputs(vocab_size, batch_size, seq_len, device, return_tuple=True)
flops = FlopCountAnalysis(model, inputs)
estimated_flops = flops.total()
if DEBUG:
logger.info(f'''
flops per batch (flops.total()): {estimated_flops/1e12:.4f}TFLOPs
flops.by_module_and_operator(): {flops.by_module_and_operator()}
''')
return estimated_flops
# ------------------------------------------------------------------------------------------------
# https://github.com/NVIDIA/Megatron-LM/blob/core_r0.11.0/megatron/training/training.py#L118
# classic megatron style handwritten flops calculation
def num_floating_point_operations_llm_exact(
# for basic transformer blocks
vocab_size: int,
num_layers: int,
hidden_dim: int,
mlp_ratio: int,
batch_size: int,
seq_len: int,
):
# ----------------------------------------------------------------
'''
Basic arithmetic for FLOPs calculation
let B is batch size, s is seq_len, h is embedding dim,
for one self_attnetion block (prenorm is not included)
qkv projection: 6Bsh^2
attn: 2Bs^2h
attn over value: 2Bs^2h
oproj: 2Bsh^2
original reference
https://arxiv.org/abs/2305.10403
https://arxiv.org/abs/2205.05198
another reference
https://github.com/pytorch/torchtitan/blob/main/torchtitan/utils.py#L282-L297
i will ignore non-matmul operations like normalization layers, bias addition, modulation scale, gating, ...
The 12x term comes from the following factors (see reference in num_floating_point_operations_ref):
- 3x: Each GEMM needs to be performed 3 times (forward pass, backward wgrad, backward dgrad)
- 2x: A GEMM of m*n tensor with n*k tensor requires 2mnk floating-point operations
- 2x: GEMMs are stacked twice in standard Transformer architectures (e.g., h->ffn_h GEMM and ffn_h->h GEMM in MLP layer)
'''
# expansion factor
fwd_bwd_fma_factor = 3 * 2
expansion_factor = fwd_bwd_fma_factor * 2
# ----------------------------------------------------------------
# residual blocks
# 1. self attention projection and SDPA operation per layer
self_attn_term = expansion_factor * hidden_dim * hidden_dim * (
2 # qkv and out proj
+ seq_len / hidden_dim / 2 # SDPA (causal)
)
# 2. mlp (not gated mlp) per layer
mlp_term = expansion_factor * hidden_dim * (hidden_dim * mlp_ratio)
transformer_flops = num_layers * (self_attn_term + mlp_term)
# ----------------------------------------------------------------
# logit
logit_flops = fwd_bwd_fma_factor * hidden_dim * vocab_size
# ----------------------------------------------------------------
total_flops_per_token = (
transformer_flops
+ logit_flops
)
estimated_flops = batch_size * seq_len * total_flops_per_token
if DEBUG:
logger.info(f'''
batch_size: {batch_size}
seq_len (input): {seq_len}
total_flops_per_token: {total_flops_per_token/1e12:.4f}TFLOPs
self_attn_term: {self_attn_term/1e12:.4f}TFLOPs
mlp_term: {mlp_term/1e12:.4f}TFLOPs
logit_flops: {logit_flops/1e12:.4f}TFLOPs
total estimated flops (batch_size * seq_len * total_flops_per_token): {estimated_flops/1e12:.2f}TFLOPs
''')
return estimated_flops
# ------------------------------------------------------------------------------------------------
def get_llm_flops_per_batch(
model,
model_config,
flops_calc_type,
batch_size: int,
seq_len: int,
):
vocab_size = model_config.vocab_size
num_layers = model_config.depth
hidden_dim = model_config.hidden_dim
num_heads = model_config.num_heads
dim_head = hidden_dim // num_heads
mlp_ratio = model_config.mlp_ratio
# TODO: placeholder
if flops_calc_type == 'simple':
return num_floating_point_operations_simple(
model,
num_layers,
num_heads,
dim_head,
batch_size,
seq_len,
)
elif flops_calc_type == 'torch_native':
return num_floating_point_operations_torch_native(
model,
vocab_size,
batch_size,
seq_len,
)
elif flops_calc_type == 'fvcore':
return num_floating_point_operations_fvcore(
model,
vocab_size,
batch_size,
seq_len,
)
elif flops_calc_type == 'exact':
return num_floating_point_operations_llm_exact(
vocab_size,
num_layers,
hidden_dim,
mlp_ratio,
batch_size,
seq_len,
)
else:
raise ValueError(f"flops_calc_type: {flops_calc_type} is not supported")
# ------------------------------------------------------------------------------------------------
class MFUEstimator:
def __init__(
self,
model_type: str,
model_config,
batch_size: int,
seq_len: int = None,
model: torch.nn.Module = None,
flops_calc_type: str = 'simple',
):
# args
self.model_type = model_type
self.model_config = model_config
self.batch_size = batch_size
self.seq_len = seq_len # transformer seq_len
self.model = model
self.flops_calc_type = flops_calc_type
self.num_floating_point_operations = self.get_num_floating_point_operations()
self.peak_flops = self.get_peak_flops()
if DEBUG:
logger.info(f'''
bf16 peak flops of this GPU
GPU type: {_get_device_module("cuda").get_device_name(0)}
peak flops: {self.peak_flops/1e12:.2f}TFLOPs
''')
def estimate_mfu(self, time_delta: float) -> float:
estimated_mfu = (
100
* self.num_floating_point_operations
/ self.peak_flops
/ time_delta
)
if DEBUG:
logger.info(f'''
time_delta: {time_delta:.2f}sec
estimated_mfu: {estimated_mfu:.2f}%
''')
return estimated_mfu
def get_num_floating_point_operations(self):
if self.model_type == 'gpt2':
return get_llm_flops_per_batch(
self.model,
self.model_config,
self.flops_calc_type,
self.batch_size,
self.seq_len,
)
elif self.model_type == 'mmdit':
raise NotImplementedError
else:
# fallback
logger.warning(f"self.model_type: {self.model_type} is not supported for mfu estimation, mfu will logged as 0 (fallback)")
return 0
def get_peak_flops(self) -> int:
# hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC
# https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py#L67
device_type = _get_available_device_type() or "cuda"
device_module = _get_device_module(device_type)
device_name = device_module.get_device_name(0)
try:
# Run the lspci command and capture the output
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
# Filter the output for lines containing both "NVIDIA" and "H100"
filtered_lines = [
line
for line in result.stdout.splitlines()
if "NVIDIA" in line and "H100" in line
]
# Join all filtered lines into a single string
device_name = " ".join(filtered_lines) or device_name
except FileNotFoundError as e:
logger.warning(f"Error running lspci: {e}, fallback to use device_name")
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12 # 100 tflops 100/312=0.32 << 0.50, 0.60
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 835e12
elif "PCIe" in device_name:
return 756e12
else: # for H100 SXM and other variants
return 989e12
elif "H200" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h200/
return 989e12
elif "B200" in device_name:
# data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
return 2.25e15
elif "MI300X" in device_name or "MI325X" in device_name:
# MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
# MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
return 1300e12
elif "MI250X" in device_name:
# data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
return 191.5e12
elif "Data Center GPU Max 1550" in device_name:
# Also known as Ponte Vecchio (PVC).
# data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
# Dot Product Accumulate Systolic (DPAS):
# - Freq: 1300MHz
# - #ops: 512
# Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
# Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6
elif "l40s" in device_name:
# data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413"
return 362e12
else: # for other GPU types, assume A100
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
return 312e12
# ----------------------------------------------------------------
# simple language model for test
import numpy as np
from einops import rearrange
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.profiler import record_function
from dataclasses import dataclass
from transformers import GPT2Tokenizer
def get_tokenizer():
return GPT2Tokenizer.from_pretrained("gpt2")
@dataclass
class LanguageModelConfig:
# base
vocab_size: int = 50257
## typical 7~8B setup
# hidden_dim: int = 4096
# num_heads: int = 32
# depth: int = 32
## small scale proxy
hidden_dim: int = 512
num_heads: int = 4
depth: int = 12
mlp_ratio: int = 4 # because it's not GLU, only integer is supported
block_size: int = 4096
init_std: float = 0.02
# muP
mup: bool = False
mup_input_mult: float = 1.0
mup_output_mult: float = 1.0
class RMSNorm(nn.Module):
# copied from https://github.com/facebookresearch/lingua/blob/main/lingua/transformer.py for more simplicity
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
def reset_parameters(self):
nn.init.ones_(self.weight)
@record_function("rms_norm")
def forward(self, x: torch.Tensor):
output = self._norm(x.float())
return (output * self.weight.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads, mup):
super().__init__()
self.dim = dim
self.num_heads = num_heads
assert dim % num_heads == 0, "hidden size should be divisible by nhead"
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim, bias=False)
self.mup = mup
def init_weights(self, init_std, mup):
in_proj_scaler = self.dim ** -0.5 if mup else 1.0
nn.init.normal_(self.qkv.weight, std=init_std * in_proj_scaler)
out_proj_scaler = 0.0 if mup else 1.0 # zero out residual out projection
nn.init.normal_(self.proj.weight, std=init_std * out_proj_scaler)
@record_function("sdpa_kernel")
def attention(self, q: Tensor, k: Tensor, v: Tensor, scale: float=None, is_causal: bool=False) -> Tensor:
if scale is None:
scale = q.shape[-1] ** -0.5
with torch.backends.cuda.sdp_kernel(
enable_flash=False, # for test convergence issue
enable_math=False,
enable_mem_efficient=True,
):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=is_causal) # scale matters for muP
x = rearrange(x, "B H L D -> B L (H D)")
return x
@record_function("self_attn")
def forward(self, x: Tensor) -> Tensor:
with record_function("qkv"):
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
scale = self.head_dim ** -1.0 if self.mup else self.head_dim ** -0.5
x = self.attention(q, k, v, scale=scale)
with record_function("out_proj"):
x = self.proj(x)
return x
class MLP(nn.Module):
def __init__(self, dim: int, mlp_hidden_dim: int):
super().__init__()
self.dim = dim
self.mlp_hidden_dim = mlp_hidden_dim
self.ffn1 = nn.Linear(dim, mlp_hidden_dim, bias=False)
self.ffn2 = nn.Linear(mlp_hidden_dim, dim, bias=False)
def init_weights(self, init_std, mup):
in_proj_scaler = self.dim ** -0.5 if mup else 1.0
nn.init.normal_(self.ffn1.weight, std=init_std * in_proj_scaler)
out_proj_scaler = 0.0 if mup else 1.0 # zero out residual out projection
nn.init.normal_(self.ffn2.weight, std=init_std * out_proj_scaler)
@record_function("mlp")
def forward(self, x: Tensor) -> Tensor:
return self.ffn2(nn.functional.gelu(self.ffn1(x), approximate="tanh"))
class Block(nn.Module):
def __init__(self, hidden_dim, nhead, mlp_ratio, mup):
super(Block, self).__init__()
self.norm1 = RMSNorm(hidden_dim)
self.attn = SelfAttention(hidden_dim, nhead, mup)
self.norm2 = RMSNorm(hidden_dim)
self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio)
def init_weights(self, init_std, mup):
self.norm1.reset_parameters()
self.attn.init_weights(init_std, mup)
self.norm2.reset_parameters()
self.mlp.init_weights(init_std, mup)
def forward(self, x):
x = x + self.attn(self.norm1(x))
return x + self.mlp(self.norm2(x))
class LanguageModel(nn.Module):
def __init__(self, config: LanguageModelConfig):
super(LanguageModel, self).__init__()
self.config = config
self.init_std = config.init_std
self.mup = config.mup
self.mup_input_mult = config.mup_input_mult
self.mup_output_mult = config.mup_output_mult
self.model = nn.ModuleDict(
dict(
wte = nn.Embedding(config.vocab_size, config.hidden_dim),
wpe = nn.Embedding(config.block_size, config.hidden_dim),
blocks = nn.ModuleList([Block(config.hidden_dim, config.num_heads, config.mlp_ratio, self.mup) for _ in range(config.depth)]),
norm = RMSNorm(config.hidden_dim),
)
)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
tmp_weight = torch.zeros_like(self.model.wpe.weight)
self.register_buffer('tmp_weight', tmp_weight, persistent=False) ## for test
def init_weights(self):
nn.init.normal_(self.model.wte.weight, std=self.init_std)
nn.init.normal_(self.model.wpe.weight, std=self.init_std)
for block in self.model.blocks:
block.init_weights(self.init_std, self.mup)
self.model.norm.reset_parameters()
# zero out last projection (readout) for muP. still learnable because softmax for zero tensor is uniform
out_proj_scaler = 0.0 if self.mup else 1.0
nn.init.normal_(self.lm_head.weight, std=self.init_std * out_proj_scaler)
nn.init.normal_(self.tmp_weight, std=self.init_std) ## for test
def compute_loss(self, z, y, ignore_index=-100, reduction='mean'):
return F.cross_entropy(z, y, ignore_index=ignore_index, reduction=reduction)
def forward(self, x, y=None):
B, T = x.size()
pos = torch.arange(0, T, dtype=torch.long, device=x.device)
with record_function("token_embedding"):
x = self.model.wte(x) + self.model.wpe(pos)
x = x * self.mup_input_mult if self.mup else x
for block in self.model.blocks:
with record_function("residual_block"):
x = block(x)
with record_function("final_norm"):
x = self.model.norm(x)
with record_function("lm_head"):
x = x * self.mup_output_mult if self.mup else x
z = self.lm_head(x).float() # projection to logit space and upcast
z = z[..., :-1, :].contiguous().reshape(B*(T-1), -1) # B*T, C
y = y.reshape(-1) # B*T, 1
assert z.size(0) == y.size(0), f"z: {z.size()}, y: {y.size()}"
with record_function("compute_loss"):
loss = self.compute_loss(z, y)
return loss, z
# ----------------------------------------------------------------
# uv run python flops_utils.py
model_types_configs = {
"gpt2": LanguageModelConfig,
"mmdit": None,
}
model_types_classes = {
"gpt2": LanguageModel,
"mmdit": None,
}
if __name__ == "__main__":
for model_type in ['gpt2']:
model_config = model_types_configs[model_type]
if model_type == 'gpt2':
# 7.5B
batch_size = 2
seq_len = 4096
vocab_size = 131072
hidden_dim = 4096
num_heads = 32
depth = 32
model_kwargs = {
'vocab_size': vocab_size,
'block_size': seq_len,
'hidden_dim': hidden_dim,
'num_heads': num_heads,
'depth': depth,
}
elif model_type == 'mmdit':
raise NotImplementedError
else:
raise NotImplementedError
model_config = model_config(**model_kwargs)
with torch.device("meta"):
model = model_types_classes[model_type](model_config)
model = model.to_empty(device="cuda")
model.init_weights()
model = model.to(torch.bfloat16)
model.train()
logger.info(f"model shape: {model}")
estimated_flops = {}
for flops_calc_type in ['simple', 'torch_native', 'fvcore', 'exact']:
try:
td = 1.0 # dummy time delta
mfu_estimator = MFUEstimator(model_type, model_config, batch_size, seq_len, model, flops_calc_type)
mfu_estimator.estimate_mfu(td)
estimated_flops[flops_calc_type] = mfu_estimator.num_floating_point_operations
except Exception as e:
logger.warning(f"Error counting FLOPs with {flops_calc_type}: {e}")
estimated_flops[flops_calc_type] = None
summary = ''
for flops_calc_type in estimated_flops:
summary += f"{flops_calc_type}: {estimated_flops[flops_calc_type]/1e12:.2f}TFLOPs\n"
logger.info(f"\n{summary}\n")
if abs(estimated_flops['exact'] - estimated_flops['simple']) / estimated_flops['simple'] > 0.05:
logger.warning(f'''
estimated_flops['exact'] and estimated_flops['simple'] are too different:
{estimated_flops['exact']/1e12:.2f}TFLOPs vs {estimated_flops['simple']/1e12:.2f}TFLOPs
''')
@SeunghyunSEO
Copy link
Author

oh, i forgot to count backward flops wich is 2 times larger than foward typically.
but now it seems overestimated.

simple: 369.45TFLOPs
torch_native: 404.62TFLOPs
fvcore: 57.17TFLOPs
exact: 369.44TFLOPs

(and im not gonna care fvcore anymore lol)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment