Created
August 2, 2024 14:42
-
-
Save GVRV/24b2af70d3b14409d9a3192a35122cd6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "d25adea8-7d12-4ca1-9beb-a31820cc01d2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "with open('names.txt', 'r') as f:\n", | |
| " names = f.read().splitlines()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "4a7c0ab3-1efb-490f-9df4-32f60b49ae97", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "trigram = {}\n", | |
| "for name in names:\n", | |
| " chars = '.' + name + '.'\n", | |
| " for ch1, ch2, ch3 in zip(chars, chars[1:], chars[2:]):\n", | |
| " b = (ch1, ch2, ch3)\n", | |
| " trigram[b] = trigram.get(b, 0) + 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "49810a94-76a5-4d6e-a1b9-8c0aac217612", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(('a', 'h', '.'), 1714),\n", | |
| " (('n', 'a', '.'), 1673),\n", | |
| " (('a', 'n', '.'), 1509),\n", | |
| " (('o', 'n', '.'), 1503),\n", | |
| " (('.', 'm', 'a'), 1453),\n", | |
| " (('.', 'j', 'a'), 1255),\n", | |
| " (('.', 'k', 'a'), 1254),\n", | |
| " (('e', 'n', '.'), 1217),\n", | |
| " (('l', 'y', 'n'), 976),\n", | |
| " (('y', 'n', '.'), 953)]" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sorted(trigram.items(), key=lambda kv: -kv[1])[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "02e3fd83-670e-47b0-9d67-68831881136a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "09a079ed-aa4b-4077-85d9-15c4ee582884", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokens = ['.'] + sorted(list(set(''.join(names))))\n", | |
| "stoi = {c: i for i, c in enumerate(tokens)}\n", | |
| "itos = {i: c for c, i in stoi.items()}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "id": "af09e2f8-f85d-4266-af28-ffccf192094d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "trigram_tensor = torch.zeros((27 * 27, 27), dtype=torch.int32)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "id": "f89fff87-a89b-4e91-b8d1-8c5ef9083a43", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "for name in names:\n", | |
| " chars = '.' + name + '.'\n", | |
| " for ch1, ch2, ch3 in zip(chars, chars[1:], chars[2:]):\n", | |
| " trigram_tensor[(27 * stoi[ch1]) + stoi[ch2], stoi[ch3]] += 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "09a4b656-bca4-4974-a517-2fd01e0d9a85", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[203, 337, 0, 0, 0, 331, 0, 0, 0, 271, 0, 0, 0, 1,\n", | |
| " 0, 58, 0, 0, 0, 8, 0, 10, 0, 1, 0, 124, 1],\n", | |
| " [ 6, 32, 0, 0, 0, 7, 0, 0, 0, 10, 0, 0, 0, 0,\n", | |
| " 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0]],\n", | |
| " dtype=torch.int32)" | |
| ] | |
| }, | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "trigram_tensor[336:338,:]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "12f3a432-6470-411d-8ffc-3691e29749f8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 94, | |
| "id": "d882349b-272b-470e-b686-4457e4157257", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "prob_trigram_tensor = (trigram_tensor + 1).float()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 95, | |
| "id": "389c2cba-2596-4768-9b26-4378a6d020c4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[204., 338., 1., 1., 1., 332., 1., 1., 1., 272., 1., 1.,\n", | |
| " 1., 2., 1., 59., 1., 1., 1., 9., 1., 11., 1., 2.,\n", | |
| " 1., 125., 2.],\n", | |
| " [ 7., 33., 1., 1., 1., 8., 1., 1., 1., 11., 1., 1.,\n", | |
| " 1., 1., 1., 3., 1., 1., 1., 2., 1., 1., 1., 1.,\n", | |
| " 1., 3., 1.]])" | |
| ] | |
| }, | |
| "execution_count": 95, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "prob_trigram_tensor[336:338,:]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 96, | |
| "id": "0dc87947-69f7-4b62-8bf5-413fe3a9aca1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "prob_trigram_tensor /= prob_trigram_tensor.sum(1, keepdims=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 97, | |
| "id": "d127fd7f-1211-43ae-a7f4-715b32b9f99a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[0.1487, 0.2464, 0.0007, 0.0007, 0.0007, 0.2420, 0.0007, 0.0007, 0.0007,\n", | |
| " 0.1983, 0.0007, 0.0007, 0.0007, 0.0015, 0.0007, 0.0430, 0.0007, 0.0007,\n", | |
| " 0.0007, 0.0066, 0.0007, 0.0080, 0.0007, 0.0015, 0.0007, 0.0911, 0.0015],\n", | |
| " [0.0805, 0.3793, 0.0115, 0.0115, 0.0115, 0.0920, 0.0115, 0.0115, 0.0115,\n", | |
| " 0.1264, 0.0115, 0.0115, 0.0115, 0.0115, 0.0115, 0.0345, 0.0115, 0.0115,\n", | |
| " 0.0115, 0.0230, 0.0115, 0.0115, 0.0115, 0.0115, 0.0115, 0.0345, 0.0115]])" | |
| ] | |
| }, | |
| "execution_count": 97, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "prob_trigram_tensor[336:338, :]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 98, | |
| "id": "e076a9b5-407c-4e41-a1dd-70266e9208fc", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "random_trigram_tensor = torch.ones(27 * 27, 27).float()\n", | |
| "random_trigram_tensor /= random_trigram_tensor.sum(1, keepdims=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 103, | |
| "id": "7615ee06-5203-4bc9-8cd8-a541fcc70193", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "ce.\n", | |
| "za.\n", | |
| "zogh.\n", | |
| "uriana.\n", | |
| "kaydnevonimittain.\n", | |
| "luwak.\n", | |
| "ka.\n", | |
| "da.\n", | |
| "samiyah.\n", | |
| "javer.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "g = torch.Generator().manual_seed(2147483647)\n", | |
| "for i in range(10):\n", | |
| " out = []\n", | |
| " ix = 0\n", | |
| " while True:\n", | |
| " ich1 = int(ix / 27)\n", | |
| " ich2 = int(ix % 27)\n", | |
| " ch1 = itos[ich1]\n", | |
| " ch2 = itos[ich2]\n", | |
| " # print(f\"Probability distribution for '{ch1}{ch2}': {prob_trigram_tensor[ix, :]}\")\n", | |
| " p = prob_trigram_tensor[ix]\n", | |
| " ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n", | |
| " out.append(itos[ix])\n", | |
| " if ix % 27 == 0:\n", | |
| " break\n", | |
| " ix = (ich2 * 27) + ix\n", | |
| " print(''.join(out))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 104, | |
| "id": "96af7103-7742-4160-9a28-320e9cc9e794", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "log_likelihood=tensor(-410414.9688)\n", | |
| "nll=tensor(410414.9688)\n", | |
| "2.092747449874878\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "log_likelihood = 0.0\n", | |
| "n = 0\n", | |
| "\n", | |
| "for w in names:\n", | |
| "# for w in [\"andrejq\"]:\n", | |
| " chs = ['.'] + list(w) + ['.']\n", | |
| " for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):\n", | |
| " ix1 = stoi[ch1]\n", | |
| " ix2 = stoi[ch2]\n", | |
| " ix3 = stoi[ch3]\n", | |
| " prob = prob_trigram_tensor[(ix1 * 27) + ix2, ix3]\n", | |
| " logprob = torch.log(prob)\n", | |
| " log_likelihood += logprob\n", | |
| " n += 1\n", | |
| " # print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')\n", | |
| "\n", | |
| "print(f'{log_likelihood=}')\n", | |
| "nll = -log_likelihood\n", | |
| "print(f'{nll=}')\n", | |
| "print(f'{nll/n}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 105, | |
| "id": "c2cda117-8d12-41a1-b179-e876db1ccebb", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "number of examples: 196113\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# create the dataset\n", | |
| "xs, ys = [], []\n", | |
| "for w in names:\n", | |
| " chs = ['.'] + list(w) + ['.']\n", | |
| " for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):\n", | |
| " ix1 = stoi[ch1]\n", | |
| " ix2 = stoi[ch2]\n", | |
| " ix3 = stoi[ch3]\n", | |
| " xs.append((ix1 * 27) + ix2)\n", | |
| " ys.append(ix3)\n", | |
| "xs = torch.tensor(xs)\n", | |
| "ys = torch.tensor(ys)\n", | |
| "num = xs.nelement()\n", | |
| "print('number of examples: ', num)\n", | |
| "\n", | |
| "# initialize the 'network'\n", | |
| "g = torch.Generator().manual_seed(2147483647)\n", | |
| "W = torch.randn((27 * 27, 27), generator=g, requires_grad=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 111, | |
| "id": "3aa90b13-011f-42b9-9373-769601997158", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch.nn.functional as F\n", | |
| "# gradient descent\n", | |
| "for k in range(1000):\n", | |
| " \n", | |
| " # forward pass\n", | |
| " xenc = F.one_hot(xs, num_classes=27*27).float() # input to the network: one-hot encoding\n", | |
| " logits = xenc @ W # predict log-counts\n", | |
| " counts = logits.exp() # counts, equivalent to N\n", | |
| " probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", | |
| " loss = -probs[torch.arange(num), ys].log().mean() # + 0.01*(W**2).mean()\n", | |
| " # print(loss.item())\n", | |
| " \n", | |
| " # backward pass\n", | |
| " W.grad = None # set to zero the gradient\n", | |
| " loss.backward()\n", | |
| " \n", | |
| " # update\n", | |
| " W.data += -50 * W.grad" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 112, | |
| "id": "8f470d40-9d50-4d81-af67-b2532c12fd82", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.0905237197875977\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(loss.item())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 113, | |
| "id": "39f675a0-bcd7-4de8-8afa-58e15b787892", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "zexza.\n", | |
| "zoganuriana.\n", | |
| "otah.\n", | |
| "oll.\n", | |
| "imittain.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# finally, sample from the 'neural net' model\n", | |
| "g = torch.Generator().manual_seed(2147483647)\n", | |
| "\n", | |
| "for i in range(5):\n", | |
| " \n", | |
| " out = []\n", | |
| " ix = 0\n", | |
| " while True:\n", | |
| " \n", | |
| " # ----------\n", | |
| " # BEFORE:\n", | |
| " #p = P[ix]\n", | |
| " # ----------\n", | |
| " # NOW:\n", | |
| " xenc = F.one_hot(torch.tensor([ix]), num_classes=27 * 27).float()\n", | |
| " logits = xenc @ W # predict log-counts\n", | |
| " counts = logits.exp() # counts, equivalent to N\n", | |
| " p = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", | |
| " # ----------\n", | |
| " ich1 = int(ix / 27)\n", | |
| " ich2 = ix % 27\n", | |
| " ch1 = itos[ich1]\n", | |
| " ch2 = itos[ich2]\n", | |
| " # print(f\"Probability distribution for '{ch1}{ch2}': {p}\")\n", | |
| " ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n", | |
| " out.append(itos[ix])\n", | |
| " if ix % 27 == 0:\n", | |
| " break\n", | |
| " ix = (ich2 * 27) + ix\n", | |
| " print(''.join(out))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "790a81f7-2c9b-4638-bde8-e44729d3bec0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment