Skip to content

Instantly share code, notes, and snippets.

@vukrosic
Created August 5, 2025 10:42
Show Gist options
  • Select an option

  • Save vukrosic/305f710f983532249f58ce10e17a35e8 to your computer and use it in GitHub Desktop.

Select an option

Save vukrosic/305f710f983532249f58ce10e17a35e8 to your computer and use it in GitHub Desktop.
Small transformer LLM
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import math
import random
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import time
from transformers import AutoTokenizer
from dataclasses import dataclass
from typing import List, Optional
import warnings
import os
import pickle
warnings.filterwarnings('ignore')
def set_seed(seed: int = 42):
"""Set all random seeds for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f"🌱 Set all seeds to {seed}")
@dataclass
class ModelConfig:
# Model architecture
d_model: int = 384
n_heads: int = 8
n_layers: int = 6
d_ff: int = 1536
batch_size: int = 24
max_steps: int = 5000
# Training parameters
gradient_accumulation_steps: int = 4
muon_lr: float = 0.01
# Data parameters
max_seq_len: int = 512
num_documents: int = 2000
max_tokens: int = 500000
# Evaluation
eval_every: int = 500
eval_steps: int = 100
# Regularization
weight_decay: float = 0.1
dropout: float = 0.1
grad_clip: float = 1.0
# Technical
use_amp: bool = True
vocab_size: Optional[int] = None
def __post_init__(self):
self.d_k = self.d_model // self.n_heads
assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
"""Newton-Schulz iteration to compute the zeroth power / orthogonalization of G."""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
class Muon(torch.optim.Optimizer):
"""Muon - MomentUm Orthogonalized by Newton-schulz"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
p.add_(g.view_as(p), alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
def load_and_cache_data(config: ModelConfig, cache_dir: str = "data_cache"):
"""Load and cache tokenized data to avoid reprocessing"""
os.makedirs(cache_dir, exist_ok=True)
cache_file = f"{cache_dir}/tokenized_data_{config.num_documents}_{config.max_tokens}.pkl"
# Check if cached data exists
if os.path.exists(cache_file):
print(f"πŸ“¦ Loading cached data from {cache_file}")
with open(cache_file, 'rb') as f:
cached_data = pickle.load(f)
texts = cached_data['texts']
tokenizer = cached_data['tokenizer']
tokens = cached_data['tokens']
config.vocab_size = tokenizer.vocab_size
print(f"βœ… Loaded {len(texts)} documents, {len(tokens):,} tokens from cache")
return texts, tokenizer, tokens
print(f"πŸ”„ Processing new data (will cache for future use)")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M", token=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset
dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", streaming=True, token=False)
texts = []
for i, item in enumerate(dataset):
if i >= config.num_documents:
break
texts.append(item["text"][:3000])
print(f"Loaded {len(texts)} documents")
# Tokenize
print("Tokenizing texts...")
all_tokens = []
for text in tqdm(texts, desc="Tokenizing"):
tokens = tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
tokens = all_tokens[:config.max_tokens]
print(f"Using {len(tokens):,} tokens")
config.vocab_size = tokenizer.vocab_size
# Cache the processed data
cached_data = {'texts': texts, 'tokenizer': tokenizer, 'tokens': tokens}
with open(cache_file, 'wb') as f:
pickle.dump(cached_data, f)
print(f"πŸ’Ύ Cached data to {cache_file}")
return texts, tokenizer, tokens
class TextTokenDataset(Dataset):
def __init__(self, tokens: List[int], seq_len: int = 512):
self.tokens = tokens
self.seq_len = seq_len
def __len__(self):
return max(0, len(self.tokens) - self.seq_len)
def __getitem__(self, idx):
x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
return x, y
class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
super().__init__()
angular_freq = (1 / 10000) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
t = torch.arange(max_seq_len, dtype=torch.float32)
theta = torch.einsum("i,j -> ij", t, angular_freq)
self.register_buffer('cos', theta.cos(), persistent=False)
self.register_buffer('sin', theta.sin(), persistent=False)
def forward(self, x_BTHD: torch.Tensor):
assert self.cos.size(0) >= x_BTHD.size(-3)
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, max_seq_len: int, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.rotary = Rotary(self.d_k, max_seq_len)
self.dropout = dropout
def forward(self, x):
batch_size, seq_len = x.size(0), x.size(1)
qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2]
Q = self.rotary(Q)
K = self.rotary(K)
attn_output = F.scaled_dot_product_attention(
Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0
)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
return self.w_o(attn_output)
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff, bias=False)
self.linear2 = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.silu(self.linear1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, max_seq_len: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.RMSNorm(d_model)
self.norm2 = nn.RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
attn_out = self.attention(self.norm1(x))
x = x + self.dropout(attn_out)
ff_out = self.feed_forward(self.norm2(x))
x = x + self.dropout(ff_out)
return x
class MinimalLLM(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.position_dropout = nn.Dropout(config.dropout)
self.transformer_blocks = nn.ModuleList([
TransformerBlock(config.d_model, config.n_heads, config.d_ff, config.max_seq_len, config.dropout)
for _ in range(config.n_layers)
])
self.norm = nn.RMSNorm(config.d_model)
self.output_dropout = nn.Dropout(config.dropout)
# Tie weights
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.token_embedding.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x):
x = self.token_embedding(x) * math.sqrt(self.config.d_model)
x = self.position_dropout(x)
for block in self.transformer_blocks:
x = block(x)
x = self.norm(x)
x = self.output_dropout(x)
logits = self.lm_head(x)
return logits
def evaluate_model(model: nn.Module, val_loader: DataLoader, config: ModelConfig):
"""Evaluate model performance"""
model.eval()
total_loss = 0
total_tokens = 0
total_correct = 0
device = next(model.parameters()).device
with torch.no_grad():
for i, (x, y) in enumerate(val_loader):
if i >= config.eval_steps:
break
x, y = x.to(device), y.to(device)
with autocast(enabled=config.use_amp):
logits = model(x)
loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
total_loss += loss.item() * y.numel()
total_tokens += y.numel()
predictions = logits.argmax(dim=-1)
total_correct += (predictions == y).sum().item()
avg_loss = total_loss / total_tokens
accuracy = total_correct / total_tokens
perplexity = math.exp(min(avg_loss, 20))
model.train()
return {'val_loss': avg_loss, 'val_accuracy': accuracy, 'val_perplexity': perplexity}
def setup_muon_optimizer(model: nn.Module, config: ModelConfig):
"""Setup Muon optimizer with hybrid approach"""
muon_params = []
adamw_params = []
for name, param in model.named_parameters():
if (param.ndim == 2 and
'token_embedding' not in name and
'norm' not in name and
param.requires_grad):
muon_params.append(param)
else:
adamw_params.append(param)
print(f" Muon parameters: {sum(p.numel() for p in muon_params):,}")
print(f" AdamW parameters: {sum(p.numel() for p in adamw_params):,}")
muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95)
adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay)
return [muon_optimizer, adamw_optimizer]
def train_model(config: ModelConfig, train_loader: DataLoader, val_loader: DataLoader):
"""Train the model with Muon optimizer"""
print(f"\nπŸš€ Training Small model with Muon optimizer")
# Initialize model
set_seed(42)
model = MinimalLLM(config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f" πŸ“Š Total parameters: {total_params:,}")
# Setup optimizers
optimizers = setup_muon_optimizer(model, config)
# Learning rate schedule
schedulers = []
for optimizer in optimizers:
warmup_steps = config.max_steps // 20
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
progress = (step - warmup_steps) / (config.max_steps - warmup_steps)
return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
schedulers.append(scheduler)
scaler = GradScaler() if config.use_amp else None
# Training loop
model.train()
step = 0
start_time = time.time()
best_val_loss = float('inf')
pbar = tqdm(total=config.max_steps, desc="Training")
while step < config.max_steps:
for batch_idx, (x, y) in enumerate(train_loader):
if step >= config.max_steps:
break
x, y = x.to(device), y.to(device)
# Forward pass with gradient accumulation
if config.use_amp:
with autocast():
logits = model(x)
loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
loss = loss / config.gradient_accumulation_steps
scaler.scale(loss).backward()
else:
logits = model(x)
loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
loss = loss / config.gradient_accumulation_steps
loss.backward()
# Optimizer step after accumulation
if (step + 1) % config.gradient_accumulation_steps == 0:
if config.use_amp:
for optimizer in optimizers:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
for optimizer in optimizers:
scaler.step(optimizer)
optimizer.zero_grad()
for scheduler in schedulers:
scheduler.step()
scaler.update()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
for optimizer in optimizers:
optimizer.step()
optimizer.zero_grad()
for scheduler in schedulers:
scheduler.step()
# Logging
if step % 100 == 0:
with torch.no_grad():
predictions = logits.argmax(dim=-1)
accuracy = (predictions == y).float().mean().item()
current_loss = loss.item() * config.gradient_accumulation_steps
perplexity = math.exp(min(current_loss, 20))
pbar.set_postfix({
'loss': f'{current_loss:.4f}',
'acc': f'{accuracy:.3f}',
'ppl': f'{perplexity:.1f}',
'lr': f'{optimizers[0].param_groups[0]["lr"]:.2e}'
})
# Evaluation
if step % config.eval_every == 0 and step > 0:
eval_metrics = evaluate_model(model, val_loader, config)
print(f"\nStep {step}: Val Loss: {eval_metrics['val_loss']:.4f}, "
f"Val Acc: {eval_metrics['val_accuracy']:.4f}, "
f"Val PPL: {eval_metrics['val_perplexity']:.2f}")
if eval_metrics['val_loss'] < best_val_loss:
best_val_loss = eval_metrics['val_loss']
step += 1
if step % 100 == 0:
pbar.update(100)
pbar.close()
training_time = time.time() - start_time
print(f" ⏱️ Training completed in {training_time:.1f} seconds")
# Final evaluation
final_eval = evaluate_model(model, val_loader, config)
print(f" πŸ“Š Final - Loss: {final_eval['val_loss']:.4f}, "
f"Acc: {final_eval['val_accuracy']:.4f}, PPL: {final_eval['val_perplexity']:.2f}")
return model, final_eval
if __name__ == "__main__":
# Check system
print(f"πŸ” Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Set seed
set_seed(42)
# Create config for Small model
config = ModelConfig()
print(f"\nπŸ“‹ Model Configuration:")
print(f" Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
print(f" Training: {config.max_steps} steps, batch size {config.batch_size}")
print(f" Data: {config.max_tokens:,} tokens, seq_len {config.max_seq_len}")
# Load data
texts, tokenizer, tokens = load_and_cache_data(config)
dataset = TextTokenDataset(tokens, config.max_seq_len)
# Train/val split
val_size = len(dataset) // 10
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)
print(f"πŸ“Š Dataset: {len(train_dataset)} train, {len(val_dataset)} val samples")
# Train model
start_time = time.time()
model, final_metrics = train_model(config, train_loader, val_loader)
total_time = time.time() - start_time
print(f"\nπŸŽ‰ TRAINING COMPLETED!")
print(f"⏱️ Total time: {total_time/60:.1f} minutes")
print(f"πŸ† Final Results:")
print(f" Validation Loss: {final_metrics['val_loss']:.4f}")
print(f" Validation Accuracy: {final_metrics['val_accuracy']:.4f}")
print(f" Validation Perplexity: {final_metrics['val_perplexity']:.2f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment