Last active
February 27, 2026 17:29
-
-
Save alexlitz/d05abbddb56d22ee4c0ac1563f0e0ef6 to your computer and use it in GitHub Desktop.
tiny adder 36
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| TinyAdder: A 36-parameter hand-crafted transformer for 10-digit addition. | |
| This model adds two 10-digit numbers with 100% accuracy using only 36 unique parameters. | |
| Architecture: | |
| - 2-layer transformer with ALiBi positional encoding | |
| - Layer 0: 5 attention heads (only 2 active), ReGLU FFN | |
| - Layer 1: 1 head uniform attention, V-shaped error FFN | |
| - softmax1: softmax with +1 in denominator (can attend to nothing) | |
| Parameter counting rules: | |
| - Identity mappings (direct copy): 0 params | |
| - Broadcast (1 value to N outputs): 1 param | |
| - Distinct values: count each | |
| Breakdown (unique scalar parameters): | |
| Embedding: 13 | |
| - 9 digit value embeddings (1–9) | |
| - 4 special flags (=, <bos>, <eos>, +) | |
| Layer 0 Attention: 6 | |
| - Q bias: 1 | |
| - K weight + bias: 2 | |
| - V weights: 3 | |
| Layer 0 FFN: 12 | |
| - Gate broadcast bias: 1 | |
| - Up projection values: 11 | |
| Layer 1 Attention: 2 | |
| - V weight + bias: 2 | |
| Layer 1 FFN: 3 | |
| - ±V_SCALE gates: 2 | |
| - Up broadcast: 1 | |
| ─────────────── | |
| Total: 36 | |
| Author: Alex Litzenberger | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from math import log, exp | |
| # === Configuration === | |
| NUM_DIGITS = 10 | |
| TOKENS = [str(i) for i in range(10)] + ["=", "<bos>", "<eos>", "+"] | |
| # Scaling constants | |
| DIGIT_EMBED_SCALE = 10 # Embedding scale for digit values | |
| V_SCALE = 1e4 # V-shaped error amplification | |
| DIGIT_SCALE = 1e10 # Position encoding scale | |
| FINAL_SCALE = 100 # Output scaling | |
| DIGIT_OFFSET = 0.5 # Offset for rounding behavior | |
| GATE_BIAS_SHIFT = 15.0 # Bias for gate activation | |
| ALIBI_CONSTANT = log(10) # 10x decay per position | |
| # Dimension indices | |
| EQ_DIM, SPECIAL_DIM, DIGIT_DIM, COUNT_DIM, SCALE_DIM = 0, 1, 2, 3, 4 | |
| EMBEDDING_DIM = 5 | |
| LAYER0_HEADS = 5 | |
| ADJUSTMENT_HEAD = 3 | |
| SCALE_HEAD = 4 | |
| CANDIDATES_START = 5 | |
| DIGIT_POS_DIM = 15 | |
| LAYER1_D_MODEL = 16 | |
| # Attention score constants | |
| K_DIGIT_SCORE = -1000.0 | |
| K_SPECIAL_SCORE = -40.0 | |
| V_PROJ_SPECIAL = 0.1 | |
| V_PROJ_NEG_DOUBLE = -1.1 | |
| V_PROJ_SCALE = exp(K_SPECIAL_SCORE - log(10)) | |
| def softmax1(x, dim=-1): | |
| """Softmax with +1 in denominator - allows attending to nothing.""" | |
| exp_x = x.exp() | |
| return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True)) | |
| def apply_alibi(seq_len, n_heads): | |
| """ALiBi positional bias - only ADJUSTMENT_HEAD uses it.""" | |
| pos = torch.arange(seq_len) | |
| rel_pos = pos.unsqueeze(0) - pos.unsqueeze(1) | |
| slopes = torch.zeros(n_heads, dtype=torch.float64) | |
| slopes[ADJUSTMENT_HEAD] = ALIBI_CONSTANT | |
| return slopes.unsqueeze(1).unsqueeze(2) * rel_pos.unsqueeze(0) | |
| def pad_to(x, d): | |
| """Pad or truncate last dimension to size d.""" | |
| if x.size(-1) >= d: | |
| return x[..., :d] | |
| return torch.cat([x, torch.zeros(*x.shape[:-1], d - x.size(-1), dtype=x.dtype)], dim=-1) | |
| class TinyAdder: | |
| """36-parameter transformer for 10-digit addition.""" | |
| def __init__(self): | |
| d = torch.float64 | |
| # === EMBEDDING (13 params) === | |
| # 9 digit values (1-9) + 4 special flags (=, <bos>, +) | |
| emb_idx = [[i, DIGIT_DIM] for i in range(1, 10)] | |
| emb_idx += [[10, EQ_DIM], [10, SPECIAL_DIM], [11, SPECIAL_DIM], [13, SPECIAL_DIM]] | |
| emb_val = [float(i * DIGIT_EMBED_SCALE) for i in range(1, 10)] + [1.0, 1.0, 1.0, 1.0] | |
| self.embedding = torch.sparse_coo_tensor( | |
| torch.tensor(emb_idx).T, torch.tensor(emb_val, dtype=d), (14, 5) | |
| ).to_dense() | |
| # === L0 ATTENTION (6 params) === | |
| self.k0_weight = torch.tensor(K_SPECIAL_SCORE - K_DIGIT_SCORE, dtype=d) | |
| self.k0_bias = torch.tensor(K_DIGIT_SCORE, dtype=d) | |
| self.v0_w1 = torch.tensor(V_PROJ_SPECIAL / V_PROJ_SCALE, dtype=d) | |
| self.v0_w2 = torch.tensor(V_PROJ_NEG_DOUBLE / V_PROJ_SCALE, dtype=d) | |
| self.v0_w3 = torch.tensor(1.0, dtype=d) | |
| # === L0 FFN (12 params) === | |
| pv = [(i + DIGIT_OFFSET) * DIGIT_SCALE * FINAL_SCALE for i in range(NUM_DIGITS)] | |
| self.up0_vals = torch.tensor(pv + [DIGIT_SCALE], dtype=d) | |
| @torch.inference_mode() | |
| def forward(self, x): | |
| batch_size, seq_len = x.shape | |
| d = torch.float64 | |
| h = self.embedding[x] | |
| # === LAYER 0: Attention === | |
| h = pad_to(h, EMBEDDING_DIM) | |
| q = torch.ones(batch_size, seq_len, LAYER0_HEADS, dtype=d) | |
| k = torch.zeros(batch_size, seq_len, LAYER0_HEADS, dtype=d) | |
| k[..., ADJUSTMENT_HEAD] = h[..., SPECIAL_DIM] * self.k0_weight + self.k0_bias | |
| v = torch.zeros(batch_size, seq_len, LAYER0_HEADS, dtype=d) | |
| v[..., ADJUSTMENT_HEAD] = h[..., SPECIAL_DIM] * self.v0_w1 + h[..., EQ_DIM] * self.v0_w2 | |
| v[..., SCALE_HEAD] = h[..., EQ_DIM] * self.v0_w3 | |
| q = q.view(batch_size, seq_len, LAYER0_HEADS, 1).transpose(1, 2) | |
| k = k.view(batch_size, seq_len, LAYER0_HEADS, 1).transpose(1, 2) | |
| v = v.view(batch_size, seq_len, LAYER0_HEADS, 1).transpose(1, 2) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) + apply_alibi(seq_len, LAYER0_HEADS).unsqueeze(0) | |
| scores = scores.masked_fill(torch.triu(torch.ones(seq_len, seq_len), 1).bool(), float('-inf')) | |
| attn = softmax1(scores, dim=-1).double() | |
| h = h + torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1) | |
| # === LAYER 0: FFN === | |
| gate_in = torch.zeros(batch_size, seq_len, 11, dtype=d) | |
| gate_in[..., :NUM_DIGITS] = h[..., SCALE_DIM:SCALE_DIM+1] | |
| gate_in[..., NUM_DIGITS] = h[..., DIGIT_DIM] | |
| gate_out = F.relu(gate_in) | |
| up_out = h[..., COUNT_DIM:COUNT_DIM+1] * self.up0_vals | |
| ffn_hidden = gate_out * up_out | |
| h = pad_to(h, LAYER1_D_MODEL) | |
| h[..., 5:16] = h[..., 5:16] + ffn_hidden | |
| # === LAYER 1: Attention === | |
| v = h[..., DIGIT_POS_DIM:DIGIT_POS_DIM+1] * FINAL_SCALE + GATE_BIAS_SHIFT | |
| v = v.view(batch_size, seq_len, 1, 1).transpose(1, 2) | |
| scores = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=d) | |
| scores = scores.masked_fill(torch.triu(torch.ones(seq_len, seq_len), 1).bool(), float('-inf')) | |
| attn = softmax1(scores, dim=-1).double() | |
| h = h + torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1) | |
| # === LAYER 1: FFN (V-shaped error) === | |
| candidates = h[..., CANDIDATES_START:CANDIDATES_START+NUM_DIGITS] | |
| gate_pos = F.relu(candidates * V_SCALE) | |
| gate_neg = F.relu(candidates * -V_SCALE) | |
| ffn_out = (gate_pos + gate_neg) * FINAL_SCALE | |
| h = pad_to(h, NUM_DIGITS) | |
| h = h + ffn_out | |
| return h.argmin(dim=-1) | |
| def add(self, a: int, b: int) -> int: | |
| """Add two integers (each up to 10 digits).""" | |
| s = a + b | |
| expr = f"{a:010d}+{b:010d}={s:011d}" | |
| toks = [TOKENS.index(t) for t in ["<bos>"] + list(expr) + ["<eos>"]] | |
| x = torch.tensor(toks).unsqueeze(0) | |
| pred = self.forward(x) | |
| return int(''.join(str(pred[0, p].item()) for p in range(22, 33))) | |
| if __name__ == "__main__": | |
| import random | |
| model = TinyAdder() | |
| print("TinyAdder: 36-parameter transformer for 10-digit addition") | |
| print("=" * 55) | |
| print("\nParameter breakdown (unique scalar params):") | |
| print(" Embedding: 13") | |
| print(" - 9 digit value embeddings (1–9)") | |
| print(" - 4 special flags (=, <bos>, <eos>, +)") | |
| print(" Layer 0 Attention: 6") | |
| print(" - Q bias: 1") | |
| print(" - K weight + bias: 2") | |
| print(" - V weights: 3") | |
| print(" Layer 0 FFN: 12") | |
| print(" - Gate broadcast bias: 1") | |
| print(" - Up projection values: 11") | |
| print(" Layer 1 Attention: 2") | |
| print(" - V weight + bias: 2") | |
| print(" Layer 1 FFN: 3") | |
| print(" - ±V_SCALE gates: 2") | |
| print(" - Up broadcast: 1") | |
| print(" " + "─" * 15) | |
| print(" Total: 36") | |
| # Demo | |
| print("\nExamples:") | |
| examples = [(1234567890, 9876543210), (9999999999, 1), (0, 0)] | |
| for a, b in examples: | |
| result = model.add(a, b) | |
| check = "✓" if result == a + b else "✗" | |
| print(f" {a:>10} + {b:>10} = {result:>11} {check}") | |
| # Full test | |
| print("\nRunning 1000 random tests...") | |
| random.seed(42) | |
| correct = sum(1 for _ in range(1000) | |
| for a, b in [(random.randint(0, 9_999_999_999), random.randint(0, 9_999_999_999))] | |
| if model.add(a, b) == a + b) | |
| print(f"Accuracy: {correct}/1000 ({100*correct/1000:.1f}%)") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment