Last active
October 17, 2025 17:11
-
-
Save SeunghyunSEO/ce9a96b83efd928edb1ddc5bbcd05f6a to your computer and use it in GitHub Desktop.
flops_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import subprocess | |
| import random | |
| import numpy as np | |
| import torch | |
| from torch._utils import _get_available_device_type, _get_device_module | |
| from loguru import logger | |
| # DEBUG = False | |
| DEBUG = True | |
| # ------------------------------------------------------------------------------------------------ | |
| # reference flops per token for language model | |
| # https://github.com/pytorch/torchtitan/blob/327a99cc2371964a1160ff4aec15e37806993139/torchtitan/models/llama3/model/args.py#L63 | |
| # https://github.com/pytorch/torchtitan/blob/e65ef30dec74e8e592191a1487e92256d27ecb2b/torchtitan/components/metrics.py#L366-L378 | |
| def num_floating_point_operations_simple( | |
| model: torch.nn.Module, | |
| l: int, # num layers | |
| h: int, # num head | |
| q: int, # head dim | |
| batch_size: int, | |
| seq_len: int, | |
| ): | |
| assert model is not None | |
| nparams = sum(p.numel() for p in model.parameters()) | |
| if hasattr(model, 'model') and hasattr(model.model, 'blocks'): | |
| residual_blocks_nparams = sum(p.numel() for p in model.model.blocks.parameters()) | |
| else: | |
| raise NotImplementedError | |
| # nparams_embedding = sum( | |
| # sum(p.numel() for p in m.parameters()) | |
| # for m in model.children() | |
| # if isinstance(m, torch.nn.Embedding) | |
| # ) | |
| # residual_blocks_nparams = (nparams - nparams_embedding) | |
| kaplan = 6 * residual_blocks_nparams # non-embedding count is more common in llm (sparse case) | |
| sdpa = 12 * l * h * q * seq_len | |
| total_flops_per_token = kaplan + sdpa # rough kaplan estimation + sdpa | |
| estimated_flops = batch_size * seq_len * total_flops_per_token | |
| if DEBUG: | |
| logger.info(f''' | |
| batch_size: {batch_size} | |
| seq_len: {seq_len} | |
| model size (N): {nparams/1e9:.4f}B (in billion) | |
| params in residual blocks: {residual_blocks_nparams/1e9:.4f}B (in billion) | |
| 6*N kaplan estimation: {kaplan/1e12:.4f}TFLOPs | |
| 12*L*H*Q*S sdpa estimation: {sdpa/1e12:.4f}TFLOPs | |
| flops per token (6*N + 12*L*H*Q*S): {total_flops_per_token/1e12:.6f}TFLOPs | |
| flops per batch: {estimated_flops/1e12:.2f}TFLOPs | |
| ''') | |
| return estimated_flops | |
| # ------------------------------------------------------------------------------------------------ | |
| # using generalized flop counter | |
| def set_seed(seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_inputs(vocab_size, batch_size, seq_len, device, return_tuple=False): | |
| set_seed() | |
| x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) | |
| y = x[:, 1:] | |
| if return_tuple: | |
| return (x, y) | |
| else: | |
| return { | |
| 'x': x, | |
| 'y': y, | |
| } | |
| # https://github.com/pytorch/pytorch/pull/95751 | |
| def num_floating_point_operations_torch_native( | |
| model, | |
| vocab_size, | |
| batch_size, | |
| seq_len, | |
| ): | |
| from torch.utils.flop_counter import FlopCounterMode | |
| flop_counter = FlopCounterMode(display=True) | |
| device = next(model.parameters()).device | |
| inputs = get_inputs(vocab_size, batch_size, seq_len, device) | |
| with flop_counter: | |
| out = model(**inputs) | |
| out[0].backward() # backward should be included too | |
| # estimated_flops = sum(flop_counter.get_flop_counts()['Global'].values()) # https://x.com/cHHillee/status/1649209467811299329 | |
| estimated_flops = flop_counter.get_total_flops() | |
| if DEBUG: | |
| logger.info(f''' | |
| batch_size: {batch_size} | |
| seq_len: {seq_len} | |
| flops per batch: {estimated_flops/1e12:.2f} TFLOPs | |
| ''') | |
| return estimated_flops | |
| # https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md | |
| def num_floating_point_operations_fvcore( | |
| model, | |
| vocab_size, | |
| batch_size, | |
| seq_len, | |
| ): | |
| from fvcore.nn import FlopCountAnalysis | |
| device = next(model.parameters()).device | |
| inputs = get_inputs(vocab_size, batch_size, seq_len, device, return_tuple=True) | |
| flops = FlopCountAnalysis(model, inputs) | |
| estimated_flops = flops.total() | |
| if DEBUG: | |
| logger.info(f''' | |
| flops per batch (flops.total()): {estimated_flops/1e12:.4f}TFLOPs | |
| flops.by_module_and_operator(): {flops.by_module_and_operator()} | |
| ''') | |
| return estimated_flops | |
| # ------------------------------------------------------------------------------------------------ | |
| # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.11.0/megatron/training/training.py#L118 | |
| # classic megatron style handwritten flops calculation | |
| def num_floating_point_operations_llm_exact( | |
| # for basic transformer blocks | |
| vocab_size: int, | |
| num_layers: int, | |
| hidden_dim: int, | |
| mlp_ratio: int, | |
| batch_size: int, | |
| seq_len: int, | |
| ): | |
| # ---------------------------------------------------------------- | |
| ''' | |
| Basic arithmetic for FLOPs calculation | |
| let B is batch size, s is seq_len, h is embedding dim, | |
| for one self_attnetion block (prenorm is not included) | |
| qkv projection: 6Bsh^2 | |
| attn: 2Bs^2h | |
| attn over value: 2Bs^2h | |
| oproj: 2Bsh^2 | |
| original reference | |
| https://arxiv.org/abs/2305.10403 | |
| https://arxiv.org/abs/2205.05198 | |
| another reference | |
| https://github.com/pytorch/torchtitan/blob/main/torchtitan/utils.py#L282-L297 | |
| i will ignore non-matmul operations like normalization layers, bias addition, modulation scale, gating, ... | |
| The 12x term comes from the following factors (see reference in num_floating_point_operations_ref): | |
| - 3x: Each GEMM needs to be performed 3 times (forward pass, backward wgrad, backward dgrad) | |
| - 2x: A GEMM of m*n tensor with n*k tensor requires 2mnk floating-point operations | |
| - 2x: GEMMs are stacked twice in standard Transformer architectures (e.g., h->ffn_h GEMM and ffn_h->h GEMM in MLP layer) | |
| ''' | |
| # expansion factor | |
| fwd_bwd_fma_factor = 3 * 2 | |
| expansion_factor = fwd_bwd_fma_factor * 2 | |
| # ---------------------------------------------------------------- | |
| # residual blocks | |
| # 1. self attention projection and SDPA operation per layer | |
| self_attn_term = expansion_factor * hidden_dim * hidden_dim * ( | |
| 2 # qkv and out proj | |
| + seq_len / hidden_dim / 2 # SDPA (causal) | |
| ) | |
| # 2. mlp (not gated mlp) per layer | |
| mlp_term = expansion_factor * hidden_dim * (hidden_dim * mlp_ratio) | |
| transformer_flops = num_layers * (self_attn_term + mlp_term) | |
| # ---------------------------------------------------------------- | |
| # logit | |
| logit_flops = fwd_bwd_fma_factor * hidden_dim * vocab_size | |
| # ---------------------------------------------------------------- | |
| total_flops_per_token = ( | |
| transformer_flops | |
| + logit_flops | |
| ) | |
| estimated_flops = batch_size * seq_len * total_flops_per_token | |
| if DEBUG: | |
| logger.info(f''' | |
| batch_size: {batch_size} | |
| seq_len (input): {seq_len} | |
| total_flops_per_token: {total_flops_per_token/1e12:.4f}TFLOPs | |
| self_attn_term: {self_attn_term/1e12:.4f}TFLOPs | |
| mlp_term: {mlp_term/1e12:.4f}TFLOPs | |
| logit_flops: {logit_flops/1e12:.4f}TFLOPs | |
| total estimated flops (batch_size * seq_len * total_flops_per_token): {estimated_flops/1e12:.2f}TFLOPs | |
| ''') | |
| return estimated_flops | |
| # ------------------------------------------------------------------------------------------------ | |
| def get_llm_flops_per_batch( | |
| model, | |
| model_config, | |
| flops_calc_type, | |
| batch_size: int, | |
| seq_len: int, | |
| ): | |
| vocab_size = model_config.vocab_size | |
| num_layers = model_config.depth | |
| hidden_dim = model_config.hidden_dim | |
| num_heads = model_config.num_heads | |
| dim_head = hidden_dim // num_heads | |
| mlp_ratio = model_config.mlp_ratio | |
| # TODO: placeholder | |
| if flops_calc_type == 'simple': | |
| return num_floating_point_operations_simple( | |
| model, | |
| num_layers, | |
| num_heads, | |
| dim_head, | |
| batch_size, | |
| seq_len, | |
| ) | |
| elif flops_calc_type == 'torch_native': | |
| return num_floating_point_operations_torch_native( | |
| model, | |
| vocab_size, | |
| batch_size, | |
| seq_len, | |
| ) | |
| elif flops_calc_type == 'fvcore': | |
| return num_floating_point_operations_fvcore( | |
| model, | |
| vocab_size, | |
| batch_size, | |
| seq_len, | |
| ) | |
| elif flops_calc_type == 'exact': | |
| return num_floating_point_operations_llm_exact( | |
| vocab_size, | |
| num_layers, | |
| hidden_dim, | |
| mlp_ratio, | |
| batch_size, | |
| seq_len, | |
| ) | |
| else: | |
| raise ValueError(f"flops_calc_type: {flops_calc_type} is not supported") | |
| # ------------------------------------------------------------------------------------------------ | |
| class MFUEstimator: | |
| def __init__( | |
| self, | |
| model_type: str, | |
| model_config, | |
| batch_size: int, | |
| seq_len: int = None, | |
| model: torch.nn.Module = None, | |
| flops_calc_type: str = 'simple', | |
| ): | |
| # args | |
| self.model_type = model_type | |
| self.model_config = model_config | |
| self.batch_size = batch_size | |
| self.seq_len = seq_len # transformer seq_len | |
| self.model = model | |
| self.flops_calc_type = flops_calc_type | |
| self.num_floating_point_operations = self.get_num_floating_point_operations() | |
| self.peak_flops = self.get_peak_flops() | |
| if DEBUG: | |
| logger.info(f''' | |
| bf16 peak flops of this GPU | |
| GPU type: {_get_device_module("cuda").get_device_name(0)} | |
| peak flops: {self.peak_flops/1e12:.2f}TFLOPs | |
| ''') | |
| def estimate_mfu(self, time_delta: float) -> float: | |
| estimated_mfu = ( | |
| 100 | |
| * self.num_floating_point_operations | |
| / self.peak_flops | |
| / time_delta | |
| ) | |
| if DEBUG: | |
| logger.info(f''' | |
| time_delta: {time_delta:.2f}sec | |
| estimated_mfu: {estimated_mfu:.2f}% | |
| ''') | |
| return estimated_mfu | |
| def get_num_floating_point_operations(self): | |
| if self.model_type == 'gpt2': | |
| return get_llm_flops_per_batch( | |
| self.model, | |
| self.model_config, | |
| self.flops_calc_type, | |
| self.batch_size, | |
| self.seq_len, | |
| ) | |
| elif self.model_type == 'mmdit': | |
| raise NotImplementedError | |
| else: | |
| # fallback | |
| logger.warning(f"self.model_type: {self.model_type} is not supported for mfu estimation, mfu will logged as 0 (fallback)") | |
| return 0 | |
| def get_peak_flops(self) -> int: | |
| # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC | |
| # https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py#L67 | |
| device_type = _get_available_device_type() or "cuda" | |
| device_module = _get_device_module(device_type) | |
| device_name = device_module.get_device_name(0) | |
| try: | |
| # Run the lspci command and capture the output | |
| result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) | |
| # Filter the output for lines containing both "NVIDIA" and "H100" | |
| filtered_lines = [ | |
| line | |
| for line in result.stdout.splitlines() | |
| if "NVIDIA" in line and "H100" in line | |
| ] | |
| # Join all filtered lines into a single string | |
| device_name = " ".join(filtered_lines) or device_name | |
| except FileNotFoundError as e: | |
| logger.warning(f"Error running lspci: {e}, fallback to use device_name") | |
| if "A100" in device_name: | |
| # data from https://www.nvidia.com/en-us/data-center/a100/ | |
| return 312e12 # 100 tflops 100/312=0.32 << 0.50, 0.60 | |
| elif "H100" in device_name: | |
| # data from https://www.nvidia.com/en-us/data-center/h100/ | |
| # NOTE: Specifications are one-half lower without sparsity. | |
| if "NVL" in device_name: | |
| return 835e12 | |
| elif "PCIe" in device_name: | |
| return 756e12 | |
| else: # for H100 SXM and other variants | |
| return 989e12 | |
| elif "H200" in device_name: | |
| # data from https://www.nvidia.com/en-us/data-center/h200/ | |
| return 989e12 | |
| elif "B200" in device_name: | |
| # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 | |
| return 2.25e15 | |
| elif "MI300X" in device_name or "MI325X" in device_name: | |
| # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html | |
| # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html | |
| return 1300e12 | |
| elif "MI250X" in device_name: | |
| # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD) | |
| return 191.5e12 | |
| elif "Data Center GPU Max 1550" in device_name: | |
| # Also known as Ponte Vecchio (PVC). | |
| # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html | |
| # Dot Product Accumulate Systolic (DPAS): | |
| # - Freq: 1300MHz | |
| # - #ops: 512 | |
| # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16) | |
| # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16) | |
| max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units | |
| return 512 * max_comp_units * 1300 * 10**6 | |
| elif "l40s" in device_name: | |
| # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413" | |
| return 362e12 | |
| else: # for other GPU types, assume A100 | |
| logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") | |
| return 312e12 | |
| # ---------------------------------------------------------------- | |
| # simple language model for test | |
| import numpy as np | |
| from einops import rearrange | |
| import torch | |
| from torch import nn, Tensor | |
| import torch.nn.functional as F | |
| from torch.profiler import record_function | |
| from dataclasses import dataclass | |
| from transformers import GPT2Tokenizer | |
| def get_tokenizer(): | |
| return GPT2Tokenizer.from_pretrained("gpt2") | |
| @dataclass | |
| class LanguageModelConfig: | |
| # base | |
| vocab_size: int = 50257 | |
| ## typical 7~8B setup | |
| # hidden_dim: int = 4096 | |
| # num_heads: int = 32 | |
| # depth: int = 32 | |
| ## small scale proxy | |
| hidden_dim: int = 512 | |
| num_heads: int = 4 | |
| depth: int = 12 | |
| mlp_ratio: int = 4 # because it's not GLU, only integer is supported | |
| block_size: int = 4096 | |
| init_std: float = 0.02 | |
| # muP | |
| mup: bool = False | |
| mup_input_mult: float = 1.0 | |
| mup_output_mult: float = 1.0 | |
| class RMSNorm(nn.Module): | |
| # copied from https://github.com/facebookresearch/lingua/blob/main/lingua/transformer.py for more simplicity | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x: torch.Tensor): | |
| return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) | |
| def reset_parameters(self): | |
| nn.init.ones_(self.weight) | |
| @record_function("rms_norm") | |
| def forward(self, x: torch.Tensor): | |
| output = self._norm(x.float()) | |
| return (output * self.weight.float()).type_as(x) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads, mup): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| assert dim % num_heads == 0, "hidden size should be divisible by nhead" | |
| self.head_dim = dim // num_heads | |
| self.qkv = nn.Linear(dim, dim * 3, bias=False) | |
| self.proj = nn.Linear(dim, dim, bias=False) | |
| self.mup = mup | |
| def init_weights(self, init_std, mup): | |
| in_proj_scaler = self.dim ** -0.5 if mup else 1.0 | |
| nn.init.normal_(self.qkv.weight, std=init_std * in_proj_scaler) | |
| out_proj_scaler = 0.0 if mup else 1.0 # zero out residual out projection | |
| nn.init.normal_(self.proj.weight, std=init_std * out_proj_scaler) | |
| @record_function("sdpa_kernel") | |
| def attention(self, q: Tensor, k: Tensor, v: Tensor, scale: float=None, is_causal: bool=False) -> Tensor: | |
| if scale is None: | |
| scale = q.shape[-1] ** -0.5 | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=False, # for test convergence issue | |
| enable_math=False, | |
| enable_mem_efficient=True, | |
| ): | |
| x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=is_causal) # scale matters for muP | |
| x = rearrange(x, "B H L D -> B L (H D)") | |
| return x | |
| @record_function("self_attn") | |
| def forward(self, x: Tensor) -> Tensor: | |
| with record_function("qkv"): | |
| qkv = self.qkv(x) | |
| q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | |
| scale = self.head_dim ** -1.0 if self.mup else self.head_dim ** -0.5 | |
| x = self.attention(q, k, v, scale=scale) | |
| with record_function("out_proj"): | |
| x = self.proj(x) | |
| return x | |
| class MLP(nn.Module): | |
| def __init__(self, dim: int, mlp_hidden_dim: int): | |
| super().__init__() | |
| self.dim = dim | |
| self.mlp_hidden_dim = mlp_hidden_dim | |
| self.ffn1 = nn.Linear(dim, mlp_hidden_dim, bias=False) | |
| self.ffn2 = nn.Linear(mlp_hidden_dim, dim, bias=False) | |
| def init_weights(self, init_std, mup): | |
| in_proj_scaler = self.dim ** -0.5 if mup else 1.0 | |
| nn.init.normal_(self.ffn1.weight, std=init_std * in_proj_scaler) | |
| out_proj_scaler = 0.0 if mup else 1.0 # zero out residual out projection | |
| nn.init.normal_(self.ffn2.weight, std=init_std * out_proj_scaler) | |
| @record_function("mlp") | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.ffn2(nn.functional.gelu(self.ffn1(x), approximate="tanh")) | |
| class Block(nn.Module): | |
| def __init__(self, hidden_dim, nhead, mlp_ratio, mup): | |
| super(Block, self).__init__() | |
| self.norm1 = RMSNorm(hidden_dim) | |
| self.attn = SelfAttention(hidden_dim, nhead, mup) | |
| self.norm2 = RMSNorm(hidden_dim) | |
| self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio) | |
| def init_weights(self, init_std, mup): | |
| self.norm1.reset_parameters() | |
| self.attn.init_weights(init_std, mup) | |
| self.norm2.reset_parameters() | |
| self.mlp.init_weights(init_std, mup) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| return x + self.mlp(self.norm2(x)) | |
| class LanguageModel(nn.Module): | |
| def __init__(self, config: LanguageModelConfig): | |
| super(LanguageModel, self).__init__() | |
| self.config = config | |
| self.init_std = config.init_std | |
| self.mup = config.mup | |
| self.mup_input_mult = config.mup_input_mult | |
| self.mup_output_mult = config.mup_output_mult | |
| self.model = nn.ModuleDict( | |
| dict( | |
| wte = nn.Embedding(config.vocab_size, config.hidden_dim), | |
| wpe = nn.Embedding(config.block_size, config.hidden_dim), | |
| blocks = nn.ModuleList([Block(config.hidden_dim, config.num_heads, config.mlp_ratio, self.mup) for _ in range(config.depth)]), | |
| norm = RMSNorm(config.hidden_dim), | |
| ) | |
| ) | |
| self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) | |
| tmp_weight = torch.zeros_like(self.model.wpe.weight) | |
| self.register_buffer('tmp_weight', tmp_weight, persistent=False) ## for test | |
| def init_weights(self): | |
| nn.init.normal_(self.model.wte.weight, std=self.init_std) | |
| nn.init.normal_(self.model.wpe.weight, std=self.init_std) | |
| for block in self.model.blocks: | |
| block.init_weights(self.init_std, self.mup) | |
| self.model.norm.reset_parameters() | |
| # zero out last projection (readout) for muP. still learnable because softmax for zero tensor is uniform | |
| out_proj_scaler = 0.0 if self.mup else 1.0 | |
| nn.init.normal_(self.lm_head.weight, std=self.init_std * out_proj_scaler) | |
| nn.init.normal_(self.tmp_weight, std=self.init_std) ## for test | |
| def compute_loss(self, z, y, ignore_index=-100, reduction='mean'): | |
| return F.cross_entropy(z, y, ignore_index=ignore_index, reduction=reduction) | |
| def forward(self, x, y=None): | |
| B, T = x.size() | |
| pos = torch.arange(0, T, dtype=torch.long, device=x.device) | |
| with record_function("token_embedding"): | |
| x = self.model.wte(x) + self.model.wpe(pos) | |
| x = x * self.mup_input_mult if self.mup else x | |
| for block in self.model.blocks: | |
| with record_function("residual_block"): | |
| x = block(x) | |
| with record_function("final_norm"): | |
| x = self.model.norm(x) | |
| with record_function("lm_head"): | |
| x = x * self.mup_output_mult if self.mup else x | |
| z = self.lm_head(x).float() # projection to logit space and upcast | |
| z = z[..., :-1, :].contiguous().reshape(B*(T-1), -1) # B*T, C | |
| y = y.reshape(-1) # B*T, 1 | |
| assert z.size(0) == y.size(0), f"z: {z.size()}, y: {y.size()}" | |
| with record_function("compute_loss"): | |
| loss = self.compute_loss(z, y) | |
| return loss, z | |
| # ---------------------------------------------------------------- | |
| # uv run python flops_utils.py | |
| model_types_configs = { | |
| "gpt2": LanguageModelConfig, | |
| "mmdit": None, | |
| } | |
| model_types_classes = { | |
| "gpt2": LanguageModel, | |
| "mmdit": None, | |
| } | |
| if __name__ == "__main__": | |
| for model_type in ['gpt2']: | |
| model_config = model_types_configs[model_type] | |
| if model_type == 'gpt2': | |
| # 7.5B | |
| batch_size = 2 | |
| seq_len = 4096 | |
| vocab_size = 131072 | |
| hidden_dim = 4096 | |
| num_heads = 32 | |
| depth = 32 | |
| model_kwargs = { | |
| 'vocab_size': vocab_size, | |
| 'block_size': seq_len, | |
| 'hidden_dim': hidden_dim, | |
| 'num_heads': num_heads, | |
| 'depth': depth, | |
| } | |
| elif model_type == 'mmdit': | |
| raise NotImplementedError | |
| else: | |
| raise NotImplementedError | |
| model_config = model_config(**model_kwargs) | |
| with torch.device("meta"): | |
| model = model_types_classes[model_type](model_config) | |
| model = model.to_empty(device="cuda") | |
| model.init_weights() | |
| model = model.to(torch.bfloat16) | |
| model.train() | |
| logger.info(f"model shape: {model}") | |
| estimated_flops = {} | |
| for flops_calc_type in ['simple', 'torch_native', 'fvcore', 'exact']: | |
| try: | |
| td = 1.0 # dummy time delta | |
| mfu_estimator = MFUEstimator(model_type, model_config, batch_size, seq_len, model, flops_calc_type) | |
| mfu_estimator.estimate_mfu(td) | |
| estimated_flops[flops_calc_type] = mfu_estimator.num_floating_point_operations | |
| except Exception as e: | |
| logger.warning(f"Error counting FLOPs with {flops_calc_type}: {e}") | |
| estimated_flops[flops_calc_type] = None | |
| summary = '' | |
| for flops_calc_type in estimated_flops: | |
| summary += f"{flops_calc_type}: {estimated_flops[flops_calc_type]/1e12:.2f}TFLOPs\n" | |
| logger.info(f"\n{summary}\n") | |
| if abs(estimated_flops['exact'] - estimated_flops['simple']) / estimated_flops['simple'] > 0.05: | |
| logger.warning(f''' | |
| estimated_flops['exact'] and estimated_flops['simple'] are too different: | |
| {estimated_flops['exact']/1e12:.2f}TFLOPs vs {estimated_flops['simple']/1e12:.2f}TFLOPs | |
| ''') |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
oh, i forgot to count backward flops wich is 2 times larger than foward typically.
but now it seems overestimated.
(and im not gonna care fvcore anymore lol)