Skip to content

Instantly share code, notes, and snippets.

@alexlitz
Created February 25, 2026 10:07
Show Gist options
  • Select an option

  • Save alexlitz/98afb7d6573d1a42f2b0a2cbbdc0304f to your computer and use it in GitHub Desktop.

Select an option

Save alexlitz/98afb7d6573d1a42f2b0a2cbbdc0304f to your computer and use it in GitHub Desktop.
Tiny Adder
#!/usr/bin/env python3
"""
TinyAdder: A hand-crafted 95-parameter transformer that performs 10-digit addition with ~100% accuracy.
Only non-zero parameters are counted, so the nominal number of parameters is higher but most are zero.
Architecture:
- 2-layer transformer with ALiBi positional encoding
- Layer 0: 5 attention heads
- Layer 1: 1 head uniform attention
- Embeddings: 5 dimensions [eq_flag, special_flag, digit_value, 0, 0]
Example: "1234567890+9876543210=11111111100"
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log, exp
import random
# === Constants ===
NUM_DIGITS = 10
NUM_INPUT_DIGITS = 10
tokens = [str(i) for i in range(NUM_DIGITS)] + ["=", "<bos>", "<eos>", "+"]
# Position indices
POS_NUM1_START, POS_NUM2_START = 1, 12
POS_EQUALS = 22
POS_ANS_TOKEN_START = 23
POS_ANS_OUTPUT_START = 22
POS_ANS_OUTPUT_END = 33
# Scaling
DIGIT_EMBED_SCALE = 10 # Scale for digit values in embedding
V_SCALE = 1e4 # Amplification for V-shaped error function
DIGIT_SCALE = 1e10
FINAL_SCALE = 100
DIGIT_OFFSET = 0.5
GATE_BIAS_SHIFT = 15.0
ALIBI_CONSTANT = log(10) # gives 10x decay per position
# Embedding indices
EQ_DIM, SPECIAL_DIM, DIGIT_DIM, COUNT_DIM, SCALE_DIM = 0, 1, 2, 3, 4
EMBEDDING_DIM = 5
# Attention configuration
LAYER0_HEADS = EMBEDDING_DIM # 5 heads, one per embedding dim
LAYER1_HEADS = 1
ADJUSTMENT_HEAD = 3
SCALE_HEAD = 4
# FFN dimensions
LAYER0_FFN_HIDDEN = NUM_DIGITS + 1 # 11: 10 candidates + digit_pos
LAYER1_FFN_HIDDEN = NUM_DIGITS * 2 # 20
# Dimension layout (non-overlapping):
# Dims 0-4: Embedding values
# Dims 5-14: Candidate errors (from FFN)
# Dim 15: Digit position value (from FFN)
FFN_OUTPUT_START = EMBEDDING_DIM # 5
CANDIDATES_START = FFN_OUTPUT_START # 5
CANDIDATES_END = CANDIDATES_START + NUM_DIGITS # 15
DIGIT_POS_DIM = CANDIDATES_END # 15
LAYER1_D_MODEL = DIGIT_POS_DIM + 1 # 16
# Attention score constants
K_DIGIT_SCORE = -1000.0
K_SPECIAL_BASE = -40.0
K_SPECIAL_SCORE = K_SPECIAL_BASE
# V projection values (absorbs /= DIGIT_EMBED_SCALE scaling for input positions)
V_PROJ_SPECIAL = 0.1 # V = 0.1 for <bos>/+ (input positions attend here)
V_PROJ_NEG_DOUBLE = -1.1 # V = 0.1 + (-1.1) = -1.0 for '=' (answer positions attend here)
V_PROJ_SCALE = exp(K_SPECIAL_SCORE - log(10)) # ≈ 9.357e-15
# Embeddings: [eq_flag, special_flag, digit*DIGIT_EMBED_SCALE, count, scale]
embeddings = {
**{i: [0, 0, i * DIGIT_EMBED_SCALE, 0, 0] for i in range(NUM_DIGITS)}, # digits
10: [1, 1, 0, 0, 0], # "="
11: [0, 1, 0, 0, 0], # "<bos>"
12: [0, 0, 0, 0, 0], # "<eos>"
13: [0, 1, 0, 0, 0], # "+"
}
def tokenize(expr):
return ["<bos>"] + list(expr) + ["<eos>"]
def softmax1(x, dim=-1):
"""Softmax with +1 in denominator (allows attending to nothing)."""
exp_x = x.exp()
return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True))
def apply_alibi(seq_len, n_heads):
"""ALiBi positional bias for attention, scaled by ALIBI_CONSTANT."""
pos = torch.arange(seq_len)
rel_pos = pos.unsqueeze(0) - pos.unsqueeze(1)
slopes = torch.zeros(n_heads, dtype=torch.float64)
slopes[ADJUSTMENT_HEAD] = ALIBI_CONSTANT # Only the adjustment head uses ALiBi
return slopes.unsqueeze(1).unsqueeze(2) * rel_pos.unsqueeze(0)
def pad_to(x, d):
if x.size(-1) >= d:
return x[..., :d]
return torch.cat([x, torch.zeros(*x.shape[:-1], d - x.size(-1), dtype=x.dtype)], dim=-1)
class TinyAdder(nn.Module):
def __init__(self):
super().__init__()
d = torch.float64
self.n_layers = 2
self.embedding = torch.tensor([embeddings[i] for i in range(len(embeddings))], dtype=d)
self.n_heads = [LAYER0_HEADS, LAYER1_HEADS]
self.d_model = [EMBEDDING_DIM, LAYER1_D_MODEL, NUM_DIGITS]
self.ffn_hidden = [LAYER0_FFN_HIDDEN, LAYER1_FFN_HIDDEN]
self.use_alibi = [True, False] # Only layer 0 uses ALiBi
self.q_proj = nn.ModuleList([
nn.Linear(self.d_model[i], self.n_heads[i], dtype=d) for i in range(self.n_layers)
])
self.k_proj = nn.ModuleList([
nn.Linear(self.d_model[i], self.n_heads[i], dtype=d) for i in range(self.n_layers)
])
self.v_proj = nn.ModuleList([
nn.Linear(self.d_model[i], self.n_heads[i], dtype=d) for i in range(self.n_layers)
])
self.gate = nn.ModuleList([
nn.Linear(self.d_model[i], self.ffn_hidden[i], dtype=d) for i in range(self.n_layers)
])
self.up = nn.ModuleList([
nn.Linear(self.d_model[i], self.ffn_hidden[i], dtype=d) for i in range(self.n_layers)
])
self.down = nn.ModuleList([
nn.Linear(self.ffn_hidden[i], self.d_model[i+1], bias=False, dtype=d) for i in range(self.n_layers)
])
# Initialize all parameters to zero
for module in [self.q_proj, self.k_proj, self.v_proj, self.gate, self.up, self.down]:
for layer_module in module:
if hasattr(layer_module, 'weight') and layer_module.weight is not None:
layer_module.weight.data.zero_()
if hasattr(layer_module, 'bias') and layer_module.bias is not None:
layer_module.bias.data.zero_()
self.q_proj[0].bias.data = torch.ones(1, dtype=d)
self.k_proj[0].weight.data[ADJUSTMENT_HEAD, SPECIAL_DIM] = K_SPECIAL_SCORE - K_DIGIT_SCORE
self.k_proj[0].bias.data[ADJUSTMENT_HEAD] = K_DIGIT_SCORE
self.v_proj[0].weight.data[ADJUSTMENT_HEAD, SPECIAL_DIM] = V_PROJ_SPECIAL / V_PROJ_SCALE
self.v_proj[0].weight.data[ADJUSTMENT_HEAD, EQ_DIM] = V_PROJ_NEG_DOUBLE / V_PROJ_SCALE
self.v_proj[0].weight.data[SCALE_HEAD, EQ_DIM] = 1.0
self.v_proj[1].weight.data[0, DIGIT_POS_DIM] = FINAL_SCALE
self.v_proj[1].bias.data[0] = GATE_BIAS_SHIFT
possible_values = (torch.arange(NUM_DIGITS).double() + DIGIT_OFFSET) * DIGIT_SCALE * FINAL_SCALE
self.gate[0].weight.data[NUM_DIGITS, DIGIT_DIM] = 1
self.gate[0].weight.data[:NUM_DIGITS, SCALE_DIM] = 1
self.up[0].weight.data[NUM_DIGITS, COUNT_DIM] = DIGIT_SCALE
self.up[0].weight.data[:NUM_DIGITS, COUNT_DIM] = possible_values
self.down[0].weight.data[FFN_OUTPUT_START:FFN_OUTPUT_START + LAYER0_FFN_HIDDEN, range(LAYER0_FFN_HIDDEN)] = torch.eye(LAYER0_FFN_HIDDEN, dtype=d)
self.gate[1].weight.data[:NUM_DIGITS, CANDIDATES_START:CANDIDATES_END] = torch.eye(NUM_DIGITS, dtype=d) * V_SCALE
self.gate[1].weight.data[NUM_DIGITS:, CANDIDATES_START:CANDIDATES_END] = torch.eye(NUM_DIGITS, dtype=d) * -V_SCALE
self.up[1].bias.data = torch.ones(1, dtype=d) * FINAL_SCALE # Broadcast
self.down[1].weight.data[:, :NUM_DIGITS] = torch.eye(NUM_DIGITS, dtype=d)
self.down[1].weight.data[:, NUM_DIGITS:] = torch.eye(NUM_DIGITS, dtype=d)
def forward(self, x):
batch_size, seq_len = x.shape
h = self.embedding[x]
for layer in range(self.n_layers):
h = pad_to(h, self.d_model[layer])
n_heads = self.n_heads[layer]
# Attention
q = self.q_proj[layer](h).view(batch_size, seq_len, n_heads, 1).transpose(1, 2)
k = self.k_proj[layer](h).view(batch_size, seq_len, n_heads, 1).transpose(1, 2)
v = self.v_proj[layer](h).view(batch_size, seq_len, n_heads, 1).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1))
if self.use_alibi[layer]:
scores = scores + apply_alibi(seq_len, n_heads).unsqueeze(0)
scores = scores.masked_fill(torch.triu(torch.ones(seq_len, seq_len), 1).bool(), float('-inf'))
attn = softmax1(scores, dim=-1).double()
attn_out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
h = h + attn_out
# FFN
gate_out = F.relu(self.gate[layer](h))
ffn_out = self.down[layer](gate_out * self.up[layer](h))
h = pad_to(h, self.d_model[layer + 1])
h = h + ffn_out
return h.argmin(dim=-1)
@torch.inference_mode()
def test(model, n_tests=10000):
"""Test accuracy across digit ranges."""
print("=" * 50)
print("Testing 10-digit addition")
print("=" * 50)
for n_digits in range(8, 11):
max_val = 10**n_digits - 1
correct = 0
tested = 0
while tested < n_tests:
a = random.randint(0, max_val)
b = random.randint(0, max_val)
s = a + b
expr = f"{a:010d}+{b:010d}={s:011d}"
toks = [tokens.index(t) for t in tokenize(expr)]
x = torch.tensor(toks).unsqueeze(0)
pred = model(x)
pred_str = ''.join(str(pred[0, p].item()) for p in range(POS_ANS_OUTPUT_START, POS_ANS_OUTPUT_END))
if pred_str == f"{s:011d}":
correct += 1
tested += 1
print(f"{n_digits:2d} digits: {correct}/{n_tests} = {100*correct//n_tests}%")
if __name__ == "__main__":
random.seed(14)
torch.manual_seed(14)
model = TinyAdder()
# Count non-zero parameters (excluding zero-initialized weights)
nonzero = (model.embedding != 0).sum().item()
for name, p in model.named_parameters():
nonzero += (p != 0).sum().item()
print(f"Non-zero parameters: {nonzero}")
print()
test(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment