Created
February 24, 2026 00:02
-
-
Save JGalego/84f8dee8c761afdfc249471f3041fe48 to your computer and use it in GitHub Desktop.
MicroGPT on a Wengert tape in pure Rust ποΈπ¦ Adapted from https://github.com/mplekh/rust-microgpt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env -S cargo +nightly -Zscript | |
| --- | |
| [package] | |
| name = "microgpt" | |
| version = "0.1.0" | |
| edition = "2024" | |
| --- | |
| //! Rust translation of Andrej Karpathy's microgpt, using a Wengert tape for | |
| //! reverse-mode automatic differentiation instead of a pointer-based DAG. | |
| //! | |
| //! @karpathy (original Python), ported to Rust | |
| //! | |
| //! # Usage | |
| //! | |
| //! ## As a cargo script (single file, no project needed): | |
| //! | |
| //! Install nightly if you don't have it: | |
| //! rustup toolchain install nightly | |
| //! | |
| //! Run (input.txt is downloaded automatically if not present): | |
| //! cargo +nightly -Zscript tapegpt.rs | |
| //! | |
| //! Or make it executable: | |
| //! chmod +x tapegpt.rs | |
| //! ./tapegpt.rs | |
| use std::f64; | |
| use std::fs; | |
| use std::io::{self, Write}; | |
| // --------------------------------------------------------------------------- | |
| // Autograd: Wengert tape with reverse-mode autodiff | |
| // --------------------------------------------------------------------------- | |
| // | |
| // Instead of a pointer-based graph (like micrograd's Rc<RefCell<...>>), | |
| // every computed value is appended to a flat Vec (the "tape") in the order | |
| // it was created. Each node stores: | |
| // - its scalar `data` (forward value) | |
| // - its accumulated `grad` (filled during backward) | |
| // - indices of its `children` into the same Vec | |
| // - `local_grads` (partial derivatives w.r.t. each child) | |
| // | |
| // Because a node can only reference *earlier* indices, iterating in reverse | |
| // is always a valid topological order β no DFS required. | |
| // | |
| // `V` is just `usize` β a lightweight handle into the tape. | |
| type V = usize; | |
| #[derive(Clone)] | |
| struct Value { | |
| data: f64, | |
| grad: f64, | |
| children: [usize; 2], | |
| local_grads: [f64; 2], | |
| arity: u8, | |
| } | |
| struct Tape { | |
| nodes: Vec<Value>, | |
| } | |
| impl Tape { | |
| fn new() -> Self { | |
| Self { nodes: Vec::new() } | |
| } | |
| fn val(&mut self, data: f64) -> V { | |
| let id = self.nodes.len(); | |
| self.nodes.push(Value { | |
| data, | |
| grad: 0.0, | |
| children: [0, 0], | |
| local_grads: [0.0, 0.0], | |
| arity: 0, | |
| }); | |
| id | |
| } | |
| fn zero_grad(&mut self) { | |
| for node in &mut self.nodes { | |
| node.grad = 0.0; | |
| } | |
| } | |
| } | |
| fn add(t: &mut Tape, a: V, b: V) -> V { | |
| let data = t.nodes[a].data + t.nodes[b].data; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { data, grad: 0.0, children: [a, b], local_grads: [1.0, 1.0], arity: 2 }); | |
| id | |
| } | |
| fn mul(t: &mut Tape, a: V, b: V) -> V { | |
| let data = t.nodes[a].data * t.nodes[b].data; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { | |
| data, | |
| grad: 0.0, | |
| children: [a, b], | |
| local_grads: [t.nodes[b].data, t.nodes[a].data], | |
| arity: 2, | |
| }); | |
| id | |
| } | |
| fn neg(t: &mut Tape, a: V) -> V { | |
| let data = -t.nodes[a].data; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { data, grad: 0.0, children: [a, usize::MAX], local_grads: [-1.0, 0.0], arity: 1 }); | |
| id | |
| } | |
| fn sub(t: &mut Tape, a: V, b: V) -> V { | |
| let data = t.nodes[a].data - t.nodes[b].data; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { data, grad: 0.0, children: [a, b], local_grads: [1.0, -1.0], arity: 2 }); | |
| id | |
| } | |
| fn div(t: &mut Tape, a: V, b: V) -> V { | |
| let (a_data, b_data) = (t.nodes[a].data, t.nodes[b].data); | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { | |
| data: a_data / b_data, | |
| grad: 0.0, | |
| children: [a, b], | |
| local_grads: [1.0 / b_data, -a_data / (b_data * b_data)], | |
| arity: 2, | |
| }); | |
| id | |
| } | |
| fn mul_const(t: &mut Tape, a: V, c: f64) -> V { | |
| let data = t.nodes[a].data * c; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { data, grad: 0.0, children: [a, usize::MAX], local_grads: [c, 0.0], arity: 1 }); | |
| id | |
| } | |
| fn pow(t: &mut Tape, a: V, p: f64) -> V { | |
| let a_data = t.nodes[a].data; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { | |
| data: a_data.powf(p), | |
| grad: 0.0, | |
| children: [a, usize::MAX], | |
| local_grads: [p * a_data.powf(p - 1.0), 0.0], | |
| arity: 1, | |
| }); | |
| id | |
| } | |
| fn exp(t: &mut Tape, a: V) -> V { | |
| let e = t.nodes[a].data.exp(); | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { data: e, grad: 0.0, children: [a, usize::MAX], local_grads: [e, 0.0], arity: 1 }); | |
| id | |
| } | |
| fn log(t: &mut Tape, a: V) -> V { | |
| let a_data = t.nodes[a].data; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { | |
| data: a_data.ln(), | |
| grad: 0.0, | |
| children: [a, usize::MAX], | |
| local_grads: [1.0 / a_data, 0.0], | |
| arity: 1, | |
| }); | |
| id | |
| } | |
| fn relu(t: &mut Tape, a: V) -> V { | |
| let x = t.nodes[a].data; | |
| let id = t.nodes.len(); | |
| t.nodes.push(Value { | |
| data: x.max(0.0), | |
| grad: 0.0, | |
| children: [a, usize::MAX], | |
| local_grads: [if x > 0.0 { 1.0 } else { 0.0 }, 0.0], | |
| arity: 1, | |
| }); | |
| id | |
| } | |
| fn backward(t: &mut Tape, root: V) { | |
| t.nodes[root].grad = 1.0; | |
| // Reverse iteration is a valid topological sort because each node's | |
| // children always have smaller indices (appended before the parent). | |
| for i in (0..=root).rev() { | |
| let grad = t.nodes[i].grad; | |
| if grad == 0.0 { continue; } | |
| let arity = t.nodes[i].arity as usize; | |
| for j in 0..arity { | |
| let child = t.nodes[i].children[j]; | |
| let lg = t.nodes[i].local_grads[j]; | |
| t.nodes[child].grad += lg * grad; | |
| } | |
| } | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Random number generation (XorShift64 + Marsaglia Polar Method) | |
| // --------------------------------------------------------------------------- | |
| struct Rng { | |
| state: u64, | |
| next_gauss: Option<f64>, | |
| } | |
| impl Rng { | |
| fn new(seed: u64) -> Self { | |
| let s = if seed == 0 { 0xACE1_u64 } else { seed }; | |
| Self { state: s, next_gauss: None } | |
| } | |
| fn next_u32(&mut self) -> u32 { | |
| let mut x = self.state; | |
| x ^= x << 13; | |
| x ^= x >> 7; | |
| x ^= x << 17; | |
| self.state = x; | |
| (x >> 16) as u32 | |
| } | |
| /// Uniform in the open interval (0, 1) | |
| fn uniform(&mut self) -> f64 { | |
| (self.next_u32() as f64 + 0.5) / (u32::MAX as f64 + 1.0) | |
| } | |
| fn uniform_signed(&mut self) -> f64 { | |
| (self.next_u32() as f64 / u32::MAX as f64) * 2.0 - 1.0 | |
| } | |
| fn gauss(&mut self, mean: f64, std: f64) -> f64 { | |
| if let Some(v) = self.next_gauss.take() { | |
| return mean + std * v; | |
| } | |
| // Marsaglia Polar Method | |
| let (mut x, mut y, mut s); | |
| loop { | |
| x = self.uniform_signed(); | |
| y = self.uniform_signed(); | |
| s = x * x + y * y; | |
| if s < 1.0 && s > 0.0 { break; } | |
| } | |
| let multiplier = (-2.0 * s.ln() / s).sqrt(); | |
| self.next_gauss = Some(y * multiplier); | |
| mean + std * x * multiplier | |
| } | |
| fn shuffle<T>(&mut self, v: &mut Vec<T>) { | |
| for i in (1..v.len()).rev() { | |
| let j = (self.uniform() * (i as f64 + 1.0)) as usize; | |
| v.swap(i, j); | |
| } | |
| } | |
| fn choices(&mut self, weights: &[f64]) -> usize { | |
| let total: f64 = weights.iter().sum(); | |
| let mut r = self.uniform() * total; | |
| for (i, &w) in weights.iter().enumerate() { | |
| r -= w; | |
| if r <= 0.0 { return i; } | |
| } | |
| weights.len() - 1 | |
| } | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Model utilities: linear, softmax, rmsnorm | |
| // --------------------------------------------------------------------------- | |
| fn linear(t: &mut Tape, x: &[V], w: &Vec<Vec<V>>) -> Vec<V> { | |
| w.iter() | |
| .map(|row| { | |
| // Collect products first so only one closure borrows `t` at a time | |
| let products: Vec<V> = row.iter() | |
| .zip(x) | |
| .map(|(&wi, &xi)| mul(t, wi, xi)) | |
| .collect(); | |
| products.into_iter() | |
| .reduce(|acc, val| add(t, acc, val)) | |
| .unwrap_or_else(|| t.val(0.0)) | |
| }) | |
| .collect() | |
| } | |
| fn softmax(t: &mut Tape, logits: &[V]) -> Vec<V> { | |
| let max_data = logits.iter().map(|&v| t.nodes[v].data).fold(f64::NEG_INFINITY, f64::max); | |
| let max_val = t.val(max_data); | |
| let exps: Vec<V> = logits.iter().map(|&v| { let d = sub(t, v, max_val); exp(t, d) }).collect(); | |
| let sum_exp = exps.iter().copied().reduce(|acc, e| add(t, acc, e)).unwrap_or_else(|| t.val(1e-9)); | |
| exps.into_iter().map(|e| div(t, e, sum_exp)).collect() | |
| } | |
| fn rmsnorm(t: &mut Tape, x: &[V]) -> Vec<V> { | |
| let len_val = t.val(x.len() as f64); | |
| let eps = t.val(1e-5); | |
| let mut ms = t.val(0.0); | |
| for &xi in x { | |
| let sq = mul(t, xi, xi); | |
| ms = add(t, ms, sq); | |
| } | |
| ms = div(t, ms, len_val); | |
| let ms_eps = add(t, ms, eps); | |
| let scale = pow(t, ms_eps, -0.5); | |
| x.iter().map(|&xi| mul(t, xi, scale)).collect() | |
| } | |
| // --------------------------------------------------------------------------- | |
| // GPT forward pass β one token at a time with explicit KV cache | |
| // --------------------------------------------------------------------------- | |
| fn gpt( | |
| t: &mut Tape, | |
| token_id: usize, | |
| pos_id: usize, | |
| keys: &mut Vec<Vec<Vec<V>>>, | |
| values: &mut Vec<Vec<Vec<V>>>, | |
| wte: &Vec<Vec<V>>, | |
| wpe: &Vec<Vec<V>>, | |
| lm_head: &Vec<Vec<V>>, | |
| attn_wq: &Vec<Vec<V>>, | |
| attn_wk: &Vec<Vec<V>>, | |
| attn_wv: &Vec<Vec<V>>, | |
| attn_wo: &Vec<Vec<V>>, | |
| mlp_fc1: &Vec<Vec<V>>, | |
| mlp_fc2: &Vec<Vec<V>>, | |
| n_head: usize, | |
| head_dim: usize, | |
| ) -> Vec<V> { | |
| // Token + position embedding | |
| let x_emb: Vec<V> = wte[token_id].iter().zip(&wpe[pos_id]) | |
| .map(|(&tk, &ps)| add(t, tk, ps)) | |
| .collect(); | |
| let mut x = rmsnorm(t, &x_emb); | |
| // --- Multi-head self-attention --- | |
| let x_residual_attn = x.clone(); | |
| let x_norm = rmsnorm(t, &x); | |
| let q = linear(t, &x_norm, attn_wq); | |
| let k = linear(t, &x_norm, attn_wk); | |
| let v = linear(t, &x_norm, attn_wv); | |
| keys[0].push(k); | |
| values[0].push(v); | |
| let scale = 1.0 / (head_dim as f64).sqrt(); | |
| let mut x_attn: Vec<V> = vec![]; | |
| for h in 0..n_head { | |
| let hs = h * head_dim; | |
| let q_h = &q[hs..hs + head_dim]; | |
| let mut attn_logits: Vec<V> = vec![]; | |
| for ts in 0..keys[0].len() { | |
| let k_h = &keys[0][ts][hs..hs + head_dim]; | |
| let mut dot = t.val(0.0); | |
| for j in 0..head_dim { | |
| let prod = mul(t, q_h[j], k_h[j]); | |
| dot = add(t, dot, prod); | |
| } | |
| attn_logits.push(mul_const(t, dot, scale)); | |
| } | |
| let attn_weights = softmax(t, &attn_logits); | |
| for j in 0..head_dim { | |
| let mut sum = t.val(0.0); | |
| for ts in 0..values[0].len() { | |
| let term = mul(t, attn_weights[ts], values[0][ts][hs + j]); | |
| sum = add(t, sum, term); | |
| } | |
| x_attn.push(sum); | |
| } | |
| } | |
| x = linear(t, &x_attn, attn_wo); | |
| x = x.iter().zip(&x_residual_attn).map(|(&a, &b)| add(t, a, b)).collect(); | |
| // --- MLP block --- | |
| let x_residual_mlp = x.clone(); | |
| let x_norm_mlp = rmsnorm(t, &x); | |
| let mut x_mlp = linear(t, &x_norm_mlp, mlp_fc1); | |
| x_mlp = x_mlp.into_iter().map(|xi| relu(t, xi)).collect(); | |
| x_mlp = linear(t, &x_mlp, mlp_fc2); | |
| x = x_mlp.iter().zip(&x_residual_mlp).map(|(&a, &b)| add(t, a, b)).collect(); | |
| linear(t, &x, lm_head) | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Main: dataset β train β inference | |
| // --------------------------------------------------------------------------- | |
| fn main() { | |
| let input_path = "input.txt"; | |
| if !std::path::Path::new(input_path).exists() { | |
| let url = "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt"; | |
| eprintln!("input.txt not found. Downloading from {}...", url); | |
| let status = std::process::Command::new("curl") | |
| .args(["-fsSL", "-o", input_path, url]) | |
| .status() | |
| .expect("failed to launch curl β please install curl or download input.txt manually"); | |
| if !status.success() { | |
| eprintln!("curl failed (exit code {:?}). Download manually:", status.code()); | |
| eprintln!(" curl -o input.txt {}", url); | |
| std::process::exit(1); | |
| } | |
| eprintln!("Downloaded input.txt successfully."); | |
| } | |
| let mut rng = Rng::new(42); | |
| let contents = fs::read_to_string(input_path).expect("failed to read input.txt"); | |
| let mut docs: Vec<String> = contents | |
| .lines() | |
| .map(|l| l.trim()) | |
| .filter(|l| !l.is_empty()) | |
| .map(|l| l.to_string()) | |
| .collect(); | |
| rng.shuffle(&mut docs); | |
| println!("num docs: {}", docs.len()); | |
| let mut uchars: Vec<char> = docs.iter().flat_map(|d| d.chars()).collect(); | |
| uchars.sort_unstable(); | |
| uchars.dedup(); | |
| let bos = uchars.len(); | |
| let vocab_size = uchars.len() + 1; | |
| println!("vocab size: {}", vocab_size); | |
| // Hyperparameters | |
| let n_embd = 16; | |
| let n_layer = 1; | |
| let n_head = 4; | |
| let head_dim = n_embd / n_head; | |
| let block_size = 16; | |
| let mut tape = Tape::new(); | |
| let mut matrix = |t: &mut Tape, nout: usize, nin: usize| -> Vec<Vec<V>> { | |
| (0..nout).map(|_| (0..nin).map(|_| t.val(rng.gauss(0.0, 0.02))).collect()).collect() | |
| }; | |
| // Initialize weight matrices (these stay at low indices on the tape) | |
| let wte = matrix(&mut tape, vocab_size, n_embd); | |
| let wpe = matrix(&mut tape, block_size, n_embd); | |
| let lm_head = matrix(&mut tape, vocab_size, n_embd); | |
| let attn_wq = matrix(&mut tape, n_embd, n_embd); | |
| let attn_wk = matrix(&mut tape, n_embd, n_embd); | |
| let attn_wv = matrix(&mut tape, n_embd, n_embd); | |
| let attn_wo = matrix(&mut tape, n_embd, n_embd); | |
| let mlp_fc1 = matrix(&mut tape, 4 * n_embd, n_embd); | |
| let mlp_fc2 = matrix(&mut tape, n_embd, 4 * n_embd); | |
| let mut params: Vec<V> = vec![]; | |
| for mat in [&wte, &wpe, &lm_head, &attn_wq, &attn_wk, &attn_wv, &attn_wo, &mlp_fc1, &mlp_fc2] { | |
| for row in mat { for &p in row { params.push(p); } } | |
| } | |
| println!("num params: {}", params.len()); | |
| // Adam buffers | |
| let (lr, beta1, beta2, eps_adam) = (0.01_f64, 0.85_f64, 0.99_f64, 1e-8_f64); | |
| let mut m_buf = vec![0.0_f64; params.len()]; | |
| let mut v_buf = vec![0.0_f64; params.len()]; | |
| // Watermark: everything after this index is ephemeral computation | |
| let weights_end = tape.nodes.len(); | |
| // ------- Training loop ------- | |
| let num_steps = 1000; | |
| for step in 0..num_steps { | |
| let doc = &docs[step % docs.len()]; | |
| let mut tokens = vec![bos]; | |
| for ch in doc.chars() { | |
| tokens.push(uchars.iter().position(|&c| c == ch).unwrap()); | |
| } | |
| tokens.push(bos); | |
| let mut keys: Vec<Vec<Vec<V>>> = vec![vec![]; n_layer]; | |
| let mut values: Vec<Vec<Vec<V>>> = vec![vec![]; n_layer]; | |
| let mut losses = vec![]; | |
| for pos in 0..tokens.len() - 1 { | |
| let logits = gpt( | |
| &mut tape, tokens[pos], pos, | |
| &mut keys, &mut values, | |
| &wte, &wpe, &lm_head, | |
| &attn_wq, &attn_wk, &attn_wv, &attn_wo, | |
| &mlp_fc1, &mlp_fc2, | |
| n_head, head_dim, | |
| ); | |
| let probs = softmax(&mut tape, &logits); | |
| let log_val = log(&mut tape, probs[tokens[pos + 1]]); | |
| losses.push(neg(&mut tape, log_val)); | |
| } | |
| let mut sum_loss = tape.val(0.0); | |
| for l in losses { sum_loss = add(&mut tape, sum_loss, l); } | |
| let n_val = tape.val((tokens.len() - 1) as f64); | |
| let total_loss = div(&mut tape, sum_loss, n_val); | |
| backward(&mut tape, total_loss); | |
| // Adam update | |
| let step1 = (step + 1) as f64; | |
| for (i, &p_idx) in params.iter().enumerate() { | |
| let g = tape.nodes[p_idx].grad; | |
| m_buf[i] = beta1 * m_buf[i] + (1.0 - beta1) * g; | |
| v_buf[i] = beta2 * v_buf[i] + (1.0 - beta2) * g * g; | |
| let m_hat = m_buf[i] / (1.0 - beta1.powi(step1 as i32)); | |
| let v_hat = v_buf[i] / (1.0 - beta2.powi(step1 as i32)); | |
| tape.nodes[p_idx].data -= lr * m_hat / (v_hat.sqrt() + eps_adam); | |
| } | |
| let loss_f64 = tape.nodes[total_loss].data; | |
| print!("\rstep {:4} / {:4} | loss {:.4}", step + 1, num_steps, loss_f64); | |
| io::stdout().flush().unwrap(); | |
| // Discard computation graph, keep weights | |
| tape.nodes.truncate(weights_end); | |
| tape.zero_grad(); | |
| } | |
| println!(); | |
| // ------- Inference ------- | |
| let temperature = 0.5_f64; | |
| println!("--- inference (hallucinated names) ---"); | |
| for sample_idx in 0..20 { | |
| let mut keys: Vec<Vec<Vec<V>>> = vec![vec![]; n_layer]; | |
| let mut values: Vec<Vec<Vec<V>>> = vec![vec![]; n_layer]; | |
| let mut token_id = bos; | |
| let mut sample = String::new(); | |
| for pos_id in 0..block_size { | |
| let logits = gpt( | |
| &mut tape, token_id, pos_id, | |
| &mut keys, &mut values, | |
| &wte, &wpe, &lm_head, | |
| &attn_wq, &attn_wk, &attn_wv, &attn_wo, | |
| &mlp_fc1, &mlp_fc2, | |
| n_head, head_dim, | |
| ); | |
| let temp_val = tape.val(temperature); | |
| let scaled: Vec<V> = logits.iter().map(|&l| div(&mut tape, l, temp_val)).collect(); | |
| let probs = softmax(&mut tape, &scaled); | |
| let weights: Vec<f64> = probs.iter().map(|&p| tape.nodes[p].data).collect(); | |
| tape.nodes.truncate(weights_end); | |
| token_id = rng.choices(&weights); | |
| if token_id == bos { break; } | |
| sample.push(uchars[token_id]); | |
| } | |
| println!("sample {:2}: {}", sample_idx + 1, sample); | |
| } | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
One-line command to run locally:
curl -fsSL https://gist.githubusercontent.com/JGalego/84f8dee8c761afdfc249471f3041fe48/raw/tapegpt.rs -o tapegpt.rs && chmod +x tapegpt.rs && ./tapegpt.rs