Created
February 25, 2026 10:07
-
-
Save alexlitz/98afb7d6573d1a42f2b0a2cbbdc0304f to your computer and use it in GitHub Desktop.
Tiny Adder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| TinyAdder: A hand-crafted 95-parameter transformer that performs 10-digit addition with ~100% accuracy. | |
| Only non-zero parameters are counted, so the nominal number of parameters is higher but most are zero. | |
| Architecture: | |
| - 2-layer transformer with ALiBi positional encoding | |
| - Layer 0: 5 attention heads | |
| - Layer 1: 1 head uniform attention | |
| - Embeddings: 5 dimensions [eq_flag, special_flag, digit_value, 0, 0] | |
| Example: "1234567890+9876543210=11111111100" | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from math import log, exp | |
| import random | |
| # === Constants === | |
| NUM_DIGITS = 10 | |
| NUM_INPUT_DIGITS = 10 | |
| tokens = [str(i) for i in range(NUM_DIGITS)] + ["=", "<bos>", "<eos>", "+"] | |
| # Position indices | |
| POS_NUM1_START, POS_NUM2_START = 1, 12 | |
| POS_EQUALS = 22 | |
| POS_ANS_TOKEN_START = 23 | |
| POS_ANS_OUTPUT_START = 22 | |
| POS_ANS_OUTPUT_END = 33 | |
| # Scaling | |
| DIGIT_EMBED_SCALE = 10 # Scale for digit values in embedding | |
| V_SCALE = 1e4 # Amplification for V-shaped error function | |
| DIGIT_SCALE = 1e10 | |
| FINAL_SCALE = 100 | |
| DIGIT_OFFSET = 0.5 | |
| GATE_BIAS_SHIFT = 15.0 | |
| ALIBI_CONSTANT = log(10) # gives 10x decay per position | |
| # Embedding indices | |
| EQ_DIM, SPECIAL_DIM, DIGIT_DIM, COUNT_DIM, SCALE_DIM = 0, 1, 2, 3, 4 | |
| EMBEDDING_DIM = 5 | |
| # Attention configuration | |
| LAYER0_HEADS = EMBEDDING_DIM # 5 heads, one per embedding dim | |
| LAYER1_HEADS = 1 | |
| ADJUSTMENT_HEAD = 3 | |
| SCALE_HEAD = 4 | |
| # FFN dimensions | |
| LAYER0_FFN_HIDDEN = NUM_DIGITS + 1 # 11: 10 candidates + digit_pos | |
| LAYER1_FFN_HIDDEN = NUM_DIGITS * 2 # 20 | |
| # Dimension layout (non-overlapping): | |
| # Dims 0-4: Embedding values | |
| # Dims 5-14: Candidate errors (from FFN) | |
| # Dim 15: Digit position value (from FFN) | |
| FFN_OUTPUT_START = EMBEDDING_DIM # 5 | |
| CANDIDATES_START = FFN_OUTPUT_START # 5 | |
| CANDIDATES_END = CANDIDATES_START + NUM_DIGITS # 15 | |
| DIGIT_POS_DIM = CANDIDATES_END # 15 | |
| LAYER1_D_MODEL = DIGIT_POS_DIM + 1 # 16 | |
| # Attention score constants | |
| K_DIGIT_SCORE = -1000.0 | |
| K_SPECIAL_BASE = -40.0 | |
| K_SPECIAL_SCORE = K_SPECIAL_BASE | |
| # V projection values (absorbs /= DIGIT_EMBED_SCALE scaling for input positions) | |
| V_PROJ_SPECIAL = 0.1 # V = 0.1 for <bos>/+ (input positions attend here) | |
| V_PROJ_NEG_DOUBLE = -1.1 # V = 0.1 + (-1.1) = -1.0 for '=' (answer positions attend here) | |
| V_PROJ_SCALE = exp(K_SPECIAL_SCORE - log(10)) # ≈ 9.357e-15 | |
| # Embeddings: [eq_flag, special_flag, digit*DIGIT_EMBED_SCALE, count, scale] | |
| embeddings = { | |
| **{i: [0, 0, i * DIGIT_EMBED_SCALE, 0, 0] for i in range(NUM_DIGITS)}, # digits | |
| 10: [1, 1, 0, 0, 0], # "=" | |
| 11: [0, 1, 0, 0, 0], # "<bos>" | |
| 12: [0, 0, 0, 0, 0], # "<eos>" | |
| 13: [0, 1, 0, 0, 0], # "+" | |
| } | |
| def tokenize(expr): | |
| return ["<bos>"] + list(expr) + ["<eos>"] | |
| def softmax1(x, dim=-1): | |
| """Softmax with +1 in denominator (allows attending to nothing).""" | |
| exp_x = x.exp() | |
| return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True)) | |
| def apply_alibi(seq_len, n_heads): | |
| """ALiBi positional bias for attention, scaled by ALIBI_CONSTANT.""" | |
| pos = torch.arange(seq_len) | |
| rel_pos = pos.unsqueeze(0) - pos.unsqueeze(1) | |
| slopes = torch.zeros(n_heads, dtype=torch.float64) | |
| slopes[ADJUSTMENT_HEAD] = ALIBI_CONSTANT # Only the adjustment head uses ALiBi | |
| return slopes.unsqueeze(1).unsqueeze(2) * rel_pos.unsqueeze(0) | |
| def pad_to(x, d): | |
| if x.size(-1) >= d: | |
| return x[..., :d] | |
| return torch.cat([x, torch.zeros(*x.shape[:-1], d - x.size(-1), dtype=x.dtype)], dim=-1) | |
| class TinyAdder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| d = torch.float64 | |
| self.n_layers = 2 | |
| self.embedding = torch.tensor([embeddings[i] for i in range(len(embeddings))], dtype=d) | |
| self.n_heads = [LAYER0_HEADS, LAYER1_HEADS] | |
| self.d_model = [EMBEDDING_DIM, LAYER1_D_MODEL, NUM_DIGITS] | |
| self.ffn_hidden = [LAYER0_FFN_HIDDEN, LAYER1_FFN_HIDDEN] | |
| self.use_alibi = [True, False] # Only layer 0 uses ALiBi | |
| self.q_proj = nn.ModuleList([ | |
| nn.Linear(self.d_model[i], self.n_heads[i], dtype=d) for i in range(self.n_layers) | |
| ]) | |
| self.k_proj = nn.ModuleList([ | |
| nn.Linear(self.d_model[i], self.n_heads[i], dtype=d) for i in range(self.n_layers) | |
| ]) | |
| self.v_proj = nn.ModuleList([ | |
| nn.Linear(self.d_model[i], self.n_heads[i], dtype=d) for i in range(self.n_layers) | |
| ]) | |
| self.gate = nn.ModuleList([ | |
| nn.Linear(self.d_model[i], self.ffn_hidden[i], dtype=d) for i in range(self.n_layers) | |
| ]) | |
| self.up = nn.ModuleList([ | |
| nn.Linear(self.d_model[i], self.ffn_hidden[i], dtype=d) for i in range(self.n_layers) | |
| ]) | |
| self.down = nn.ModuleList([ | |
| nn.Linear(self.ffn_hidden[i], self.d_model[i+1], bias=False, dtype=d) for i in range(self.n_layers) | |
| ]) | |
| # Initialize all parameters to zero | |
| for module in [self.q_proj, self.k_proj, self.v_proj, self.gate, self.up, self.down]: | |
| for layer_module in module: | |
| if hasattr(layer_module, 'weight') and layer_module.weight is not None: | |
| layer_module.weight.data.zero_() | |
| if hasattr(layer_module, 'bias') and layer_module.bias is not None: | |
| layer_module.bias.data.zero_() | |
| self.q_proj[0].bias.data = torch.ones(1, dtype=d) | |
| self.k_proj[0].weight.data[ADJUSTMENT_HEAD, SPECIAL_DIM] = K_SPECIAL_SCORE - K_DIGIT_SCORE | |
| self.k_proj[0].bias.data[ADJUSTMENT_HEAD] = K_DIGIT_SCORE | |
| self.v_proj[0].weight.data[ADJUSTMENT_HEAD, SPECIAL_DIM] = V_PROJ_SPECIAL / V_PROJ_SCALE | |
| self.v_proj[0].weight.data[ADJUSTMENT_HEAD, EQ_DIM] = V_PROJ_NEG_DOUBLE / V_PROJ_SCALE | |
| self.v_proj[0].weight.data[SCALE_HEAD, EQ_DIM] = 1.0 | |
| self.v_proj[1].weight.data[0, DIGIT_POS_DIM] = FINAL_SCALE | |
| self.v_proj[1].bias.data[0] = GATE_BIAS_SHIFT | |
| possible_values = (torch.arange(NUM_DIGITS).double() + DIGIT_OFFSET) * DIGIT_SCALE * FINAL_SCALE | |
| self.gate[0].weight.data[NUM_DIGITS, DIGIT_DIM] = 1 | |
| self.gate[0].weight.data[:NUM_DIGITS, SCALE_DIM] = 1 | |
| self.up[0].weight.data[NUM_DIGITS, COUNT_DIM] = DIGIT_SCALE | |
| self.up[0].weight.data[:NUM_DIGITS, COUNT_DIM] = possible_values | |
| self.down[0].weight.data[FFN_OUTPUT_START:FFN_OUTPUT_START + LAYER0_FFN_HIDDEN, range(LAYER0_FFN_HIDDEN)] = torch.eye(LAYER0_FFN_HIDDEN, dtype=d) | |
| self.gate[1].weight.data[:NUM_DIGITS, CANDIDATES_START:CANDIDATES_END] = torch.eye(NUM_DIGITS, dtype=d) * V_SCALE | |
| self.gate[1].weight.data[NUM_DIGITS:, CANDIDATES_START:CANDIDATES_END] = torch.eye(NUM_DIGITS, dtype=d) * -V_SCALE | |
| self.up[1].bias.data = torch.ones(1, dtype=d) * FINAL_SCALE # Broadcast | |
| self.down[1].weight.data[:, :NUM_DIGITS] = torch.eye(NUM_DIGITS, dtype=d) | |
| self.down[1].weight.data[:, NUM_DIGITS:] = torch.eye(NUM_DIGITS, dtype=d) | |
| def forward(self, x): | |
| batch_size, seq_len = x.shape | |
| h = self.embedding[x] | |
| for layer in range(self.n_layers): | |
| h = pad_to(h, self.d_model[layer]) | |
| n_heads = self.n_heads[layer] | |
| # Attention | |
| q = self.q_proj[layer](h).view(batch_size, seq_len, n_heads, 1).transpose(1, 2) | |
| k = self.k_proj[layer](h).view(batch_size, seq_len, n_heads, 1).transpose(1, 2) | |
| v = self.v_proj[layer](h).view(batch_size, seq_len, n_heads, 1).transpose(1, 2) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) | |
| if self.use_alibi[layer]: | |
| scores = scores + apply_alibi(seq_len, n_heads).unsqueeze(0) | |
| scores = scores.masked_fill(torch.triu(torch.ones(seq_len, seq_len), 1).bool(), float('-inf')) | |
| attn = softmax1(scores, dim=-1).double() | |
| attn_out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1) | |
| h = h + attn_out | |
| # FFN | |
| gate_out = F.relu(self.gate[layer](h)) | |
| ffn_out = self.down[layer](gate_out * self.up[layer](h)) | |
| h = pad_to(h, self.d_model[layer + 1]) | |
| h = h + ffn_out | |
| return h.argmin(dim=-1) | |
| @torch.inference_mode() | |
| def test(model, n_tests=10000): | |
| """Test accuracy across digit ranges.""" | |
| print("=" * 50) | |
| print("Testing 10-digit addition") | |
| print("=" * 50) | |
| for n_digits in range(8, 11): | |
| max_val = 10**n_digits - 1 | |
| correct = 0 | |
| tested = 0 | |
| while tested < n_tests: | |
| a = random.randint(0, max_val) | |
| b = random.randint(0, max_val) | |
| s = a + b | |
| expr = f"{a:010d}+{b:010d}={s:011d}" | |
| toks = [tokens.index(t) for t in tokenize(expr)] | |
| x = torch.tensor(toks).unsqueeze(0) | |
| pred = model(x) | |
| pred_str = ''.join(str(pred[0, p].item()) for p in range(POS_ANS_OUTPUT_START, POS_ANS_OUTPUT_END)) | |
| if pred_str == f"{s:011d}": | |
| correct += 1 | |
| tested += 1 | |
| print(f"{n_digits:2d} digits: {correct}/{n_tests} = {100*correct//n_tests}%") | |
| if __name__ == "__main__": | |
| random.seed(14) | |
| torch.manual_seed(14) | |
| model = TinyAdder() | |
| # Count non-zero parameters (excluding zero-initialized weights) | |
| nonzero = (model.embedding != 0).sum().item() | |
| for name, p in model.named_parameters(): | |
| nonzero += (p != 0).sum().item() | |
| print(f"Non-zero parameters: {nonzero}") | |
| print() | |
| test(model) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment