Last active
February 17, 2022 17:39
-
-
Save pmineiro/171edfa6963b7d14e6f3d10dc38af9a4 to your computer and use it in GitHub Desktop.
IGL with action dependent feedback, mnist demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "287318ac", | |
| "metadata": {}, | |
| "source": [ | |
| "# Supervised" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 70, | |
| "id": "7fb42e91", | |
| "metadata": { | |
| "code_folding": [ | |
| 0, | |
| 6, | |
| 35, | |
| 36, | |
| 58, | |
| 63, | |
| 70, | |
| 84, | |
| 114, | |
| 121 | |
| ], | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "n \tmean \tsince \tacc \tsince \n", | |
| "1 \t2.30257 \t2.30257 \t0.15625 \t0.15625 \n", | |
| "2 \t2.40082 \t2.49907 \t0.21875 \t0.28125 \n", | |
| "3 \t2.21948 \t1.85680 \t0.29167 \t0.43750 \n", | |
| "5 \t1.95383 \t1.55536 \t0.40000 \t0.56250 \n", | |
| "9 \t1.56551 \t1.08011 \t0.50694 \t0.64062 \n", | |
| "17 \t1.18324 \t0.75318 \t0.63419 \t0.77734 \n", | |
| "33 \t0.87001 \t0.53720 \t0.73059 \t0.83301 \n", | |
| "65 \t0.65744 \t0.43823 \t0.79615 \t0.86377 \n", | |
| "129 \t0.49248 \t0.32494 \t0.84726 \t0.89917 \n", | |
| "257 \t0.39149 \t0.28972 \t0.87840 \t0.90979 \n", | |
| "513 \t0.31768 \t0.24357 \t0.90244 \t0.92657 \n", | |
| "938 \t0.27401 \t0.22131 \t0.91613 \t0.93265 \n", | |
| "testacc 0.9558735489845276 testloss 0.14295095205307007\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def supervisedLearn():\n", | |
| " import itertools\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " import torchvision\n", | |
| " \n", | |
| " class EasyAcc:\n", | |
| " def __init__(self):\n", | |
| " self.n = 0\n", | |
| " self.sum = 0\n", | |
| " self.sumsq = 0\n", | |
| "\n", | |
| " def __iadd__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum += other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def __isub__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum -= other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def mean(self):\n", | |
| " return self.sum / max(self.n, 1)\n", | |
| "\n", | |
| " def var(self):\n", | |
| " from math import sqrt\n", | |
| " return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
| "\n", | |
| " def semean(self):\n", | |
| " from math import sqrt\n", | |
| " return self.var() / sqrt(max(self.n, 1))\n", | |
| "\n", | |
| " class RFFSoftmax(torch.nn.Module):\n", | |
| " def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
| " from math import pi\n", | |
| " import numpy as np\n", | |
| "\n", | |
| " super(RFFSoftmax, self).__init__()\n", | |
| "\n", | |
| " torch.manual_seed(seed)\n", | |
| " nobs = hilo.shape[1]\n", | |
| " high = hilo[1, :]\n", | |
| " low = hilo[0, :]\n", | |
| " \n", | |
| " self.rff = torch.nn.Linear(nobs, numrff)\n", | |
| " self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
| " torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
| " self.rff.weight.requires_grad = False\n", | |
| " self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
| " self.rff.bias.requires_grad = False\n", | |
| " self.sqrtrff = np.sqrt(numrff)\n", | |
| " self.final = torch.nn.Linear(numrff, naction)\n", | |
| " self.final.weight.data *= 0.01\n", | |
| " self.final.bias.data *= 0.01\n", | |
| "\n", | |
| " def logits(self, x):\n", | |
| " with torch.no_grad():\n", | |
| " rff = self.rff(x).cos() / self.sqrtrff\n", | |
| " return self.final(rff)\n", | |
| "\n", | |
| " transform = torchvision.transforms.Compose([\n", | |
| " torchvision.transforms.ToTensor(),\n", | |
| " torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
| " ])\n", | |
| " mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
| " \n", | |
| " quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(quantile_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
| " pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
| " break\n", | |
| " \n", | |
| " train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
| " mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
| " test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
| " \n", | |
| " opt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=0.1)\n", | |
| " loss = torch.nn.CrossEntropyLoss()\n", | |
| " acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n", | |
| " \n", | |
| " print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
| " 'n', 'mean', 'since',\n", | |
| " 'acc', 'since',\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " \n", | |
| " for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " \n", | |
| " opt.zero_grad()\n", | |
| " ld = pi.logits(flat)\n", | |
| " output = loss(ld, labels)\n", | |
| " output.backward()\n", | |
| " opt.step()\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " pred = ld.argmax(dim=1)\n", | |
| " acc += torch.mean((labels == pred).float())\n", | |
| " accsincelast += torch.mean((labels == pred).float())\n", | |
| " avloss += output\n", | |
| " avlosssincelast += output\n", | |
| "\n", | |
| " if (bno & (bno - 1) == 0):\n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast = EasyAcc(), EasyAcc()\n", | |
| " \n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast = EasyAcc(), EasyAcc()\n", | |
| " testacc, testloss = EasyAcc(), EasyAcc()\n", | |
| " with torch.no_grad():\n", | |
| " for ti, tl in train_loader:\n", | |
| " flat = ti.reshape(ti.shape[0], -1)\n", | |
| " ld = pi.logits(flat)\n", | |
| " output = loss(ld, tl)\n", | |
| " testloss += output\n", | |
| " testpred = ld.argmax(dim=1)\n", | |
| " testacc += torch.mean((tl == testpred).float())\n", | |
| "\n", | |
| " print(f'testacc {testacc.mean()} testloss {testloss.mean()}')\n", | |
| " \n", | |
| "supervisedLearn()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "55cda05e", | |
| "metadata": {}, | |
| "source": [ | |
| "# CB" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 80, | |
| "id": "de5fa0ff", | |
| "metadata": { | |
| "code_folding": [ | |
| 0, | |
| 6, | |
| 7, | |
| 10, | |
| 20, | |
| 49, | |
| 50, | |
| 73, | |
| 78 | |
| ] | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "n \tloss \tsince \tacc \tsince \treward \tsince \n", | |
| "1 \t0.69313 \t0.69313 \t0.15625 \t0.15625 \t0.06250 \t0.06250 \n", | |
| "2 \t0.55508 \t0.41703 \t0.10156 \t0.04688 \t0.05469 \t0.04688 \n", | |
| "3 \t0.46854 \t0.29545 \t0.09896 \t0.09375 \t0.05729 \t0.06250 \n", | |
| "5 \t0.46672 \t0.46399 \t0.11563 \t0.14062 \t0.10000 \t0.16406 \n", | |
| "9 \t0.42025 \t0.36216 \t0.18229 \t0.26562 \t0.16493 \t0.24609 \n", | |
| "17 \t0.37235 \t0.31846 \t0.22151 \t0.26562 \t0.20956 \t0.25977 \n", | |
| "33 \t0.34175 \t0.30924 \t0.27415 \t0.33008 \t0.26089 \t0.31543 \n", | |
| "65 \t0.34546 \t0.34929 \t0.38726 \t0.50391 \t0.36250 \t0.46729 \n", | |
| "129 \t0.31504 \t0.28415 \t0.61216 \t0.84058 \t0.56468 \t0.77002 \n", | |
| "257 \t0.26828 \t0.22116 \t0.75298 \t0.89490 \t0.69127 \t0.81885 \n", | |
| "513 \t0.22252 \t0.17659 \t0.83458 \t0.91650 \t0.76459 \t0.83820 \n", | |
| "938 \t0.19493 \t0.16162 \t0.87785 \t0.93007 \t0.80360 \t0.85070 \n", | |
| "testacc 0.9445962309837341\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def cbLearn():\n", | |
| " import itertools\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " import torchvision\n", | |
| " \n", | |
| " class FastCB:\n", | |
| " def __init__(self, gamma):\n", | |
| " self.gamma = gamma\n", | |
| "\n", | |
| " def sample(self, fhat):\n", | |
| " N, K = fhat.shape\n", | |
| " rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
| " fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
| " fhatrando = torch.gather(input=fhat, dim=1, index=rando)\n", | |
| " probs = K / (K + self.gamma * (1 - fhatrando / fhatstar))\n", | |
| " unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
| " shouldexplore = (unif <= probs).long()\n", | |
| " return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1)\n", | |
| "\n", | |
| " class EasyAcc:\n", | |
| " def __init__(self):\n", | |
| " self.n = 0\n", | |
| " self.sum = 0\n", | |
| " self.sumsq = 0\n", | |
| "\n", | |
| " def __iadd__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum += other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def __isub__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum -= other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def mean(self):\n", | |
| " return self.sum / max(self.n, 1)\n", | |
| "\n", | |
| " def var(self):\n", | |
| " from math import sqrt\n", | |
| " return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
| "\n", | |
| " def semean(self):\n", | |
| " from math import sqrt\n", | |
| " return self.var() / sqrt(max(self.n, 1))\n", | |
| "\n", | |
| " class RFFSoftmax(torch.nn.Module):\n", | |
| " def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
| " from math import pi\n", | |
| " import numpy as np\n", | |
| "\n", | |
| " super(RFFSoftmax, self).__init__()\n", | |
| "\n", | |
| " torch.manual_seed(seed)\n", | |
| " nobs = hilo.shape[1]\n", | |
| " high = hilo[1, :]\n", | |
| " low = hilo[0, :]\n", | |
| " \n", | |
| " self.rff = torch.nn.Linear(nobs, numrff)\n", | |
| " self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
| " torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
| " self.rff.weight.requires_grad = False\n", | |
| " self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
| " self.rff.bias.requires_grad = False\n", | |
| " self.sqrtrff = np.sqrt(numrff)\n", | |
| " self.final = torch.nn.Linear(numrff, naction)\n", | |
| " self.final.weight.data *= 0.01\n", | |
| " self.final.bias.data *= 0.01\n", | |
| " self.sigmoid = torch.nn.Sigmoid()\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " with torch.no_grad():\n", | |
| " rff = self.rff(x).cos() / self.sqrtrff\n", | |
| " return self.final(rff)\n", | |
| " \n", | |
| " def density(self, logits):\n", | |
| " return self.sigmoid(logits)\n", | |
| "\n", | |
| " transform = torchvision.transforms.Compose([\n", | |
| " torchvision.transforms.ToTensor(),\n", | |
| " torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
| " ])\n", | |
| " mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
| " \n", | |
| " quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(quantile_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
| " pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
| " break\n", | |
| " \n", | |
| " train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
| " mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
| " test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
| " sampler = FastCB(gamma=100)\n", | |
| " \n", | |
| " opt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-1)\n", | |
| " log_loss = torch.nn.BCEWithLogitsLoss()\n", | |
| " acc, accsincelast, avloss, avlosssincelast, avreward, avrewardsincelast = [ EasyAcc() for _ in range(6) ]\n", | |
| " \n", | |
| " print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
| " 'n', 'loss', 'since', \n", | |
| " 'acc', 'since',\n", | |
| " 'reward', 'since',\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " \n", | |
| " for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " \n", | |
| " opt.zero_grad()\n", | |
| " logit = pi(flat)\n", | |
| " with torch.no_grad():\n", | |
| " fhat = pi.density(logit)\n", | |
| " sample = sampler.sample(fhat)\n", | |
| " reward = (sample == labels).unsqueeze(1).float()\n", | |
| " \n", | |
| " samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)\n", | |
| " loss = log_loss(samplelogit, reward)\n", | |
| " loss.backward()\n", | |
| " opt.step()\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " pred = logit.argmax(dim=1)\n", | |
| " acc += torch.mean((labels == pred).float())\n", | |
| " accsincelast += torch.mean((labels == pred).float())\n", | |
| " avloss += loss\n", | |
| " avlosssincelast += loss\n", | |
| " avreward += torch.mean(reward)\n", | |
| " avrewardsincelast += torch.mean(reward)\n", | |
| " \n", | |
| " if (bno & (bno - 1) == 0):\n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " avreward.mean(), avrewardsincelast.mean(),\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
| " \n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " avreward.mean(), avrewardsincelast.mean(),\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
| " testacc = EasyAcc()\n", | |
| " with torch.no_grad():\n", | |
| " for ti, tl in train_loader:\n", | |
| " flat = ti.reshape(ti.shape[0], -1)\n", | |
| " logit = pi(flat)\n", | |
| " testpred = logit.argmax(dim=1)\n", | |
| " testacc += torch.mean((tl == testpred).float())\n", | |
| "\n", | |
| " print(f'testacc {testacc.mean()}')\n", | |
| "\n", | |
| "cbLearn()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "09186826", | |
| "metadata": {}, | |
| "source": [ | |
| "# IGL ($y_a \\perp x, a|r_a$)\n", | |
| "$y_a$ is a (randomly selected) \"zero\" image or a (randomly selected) \"one\" image depending only upon $r_a$." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 113, | |
| "id": "67459e03", | |
| "metadata": { | |
| "code_folding": [ | |
| 0, | |
| 6, | |
| 22, | |
| 51, | |
| 83, | |
| 90, | |
| 100, | |
| 113, | |
| 139, | |
| 188, | |
| 197 | |
| ] | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "n \tloss \tsince \tacc \tsince \treward \tsince \tfake \tsince \n", | |
| "1 \t1.38627 \t1.38627 \t0.32812 \t0.32812 \t0.09375 \t0.09375 \t0.49997 \t0.49997 \n", | |
| "2 \t1.37295 \t1.35963 \t0.18750 \t0.04688 \t0.07812 \t0.06250 \t0.48236 \t0.46476 \n", | |
| "3 \t1.35572 \t1.32125 \t0.17188 \t0.14062 \t0.08333 \t0.09375 \t0.46612 \t0.43364 \n", | |
| "5 \t1.33542 \t1.30497 \t0.14687 \t0.10938 \t0.09375 \t0.10938 \t0.43557 \t0.38975 \n", | |
| "9 \t1.28511 \t1.22223 \t0.14583 \t0.14453 \t0.11111 \t0.13281 \t0.38302 \t0.31733 \n", | |
| "17 \t1.19374 \t1.09095 \t0.15533 \t0.16602 \t0.11949 \t0.12891 \t0.31242 \t0.23300 \n", | |
| "33 \t1.19390 \t1.19407 \t0.16714 \t0.17969 \t0.12689 \t0.13477 \t0.26854 \t0.22192 \n", | |
| "65 \t1.25434 \t1.31667 \t0.25024 \t0.33594 \t0.18438 \t0.24365 \t0.43038 \t0.59728 \n", | |
| "129 \t1.12546 \t0.99457 \t0.45094 \t0.65479 \t0.33285 \t0.48364 \t0.54036 \t0.65205 \n", | |
| "257 \t0.88065 \t0.63392 \t0.65394 \t0.85852 \t0.51283 \t0.69421 \t0.65974 \t0.78006 \n", | |
| "513 \t0.68728 \t0.49315 \t0.77656 \t0.89966 \t0.63411 \t0.75586 \t0.74226 \t0.82510 \n", | |
| "938 \t0.56352 \t0.41413 \t0.84097 \t0.91871 \t0.70281 \t0.78574 \t0.78930 \t0.84607 \n", | |
| "testacc 0.928521454334259\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def iglLearn():\n", | |
| " import itertools\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " import torchvision\n", | |
| " \n", | |
| " class SquareCB(object):\n", | |
| " def __init__(self, gamma):\n", | |
| " super(SquareCB, self).__init__()\n", | |
| "\n", | |
| " self.gamma = gamma\n", | |
| "\n", | |
| " def sample(self, fhat):\n", | |
| " N, K = fhat.shape\n", | |
| " rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
| " fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
| " fhatrando = torch.gather(input=fhat, dim=1, index=rando)\n", | |
| " probs = K / (K + self.gamma * (fhatstar - fhatrando))\n", | |
| " unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
| " shouldexplore = (unif <= probs).long()\n", | |
| " return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1)\n", | |
| " \n", | |
| " class EasyAcc:\n", | |
| " def __init__(self):\n", | |
| " self.n = 0\n", | |
| " self.sum = 0\n", | |
| " self.sumsq = 0\n", | |
| "\n", | |
| " def __iadd__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum += other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def __isub__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum -= other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def mean(self):\n", | |
| " return self.sum / max(self.n, 1)\n", | |
| "\n", | |
| " def var(self):\n", | |
| " from math import sqrt\n", | |
| " return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
| "\n", | |
| " def semean(self):\n", | |
| " from math import sqrt\n", | |
| " return self.var() / sqrt(max(self.n, 1))\n", | |
| "\n", | |
| " class RFFSoftmax(torch.nn.Module):\n", | |
| " def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
| " from math import pi\n", | |
| " import numpy as np\n", | |
| "\n", | |
| " super(RFFSoftmax, self).__init__()\n", | |
| "\n", | |
| " torch.manual_seed(seed)\n", | |
| " nobs = hilo.shape[1]\n", | |
| " high = hilo[1, :]\n", | |
| " low = hilo[0, :]\n", | |
| " \n", | |
| " self.rff = torch.nn.Linear(nobs, numrff)\n", | |
| " self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
| " torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
| " self.rff.weight.requires_grad = False\n", | |
| " self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
| " self.rff.bias.requires_grad = False\n", | |
| " self.sqrtrff = np.sqrt(numrff)\n", | |
| " self.final = torch.nn.Linear(numrff, naction)\n", | |
| " self.final.weight.data *= 0.01\n", | |
| " self.final.bias.data *= 0.01\n", | |
| " self.sigmoid = torch.nn.Sigmoid()\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " with torch.no_grad():\n", | |
| " rff = self.rff(x).cos() / self.sqrtrff\n", | |
| " return self.final(rff)\n", | |
| " \n", | |
| " def density(self, logits):\n", | |
| " return self.sigmoid(logits)\n", | |
| "\n", | |
| " transform = torchvision.transforms.Compose([\n", | |
| " torchvision.transforms.ToTensor(),\n", | |
| " torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
| " ])\n", | |
| " mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
| " \n", | |
| " quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(quantile_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
| " pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
| " decoder = RFFSoftmax(hilo, 1, 2000, 0.01, 2112)\n", | |
| " break\n", | |
| " \n", | |
| " zero_one_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1, shuffle=True)\n", | |
| " zeros = []\n", | |
| " ones = []\n", | |
| " for bno, (images, labels) in enumerate(zero_one_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " if labels[0] == 0:\n", | |
| " zeros.append(flat)\n", | |
| " elif labels[0] == 1:\n", | |
| " ones.append(flat)\n", | |
| " \n", | |
| " if len(zeros) > 100 and len(ones) > 100:\n", | |
| " break \n", | |
| " zeros = torch.cat(zeros, dim=0)\n", | |
| " ones = torch.cat(ones, dim=0)\n", | |
| " \n", | |
| " # pre-train to get policy \"better than random\"\n", | |
| " if True:\n", | |
| " preopt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-3) # 0.1\n", | |
| " preloss = torch.nn.CrossEntropyLoss()\n", | |
| " pretrain_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(itertools.chain(*[ pretrain_loader for _ in range(1) ])):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| "\n", | |
| " preopt.zero_grad()\n", | |
| " ld = pi.forward(flat)\n", | |
| " output = preloss(ld, labels)\n", | |
| " output.backward()\n", | |
| " preopt.step()\n", | |
| "\n", | |
| " if bno > 0:\n", | |
| " break\n", | |
| " \n", | |
| " train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
| " mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
| " test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
| " \n", | |
| " opt = torch.optim.Adam(( p for p in itertools.chain(pi.parameters(), decoder.parameters()) if p.requires_grad ), lr=1e-2)\n", | |
| " log_loss = torch.nn.BCEWithLogitsLoss(reduce='none')\n", | |
| " sampler = SquareCB(gamma=100)\n", | |
| " acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n", | |
| " avreward, avrewardsincelast, avfake, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
| " \n", | |
| " print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
| " 'n', 'loss', 'since', \n", | |
| " 'acc', 'since',\n", | |
| " 'reward', 'since',\n", | |
| " 'fake', 'since',\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " \n", | |
| " for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " \n", | |
| " opt.zero_grad()\n", | |
| " logit = pi(flat)\n", | |
| " with torch.no_grad():\n", | |
| " fhat = pi.density(logit)\n", | |
| " sample = sampler.sample(fhat)\n", | |
| " reward = (sample == labels).unsqueeze(1).float()\n", | |
| " pred = logit.argmax(dim=1)\n", | |
| " ispred = (sample == pred).unsqueeze(1).float()\n", | |
| " antipred = logit.argmin(dim=1)\n", | |
| " isantipred = (sample == antipred).unsqueeze(1).float()\n", | |
| " zerossample = torch.randint(low=0, high=zeros.shape[0], size=(fhat.shape[0], 1))\n", | |
| " zerofeedback = torch.gather(input=zeros, index=zerossample.expand(-1, zeros.shape[1]), dim=0)\n", | |
| " onessample = torch.randint(low=0, high=ones.shape[0], size=(fhat.shape[0], 1))\n", | |
| " onefeedback = torch.gather(input=ones, index=onessample.expand(-1, ones.shape[1]), dim=0)\n", | |
| " feedback = zerofeedback + reward * (onefeedback - zerofeedback) \n", | |
| " \n", | |
| " samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)\n", | |
| " fakelogit = decoder(feedback)\n", | |
| " fakereward = decoder.density(fakelogit)\n", | |
| " predloss = torch.mean(log_loss(fakelogit, ispred) + log_loss(samplelogit, fakereward.detach()))\n", | |
| " antipredloss = torch.mean(log_loss(1 - fakelogit, isantipred) + log_loss(1 - samplelogit, fakereward.detach()))\n", | |
| " loss = torch.min(predloss, antipredloss)\n", | |
| "\n", | |
| " loss.backward()\n", | |
| " opt.step()\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " pred = logit.argmax(dim=1)\n", | |
| " acc += torch.mean((labels == pred).float())\n", | |
| " accsincelast += torch.mean((labels == pred).float())\n", | |
| " avloss += loss\n", | |
| " avlosssincelast += loss\n", | |
| " avreward += torch.mean(reward)\n", | |
| " avrewardsincelast += torch.mean(reward)\n", | |
| " avfake += torch.mean(fakereward)\n", | |
| " avfakesincelast += torch.mean(fakereward)\n", | |
| " \n", | |
| " if (bno & (bno - 1) == 0):\n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " avreward.mean(), avrewardsincelast.mean(),\n", | |
| " avfake.mean(), avfakesincelast.mean(),\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast, avrewardsincelast, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
| " \n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " avreward.mean(), avrewardsincelast.mean(),\n", | |
| " avfake.mean(), avfakesincelast.mean(),\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
| " testacc = EasyAcc()\n", | |
| " with torch.no_grad():\n", | |
| " for ti, tl in train_loader:\n", | |
| " flat = ti.reshape(ti.shape[0], -1)\n", | |
| " logit = pi(flat)\n", | |
| " testpred = logit.argmax(dim=1)\n", | |
| " testacc += torch.mean((tl == testpred).float())\n", | |
| "\n", | |
| " print(f'testacc {testacc.mean()}')\n", | |
| "\n", | |
| "iglLearn()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "90994504", | |
| "metadata": {}, | |
| "source": [ | |
| "# IGL ($y_a \\perp x|r_a$)\n", | |
| "$y_a$ is an image of the action taken if $r_a = 1$, e.g., if $a=3$, a \"three\" image; otherwise if $r_a = 0$, an image of $(9-a)$, e.g., if $a=3$, a \"six\" image." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 288, | |
| "id": "4d4e6631", | |
| "metadata": { | |
| "code_folding": [ | |
| 6, | |
| 16, | |
| 26, | |
| 31, | |
| 55, | |
| 84, | |
| 85, | |
| 108, | |
| 113, | |
| 116, | |
| 117, | |
| 140, | |
| 145, | |
| 148, | |
| 155, | |
| 164, | |
| 172, | |
| 199, | |
| 213, | |
| 233, | |
| 271, | |
| 305, | |
| 314 | |
| ], | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "n \tloss \tsince \tacc \tsince \treward \tsince \tfake \tsince \n", | |
| "1 \t1.38891 \t1.38891 \t0.28125 \t0.28125 \t0.23438 \t0.23438 \t0.49940 \t0.49940 \n", | |
| "2 \t1.38223 \t1.37554 \t0.29688 \t0.31250 \t0.19531 \t0.15625 \t0.49962 \t0.49983 \n", | |
| "3 \t1.38101 \t1.37858 \t0.31771 \t0.35938 \t0.18750 \t0.17188 \t0.49995 \t0.50062 \n", | |
| "5 \t1.38767 \t1.39765 \t0.29063 \t0.25000 \t0.14687 \t0.08594 \t0.49908 \t0.49777 \n", | |
| "9 \t1.38487 \t1.38138 \t0.19965 \t0.08594 \t0.13368 \t0.11719 \t0.50189 \t0.50540 \n", | |
| "17 \t1.35851 \t1.32885 \t0.18199 \t0.16211 \t0.13787 \t0.14258 \t0.50388 \t0.50613 \n", | |
| "33 \t1.34694 \t1.33465 \t0.20028 \t0.21973 \t0.14725 \t0.15723 \t0.50259 \t0.50121 \n", | |
| "65 \t1.33522 \t1.32314 \t0.27428 \t0.35059 \t0.19447 \t0.24316 \t0.50141 \t0.50020 \n", | |
| "129 \t1.26931 \t1.20236 \t0.35913 \t0.44531 \t0.26211 \t0.33081 \t0.50427 \t0.50716 \n", | |
| "257 \t1.14736 \t1.02446 \t0.52420 \t0.69055 \t0.40096 \t0.54089 \t0.53432 \t0.56461 \n", | |
| "513 \t0.99472 \t0.84148 \t0.65068 \t0.77765 \t0.52537 \t0.65027 \t0.57622 \t0.61829 \n", | |
| "938 \t0.86970 \t0.71879 \t0.73391 \t0.83438 \t0.61042 \t0.71309 \t0.61330 \t0.65806 \n", | |
| "testacc 0.8389525413513184\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def iglADepLearn():\n", | |
| " import itertools\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " import torchvision\n", | |
| " \n", | |
| " class WeightedReservoir(object):\n", | |
| " def __init__(self, n, seed):\n", | |
| " import random\n", | |
| " \n", | |
| " super().__init__()\n", | |
| " self.n = n\n", | |
| " self.items = []\n", | |
| " self.wsum = 0\n", | |
| " self.gen = random.Random(seed) \n", | |
| " \n", | |
| " def insert(self, item, weight):\n", | |
| " if weight > 0:\n", | |
| " self.wsum += weight\n", | |
| " if self.wsum * self.gen.random() < weight:\n", | |
| " if len(self.items) < self.n:\n", | |
| " self.items.append(item)\n", | |
| " else:\n", | |
| " index = self.gen.randrange(0, self.n) \n", | |
| " self.items[index] = item\n", | |
| " \n", | |
| " def sample(self):\n", | |
| " assert len(self.items) > 0\n", | |
| " index = self.gen.randrange(0, len(self.items))\n", | |
| " return self.items[index]\n", | |
| " \n", | |
| " class SquareCB(object):\n", | |
| " def __init__(self, gamma):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.gamma = gamma\n", | |
| "\n", | |
| " def sample(self, fhat, *, keepdim=False):\n", | |
| " N, K = fhat.shape\n", | |
| " fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
| " probs = 1 / (K + self.gamma * (fhatstar - fhat))\n", | |
| " psum = torch.sum(probs, dim=1, keepdim=True)\n", | |
| " phatstar = psum + torch.gather(input=probs, dim=1, index=ahatstar)\n", | |
| "\n", | |
| " rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
| " prando = torch.gather(input=probs, dim=1, index=rando)\n", | |
| " unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
| " shouldexplore = (unif <= K * prando).long()\n", | |
| " actions = ahatstar + shouldexplore * (rando - ahatstar)\n", | |
| " pactions = phatstar + shouldexplore * (prando - phatstar)\n", | |
| " if not keepdim:\n", | |
| " actions = actions.squeeze(1)\n", | |
| " pactions = pactions.squeeze(1)\n", | |
| " return actions, pactions\n", | |
| " \n", | |
| " class EasyAcc:\n", | |
| " def __init__(self):\n", | |
| " self.n = 0\n", | |
| " self.sum = 0\n", | |
| " self.sumsq = 0\n", | |
| "\n", | |
| " def __iadd__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum += other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def __isub__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum -= other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def mean(self):\n", | |
| " return self.sum / max(self.n, 1)\n", | |
| "\n", | |
| " def var(self):\n", | |
| " from math import sqrt\n", | |
| " return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
| "\n", | |
| " def semean(self):\n", | |
| " from math import sqrt\n", | |
| " return self.var() / sqrt(max(self.n, 1))\n", | |
| "\n", | |
| " class RFFBilinearSoftmax(torch.nn.Module):\n", | |
| " def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
| " from math import pi\n", | |
| " import numpy as np\n", | |
| "\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " torch.manual_seed(seed)\n", | |
| " nobs = hilo.shape[1]\n", | |
| " high = hilo[1, :]\n", | |
| " low = hilo[0, :]\n", | |
| " \n", | |
| " self.rff = torch.nn.Linear(nobs, numrff)\n", | |
| " self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
| " torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
| " self.rff.weight.requires_grad = False\n", | |
| " self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
| " self.rff.bias.requires_grad = False\n", | |
| " self.sqrtrff = np.sqrt(numrff)\n", | |
| " self.final = torch.nn.Bilinear(naction, numrff, 1)\n", | |
| " self.final.weight.data *= 0.01\n", | |
| " self.final.bias.data *= 0.01\n", | |
| " self.sigmoid = torch.nn.Sigmoid()\n", | |
| "\n", | |
| " def forward(self, a, y):\n", | |
| " with torch.no_grad():\n", | |
| " rff = self.rff(y).cos() / self.sqrtrff\n", | |
| " return self.final(a, rff)\n", | |
| " \n", | |
| " def density(self, logits):\n", | |
| " return self.sigmoid(logits)\n", | |
| "\n", | |
| " class RFFSoftmax(torch.nn.Module):\n", | |
| " def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
| " from math import pi\n", | |
| " import numpy as np\n", | |
| "\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " torch.manual_seed(seed)\n", | |
| " nobs = hilo.shape[1]\n", | |
| " high = hilo[1, :]\n", | |
| " low = hilo[0, :]\n", | |
| " \n", | |
| " self.rff = torch.nn.Linear(nobs, numrff)\n", | |
| " self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
| " torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
| " self.rff.weight.requires_grad = False\n", | |
| " self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
| " self.rff.bias.requires_grad = False\n", | |
| " self.sqrtrff = np.sqrt(numrff)\n", | |
| " self.final = torch.nn.Linear(numrff, naction)\n", | |
| " self.final.weight.data *= 0.01\n", | |
| " self.final.bias.data *= 0.01\n", | |
| " self.sigmoid = torch.nn.Sigmoid()\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " with torch.no_grad():\n", | |
| " rff = self.rff(x).cos() / self.sqrtrff\n", | |
| " return self.final(rff)\n", | |
| " \n", | |
| " def preq1(self, logits):\n", | |
| " return self.sigmoid(logits)\n", | |
| "\n", | |
| " transform = torchvision.transforms.Compose([\n", | |
| " torchvision.transforms.ToTensor(),\n", | |
| " torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
| " ])\n", | |
| " mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
| " \n", | |
| " quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(quantile_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
| " pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
| " decoder = RFFBilinearSoftmax(hilo, 10, 2000, 0.01, 2112)\n", | |
| " break\n", | |
| " \n", | |
| " feedback_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1, shuffle=True)\n", | |
| " feedbacks = [ [] for _ in range(10) ]\n", | |
| " for bno, (images, labels) in enumerate(feedback_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " feedbacks[labels[0]].append(flat)\n", | |
| " if all(len(x) > 100 for x in feedbacks):\n", | |
| " break \n", | |
| " feedbacks = torch.cat([ torch.cat(x[:100], dim=0).unsqueeze(0) for x in feedbacks ], dim=0)\n", | |
| " \n", | |
| " # pre-train to get policy \"better than random\"\n", | |
| " if True:\n", | |
| " preopt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-2) # 0.1\n", | |
| " preloss = torch.nn.CrossEntropyLoss()\n", | |
| " pretrain_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(itertools.chain(*[ pretrain_loader for _ in range(1) ])):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| "\n", | |
| " preopt.zero_grad()\n", | |
| " ld = pi.forward(flat)\n", | |
| " output = preloss(ld, labels)\n", | |
| " output.backward()\n", | |
| " preopt.step()\n", | |
| "\n", | |
| " if bno > 0:\n", | |
| " break\n", | |
| " \n", | |
| " train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
| " mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
| " test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
| " \n", | |
| " opt = torch.optim.Adam(( p for p in itertools.chain(pi.parameters(), decoder.parameters()) if p.requires_grad ), lr=1e-2)\n", | |
| " log_loss = torch.nn.BCEWithLogitsLoss(reduce='none')\n", | |
| " sampler = SquareCB(gamma=100)\n", | |
| " acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n", | |
| " avreward, avrewardsincelast, avfake, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
| " reservoirs = [ WeightedReservoir(20, 1973+a) for a in range(10) ]\n", | |
| " \n", | |
| " print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
| " 'n', 'loss', 'since', \n", | |
| " 'acc', 'since',\n", | |
| " 'reward', 'since',\n", | |
| " 'fake', 'since',\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " \n", | |
| " for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
| " flatimage = images.reshape(images.shape[0], -1)\n", | |
| " \n", | |
| " opt.zero_grad()\n", | |
| " logit = pi(flatimage)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " fhat = pi.preq1(logit)\n", | |
| " sample, probs = sampler.sample(fhat, keepdim=True)\n", | |
| " \n", | |
| " reward = (sample == labels.unsqueeze(1)).float()\n", | |
| " pred = logit.argmax(dim=1, keepdim=True)\n", | |
| " ispred = (sample == pred).float()\n", | |
| " antipred = logit.argmin(dim=1, keepdim=True)\n", | |
| " isantipred = (sample == antipred).float()\n", | |
| " \n", | |
| " # this assumes a particular majorization (Torch tensors are row-major)\n", | |
| " bigfeedbacks = feedbacks.unsqueeze(0).expand(fhat.shape[0], -1, -1, -1).reshape(fhat.shape[0], -1, flatimage.shape[1]) # Batch x (A x Rep) x Pixels\n", | |
| " nreps = feedbacks.shape[1]\n", | |
| " goodwhich = feedbacks.shape[1] * sample.squeeze(1) + torch.randint(low=0, high=feedbacks.shape[1], size=(fhat.shape[0],))\n", | |
| " goodwhich = goodwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, flatimage.shape[1])\n", | |
| " goodfeedbacks = torch.gather(input=bigfeedbacks, index=goodwhich, dim=1).squeeze(1)\n", | |
| " badwhich = feedbacks.shape[1] * (9-sample).squeeze(1) + torch.randint(low=0, high=feedbacks.shape[1], size=(fhat.shape[0],))\n", | |
| " badwhich = badwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, flatimage.shape[1])\n", | |
| " badfeedbacks = torch.gather(input=bigfeedbacks, index=badwhich, dim=1).squeeze(1)\n", | |
| " \n", | |
| " if False:\n", | |
| " import matplotlib.pyplot as plt\n", | |
| "\n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(sample, goodfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| " \n", | |
| " plt.show()\n", | |
| " \n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(sample, badfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| " \n", | |
| " plt.show()\n", | |
| " assert False\n", | |
| " \n", | |
| " feedback = badfeedbacks + reward * (goodfeedbacks - badfeedbacks)\n", | |
| " onehotsample = torch.nn.functional.one_hot(sample.squeeze(1), num_classes=fhat.shape[1]).float()\n", | |
| " \n", | |
| " # insert then sample ... means the first time we play an action there will be no update, that's ok\n", | |
| " for s, p, r, f in zip(sample, probs, reward, feedback):\n", | |
| " reservoirs[s.item()].insert((f, r), 1/p)\n", | |
| " \n", | |
| " compfeedback = []\n", | |
| " compreward = []\n", | |
| " for s in sample:\n", | |
| " f, r = reservoirs[s.item()].sample()\n", | |
| " compfeedback.append(f.unsqueeze(0))\n", | |
| " compreward.append(r.unsqueeze(0))\n", | |
| " compfeedback = torch.cat(compfeedback, dim=0)\n", | |
| " compreward = torch.cat(compreward, dim=0)\n", | |
| " \n", | |
| " if False:\n", | |
| " import matplotlib.pyplot as plt\n", | |
| "\n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f, r) in enumerate(zip(sample, compfeedback, compreward)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()} {r.long().item()}')\n", | |
| " \n", | |
| " plt.show()\n", | |
| " assert False\n", | |
| "\n", | |
| " samplelogit = torch.gather(input=logit, index=sample, dim=1)\n", | |
| " fakelogit = decoder(onehotsample, feedback)\n", | |
| " fakereward = decoder.density(fakelogit)\n", | |
| " fakecomplogit = decoder(onehotsample, compfeedback)\n", | |
| " predloss = torch.mean(log_loss(fakelogit - fakecomplogit, ispred) + log_loss(samplelogit, fakereward.detach()))\n", | |
| " antipredloss = torch.mean(log_loss(fakecomplogit - fakelogit, isantipred) + log_loss(1 - samplelogit, fakereward.detach()))\n", | |
| " loss = torch.min(predloss, antipredloss)\n", | |
| " loss.backward()\n", | |
| " opt.step()\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " acc += torch.mean((labels.unsqueeze(1) == pred).float())\n", | |
| " accsincelast += torch.mean((labels.unsqueeze(1) == pred).float())\n", | |
| " avloss += loss\n", | |
| " avlosssincelast += loss\n", | |
| " avreward += torch.mean(reward)\n", | |
| " avrewardsincelast += torch.mean(reward)\n", | |
| " avfake += torch.mean(fakereward)\n", | |
| " avfakesincelast += torch.mean(fakereward)\n", | |
| " \n", | |
| " if (bno & (bno - 1) == 0):\n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " avreward.mean(), avrewardsincelast.mean(),\n", | |
| " avfake.mean(), avfakesincelast.mean(),\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast, avrewardsincelast, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
| " \n", | |
| " print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
| " avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
| " acc.mean(), accsincelast.mean(), \n", | |
| " avreward.mean(), avrewardsincelast.mean(),\n", | |
| " avfake.mean(), avfakesincelast.mean(),\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
| " testacc = EasyAcc()\n", | |
| " with torch.no_grad():\n", | |
| " for ti, tl in train_loader:\n", | |
| " flat = ti.reshape(ti.shape[0], -1)\n", | |
| " logit = pi(flat)\n", | |
| " testpred = logit.argmax(dim=1)\n", | |
| " testacc += torch.mean((tl == testpred).float())\n", | |
| "\n", | |
| " print(f'testacc {testacc.mean()}')\n", | |
| "\n", | |
| "iglADepLearn()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "ef8c19c9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment