-
-
Save JGalego/d6134d8e7728294a095ad7f3aebdb37c to your computer and use it in GitHub Desktop.
microgpt in 100 lines of less-than-ideal Python code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Atomic GPT training & inference in pure Python. @karpathy, modified by @JGalego""" | |
| import os, math, random; random.seed(42) | |
| if not os.path.exists('input.txt'): | |
| import urllib.request | |
| urllib.request.urlretrieve('https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt', 'input.txt') | |
| with open('input.txt', encoding='utf-8') as f: docs = [l.strip() for l in f if l.strip()] | |
| random.shuffle(docs) | |
| uchars = sorted(set(''.join(docs))) | |
| BOS, vocab_size = len(uchars), len(uchars) + 1 | |
| print(f"num docs: {len(docs)}, vocab size: {vocab_size}") | |
| class Value: | |
| __slots__ = ('data', 'grad', 'children', 'local_grads') | |
| def __init__(self, data, children=(), local_grads=()): | |
| self.data, self.grad, self.children, self.local_grads = data, 0, children, local_grads | |
| def __add__(self, o): | |
| o = o if isinstance(o, Value) else Value(o) | |
| return Value(self.data + o.data, (self, o), (1, 1)) | |
| def __mul__(self, o): | |
| o = o if isinstance(o, Value) else Value(o) | |
| return Value(self.data * o.data, (self, o), (o.data, self.data)) | |
| def __pow__(self, n): return Value(self.data**n, (self,), (n * self.data**(n-1),)) | |
| def log(self): return Value(math.log(self.data), (self,), (1/self.data,)) | |
| def exp(self): return Value(math.exp(self.data), (self,), (math.exp(self.data),)) | |
| def relu(self): return Value(max(0, self.data), (self,), (float(self.data > 0),)) | |
| __neg__ = lambda s: s * -1 | |
| __radd__ = lambda s, o: s + o | |
| __sub__ = lambda s, o: s + (-o) | |
| __rsub__ = lambda s, o: o + (-s) | |
| __rmul__ = lambda s, o: s * o | |
| __truediv__ = lambda s, o: s * o**-1 | |
| __rtruediv__ = lambda s, o: o * s**-1 | |
| def backward(self): | |
| topo, visited = [], set() | |
| def build(v): | |
| if v not in visited: visited.add(v); [build(c) for c in v.children]; topo.append(v) | |
| build(self); self.grad = 1 | |
| for v in reversed(topo): | |
| for c, lg in zip(v.children, v.local_grads): c.grad += lg * v.grad | |
| n_layer, n_embd, block_size, n_head = 1, 16, 16, 4 | |
| head_dim = n_embd // n_head | |
| mat = lambda r, c, s=0.08: [[Value(random.gauss(0, s)) for _ in range(c)] for _ in range(r)] | |
| sd = {'wte': mat(vocab_size, n_embd), 'wpe': mat(block_size, n_embd), 'lm_head': mat(vocab_size, n_embd)} | |
| for i in range(n_layer): | |
| for key in ('attn_wq', 'attn_wk', 'attn_wv', 'attn_wo'): sd[f'layer{i}.{key}'] = mat(n_embd, n_embd) | |
| sd[f'layer{i}.mlp_fc1'], sd[f'layer{i}.mlp_fc2'] = mat(4*n_embd, n_embd), mat(n_embd, 4*n_embd) | |
| params = [p for m in sd.values() for r in m for p in r] | |
| print(f"num params: {len(params)}") | |
| linear = lambda x, w: [sum(wi*xi for wi, xi in zip(wo, x)) for wo in w] | |
| def softmax(logits): | |
| mx = max(v.data for v in logits); e = [(v - mx).exp() for v in logits]; s = sum(e) | |
| return [ei / s for ei in e] | |
| def rmsnorm(x): | |
| s = (sum(xi*xi for xi in x)/len(x) + 1e-5)**-0.5; return [xi*s for xi in x] | |
| def gpt(token_id, pos_id, keys, values): | |
| x = rmsnorm([t + p for t, p in zip(sd['wte'][token_id], sd['wpe'][pos_id])]) | |
| for li in range(n_layer): | |
| xr = x; x = rmsnorm(x) | |
| q, k, v = linear(x, sd[f'layer{li}.attn_wq']), linear(x, sd[f'layer{li}.attn_wk']), linear(x, sd[f'layer{li}.attn_wv']) | |
| keys[li].append(k); values[li].append(v) | |
| xa = [] | |
| for h in range(n_head): | |
| hs = h * head_dim | |
| qh, kh, vh = q[hs:hs+head_dim], [ki[hs:hs+head_dim] for ki in keys[li]], [vi[hs:hs+head_dim] for vi in values[li]] | |
| aw = softmax([sum(qh[j]*kh[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(kh))]) | |
| xa.extend([sum(aw[t]*vh[t][j] for t in range(len(vh))) for j in range(head_dim)]) | |
| x = [a + b for a, b in zip(linear(xa, sd[f'layer{li}.attn_wo']), xr)] | |
| xr = x; x = [xi.relu() for xi in linear(rmsnorm(x), sd[f'layer{li}.mlp_fc1'])] | |
| x = [a + b for a, b in zip(linear(x, sd[f'layer{li}.mlp_fc2']), xr)] | |
| return linear(x, sd['lm_head']) | |
| lr, b1, b2, eps, num_steps = 0.01, 0.85, 0.99, 1e-8, 1000 | |
| m_buf, v_buf = [0.0]*len(params), [0.0]*len(params) | |
| for step in range(num_steps): | |
| tokens = [BOS] + [uchars.index(ch) for ch in docs[step % len(docs)]] + [BOS] | |
| nseq = min(block_size, len(tokens) - 1) | |
| ks, vs = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)] | |
| losses = [-softmax(gpt(tokens[p], p, ks, vs))[tokens[p+1]].log() for p in range(nseq)] | |
| loss = (1/nseq) * sum(losses); loss.backward() | |
| lr_t = lr * (1 - step/num_steps) | |
| for i, p in enumerate(params): | |
| m_buf[i] = b1*m_buf[i] + (1-b1)*p.grad | |
| v_buf[i] = b2*v_buf[i] + (1-b2)*p.grad**2 | |
| p.data -= lr_t * (m_buf[i]/(1-b1**(step+1))) / ((v_buf[i]/(1-b2**(step+1)))**0.5 + eps) | |
| p.grad = 0 | |
| print(f"step {step+1:4d}/{num_steps} | loss {loss.data:.4f}", end='\r') | |
| print("\n--- inference (new, hallucinated names) ---") | |
| for si in range(20): | |
| ks, vs, tid, sample = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)], BOS, [] | |
| for pos in range(block_size): | |
| probs = softmax([l / 0.5 for l in gpt(tid, pos, ks, vs)]) | |
| tid = random.choices(range(vocab_size), weights=[p.data for p in probs])[0] | |
| if tid == BOS: break | |
| sample.append(uchars[tid]) | |
| print(f"sample {si+1:2d}: {''.join(sample)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment