Skip to content

Instantly share code, notes, and snippets.

@sytelus
Created February 25, 2026 07:18
Show Gist options
  • Select an option

  • Save sytelus/ceb85dab52f6cac741e602ac71b752f9 to your computer and use it in GitHub Desktop.

Select an option

Save sytelus/ceb85dab52f6cac741e602ac71b752f9 to your computer and use it in GitHub Desktop.
GPT-style decoder that adds two 10-digit numbers with just 46 params
import math
import random
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
"""
Minimal GPT-style decoder-only adder (<50 params, no checkpoint).
Key points:
- Model remains tiny (46 trainable parameters).
- No conventional training loop over large datasets/checkpoints.
- We solve weights using a tiny optimization over the 16 full-adder transition
states in `compute_weights()`.
- Inference is standard autoregressive generation (`argmax` over next token).
"""
WIDTH = 35
PROMPT_LEN = WIDTH + 1
GEN_LEN = WIDTH
# ==========================================
# 1. TOKENIZER (binary pair encoding)
# ==========================================
class AdderTokenizer:
"""Encode A+B prompt into 35 bit-pair tokens + one start-state token."""
prompt_len = PROMPT_LEN
gen_len = GEN_LEN
@staticmethod
def parse_prompt(prompt: str) -> Tuple[int, int]:
if not prompt.endswith("=") or "+" not in prompt:
raise ValueError(f"Invalid prompt format: {prompt!r}")
a_str, b_str = prompt[:-1].split("+")
a = int(a_str)
b = int(b_str)
if not (0 <= a <= 9_999_999_999 and 0 <= b <= 9_999_999_999):
raise ValueError("Operands must be in [0, 9_999_999_999]")
return a, b
def encode(self, strings: Sequence[str]) -> torch.Tensor:
batch = []
for s in strings:
a, b = self.parse_prompt(s)
a_bin = bin(a)[2:].zfill(WIDTH)[::-1] # LSB first
b_bin = bin(b)[2:].zfill(WIDTH)[::-1] # LSB first
# Token in {0,1,2,3} encodes one bit pair (a_i, b_i): 2*a_i + b_i
seq = [int(x) * 2 + int(y) for x, y in zip(a_bin, b_bin)]
# Initial state token O0 = 0 (sum_bit=0, carry=0).
seq.append(0)
batch.append(seq)
return torch.tensor(batch, dtype=torch.long)
@staticmethod
def decode(token_ids: torch.Tensor) -> List[str]:
"""Decode generated state tokens into 11-digit decimal strings."""
answers = []
for seq in token_ids:
gen_tokens = seq[PROMPT_LEN:] # generated state tokens
bits = [str(int(t.item()) % 2) for t in gen_tokens] # sum bits
val = int("".join(bits[::-1]), 2) # back to MSB-first
answers.append(f"{val:011d}")
return answers
# ==========================================
# 2. GPT-STYLE DECODER-ONLY MODEL (46 params)
# ==========================================
class GPTAdder(nn.Module):
def __init__(self):
super().__init__()
# 8 params
self.wte = nn.Embedding(4, 2)
# 16 params (bias=False keeps it minimal)
self.attn = nn.MultiheadAttention(embed_dim=2, num_heads=1, bias=False, batch_first=True)
# 22 params
self.mlp = nn.Sequential(
nn.Linear(2, 4, bias=True),
nn.ReLU(),
nn.Linear(4, 2, bias=True),
)
# tied head (0 extra params)
self.lm_head = nn.Linear(2, 4, bias=False)
self.lm_head.weight = self.wte.weight
# Generic default: standard causal mask.
causal = ~torch.tril(torch.ones(PROMPT_LEN + GEN_LEN - 1, PROMPT_LEN + GEN_LEN - 1, dtype=torch.bool))
self.register_buffer("attn_mask", causal, persistent=False)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.wte(input_ids)
seq_len = x.size(1)
attn_out, _ = self.attn(x, x, x, attn_mask=self.attn_mask[:seq_len, :seq_len], need_weights=False)
x = x + attn_out
x = x + self.mlp(x)
return self.lm_head(x)
# ==========================================
# 3. WEIGHT SOLVER (tiny optimization, no ckpt)
# ==========================================
def build_transition_supervision() -> Tuple[torch.Tensor, torch.Tensor]:
"""16 full-adder transitions (T,O)->Y as tiny supervised set."""
contexts = []
targets = []
for a in [0, 1]:
for b in [0, 1]:
for s in [0, 1]:
for c in [0, 1]:
t = a * 2 + b # pair token
o = s + 2 * c # current state token
y = (a + b + c) % 2 + 2 * ((a + b + c) // 2) # next state
x = torch.zeros(PROMPT_LEN, dtype=torch.long)
x[0] = t
x[WIDTH] = o
contexts.append(x)
targets.append(y)
return torch.stack(contexts), torch.tensor(targets, dtype=torch.long)
def program_transition_attention_mask(model: GPTAdder) -> None:
"""Program routing mask outside the model code.
This keeps model/generation generic while allowing task-specific weight
programming in this function.
"""
with torch.no_grad():
m = torch.ones_like(model.attn_mask, dtype=torch.bool)
i = torch.arange(m.size(0), device=m.device)
m[i, i] = False
src = i - WIDTH
valid = src >= 0
m[i[valid], src[valid]] = False
model.attn_mask.copy_(m)
def transition_table_accuracy(model: GPTAdder, contexts: torch.Tensor, targets: torch.Tensor) -> float:
with torch.no_grad():
logits = model(contexts)[:, -1, :]
pred = logits.argmax(dim=-1)
return float((pred == targets).float().mean().item())
def compute_weights(model: GPTAdder, max_restarts: int = 8, max_steps: int = 3000, lr: float = 5e-3) -> None:
"""Solve the tiny model by optimization on 16 transitions.
This is not conventional training on large datasets. It is a direct
parameter solve over the exact full-adder truth table.
"""
contexts, targets = build_transition_supervision()
best_acc = -1.0
best_state = None
program_transition_attention_mask(model)
for seed in range(1, max_restarts + 1):
torch.manual_seed(seed)
fresh = GPTAdder()
program_transition_attention_mask(fresh)
model.load_state_dict(fresh.state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for _ in range(max_steps):
logits = model(contexts)[:, -1, :]
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
if bool((logits.argmax(dim=-1) == targets).all()):
print(f"Solved transition table exactly with seed={seed}")
return
acc = transition_table_accuracy(model, contexts, targets)
if acc > best_acc:
best_acc = acc
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
if best_state is not None:
model.load_state_dict(best_state)
raise RuntimeError(f"Could not solve transitions exactly. Best transition accuracy: {best_acc:.4f}")
# ==========================================
# 4. AUTOREGRESSIVE GENERATION
# ==========================================
def generate(model: GPTAdder, tokenizer: AdderTokenizer, strings: Sequence[str]) -> List[str]:
input_ids = tokenizer.encode(strings)
model.eval()
with torch.no_grad():
for _ in range(tokenizer.gen_len):
logits = model(input_ids)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
return tokenizer.decode(input_ids)
# ==========================================
# 5. DEBUG / EVAL
# ==========================================
def expected_states(a: int, b: int) -> List[int]:
out = []
carry = 0
for i in range(35):
abit = (a >> i) & 1
bbit = (b >> i) & 1
s = abit + bbit + carry
sum_bit = s & 1
carry = s >> 1
out.append(sum_bit + 2 * carry)
return out
def debug_one(model: GPTAdder, tokenizer: AdderTokenizer, prompt: str) -> None:
a, b = tokenizer.parse_prompt(prompt)
expected = f"{a + b:011d}"
ids = tokenizer.encode([prompt])
trace = []
model.eval()
with torch.no_grad():
for _ in range(tokenizer.gen_len):
logits = model(ids)
nxt = int(logits[0, -1, :].argmax().item())
trace.append(nxt)
ids = torch.cat([ids, torch.tensor([[nxt]], dtype=torch.long)], dim=1)
pred = tokenizer.decode(ids)[0]
exp_trace = expected_states(a, b)
print("Failure debug:")
print(f"prompt={prompt}")
print(f"pred={pred} expected={expected}")
print(f"generated_states_first12={trace[:12]}")
print(f"expected_states_first12={exp_trace[:12]}")
print(f"generated_sum_bits_first12={[t % 2 for t in trace[:12]]}")
print(f"expected_sum_bits_first12={[t % 2 for t in exp_trace[:12]]}")
def make_prompts(n: int, seed: int = 42) -> List[str]:
random.seed(seed)
out = []
for _ in range(n):
a = random.randint(0, 9_999_999_999)
b = random.randint(0, 9_999_999_999)
out.append(f"{a:010d}+{b:010d}=")
return out
def run_stage(model: GPTAdder, tokenizer: AdderTokenizer, prompts: Sequence[str], n: int) -> bool:
subset = list(prompts[:n])
pred = generate(model, tokenizer, subset)
exp = [f"{tokenizer.parse_prompt(p)[0] + tokenizer.parse_prompt(p)[1]:011d}" for p in subset]
correct = sum(int(p == e) for p, e in zip(pred, exp))
print("====================================")
print(f"Stage {n}: {correct}/{n} correct")
print("====================================")
for i in range(min(n, 3)):
print(f"Prompt : {subset[i]}")
print(f"Output : {pred[i]}")
print(f"Math : {exp[i]}\n")
if correct != n:
first_bad = next(i for i, (p, e) in enumerate(zip(pred, exp)) if p != e)
debug_one(model, tokenizer, subset[first_bad])
return False
return True
# ==========================================
# 6. MAIN
# ==========================================
if __name__ == "__main__":
model = GPTAdder()
compute_weights(model)
tokenizer = AdderTokenizer()
param_count = sum(p.numel() for p in model.parameters())
print(f"Total Standard Parameter Count: {param_count}\n")
prompts = make_prompts(100, seed=42)
# Requested progression
if not run_stage(model, tokenizer, prompts, 1):
raise SystemExit(1)
if not run_stage(model, tokenizer, prompts, 2):
raise SystemExit(1)
if not run_stage(model, tokenizer, prompts, 3):
raise SystemExit(1)
# Extended checks
ok10 = run_stage(model, tokenizer, prompts, 10)
ok100 = run_stage(model, tokenizer, prompts, 100)
if ok10 and ok100:
print("All stages passed: 1, 2, 3, 10, 100.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment