Created
June 23, 2025 16:06
-
-
Save vwxyzjn/81338803f612d3bef5c0cbfe4b5982b6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| # Convert HF to TorchTitan DCP | |
| python convert.py hf_to_dcp --input-path meta-llama/Meta-Llama-3.1-8B --output-path ./torchtitan_model | |
| # Convert TorchTitan DCP to HF (works with any checkpoint structure) | |
| python convert.py dcp_to_hf --input-path ./torchtitan_model --output-path ./hf_model | |
| # Model structure | |
| If you run the following code, you will get the model structure. | |
| ```python | |
| from transformers import LlamaForCausalLM | |
| import torch | |
| from torchtitan.config_manager import JobConfig, Training | |
| from torchtitan.models.llama3 import llama3_configs | |
| from torchtitan.models.llama3.model import Transformer | |
| from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer | |
| model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B") | |
| print(model) | |
| model_args = llama3_configs["8B"] | |
| tokenizer = TikTokenizer( | |
| "libs/torchtitan/assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model" | |
| ) | |
| job_config = JobConfig(training=Training(seq_len=1024)) | |
| model_args.update_from_config(job_config, tokenizer) | |
| with torch.device("meta"): | |
| model = Transformer.from_model_args(model_args) | |
| print(model) | |
| breakpoint() | |
| print("haha") | |
| The output is: | |
| ``` | |
| LlamaForCausalLM( | |
| (model): LlamaModel( | |
| (embed_tokens): Embedding(128256, 4096) | |
| (layers): ModuleList( | |
| (0-31): 32 x LlamaDecoderLayer( | |
| (self_attn): LlamaAttention( | |
| (q_proj): Linear(in_features=4096, out_features=4096, bias=False) | |
| (k_proj): Linear(in_features=4096, out_features=1024, bias=False) | |
| (v_proj): Linear(in_features=4096, out_features=1024, bias=False) | |
| (o_proj): Linear(in_features=4096, out_features=4096, bias=False) | |
| ) | |
| (mlp): LlamaMLP( | |
| (gate_proj): Linear(in_features=4096, out_features=14336, bias=False) | |
| (up_proj): Linear(in_features=4096, out_features=14336, bias=False) | |
| (down_proj): Linear(in_features=14336, out_features=4096, bias=False) | |
| (act_fn): SiLU() | |
| ) | |
| (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05) | |
| (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05) | |
| ) | |
| ) | |
| (norm): LlamaRMSNorm((4096,), eps=1e-05) | |
| (rotary_emb): LlamaRotaryEmbedding() | |
| ) | |
| (lm_head): Linear(in_features=4096, out_features=128256, bias=False) | |
| ) | |
| Transformer( | |
| (tok_embeddings): Embedding(128256, 4096) | |
| (layers): ModuleList( | |
| (0-31): 32 x TransformerBlock( | |
| (attention): Attention( | |
| (wq): Linear(in_features=4096, out_features=4096, bias=False) | |
| (wk): Linear(in_features=4096, out_features=1024, bias=False) | |
| (wv): Linear(in_features=4096, out_features=1024, bias=False) | |
| (wo): Linear(in_features=4096, out_features=4096, bias=False) | |
| (sdpa): ScaledDotProductAttention() | |
| ) | |
| (feed_forward): FeedForward( | |
| (w1): Linear(in_features=4096, out_features=14336, bias=False) | |
| (w2): Linear(in_features=14336, out_features=4096, bias=False) | |
| (w3): Linear(in_features=4096, out_features=14336, bias=False) | |
| ) | |
| (attention_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) | |
| (ffn_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) | |
| ) | |
| ) | |
| (norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) | |
| (output): Linear(in_features=4096, out_features=128256, bias=False) | |
| ) | |
| ``` | |
| """ | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Union | |
| import torch | |
| import torch.distributed.checkpoint as DCP | |
| from torch.distributed.checkpoint.format_utils import dcp_to_torch_save | |
| from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner | |
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaConfig | |
| from torchtitan.models.llama3.model import precompute_freqs_cis | |
| from tqdm import tqdm | |
| from tyro.extras import SubcommandApp | |
| from torchtitan.tools.logging import init_logger, logger | |
| app = SubcommandApp() | |
| def map_hf_to_torchtitan(hf_state_dict, model_config, max_seq_len=131072, rope_theta=500000.0): | |
| """Map HuggingFace state dict to TorchTitan format.""" | |
| n_layers = model_config.num_hidden_layers | |
| n_heads = model_config.num_attention_heads | |
| dim = model_config.hidden_size | |
| dims_per_head = dim // n_heads | |
| # Determine n_kv_heads for GQA models | |
| n_kv_heads = getattr(model_config, 'num_key_value_heads', n_heads) | |
| head_dim = dim // n_heads | |
| print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}") | |
| # HuggingFace permutation function (reverse of the forward permutation) | |
| def permute(w, n_heads_arg, dim1=None, dim2=None): | |
| if dim1 is None: | |
| dim1 = w.shape[0] | |
| if dim2 is None: | |
| dim2 = w.shape[1] | |
| return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) | |
| torchtitan_state_dict = {} | |
| # Convert embeddings and output (no permutation needed) | |
| if 'model.embed_tokens.weight' in hf_state_dict: | |
| torchtitan_state_dict["tok_embeddings.weight"] = hf_state_dict["model.embed_tokens.weight"].clone() | |
| if 'lm_head.weight' in hf_state_dict: | |
| torchtitan_state_dict["output.weight"] = hf_state_dict["lm_head.weight"].clone() | |
| if 'model.norm.weight' in hf_state_dict: | |
| torchtitan_state_dict["norm.weight"] = hf_state_dict["model.norm.weight"].clone() | |
| # Convert layers | |
| for layer_idx in tqdm(range(n_layers), desc="Converting layers"): | |
| hf_layer_prefix = f'model.layers.{layer_idx}' | |
| layer_prefix = f'layers.{layer_idx}' | |
| # Attention weights with proper reverse permutation | |
| if f'{hf_layer_prefix}.self_attn.q_proj.weight' in hf_state_dict: | |
| wq = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] | |
| torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] = permute(wq, n_heads) | |
| if f'{hf_layer_prefix}.self_attn.k_proj.weight' in hf_state_dict: | |
| wk = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] | |
| key_value_dim = n_kv_heads * head_dim | |
| torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] = permute( | |
| wk, n_kv_heads, key_value_dim, dim | |
| ) | |
| if f'{hf_layer_prefix}.self_attn.v_proj.weight' in hf_state_dict: | |
| # Value weights don't get permuted | |
| torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'] = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'].clone() | |
| if f'{hf_layer_prefix}.self_attn.o_proj.weight' in hf_state_dict: | |
| # Output projection doesn't get permuted | |
| torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'] = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'].clone() | |
| # MLP weights (no permutation) | |
| mlp_mappings = { | |
| f'{hf_layer_prefix}.mlp.gate_proj.weight': f'{layer_prefix}.feed_forward.w1.weight', | |
| f'{hf_layer_prefix}.mlp.down_proj.weight': f'{layer_prefix}.feed_forward.w2.weight', | |
| f'{hf_layer_prefix}.mlp.up_proj.weight': f'{layer_prefix}.feed_forward.w3.weight', | |
| } | |
| for hf_key, tt_key in mlp_mappings.items(): | |
| if hf_key in hf_state_dict: | |
| torchtitan_state_dict[tt_key] = hf_state_dict[hf_key].clone() | |
| # Layer norms (no permutation) | |
| norm_mappings = { | |
| f'{hf_layer_prefix}.input_layernorm.weight': f'{layer_prefix}.attention_norm.weight', | |
| f'{hf_layer_prefix}.post_attention_layernorm.weight': f'{layer_prefix}.ffn_norm.weight', | |
| } | |
| for hf_key, tt_key in norm_mappings.items(): | |
| if hf_key in hf_state_dict: | |
| torchtitan_state_dict[tt_key] = hf_state_dict[hf_key].clone() | |
| # Precompute RoPE frequencies | |
| torchtitan_state_dict["freqs_cis"] = precompute_freqs_cis(dims_per_head, max_seq_len, rope_theta) | |
| # Save model config for reverse conversion | |
| config_dict = { | |
| "num_hidden_layers": n_layers, | |
| "num_attention_heads": n_heads, | |
| "hidden_size": dim, | |
| "intermediate_size": model_config.intermediate_size, | |
| "max_position_embeddings": model_config.max_position_embeddings, | |
| "vocab_size": model_config.vocab_size, | |
| "rope_theta": rope_theta, | |
| "rms_norm_eps": model_config.rms_norm_eps, | |
| } | |
| if hasattr(model_config, 'num_key_value_heads'): | |
| config_dict["num_key_value_heads"] = model_config.num_key_value_heads | |
| torchtitan_state_dict["_model_config"] = config_dict | |
| print(f"Converted {len(torchtitan_state_dict)} parameters from HuggingFace to TorchTitan format") | |
| return torchtitan_state_dict | |
| def map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0): | |
| """Map TorchTitan state dict to HuggingFace format.""" | |
| layer_keys = [k for k in torchtitan_state_dict.keys() if k.startswith("layers.")] | |
| assert len(layer_keys) > 0, "No layers found in state dict" | |
| n_layers = max([int(k.split(".")[1]) for k in layer_keys]) + 1 | |
| hf_state_dict = {} | |
| # Get model info from sample weight | |
| sample_wq_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wq.weight')) | |
| wq_weight = torchtitan_state_dict[sample_wq_key] | |
| dim = wq_weight.shape[1] # input dimension | |
| # Check if we have a key weight to determine n_kv_heads | |
| sample_wk_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wk.weight')) | |
| wk_weight = torchtitan_state_dict[sample_wk_key] | |
| # Standard Llama head dim is 128 for most models | |
| head_dim = 128 if dim % 128 == 0 else 64 | |
| n_heads = dim // head_dim | |
| # For GQA models, n_kv_heads might be different | |
| n_kv_heads = wk_weight.shape[0] // head_dim | |
| print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}") | |
| # HuggingFace permutation function (exact copy from their conversion script) | |
| def permute(w, n_heads_arg, dim1=None, dim2=None): | |
| if dim1 is None: | |
| dim1 = w.shape[0] | |
| if dim2 is None: | |
| dim2 = w.shape[1] | |
| return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) | |
| # Convert embeddings and output (no permutation needed) | |
| if 'tok_embeddings.weight' in torchtitan_state_dict: | |
| hf_state_dict['model.embed_tokens.weight'] = torchtitan_state_dict['tok_embeddings.weight'].clone() | |
| if 'output.weight' in torchtitan_state_dict: | |
| hf_state_dict['lm_head.weight'] = torchtitan_state_dict['output.weight'].clone() | |
| if 'norm.weight' in torchtitan_state_dict: | |
| hf_state_dict['model.norm.weight'] = torchtitan_state_dict['norm.weight'].clone() | |
| # Convert layers | |
| for layer_idx in tqdm(range(n_layers), desc="Converting layers"): | |
| layer_prefix = f'layers.{layer_idx}' | |
| hf_layer_prefix = f'model.layers.{layer_idx}' | |
| # Attention weights with proper permutation | |
| if f'{layer_prefix}.attention.wq.weight' in torchtitan_state_dict: | |
| wq = torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] | |
| hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] = permute(wq, n_heads) | |
| if f'{layer_prefix}.attention.wk.weight' in torchtitan_state_dict: | |
| wk = torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] | |
| key_value_dim = n_kv_heads * head_dim | |
| hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] = permute( | |
| wk, n_kv_heads, key_value_dim, dim | |
| ) | |
| if f'{layer_prefix}.attention.wv.weight' in torchtitan_state_dict: | |
| # Value weights don't get permuted | |
| hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'].clone() | |
| if f'{layer_prefix}.attention.wo.weight' in torchtitan_state_dict: | |
| # Output projection doesn't get permuted | |
| hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'].clone() | |
| # MLP weights (no permutation) | |
| mlp_mappings = { | |
| f'{layer_prefix}.feed_forward.w1.weight': f'{hf_layer_prefix}.mlp.gate_proj.weight', | |
| f'{layer_prefix}.feed_forward.w2.weight': f'{hf_layer_prefix}.mlp.down_proj.weight', | |
| f'{layer_prefix}.feed_forward.w3.weight': f'{hf_layer_prefix}.mlp.up_proj.weight', | |
| } | |
| for tt_key, hf_key in mlp_mappings.items(): | |
| if tt_key in torchtitan_state_dict: | |
| hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() | |
| # Layer norms (no permutation) | |
| norm_mappings = { | |
| f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight', | |
| f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight', | |
| } | |
| for tt_key, hf_key in norm_mappings.items(): | |
| if tt_key in torchtitan_state_dict: | |
| hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() | |
| print(f"Converted {len(hf_state_dict)} parameters from TorchTitan to HuggingFace format") | |
| return hf_state_dict | |
| @app.command(name="hf_to_dcp") | |
| @torch.inference_mode() | |
| def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, dtype: str = "float32"): | |
| torch_dtype = getattr(torch, dtype) | |
| """Convert HuggingFace model to TorchTitan DCP format. | |
| Args: | |
| input_path: HuggingFace model name or path | |
| output_path: Output DCP checkpoint path | |
| max_seq_len: Max sequence length for RoPE | |
| rope_theta: RoPE theta parameter | |
| """ | |
| logger.info(f"Loading model from {input_path}") | |
| hf_model = AutoModelForCausalLM.from_pretrained(input_path, torch_dtype=torch_dtype) | |
| hf_state_dict = hf_model.state_dict() | |
| logger.info("Converting weights to TorchTitan format") | |
| torchtitan_state_dict = map_hf_to_torchtitan(hf_state_dict, hf_model.config, max_seq_len, rope_theta) | |
| logger.info(f"Writing to DCP at '{output_path}'") | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| storage_writer = DCP.filesystem.FileSystemWriter(output_path, thread_count=8) | |
| DCP.save({"model": torchtitan_state_dict}, storage_writer=storage_writer) | |
| logger.info("Conversion complete!") | |
| @app.command(name="dcp_to_hf") | |
| @torch.inference_mode() | |
| def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B"): | |
| """Convert TorchTitan DCP format to HuggingFace model. | |
| Args: | |
| input_path: Input DCP checkpoint path | |
| output_path: Output HuggingFace model path | |
| max_seq_len: Max sequence length for RoPE | |
| rope_theta: RoPE theta parameter | |
| default_model: Default HuggingFace model for config | |
| """ | |
| logger.info(f"Loading DCP checkpoint from {input_path}") | |
| from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner | |
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict | |
| # Load DCP input_path | |
| state_dict = {} | |
| _load_state_dict( | |
| state_dict, | |
| storage_reader=DCP.filesystem.FileSystemReader(input_path), | |
| planner=_EmptyStateDictLoadPlanner(), | |
| no_dist=True, | |
| ) | |
| torchtitan_state_dict = state_dict["model"] | |
| logger.info("Converting weights to HuggingFace format") | |
| hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta) | |
| # Create HuggingFace config | |
| hf_config = LlamaConfig.from_pretrained(default_model) | |
| # Create and load model | |
| logger.info("Creating HuggingFace model") | |
| tokenizer = AutoTokenizer.from_pretrained(default_model) | |
| hf_model = AutoModelForCausalLM.from_pretrained(default_model) | |
| # load state dict | |
| logger.info("Loading state dict") | |
| hf_model.load_state_dict(hf_state_dict, strict=True) | |
| # Save model | |
| logger.info(f"Saving model to {output_path}") | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| hf_model.save_pretrained(output_path) | |
| tokenizer.save_pretrained(output_path) | |
| logger.info("Conversion complete!") | |
| if __name__ == "__main__": | |
| init_logger() | |
| app.cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment