Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created June 23, 2025 16:06
Show Gist options
  • Select an option

  • Save vwxyzjn/81338803f612d3bef5c0cbfe4b5982b6 to your computer and use it in GitHub Desktop.

Select an option

Save vwxyzjn/81338803f612d3bef5c0cbfe4b5982b6 to your computer and use it in GitHub Desktop.
"""
# Convert HF to TorchTitan DCP
python convert.py hf_to_dcp --input-path meta-llama/Meta-Llama-3.1-8B --output-path ./torchtitan_model
# Convert TorchTitan DCP to HF (works with any checkpoint structure)
python convert.py dcp_to_hf --input-path ./torchtitan_model --output-path ./hf_model
# Model structure
If you run the following code, you will get the model structure.
```python
from transformers import LlamaForCausalLM
import torch
from torchtitan.config_manager import JobConfig, Training
from torchtitan.models.llama3 import llama3_configs
from torchtitan.models.llama3.model import Transformer
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
print(model)
model_args = llama3_configs["8B"]
tokenizer = TikTokenizer(
"libs/torchtitan/assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model"
)
job_config = JobConfig(training=Training(seq_len=1024))
model_args.update_from_config(job_config, tokenizer)
with torch.device("meta"):
model = Transformer.from_model_args(model_args)
print(model)
breakpoint()
print("haha")
The output is:
```
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
Transformer(
(tok_embeddings): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x TransformerBlock(
(attention): Attention(
(wq): Linear(in_features=4096, out_features=4096, bias=False)
(wk): Linear(in_features=4096, out_features=1024, bias=False)
(wv): Linear(in_features=4096, out_features=1024, bias=False)
(wo): Linear(in_features=4096, out_features=4096, bias=False)
(sdpa): ScaledDotProductAttention()
)
(feed_forward): FeedForward(
(w1): Linear(in_features=4096, out_features=14336, bias=False)
(w2): Linear(in_features=14336, out_features=4096, bias=False)
(w3): Linear(in_features=4096, out_features=14336, bias=False)
)
(attention_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
(ffn_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
)
)
(norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
(output): Linear(in_features=4096, out_features=128256, bias=False)
)
```
"""
import tempfile
from pathlib import Path
from typing import Union
import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaConfig
from torchtitan.models.llama3.model import precompute_freqs_cis
from tqdm import tqdm
from tyro.extras import SubcommandApp
from torchtitan.tools.logging import init_logger, logger
app = SubcommandApp()
def map_hf_to_torchtitan(hf_state_dict, model_config, max_seq_len=131072, rope_theta=500000.0):
"""Map HuggingFace state dict to TorchTitan format."""
n_layers = model_config.num_hidden_layers
n_heads = model_config.num_attention_heads
dim = model_config.hidden_size
dims_per_head = dim // n_heads
# Determine n_kv_heads for GQA models
n_kv_heads = getattr(model_config, 'num_key_value_heads', n_heads)
head_dim = dim // n_heads
print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}")
# HuggingFace permutation function (reverse of the forward permutation)
def permute(w, n_heads_arg, dim1=None, dim2=None):
if dim1 is None:
dim1 = w.shape[0]
if dim2 is None:
dim2 = w.shape[1]
return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
torchtitan_state_dict = {}
# Convert embeddings and output (no permutation needed)
if 'model.embed_tokens.weight' in hf_state_dict:
torchtitan_state_dict["tok_embeddings.weight"] = hf_state_dict["model.embed_tokens.weight"].clone()
if 'lm_head.weight' in hf_state_dict:
torchtitan_state_dict["output.weight"] = hf_state_dict["lm_head.weight"].clone()
if 'model.norm.weight' in hf_state_dict:
torchtitan_state_dict["norm.weight"] = hf_state_dict["model.norm.weight"].clone()
# Convert layers
for layer_idx in tqdm(range(n_layers), desc="Converting layers"):
hf_layer_prefix = f'model.layers.{layer_idx}'
layer_prefix = f'layers.{layer_idx}'
# Attention weights with proper reverse permutation
if f'{hf_layer_prefix}.self_attn.q_proj.weight' in hf_state_dict:
wq = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight']
torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] = permute(wq, n_heads)
if f'{hf_layer_prefix}.self_attn.k_proj.weight' in hf_state_dict:
wk = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight']
key_value_dim = n_kv_heads * head_dim
torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] = permute(
wk, n_kv_heads, key_value_dim, dim
)
if f'{hf_layer_prefix}.self_attn.v_proj.weight' in hf_state_dict:
# Value weights don't get permuted
torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'] = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'].clone()
if f'{hf_layer_prefix}.self_attn.o_proj.weight' in hf_state_dict:
# Output projection doesn't get permuted
torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'] = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'].clone()
# MLP weights (no permutation)
mlp_mappings = {
f'{hf_layer_prefix}.mlp.gate_proj.weight': f'{layer_prefix}.feed_forward.w1.weight',
f'{hf_layer_prefix}.mlp.down_proj.weight': f'{layer_prefix}.feed_forward.w2.weight',
f'{hf_layer_prefix}.mlp.up_proj.weight': f'{layer_prefix}.feed_forward.w3.weight',
}
for hf_key, tt_key in mlp_mappings.items():
if hf_key in hf_state_dict:
torchtitan_state_dict[tt_key] = hf_state_dict[hf_key].clone()
# Layer norms (no permutation)
norm_mappings = {
f'{hf_layer_prefix}.input_layernorm.weight': f'{layer_prefix}.attention_norm.weight',
f'{hf_layer_prefix}.post_attention_layernorm.weight': f'{layer_prefix}.ffn_norm.weight',
}
for hf_key, tt_key in norm_mappings.items():
if hf_key in hf_state_dict:
torchtitan_state_dict[tt_key] = hf_state_dict[hf_key].clone()
# Precompute RoPE frequencies
torchtitan_state_dict["freqs_cis"] = precompute_freqs_cis(dims_per_head, max_seq_len, rope_theta)
# Save model config for reverse conversion
config_dict = {
"num_hidden_layers": n_layers,
"num_attention_heads": n_heads,
"hidden_size": dim,
"intermediate_size": model_config.intermediate_size,
"max_position_embeddings": model_config.max_position_embeddings,
"vocab_size": model_config.vocab_size,
"rope_theta": rope_theta,
"rms_norm_eps": model_config.rms_norm_eps,
}
if hasattr(model_config, 'num_key_value_heads'):
config_dict["num_key_value_heads"] = model_config.num_key_value_heads
torchtitan_state_dict["_model_config"] = config_dict
print(f"Converted {len(torchtitan_state_dict)} parameters from HuggingFace to TorchTitan format")
return torchtitan_state_dict
def map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0):
"""Map TorchTitan state dict to HuggingFace format."""
layer_keys = [k for k in torchtitan_state_dict.keys() if k.startswith("layers.")]
assert len(layer_keys) > 0, "No layers found in state dict"
n_layers = max([int(k.split(".")[1]) for k in layer_keys]) + 1
hf_state_dict = {}
# Get model info from sample weight
sample_wq_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wq.weight'))
wq_weight = torchtitan_state_dict[sample_wq_key]
dim = wq_weight.shape[1] # input dimension
# Check if we have a key weight to determine n_kv_heads
sample_wk_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wk.weight'))
wk_weight = torchtitan_state_dict[sample_wk_key]
# Standard Llama head dim is 128 for most models
head_dim = 128 if dim % 128 == 0 else 64
n_heads = dim // head_dim
# For GQA models, n_kv_heads might be different
n_kv_heads = wk_weight.shape[0] // head_dim
print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}")
# HuggingFace permutation function (exact copy from their conversion script)
def permute(w, n_heads_arg, dim1=None, dim2=None):
if dim1 is None:
dim1 = w.shape[0]
if dim2 is None:
dim2 = w.shape[1]
return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
# Convert embeddings and output (no permutation needed)
if 'tok_embeddings.weight' in torchtitan_state_dict:
hf_state_dict['model.embed_tokens.weight'] = torchtitan_state_dict['tok_embeddings.weight'].clone()
if 'output.weight' in torchtitan_state_dict:
hf_state_dict['lm_head.weight'] = torchtitan_state_dict['output.weight'].clone()
if 'norm.weight' in torchtitan_state_dict:
hf_state_dict['model.norm.weight'] = torchtitan_state_dict['norm.weight'].clone()
# Convert layers
for layer_idx in tqdm(range(n_layers), desc="Converting layers"):
layer_prefix = f'layers.{layer_idx}'
hf_layer_prefix = f'model.layers.{layer_idx}'
# Attention weights with proper permutation
if f'{layer_prefix}.attention.wq.weight' in torchtitan_state_dict:
wq = torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight']
hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] = permute(wq, n_heads)
if f'{layer_prefix}.attention.wk.weight' in torchtitan_state_dict:
wk = torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight']
key_value_dim = n_kv_heads * head_dim
hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] = permute(
wk, n_kv_heads, key_value_dim, dim
)
if f'{layer_prefix}.attention.wv.weight' in torchtitan_state_dict:
# Value weights don't get permuted
hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'].clone()
if f'{layer_prefix}.attention.wo.weight' in torchtitan_state_dict:
# Output projection doesn't get permuted
hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'].clone()
# MLP weights (no permutation)
mlp_mappings = {
f'{layer_prefix}.feed_forward.w1.weight': f'{hf_layer_prefix}.mlp.gate_proj.weight',
f'{layer_prefix}.feed_forward.w2.weight': f'{hf_layer_prefix}.mlp.down_proj.weight',
f'{layer_prefix}.feed_forward.w3.weight': f'{hf_layer_prefix}.mlp.up_proj.weight',
}
for tt_key, hf_key in mlp_mappings.items():
if tt_key in torchtitan_state_dict:
hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone()
# Layer norms (no permutation)
norm_mappings = {
f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight',
f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight',
}
for tt_key, hf_key in norm_mappings.items():
if tt_key in torchtitan_state_dict:
hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone()
print(f"Converted {len(hf_state_dict)} parameters from TorchTitan to HuggingFace format")
return hf_state_dict
@app.command(name="hf_to_dcp")
@torch.inference_mode()
def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, dtype: str = "float32"):
torch_dtype = getattr(torch, dtype)
"""Convert HuggingFace model to TorchTitan DCP format.
Args:
input_path: HuggingFace model name or path
output_path: Output DCP checkpoint path
max_seq_len: Max sequence length for RoPE
rope_theta: RoPE theta parameter
"""
logger.info(f"Loading model from {input_path}")
hf_model = AutoModelForCausalLM.from_pretrained(input_path, torch_dtype=torch_dtype)
hf_state_dict = hf_model.state_dict()
logger.info("Converting weights to TorchTitan format")
torchtitan_state_dict = map_hf_to_torchtitan(hf_state_dict, hf_model.config, max_seq_len, rope_theta)
logger.info(f"Writing to DCP at '{output_path}'")
output_path.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(output_path, thread_count=8)
DCP.save({"model": torchtitan_state_dict}, storage_writer=storage_writer)
logger.info("Conversion complete!")
@app.command(name="dcp_to_hf")
@torch.inference_mode()
def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B"):
"""Convert TorchTitan DCP format to HuggingFace model.
Args:
input_path: Input DCP checkpoint path
output_path: Output HuggingFace model path
max_seq_len: Max sequence length for RoPE
rope_theta: RoPE theta parameter
default_model: Default HuggingFace model for config
"""
logger.info(f"Loading DCP checkpoint from {input_path}")
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
# Load DCP input_path
state_dict = {}
_load_state_dict(
state_dict,
storage_reader=DCP.filesystem.FileSystemReader(input_path),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
torchtitan_state_dict = state_dict["model"]
logger.info("Converting weights to HuggingFace format")
hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta)
# Create HuggingFace config
hf_config = LlamaConfig.from_pretrained(default_model)
# Create and load model
logger.info("Creating HuggingFace model")
tokenizer = AutoTokenizer.from_pretrained(default_model)
hf_model = AutoModelForCausalLM.from_pretrained(default_model)
# load state dict
logger.info("Loading state dict")
hf_model.load_state_dict(hf_state_dict, strict=True)
# Save model
logger.info(f"Saving model to {output_path}")
output_path.mkdir(parents=True, exist_ok=True)
hf_model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
logger.info("Conversion complete!")
if __name__ == "__main__":
init_logger()
app.cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment