Skip to content

Instantly share code, notes, and snippets.

@sytelus
Created February 25, 2026 00:34
Show Gist options
  • Select an option

  • Save sytelus/a8ea6b256f70dd8a044eb0e07ae29f14 to your computer and use it in GitHub Desktop.

Select an option

Save sytelus/a8ea6b256f70dd8a044eb0e07ae29f14 to your computer and use it in GitHub Desktop.
Addition with 30 params
import torch
import torch.nn as nn
class TransformerAdderDecoder(nn.Module):
def __init__(self):
super().__init__()
# 1. Embedding layer to map digit tokens (0-9) to features (10 parameters)
self.embed = nn.Embedding(10, 1)
# 2. MLP: 3 inputs (Digit A, Digit B, Carry) -> Hidden(3) -> Output(Sum, Carry)
# 3 inputs -> 3 hidden = 12 parameters
self.fc1 = nn.Linear(3, 3)
self.relu = nn.ReLU()
# 3 hidden -> 2 outputs = 8 parameters
self.fc2 = nn.Linear(3, 2)
# Initialize with exactly derived weights (No training required)
with torch.no_grad():
# Embeddings just map a token to its float value (e.g., token 7 -> 7.0)
self.embed.weight.copy_(torch.arange(10, dtype=torch.float).view(10, 1))
# Layer 1 Computes 3 values:
# h[0] = x = a + b + c (The raw sum)
# h[1] = max(0, x - 9)
# h[2] = max(0, x - 10)
self.fc1.weight.copy_(torch.tensor([
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]
]))
self.fc1.bias.copy_(torch.tensor([0.0, -9.0, -10.0]))
# Layer 2 isolates the sum digit and the carry out
# carry_out = h[1] - h[2] -> exactly 1 if x >= 10, else 0
# sum_digit = h[0] - 10 * carry_out -> strips the tens place
self.fc2.weight.copy_(torch.tensor([
[1.0, -10.0, 10.0], # Computes sum_digit
[0.0, 1.0, -1.0] # Computes carry_out
]))
self.fc2.bias.copy_(torch.tensor([0.0, 0.0]))
def forward(self, n1_str, n2_str):
# Pad strings to 10 digits
n1_str = str(n1_str).zfill(10)
n2_str = str(n2_str).zfill(10)
carry = torch.tensor([0.0])
result = []
# Process sequence right-to-left (autoregressively, like a text decoder)
for i in range(9, -1, -1):
# 1. Convert string characters to integer tokens
a_idx = torch.tensor(int(n1_str[i]), dtype=torch.long)
b_idx = torch.tensor(int(n2_str[i]), dtype=torch.long)
# 2. Map tokens to features using the Embedding
a_val = self.embed(a_idx).view(1)
b_val = self.embed(b_idx).view(1)
# 3. Concatenate (No manual addition math used here!)
x = torch.cat([a_val, b_val, carry])
# 4. Forward pass through MLP to simulate the addition logic
h = self.fc1(x)
h = self.relu(h)
out = self.fc2(h)
# 5. Extract results
sum_digit = torch.round(out[0]).item()
carry = out[1].view(1)
result.append(str(int(sum_digit)))
# Append final carry if it exists
if torch.round(carry[0]).item() > 0:
result.append(str(int(torch.round(carry[0]).item())))
return "".join(result[::-1])
# Run and verify the parameter count
model = TransformerAdderDecoder()
print(f"Total Model Parameters: {sum(p.numel() for p in model.parameters())}\n")
# Test 1: Standard large numbers
num1 = "1234567890"
num2 = "9876543210"
print(f"Test 1: {num1} + {num2}")
print(f"Model Output: {model(num1, num2)}")
print(f"Actual Math: {1234567890 + 9876543210}\n")
# Test 2: Triggering a massive cascading carry
num3 = "9999999999"
num4 = "1"
print(f"Test 2: {num3} + {num4.zfill(10)}")
print(f"Model Output: {model(num3, num4)}")
print(f"Actual Math: {9999999999 + 1}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment