Last active
March 5, 2022 21:55
-
-
Save pmineiro/902b40b3054a77a1e85af6d5ffd469fe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "f90e9f40", | |
| "metadata": {}, | |
| "source": [ | |
| "# Util" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "393278c2", | |
| "metadata": { | |
| "code_folding": [ | |
| 2, | |
| 18, | |
| 26, | |
| 47, | |
| 48, | |
| 54, | |
| 69, | |
| 82, | |
| 91, | |
| 100, | |
| 104, | |
| 117, | |
| 156, | |
| 157, | |
| 163, | |
| 174, | |
| 187, | |
| 195, | |
| 196, | |
| 208, | |
| 228, | |
| 261, | |
| 267, | |
| 268, | |
| 274, | |
| 285, | |
| 298, | |
| 306, | |
| 307, | |
| 319, | |
| 340, | |
| 373, | |
| 379, | |
| 395, | |
| 401, | |
| 411, | |
| 412, | |
| 444, | |
| 447, | |
| 465, | |
| 468, | |
| 475, | |
| 485, | |
| 512, | |
| 515, | |
| 518, | |
| 539, | |
| 549, | |
| 559, | |
| 591, | |
| 594, | |
| 602, | |
| 603, | |
| 609, | |
| 627, | |
| 630, | |
| 631, | |
| 638, | |
| 662, | |
| 665 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from abc import ABC, abstractmethod\n", | |
| "\n", | |
| "class Batch(object):\n", | |
| " def __init__(self):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " @abstractmethod\n", | |
| " def getContext(self):\n", | |
| " pass\n", | |
| "\n", | |
| " @abstractmethod\n", | |
| " def getFeedback(self, action):\n", | |
| " pass\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def getReward(self, action):\n", | |
| " pass\n", | |
| "\n", | |
| "class Simulator(object):\n", | |
| " def __init__(self):\n", | |
| " super().__init__()\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def trainIterator(self):\n", | |
| " pass\n", | |
| " \n", | |
| "class MnistSimulator(Simulator):\n", | |
| " def __init__(self, batch_size):\n", | |
| " import torchvision\n", | |
| "\n", | |
| " super().__init__()\n", | |
| " self.batch_size = batch_size\n", | |
| " \n", | |
| " transform = torchvision.transforms.Compose([\n", | |
| " torchvision.transforms.ToTensor(),\n", | |
| " torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
| " ])\n", | |
| " self.mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def computeHilo(self):\n", | |
| " pass\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def trainIterator(self):\n", | |
| " pass\n", | |
| " \n", | |
| "class MnistFullCI(MnistSimulator):\n", | |
| " def __init__(self, *, batch_size, decodability):\n", | |
| " super().__init__(batch_size)\n", | |
| " self.decodability = decodability\n", | |
| " \n", | |
| " self._makeFeedbacks()\n", | |
| " \n", | |
| " def _makeFeedbacks(self):\n", | |
| " import torch\n", | |
| " zero_one_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=1, shuffle=True)\n", | |
| " zeros, ones = [], []\n", | |
| " for bno, (images, labels) in enumerate(zero_one_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " if labels[0] == 0:\n", | |
| " zeros.append(flat)\n", | |
| " elif labels[0] == 1:\n", | |
| " ones.append(flat)\n", | |
| "\n", | |
| " if len(zeros) > 100 and len(ones) > 100:\n", | |
| " break \n", | |
| " self.zeros, self.ones = torch.cat(zeros[:100], dim=0), torch.cat(ones[:100], dim=0)\n", | |
| "\n", | |
| " def computeHilo(self):\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " quantile_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=10000, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(quantile_loader):\n", | |
| " flat = images.view(images.shape[0], -1)\n", | |
| " hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
| " break\n", | |
| " \n", | |
| " return hilo, hilo\n", | |
| " \n", | |
| " def trainIterator(self):\n", | |
| " import torch\n", | |
| " \n", | |
| " train_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)\n", | |
| "\n", | |
| " for images, labels in train_loader:\n", | |
| " yield self.MyBatch(images, labels, self.zeros, self.ones, self.decodability)\n", | |
| " \n", | |
| " class MyBatch(Batch):\n", | |
| " def __init__(self, images, labels, zeros, ones, decodability):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.images = images\n", | |
| " self.labels = labels\n", | |
| " self.zeros = zeros\n", | |
| " self.ones = ones\n", | |
| " self.decodability = decodability\n", | |
| " \n", | |
| " def getContext(self):\n", | |
| " return self.images.view(self.images.shape[0], -1)\n", | |
| " \n", | |
| " # feedback is an image of 1 if correct else 0\n", | |
| " def getFeedback(self, action):\n", | |
| " import torch\n", | |
| " with torch.no_grad():\n", | |
| " reward = (action == self.labels.unsqueeze(1)).float()\n", | |
| " zerossample = torch.randint(low=0, high=self.zeros.shape[0], size=(action.shape[0], 1))\n", | |
| " goodfeedbacks = torch.gather(input=self.zeros, index=zerossample.expand(-1, self.zeros.shape[1]), dim=0)\n", | |
| " onessample = torch.randint(low=0, high=self.ones.shape[0], size=(action.shape[0], 1))\n", | |
| " badfeedbacks = torch.gather(input=self.ones, index=onessample.expand(-1, self.ones.shape[1]), dim=0)\n", | |
| " noise = torch.rand(size=(action.shape[0], 1), device=action.device)\n", | |
| " shouldflip = (noise <= ((1 + self.decodability)/2)).long()\n", | |
| " noisyreward = reward + shouldflip * (1 - 2 * reward)\n", | |
| " feedback = badfeedbacks + noisyreward * (goodfeedbacks - badfeedbacks)\n", | |
| " \n", | |
| " if False:\n", | |
| " import matplotlib.pyplot as plt\n", | |
| "\n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(self.labels, goodfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| "\n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(self.labels, badfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| " \n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (a, r, nr, f) in enumerate(zip(action, reward, noisyreward, feedback)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{a.item()} {r.long().item()} {nr.long().item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| " assert False\n", | |
| " \n", | |
| " return feedback\n", | |
| " \n", | |
| " def getReward(self, action):\n", | |
| " import torch\n", | |
| " with torch.no_grad():\n", | |
| " reward = (action == self.labels.unsqueeze(1)).float()\n", | |
| " return torch.mean(reward)\n", | |
| "\n", | |
| "class MnistActionCI(MnistSimulator):\n", | |
| " def __init__(self, *, batch_size, decodability):\n", | |
| " super().__init__(batch_size)\n", | |
| " self.decodability = decodability\n", | |
| " \n", | |
| " self._makeFeedbacks()\n", | |
| " \n", | |
| " def _makeFeedbacks(self):\n", | |
| " import torch\n", | |
| " feedback_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=1, shuffle=True)\n", | |
| " feedbacks = [ [] for _ in range(10) ]\n", | |
| " for bno, (images, labels) in enumerate(feedback_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " feedbacks[labels[0]].append(flat)\n", | |
| " if all(len(x) > 100 for x in feedbacks):\n", | |
| " break \n", | |
| " self.feedbacks = torch.cat([ torch.cat(x[:100], dim=0).unsqueeze(0) for x in feedbacks ], dim=0)\n", | |
| "\n", | |
| " def computeHilo(self):\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " quantile_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=10000, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(quantile_loader):\n", | |
| " flat = images.view(images.shape[0], -1)\n", | |
| " hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
| " break\n", | |
| " \n", | |
| " return hilo, hilo\n", | |
| " \n", | |
| " def trainIterator(self):\n", | |
| " import torch\n", | |
| " \n", | |
| " train_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)\n", | |
| "\n", | |
| " for images, labels in train_loader:\n", | |
| " yield self.MyBatch(images, labels, self.feedbacks, self.decodability)\n", | |
| " \n", | |
| " class MyBatch(Batch):\n", | |
| " def __init__(self, images, labels, feedbacks, decodability):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.images = images\n", | |
| " self.labels = labels\n", | |
| " self.feedbacks = feedbacks\n", | |
| " self.decodability = decodability\n", | |
| " \n", | |
| " def getContext(self):\n", | |
| " return self.images.view(self.images.shape[0], -1)\n", | |
| " \n", | |
| " # feedback is an image of (x + 1) % 10 if correct else (x - 1) % 10\n", | |
| " def getFeedback(self, action):\n", | |
| " import torch\n", | |
| " with torch.no_grad():\n", | |
| " reward = (action == self.labels.unsqueeze(1)).float()\n", | |
| " pixels = self.getContext().shape[1]\n", | |
| " \n", | |
| " # this assumes a particular majorization (Torch tensors are row-major)\n", | |
| " bigfeedbacks = self.feedbacks.unsqueeze(0).expand(action.shape[0], -1, -1, -1).reshape(action.shape[0], -1, pixels) # Batch x (A x Rep) x Pixels\n", | |
| " goodwhich = self.feedbacks.shape[1] * torch.remainder(self.labels + 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n", | |
| " goodwhich = goodwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n", | |
| " goodfeedbacks = torch.gather(input=bigfeedbacks, index=goodwhich, dim=1).squeeze(1)\n", | |
| " badwhich = self.feedbacks.shape[1] * torch.remainder(self.labels - 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n", | |
| " badwhich = badwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n", | |
| " badfeedbacks = torch.gather(input=bigfeedbacks, index=badwhich, dim=1).squeeze(1)\n", | |
| " \n", | |
| " noise = torch.rand(size=(action.shape[0], 1), device=action.device)\n", | |
| " shouldflip = (noise <= ((1 + self.decodability)/2)).long()\n", | |
| " noisyreward = reward + shouldflip * (1 - 2 * reward)\n", | |
| " feedback = badfeedbacks + noisyreward * (goodfeedbacks - badfeedbacks)\n", | |
| " \n", | |
| " if False:\n", | |
| " import matplotlib.pyplot as plt\n", | |
| " \n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(self.labels, goodfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| "\n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(self.labels, badfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| " \n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (a, r, nr, f) in enumerate(zip(action, reward, noisyreward, feedback)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{a.item()} {r.long().item()} {nr.long().item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| " assert False\n", | |
| " \n", | |
| " return feedback\n", | |
| " \n", | |
| " def getReward(self, action):\n", | |
| " import torch\n", | |
| " with torch.no_grad():\n", | |
| " reward = (action == self.labels.unsqueeze(1)).float()\n", | |
| " return torch.mean(reward)\n", | |
| "\n", | |
| "class MnistContextCI(MnistSimulator):\n", | |
| " def __init__(self, *, batch_size, decodability):\n", | |
| " super().__init__(batch_size)\n", | |
| " self.decodability = decodability\n", | |
| " \n", | |
| " self._makeFeedbacks()\n", | |
| " \n", | |
| " def _makeFeedbacks(self):\n", | |
| " import torch\n", | |
| " feedback_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=1, shuffle=True)\n", | |
| " feedbacks = [ [] for _ in range(10) ]\n", | |
| " for bno, (images, labels) in enumerate(feedback_loader):\n", | |
| " flat = images.reshape(images.shape[0], -1)\n", | |
| " feedbacks[labels[0]].append(flat)\n", | |
| " if all(len(x) > 100 for x in feedbacks):\n", | |
| " break \n", | |
| " self.feedbacks = torch.cat([ torch.cat(x[:100], dim=0).unsqueeze(0) for x in feedbacks ], dim=0)\n", | |
| "\n", | |
| " def computeHilo(self):\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " quantile_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=10000, shuffle=True)\n", | |
| " for bno, (images, labels) in enumerate(quantile_loader):\n", | |
| " flat = images.view(images.shape[0], -1)\n", | |
| " hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
| " break\n", | |
| " \n", | |
| " return hilo, hilo\n", | |
| " \n", | |
| " def trainIterator(self):\n", | |
| " import torch\n", | |
| " \n", | |
| " train_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)\n", | |
| "\n", | |
| " for images, labels in train_loader:\n", | |
| " yield self.MyBatch(images, labels, self.feedbacks, self.decodability)\n", | |
| " \n", | |
| " class MyBatch(Batch):\n", | |
| " def __init__(self, images, labels, feedbacks, decodability):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.images = images\n", | |
| " self.labels = labels\n", | |
| " self.feedbacks = feedbacks\n", | |
| " self.decodability = decodability\n", | |
| " \n", | |
| " def getContext(self):\n", | |
| " return self.images.view(self.images.shape[0], -1)\n", | |
| " \n", | |
| " # feedback is an image of (a + 1) % 10 if correct else (a - 1) % 10\n", | |
| " def getFeedback(self, action):\n", | |
| " import torch\n", | |
| " with torch.no_grad():\n", | |
| " reward = (action == self.labels.unsqueeze(1)).float()\n", | |
| " pixels = self.getContext().shape[1]\n", | |
| " \n", | |
| " # this assumes a particular majorization (Torch tensors are row-major)\n", | |
| " shortaction = action.squeeze(1)\n", | |
| " bigfeedbacks = self.feedbacks.unsqueeze(0).expand(action.shape[0], -1, -1, -1).reshape(action.shape[0], -1, pixels) # Batch x (A x Rep) x Pixels\n", | |
| " goodwhich = self.feedbacks.shape[1] * torch.remainder(shortaction + 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n", | |
| " goodwhich = goodwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n", | |
| " goodfeedbacks = torch.gather(input=bigfeedbacks, index=goodwhich, dim=1).squeeze(1)\n", | |
| " badwhich = self.feedbacks.shape[1] * torch.remainder(shortaction - 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n", | |
| " badwhich = badwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n", | |
| " badfeedbacks = torch.gather(input=bigfeedbacks, index=badwhich, dim=1).squeeze(1)\n", | |
| " \n", | |
| " noise = torch.rand(size=(action.shape[0], 1), device=action.device)\n", | |
| " shouldflip = (noise <= ((1 + self.decodability)/2)).long()\n", | |
| " noisyreward = reward + shouldflip * (1 - 2 * reward)\n", | |
| " feedback = badfeedbacks + noisyreward * (goodfeedbacks - badfeedbacks)\n", | |
| " \n", | |
| " if False:\n", | |
| " import matplotlib.pyplot as plt\n", | |
| " \n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(shortaction, goodfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| "\n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (s, f) in enumerate(zip(shortaction, badfeedbacks)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{s.item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| " \n", | |
| " fig, axs = plt.subplots(1, 10)\n", | |
| " for n, (a, r, nr, f) in enumerate(zip(action, reward, noisyreward, feedback)):\n", | |
| " if n > 9:\n", | |
| " break\n", | |
| " axs[n].imshow(f.reshape(28, 28))\n", | |
| " axs[n].set_title(f'{a.item()} {r.long().item()} {nr.long().item()}')\n", | |
| "\n", | |
| " plt.show()\n", | |
| " assert False\n", | |
| " \n", | |
| " return feedback\n", | |
| " \n", | |
| " def getReward(self, action):\n", | |
| " import torch\n", | |
| " with torch.no_grad():\n", | |
| " reward = (action == self.labels.unsqueeze(1)).float()\n", | |
| " return torch.mean(reward)\n", | |
| "\n", | |
| "class Algorithm(object):\n", | |
| " def __init__(self):\n", | |
| " super().__init__()\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def sample(self, x):\n", | |
| " pass\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def greedy(self, x):\n", | |
| " pass\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def update(self, sample, feedback):\n", | |
| " pass\n", | |
| "\n", | |
| "class Util(object):\n", | |
| " import torch\n", | |
| " \n", | |
| " def __init__(self):\n", | |
| " super().__init__()\n", | |
| " \n", | |
| " class Sample(object):\n", | |
| " def __init__(self, x, action, probs):\n", | |
| " super().__init__()\n", | |
| " self.x = x\n", | |
| " self.action = action\n", | |
| " self.probs = probs\n", | |
| "\n", | |
| " def getAction(self):\n", | |
| " return self.action\n", | |
| " \n", | |
| " class RFFSoftmax(torch.nn.Module):\n", | |
| " def __init__(self, hilo, naction, numrff, sigma):\n", | |
| " from math import pi\n", | |
| " import numpy as np\n", | |
| " import torch\n", | |
| "\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " nobs = hilo.shape[1]\n", | |
| " high = hilo[1, :]\n", | |
| " low = hilo[0, :]\n", | |
| "\n", | |
| " self.rff = torch.nn.Linear(nobs, numrff)\n", | |
| " self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
| " torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
| " self.rff.weight.requires_grad = False\n", | |
| " self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
| " self.rff.bias.requires_grad = False\n", | |
| " self.sqrtrff = np.sqrt(numrff)\n", | |
| " self.final = torch.nn.Linear(numrff, naction)\n", | |
| " self.final.weight.data *= 0\n", | |
| " self.final.bias.data *= 0\n", | |
| " self.sigmoid = torch.nn.Sigmoid()\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " import torch\n", | |
| " with torch.no_grad():\n", | |
| " rff = self.rff(x).cos() / self.sqrtrff\n", | |
| " return self.final(rff)\n", | |
| " \n", | |
| " def predict(self, logits):\n", | |
| " return self.sigmoid(logits)\n", | |
| "\n", | |
| "class IKAlgorithm(Algorithm):\n", | |
| " import torch\n", | |
| " \n", | |
| " def __init__(self, *, hilo, sampler, lr):\n", | |
| " from math import log\n", | |
| " import itertools\n", | |
| " import numpy\n", | |
| " import torch\n", | |
| " \n", | |
| " super().__init__()\n", | |
| " self.sampler = sampler\n", | |
| " \n", | |
| " util = Util()\n", | |
| " \n", | |
| " self.pi = util.RFFSoftmax(hilo[0], 10, 2000, 0.01)\n", | |
| " doublehilo = numpy.concatenate((hilo[0], hilo[1]), axis=1)\n", | |
| " self.decoder = util.RFFSoftmax(doublehilo, 10, 2000, 0.01)\n", | |
| " self.alpha = 1/3\n", | |
| " self.decoder.final.bias.data.fill_(log(self.alpha / (1 - self.alpha)))\n", | |
| " self.opt = torch.optim.Adam(( p for p in itertools.chain(self.pi.parameters(), \n", | |
| " self.decoder.parameters()) \n", | |
| " if p.requires_grad ), \n", | |
| " lr=lr)\n", | |
| " \n", | |
| " def sample(self, x):\n", | |
| " import torch\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " fhatlogit = self.pi(x)\n", | |
| " fhat = self.pi.predict(fhatlogit)\n", | |
| " sample, probs = self.sampler.sample(fhat, keepdim=True)\n", | |
| " \n", | |
| " return Util().Sample(x, sample, probs)\n", | |
| "\n", | |
| " def greedy(self, x):\n", | |
| " import torch\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " fhatlogit = self.pi(x)\n", | |
| " _, pred = torch.max(fhatlogit, dim=1, keepdim=True)\n", | |
| " _, anti = torch.min(fhatlogit, dim=1, keepdim=True)\n", | |
| " \n", | |
| " return pred, anti\n", | |
| " \n", | |
| " def update(self, sample, feedback):\n", | |
| " import torch\n", | |
| " \n", | |
| " self.opt.zero_grad()\n", | |
| " \n", | |
| " fhatlogit = self.pi(sample.x)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " probs = sample.probs\n", | |
| " d = torch.nn.functional.one_hot(sample.action.squeeze(1), num_classes=fhatlogit.shape[1]).float()\n", | |
| " dprobs = self.alpha * d + ((1 - self.alpha) / (fhatlogit.shape[1] - 1)) * (1 - d)\n", | |
| " dweights = (dprobs / probs) / torch.mean(dprobs / probs) \n", | |
| " \n", | |
| " dhatlogit = self.decoder(torch.cat((sample.x, feedback), dim=1))\n", | |
| " dhat_log_loss = torch.nn.BCEWithLogitsLoss(weight=dweights)\n", | |
| " dhat_loss = dhat_log_loss(dhatlogit, d)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " fakereward = self.decoder.sigmoid(dhatlogit)\n", | |
| " loggedfakereward = torch.gather(input=fakereward, index=sample.action, dim=1)\n", | |
| " \n", | |
| " fhat_log_loss = torch.nn.BCEWithLogitsLoss()\n", | |
| " loggedfhatlogit = torch.gather(input=fhatlogit, index=sample.action, dim=1)\n", | |
| " fhat_loss = fhat_log_loss(loggedfhatlogit, loggedfakereward)\n", | |
| "\n", | |
| " loss = dhat_loss + fhat_loss\n", | |
| " loss.backward()\n", | |
| " \n", | |
| " self.opt.step()\n", | |
| " \n", | |
| " return loss.item(), torch.mean(loggedfakereward)\n", | |
| "\n", | |
| " def __str__(self):\n", | |
| " return f'IKAlgorithm(lr={self.opt.defaults[\"lr\"]} sampler={self.sampler})'\n", | |
| "\n", | |
| "class Sampler(object):\n", | |
| " def __init__(self):\n", | |
| " super().__init__()\n", | |
| " \n", | |
| " @abstractmethod\n", | |
| " def sample(fhat, *, keepdim):\n", | |
| " pass\n", | |
| "\n", | |
| "class EpsilonGreedy(Sampler):\n", | |
| " def __init__(self, *, t0):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.t0 = t0\n", | |
| " self.t = 0\n", | |
| "\n", | |
| " def sample(self, fhat, *, keepdim=False):\n", | |
| " import torch\n", | |
| " N, K = fhat.shape\n", | |
| " epsilon = (self.t0 / (self.t0 + self.t))**(1/3)\n", | |
| " self.t += 1\n", | |
| " fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
| " rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
| " unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
| " shouldexplore = (unif <= epsilon).long()\n", | |
| " actions = ahatstar + shouldexplore * (rando - ahatstar)\n", | |
| " phatstar = (1 - epsilon) + epsilon / K\n", | |
| " prando = epsilon / K\n", | |
| " pactions = phatstar + shouldexplore * (prando - phatstar)\n", | |
| " if not keepdim:\n", | |
| " actions = actions.squeeze(1)\n", | |
| " pactions = pactions.squeeze(1)\n", | |
| " return actions, pactions\n", | |
| " \n", | |
| " def __str__(self):\n", | |
| " return f'EpsilonGreedy(t0={self.t0})'\n", | |
| "\n", | |
| "class SquareCB(Sampler):\n", | |
| " def __init__(self, *, t0, gamma0):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.t0 = t0\n", | |
| " self.gamma0 = gamma0\n", | |
| " self.t = 0\n", | |
| "\n", | |
| " def sample(self, fhat, *, keepdim=False):\n", | |
| " import torch\n", | |
| " \n", | |
| " self.t += 1\n", | |
| " gamma = self.gamma0 * ((self.t0 + self.t) / self.t0)**(1/2)\n", | |
| " \n", | |
| " N, K = fhat.shape\n", | |
| " fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
| " probs = 1 / (K + gamma * (fhatstar - fhat))\n", | |
| " #probs = (1 - fhat) / (K * (1 - fhat) + gamma * (fhatstar - fhat))\n", | |
| " psum = torch.sum(probs, dim=1, keepdim=True)\n", | |
| " phatstar = (1 - psum) + torch.gather(input=probs, dim=1, index=ahatstar)\n", | |
| "\n", | |
| " rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
| " prando = torch.gather(input=probs, dim=1, index=rando)\n", | |
| " unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
| " shouldexplore = (unif <= K * prando).long()\n", | |
| " actions = ahatstar + shouldexplore * (rando - ahatstar)\n", | |
| " pactions = phatstar + shouldexplore * (prando - phatstar)\n", | |
| " if not keepdim:\n", | |
| " actions = actions.squeeze(1)\n", | |
| " pactions = pactions.squeeze(1)\n", | |
| " return actions, pactions\n", | |
| " \n", | |
| " def __str__(self):\n", | |
| " return f'SquareCB(gamma0={self.gamma0} t0={self.t0})'\n", | |
| "\n", | |
| "def run_sim_helper(*, passes, simulator, algorithm):\n", | |
| " import itertools\n", | |
| " \n", | |
| " class EasyAcc:\n", | |
| " def __init__(self):\n", | |
| " self.n = 0\n", | |
| " self.sum = 0\n", | |
| " self.sumsq = 0\n", | |
| "\n", | |
| " def __iadd__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum += other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def __isub__(self, other):\n", | |
| " self.n += 1\n", | |
| " self.sum -= other\n", | |
| " self.sumsq += other*other\n", | |
| " return self\n", | |
| "\n", | |
| " def mean(self):\n", | |
| " return self.sum / max(self.n, 1)\n", | |
| "\n", | |
| " def var(self):\n", | |
| " from math import sqrt\n", | |
| " return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
| "\n", | |
| " def semean(self):\n", | |
| " from math import sqrt\n", | |
| " return self.var() / sqrt(max(self.n, 1))\n", | |
| "\n", | |
| " print('{:<5s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}'.format(\n", | |
| " 'bno', \n", | |
| " 'loss', 'since', \n", | |
| " 'pred', 'since',\n", | |
| " 'anti', 'since',\n", | |
| " 'reward', 'since',\n", | |
| " 'fake', 'since',\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " \n", | |
| " avloss, avreward, avfakereward, avpred, avanti = [ EasyAcc() for _ in range(5) ]\n", | |
| " avlosssl, avrewardsl, avfakerewardsl, avpredsl, avantisl = [ EasyAcc() for _ in range(5) ]\n", | |
| " \n", | |
| " for bno, batch in enumerate(itertools.chain(*[ simulator.trainIterator() for _ in range(passes) ])):\n", | |
| " x = batch.getContext()\n", | |
| " sample = algorithm.sample(x)\n", | |
| " feedback = batch.getFeedback(sample.getAction())\n", | |
| " loss, fakereward = algorithm.update(sample, feedback)\n", | |
| " \n", | |
| " avloss += loss\n", | |
| " avlosssl += loss\n", | |
| "\n", | |
| " avfakereward += fakereward\n", | |
| " avfakerewardsl += fakereward\n", | |
| " \n", | |
| " reward = batch.getReward(sample.getAction())\n", | |
| " pred, anti = algorithm.greedy(x)\n", | |
| " predreward = batch.getReward(pred)\n", | |
| " antireward = batch.getReward(anti)\n", | |
| " \n", | |
| " avreward += reward\n", | |
| " avrewardsl += reward\n", | |
| " avpred += predreward\n", | |
| " avpredsl += predreward \n", | |
| " avanti += antireward\n", | |
| " avantisl += antireward\n", | |
| " \n", | |
| " if (bno & (bno - 1) == 0):\n", | |
| " print('{:<5d}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}'.format(\n", | |
| " bno, \n", | |
| " avloss.mean(), avlosssl.mean(),\n", | |
| " avpred.mean(), avpredsl.mean(),\n", | |
| " avanti.mean(), avantisl.mean(),\n", | |
| " avreward.mean(), avrewardsl.mean(),\n", | |
| " avfakereward.mean(), avfakerewardsl.mean(),\n", | |
| " ),\n", | |
| " flush=True)\n", | |
| " avlosssl, avrewardsl, avfakerewardsl, avpredsl, avantisl = [ EasyAcc() for _ in range(5) ]\n", | |
| "\n", | |
| " print('{:<5d}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}'.format(\n", | |
| " bno, \n", | |
| " avloss.mean(), avlosssl.mean(),\n", | |
| " avpred.mean(), avpredsl.mean(),\n", | |
| " avanti.mean(), avantisl.mean(),\n", | |
| " avreward.mean(), avrewardsl.mean(),\n", | |
| " avfakereward.mean(), avfakerewardsl.mean(),\n", | |
| " ),\n", | |
| " flush=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d1a9fb44", | |
| "metadata": {}, | |
| "source": [ | |
| "# Mnist Action CI $(y_a \\perp a | r_a, x)$\n", | |
| "Feedback is $(x + 1) \\mod 10$ if correct else $(x - 1) \\mod 10$." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "id": "fd88f658", | |
| "metadata": { | |
| "code_folding": [ | |
| 0 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def run_mnist_action_ci(*, seed, alglambda):\n", | |
| " import torch\n", | |
| " print(f'***** seed = {seed} *****')\n", | |
| " torch.manual_seed(seed)\n", | |
| " hilo = MnistActionCI(batch_size=64, decodability=1).computeHilo()\n", | |
| " \n", | |
| " for decodability in (1, -1, 0.5, -0.5,):\n", | |
| " import copy\n", | |
| " torch.manual_seed(seed)\n", | |
| " sim = MnistActionCI(batch_size=64, decodability=decodability)\n", | |
| " print(f'***** decodability = {decodability} *****')\n", | |
| " with torch.random.fork_rng():\n", | |
| " alg = alglambda(hilo)\n", | |
| " print(f'***** alg = {alg} *****')\n", | |
| " with torch.random.fork_rng():\n", | |
| " run_sim_helper(passes=1, simulator=copy.deepcopy(sim), algorithm=alg)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "id": "5e596247", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "***** seed = 13 *****\n", | |
| "***** decodability = 1 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30444 1.27921 \t0.08594 0.06250 \t0.10938 0.12500 \t0.11719 0.12500 \t0.33463 0.33592 \n", | |
| "2 \t1.30393 1.30293 \t0.06771 0.03125 \t0.11979 0.14062 \t0.08854 0.03125 \t0.33477 0.33506 \n", | |
| "4 \t1.30399 1.30408 \t0.06562 0.06250 \t0.11250 0.10156 \t0.09687 0.10938 \t0.33652 0.33914 \n", | |
| "8 \t1.29672 1.28764 \t0.08160 0.10156 \t0.09549 0.07422 \t0.10590 0.11719 \t0.34312 0.35137 \n", | |
| "16 \t1.29588 1.29494 \t0.11121 0.14453 \t0.09099 0.08594 \t0.11029 0.11523 \t0.35378 0.36579 \n", | |
| "32 \t1.29838 1.30103 \t0.15436 0.20020 \t0.08523 0.07910 \t0.13968 0.17090 \t0.38594 0.42010 \n", | |
| "64 \t1.30807 1.31806 \t0.17596 0.19824 \t0.07067 0.05566 \t0.16226 0.18555 \t0.41429 0.44352 \n", | |
| "128 \t1.30980 1.31156 \t0.20591 0.23633 \t0.06202 0.05322 \t0.18302 0.20410 \t0.43660 0.45927 \n", | |
| "256 \t1.31787 1.32600 \t0.26204 0.31860 \t0.05052 0.03894 \t0.22781 0.27295 \t0.45920 0.48197 \n", | |
| "512 \t1.32409 1.33034 \t0.30431 0.34674 \t0.04173 0.03290 \t0.26687 0.30609 \t0.49435 0.52965 \n", | |
| "937 \t1.33276 1.34322 \t0.39582 0.50629 \t0.03077 0.01754 \t0.35419 0.45960 \t0.52439 0.56064 \n", | |
| "***** decodability = -1 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30487 1.28007 \t0.09375 0.07812 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33465 0.33597 \n", | |
| "2 \t1.30430 1.30317 \t0.07812 0.04688 \t0.11979 0.12500 \t0.08854 0.03125 \t0.33357 0.33142 \n", | |
| "4 \t1.30306 1.30120 \t0.11250 0.16406 \t0.11250 0.10156 \t0.10000 0.11719 \t0.33489 0.33687 \n", | |
| "8 \t1.29671 1.28877 \t0.10243 0.08984 \t0.09375 0.07031 \t0.11458 0.13281 \t0.34906 0.36677 \n", | |
| "16 \t1.29965 1.30295 \t0.11581 0.13086 \t0.09007 0.08594 \t0.11857 0.12305 \t0.38112 0.41718 \n", | |
| "32 \t1.30370 1.30800 \t0.14205 0.16992 \t0.08002 0.06934 \t0.13258 0.14746 \t0.39104 0.40158 \n", | |
| "64 \t1.30846 1.31338 \t0.15986 0.17822 \t0.06731 0.05420 \t0.14543 0.15869 \t0.40492 0.41923 \n", | |
| "128 \t1.31436 1.32034 \t0.19077 0.22217 \t0.05790 0.04834 \t0.16982 0.19458 \t0.43110 0.45769 \n", | |
| "256 \t1.32377 1.33325 \t0.21948 0.24841 \t0.05235 0.04675 \t0.19303 0.21643 \t0.45743 0.48396 \n", | |
| "512 \t1.33082 1.33791 \t0.29898 0.37878 \t0.04075 0.02911 \t0.26495 0.33716 \t0.48782 0.51834 \n", | |
| "937 \t1.33090 1.33100 \t0.39592 0.51294 \t0.03305 0.02375 \t0.35553 0.46485 \t0.52269 0.56478 \n", | |
| "***** decodability = 0.5 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30485 1.28004 \t0.09375 0.07812 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33435 0.33537 \n", | |
| "2 \t1.30426 1.30308 \t0.07812 0.04688 \t0.13542 0.17188 \t0.08854 0.03125 \t0.33383 0.33279 \n", | |
| "4 \t1.30349 1.30235 \t0.08438 0.09375 \t0.12188 0.10156 \t0.09687 0.10938 \t0.33555 0.33813 \n", | |
| "8 \t1.29700 1.28888 \t0.09201 0.10156 \t0.09896 0.07031 \t0.10590 0.11719 \t0.34245 0.35107 \n", | |
| "16 \t1.29599 1.29486 \t0.10662 0.12305 \t0.10386 0.10938 \t0.10938 0.11328 \t0.35196 0.36266 \n", | |
| "32 \t1.30241 1.30924 \t0.13068 0.15625 \t0.09991 0.09570 \t0.12926 0.15039 \t0.37391 0.39723 \n", | |
| "64 \t1.31644 1.33092 \t0.11827 0.10547 \t0.08702 0.07373 \t0.12043 0.11133 \t0.39610 0.41900 \n", | |
| "128 \t1.32278 1.32921 \t0.12718 0.13623 \t0.08006 0.07300 \t0.12476 0.12915 \t0.41618 0.43656 \n", | |
| "256 \t1.33804 1.35343 \t0.12239 0.11755 \t0.07551 0.07092 \t0.11977 0.11475 \t0.43168 0.44731 \n", | |
| "512 \t1.35660 1.37523 \t0.12601 0.12964 \t0.07657 0.07764 \t0.12031 0.12085 \t0.45100 0.47039 \n", | |
| "937 \t1.37982 1.40785 \t0.14880 0.17632 \t0.06935 0.06063 \t0.14169 0.16750 \t0.45475 0.45927 \n", | |
| "***** decodability = -0.5 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30532 1.28098 \t0.10156 0.09375 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33389 0.33444 \n", | |
| "2 \t1.30427 1.30217 \t0.08333 0.04688 \t0.11458 0.10938 \t0.08854 0.03125 \t0.33239 0.32940 \n", | |
| "4 \t1.30410 1.30385 \t0.08438 0.08594 \t0.10938 0.10156 \t0.09375 0.10156 \t0.33464 0.33801 \n", | |
| "8 \t1.29698 1.28807 \t0.09201 0.10156 \t0.09201 0.07031 \t0.10243 0.11328 \t0.34187 0.35091 \n", | |
| "16 \t1.30001 1.30342 \t0.10662 0.12305 \t0.08272 0.07227 \t0.10754 0.11328 \t0.36231 0.38530 \n", | |
| "32 \t1.30547 1.31126 \t0.13210 0.15918 \t0.07812 0.07324 \t0.12311 0.13965 \t0.37633 0.39124 \n", | |
| "64 \t1.31463 1.32407 \t0.12476 0.11719 \t0.07548 0.07275 \t0.12308 0.12305 \t0.39423 0.41269 \n", | |
| "128 \t1.32324 1.33199 \t0.12270 0.12061 \t0.08031 0.08521 \t0.11664 0.11011 \t0.41592 0.43795 \n", | |
| "256 \t1.33952 1.35593 \t0.12044 0.11816 \t0.08208 0.08386 \t0.11278 0.10889 \t0.42968 0.44354 \n", | |
| "512 \t1.35598 1.37249 \t0.11882 0.11719 \t0.07782 0.07355 \t0.11355 0.11432 \t0.44759 0.46556 \n", | |
| "937 \t1.37896 1.40669 \t0.13551 0.15566 \t0.07348 0.06824 \t0.12883 0.14728 \t0.45464 0.46314 \n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "run_mnist_action_ci(seed=13, alglambda = lambda hilo: IKAlgorithm(hilo=hilo, sampler=EpsilonGreedy(t0=1), lr=5e-2))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "25aaf5cf", | |
| "metadata": {}, | |
| "source": [ | |
| "# Mnist Full CI ($y_a \\perp x, a | r_a$)\n", | |
| "Feedback is a \"1\" image if correct or a \"0\" image if incorrect." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "be02f211", | |
| "metadata": { | |
| "code_folding": [ | |
| 0 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def run_mnist_full_ci(*, seed, alglambda):\n", | |
| " import torch\n", | |
| " print(f'***** seed = {seed} *****')\n", | |
| " torch.manual_seed(seed)\n", | |
| " hilo = MnistFullCI(batch_size=64, decodability=1).computeHilo()\n", | |
| " \n", | |
| " for decodability in (1, -1, 0.5, -0.5,):\n", | |
| " import copy\n", | |
| " torch.manual_seed(seed)\n", | |
| " sim = MnistFullCI(batch_size=64, decodability=decodability)\n", | |
| " print(f'***** decodability = {decodability} *****')\n", | |
| " with torch.random.fork_rng():\n", | |
| " alg = alglambda(hilo)\n", | |
| " print(f'***** alg = {alg} *****')\n", | |
| " with torch.random.fork_rng():\n", | |
| " run_sim_helper(passes=1, simulator=copy.deepcopy(sim), algorithm=alg)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ce18fc54", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "## IK, Epsilon-Greedy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 120, | |
| "id": "12c6fc3c", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "***** seed = 13 *****\n", | |
| "***** decodability = 1 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30507 1.28049 \t0.09375 0.07812 \t0.12500 0.15625 \t0.11719 0.12500 \t0.33461 0.33589 \n", | |
| "2 \t1.30434 1.30288 \t0.08333 0.06250 \t0.13021 0.14062 \t0.09375 0.04688 \t0.33435 0.33383 \n", | |
| "4 \t1.30183 1.29805 \t0.08750 0.09375 \t0.11875 0.10156 \t0.10000 0.10938 \t0.33558 0.33743 \n", | |
| "8 \t1.29784 1.29286 \t0.08854 0.08984 \t0.09722 0.07031 \t0.10590 0.11328 \t0.34345 0.35328 \n", | |
| "16 \t1.30046 1.30342 \t0.13695 0.19141 \t0.09743 0.09766 \t0.12040 0.13672 \t0.35886 0.37619 \n", | |
| "32 \t1.30231 1.30427 \t0.16998 0.20508 \t0.09470 0.09180 \t0.14867 0.17871 \t0.37876 0.39991 \n", | |
| "64 \t1.31192 1.32184 \t0.21274 0.25684 \t0.07933 0.06348 \t0.18293 0.21826 \t0.40531 0.43269 \n", | |
| "128 \t1.31666 1.32146 \t0.25787 0.30371 \t0.06831 0.05713 \t0.22081 0.25928 \t0.43104 0.45718 \n", | |
| "256 \t1.32381 1.33101 \t0.29754 0.33752 \t0.05952 0.05066 \t0.25517 0.28979 \t0.45657 0.48229 \n", | |
| "512 \t1.32461 1.32541 \t0.36528 0.43329 \t0.05172 0.04388 \t0.32002 0.38513 \t0.49113 0.52582 \n", | |
| "937 \t1.32471 1.32484 \t0.45266 0.55813 \t0.03711 0.01949 \t0.40302 0.50320 \t0.52373 0.56309 \n", | |
| "***** decodability = -1 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30546 1.28126 \t0.07812 0.04688 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33346 0.33360 \n", | |
| "2 \t1.30742 1.31135 \t0.08854 0.10938 \t0.11979 0.12500 \t0.09375 0.04688 \t0.33497 0.33798 \n", | |
| "4 \t1.30894 1.31121 \t0.07500 0.05469 \t0.10938 0.09375 \t0.09687 0.10156 \t0.34675 0.36442 \n", | |
| "8 \t1.30966 1.31057 \t0.09549 0.12109 \t0.09375 0.07422 \t0.10417 0.11328 \t0.35984 0.37621 \n", | |
| "16 \t1.30897 1.30820 \t0.10570 0.11719 \t0.09835 0.10352 \t0.11121 0.11914 \t0.35881 0.35765 \n", | |
| "32 \t1.31093 1.31301 \t0.09991 0.09375 \t0.10701 0.11621 \t0.10511 0.09863 \t0.39271 0.42872 \n", | |
| "64 \t1.32061 1.33060 \t0.10577 0.11182 \t0.12812 0.14990 \t0.11322 0.12158 \t0.41979 0.44773 \n", | |
| "128 \t1.32243 1.32428 \t0.09460 0.08325 \t0.13227 0.13647 \t0.09629 0.07910 \t0.44220 0.46496 \n", | |
| "256 \t1.33451 1.34668 \t0.10372 0.11292 \t0.12099 0.10962 \t0.10044 0.10461 \t0.46115 0.48024 \n", | |
| "512 \t1.34569 1.35691 \t0.12850 0.15338 \t0.10304 0.08502 \t0.12141 0.14246 \t0.48409 0.50712 \n", | |
| "937 \t1.36260 1.38302 \t0.18943 0.26298 \t0.09245 0.07967 \t0.17769 0.24563 \t0.50181 0.52320 \n", | |
| "***** decodability = 0.5 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30474 1.27983 \t0.09375 0.07812 \t0.12500 0.15625 \t0.11719 0.12500 \t0.33459 0.33585 \n", | |
| "2 \t1.30399 1.30249 \t0.07292 0.03125 \t0.13542 0.15625 \t0.08854 0.03125 \t0.33390 0.33250 \n", | |
| "4 \t1.30216 1.29940 \t0.08750 0.10938 \t0.12188 0.10156 \t0.10000 0.11719 \t0.33365 0.33329 \n", | |
| "8 \t1.29573 1.28770 \t0.09722 0.10938 \t0.10069 0.07422 \t0.10243 0.10547 \t0.34649 0.36253 \n", | |
| "16 \t1.30262 1.31038 \t0.12776 0.16211 \t0.09743 0.09375 \t0.12040 0.14062 \t0.36855 0.39337 \n", | |
| "32 \t1.30842 1.31457 \t0.14015 0.15332 \t0.10275 0.10840 \t0.12879 0.13770 \t0.37469 0.38122 \n", | |
| "64 \t1.31988 1.33170 \t0.13534 0.13037 \t0.08702 0.07080 \t0.13197 0.13525 \t0.39178 0.40940 \n", | |
| "128 \t1.32899 1.33825 \t0.15540 0.17578 \t0.07982 0.07251 \t0.14462 0.15747 \t0.40806 0.42460 \n", | |
| "256 \t1.34702 1.36519 \t0.14105 0.12659 \t0.08554 0.09131 \t0.13132 0.11792 \t0.42628 0.44465 \n", | |
| "512 \t1.36833 1.38972 \t0.15278 0.16455 \t0.07392 0.06226 \t0.14215 0.15302 \t0.44862 0.47105 \n", | |
| "937 \t1.39125 1.41892 \t0.17161 0.19434 \t0.06875 0.06250 \t0.16081 0.18335 \t0.45148 0.45493 \n", | |
| "***** decodability = -0.5 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30487 1.28009 \t0.09375 0.07812 \t0.10156 0.10938 \t0.11719 0.12500 \t0.33412 0.33491 \n", | |
| "2 \t1.30595 1.30810 \t0.07812 0.04688 \t0.11458 0.14062 \t0.08854 0.03125 \t0.33473 0.33595 \n", | |
| "4 \t1.30764 1.31018 \t0.06250 0.03906 \t0.10938 0.10156 \t0.09375 0.10156 \t0.33889 0.34512 \n", | |
| "8 \t1.30273 1.29660 \t0.08160 0.10547 \t0.09375 0.07422 \t0.10243 0.11328 \t0.34147 0.34469 \n", | |
| "16 \t1.30562 1.30887 \t0.08548 0.08984 \t0.08915 0.08398 \t0.10386 0.10547 \t0.34891 0.35729 \n", | |
| "32 \t1.31029 1.31525 \t0.07955 0.07324 \t0.10417 0.12012 \t0.09659 0.08887 \t0.36639 0.38495 \n", | |
| "64 \t1.32097 1.33199 \t0.07187 0.06396 \t0.11010 0.11621 \t0.08918 0.08154 \t0.40254 0.43981 \n", | |
| "128 \t1.32663 1.33237 \t0.07001 0.06812 \t0.11810 0.12622 \t0.08043 0.07153 \t0.42110 0.43996 \n", | |
| "256 \t1.34333 1.36015 \t0.06980 0.06958 \t0.11570 0.11328 \t0.07496 0.06946 \t0.44035 0.45975 \n", | |
| "512 \t1.36080 1.37834 \t0.08114 0.09253 \t0.11507 0.11444 \t0.08373 0.09253 \t0.45726 0.47424 \n", | |
| "937 \t1.38823 1.42134 \t0.09522 0.11221 \t0.10439 0.09151 \t0.09637 0.11162 \t0.46219 0.46815 \n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "run_mnist_full_ci(seed=13, alglambda = lambda hilo: IKAlgorithm(hilo=hilo, sampler=EpsilonGreedy(t0=1), lr=5e-2))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "5ffdb723", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "## IK, SquareCB" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 101, | |
| "id": "11219b00", | |
| "metadata": { | |
| "hidden": true, | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "***** seed = 13 *****\n", | |
| "***** decodability = 1 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30494 1.28021 \t0.09375 0.07812 \t0.13281 0.17188 \t0.10938 0.10938 \t0.33776 0.34219 \n", | |
| "2 \t1.30257 1.29784 \t0.08854 0.07812 \t0.14062 0.15625 \t0.08333 0.03125 \t0.33513 0.32985 \n", | |
| "4 \t1.30280 1.30315 \t0.10312 0.12500 \t0.12500 0.10156 \t0.09687 0.11719 \t0.33583 0.33688 \n", | |
| "8 \t1.29567 1.28675 \t0.12153 0.14453 \t0.10069 0.07031 \t0.11458 0.13672 \t0.33917 0.34336 \n", | |
| "16 \t1.29371 1.29151 \t0.16085 0.20508 \t0.10938 0.11914 \t0.12592 0.13867 \t0.34092 0.34287 \n", | |
| "32 \t1.29202 1.29022 \t0.21070 0.26367 \t0.10890 0.10840 \t0.15152 0.17871 \t0.35087 0.36145 \n", | |
| "64 \t1.29390 1.29584 \t0.25697 0.30469 \t0.09207 0.07471 \t0.16779 0.18457 \t0.36396 0.37746 \n", | |
| "128 \t1.29290 1.29189 \t0.34969 0.44385 \t0.06480 0.03711 \t0.21427 0.26147 \t0.38296 0.40226 \n", | |
| "256 \t1.28826 1.28358 \t0.43020 0.51135 \t0.04596 0.02698 \t0.27243 0.33105 \t0.41878 0.45488 \n", | |
| "512 \t1.28455 1.28083 \t0.50433 0.57874 \t0.03326 0.02051 \t0.34957 0.42700 \t0.46282 0.50702 \n", | |
| "937 \t1.27793 1.26993 \t0.57862 0.66831 \t0.02434 0.01357 \t0.43420 0.53636 \t0.50870 0.56408 \n", | |
| "***** decodability = -1 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30513 1.28060 \t0.08594 0.06250 \t0.10938 0.12500 \t0.10938 0.10938 \t0.33781 0.34229 \n", | |
| "2 \t1.30370 1.30084 \t0.09375 0.10938 \t0.11458 0.12500 \t0.08333 0.03125 \t0.33464 0.32829 \n", | |
| "4 \t1.30378 1.30391 \t0.08125 0.06250 \t0.10938 0.10156 \t0.09687 0.11719 \t0.33765 0.34217 \n", | |
| "8 \t1.29949 1.29413 \t0.09028 0.10156 \t0.09028 0.06641 \t0.10590 0.11719 \t0.34235 0.34822 \n", | |
| "16 \t1.29919 1.29884 \t0.10202 0.11523 \t0.10294 0.11719 \t0.11857 0.13281 \t0.34206 0.34174 \n", | |
| "32 \t1.29828 1.29731 \t0.11932 0.13770 \t0.12074 0.13965 \t0.12358 0.12891 \t0.34872 0.35580 \n", | |
| "64 \t1.29811 1.29794 \t0.12212 0.12500 \t0.13197 0.14355 \t0.11731 0.11084 \t0.36252 0.37675 \n", | |
| "128 \t1.30163 1.30521 \t0.12875 0.13550 \t0.10938 0.08643 \t0.11810 0.11890 \t0.37397 0.38560 \n", | |
| "256 \t1.30894 1.31631 \t0.16786 0.20728 \t0.08810 0.06665 \t0.14014 0.16235 \t0.39465 0.41550 \n", | |
| "512 \t1.31513 1.32135 \t0.24050 0.31342 \t0.06768 0.04718 \t0.19219 0.24445 \t0.42662 0.45871 \n", | |
| "937 \t1.31834 1.32222 \t0.31413 0.40301 \t0.05549 0.04077 \t0.25568 0.33232 \t0.46456 0.51035 \n", | |
| "***** decodability = 0.5 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30494 1.28022 \t0.09375 0.07812 \t0.14062 0.18750 \t0.10938 0.10938 \t0.33699 0.34065 \n", | |
| "2 \t1.30267 1.29813 \t0.09375 0.09375 \t0.14062 0.14062 \t0.08333 0.03125 \t0.33467 0.33003 \n", | |
| "4 \t1.30326 1.30416 \t0.11250 0.14062 \t0.12500 0.10156 \t0.09687 0.11719 \t0.33453 0.33433 \n", | |
| "8 \t1.29485 1.28433 \t0.12500 0.14062 \t0.10243 0.07422 \t0.11458 0.13672 \t0.34006 0.34697 \n", | |
| "16 \t1.29546 1.29614 \t0.14522 0.16797 \t0.11213 0.12305 \t0.12132 0.12891 \t0.33862 0.33701 \n", | |
| "32 \t1.29664 1.29790 \t0.16146 0.17871 \t0.09517 0.07715 \t0.13494 0.14941 \t0.34379 0.34927 \n", | |
| "64 \t1.29866 1.30075 \t0.15986 0.15820 \t0.08053 0.06543 \t0.12933 0.12354 \t0.34898 0.35433 \n", | |
| "128 \t1.30344 1.30830 \t0.19634 0.23340 \t0.06734 0.05396 \t0.13929 0.14941 \t0.35562 0.36236 \n", | |
| "256 \t1.31659 1.32984 \t0.19163 0.18689 \t0.06463 0.06189 \t0.13831 0.13733 \t0.36600 0.37646 \n", | |
| "512 \t1.33155 1.34657 \t0.21040 0.22925 \t0.05781 0.05096 \t0.15850 0.17877 \t0.37977 0.39359 \n", | |
| "937 \t1.34935 1.37084 \t0.21973 0.23099 \t0.05574 0.05324 \t0.17377 0.19221 \t0.39498 0.41335 \n", | |
| "***** decodability = -0.5 *****\n", | |
| "***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n", | |
| "bno \tloss since \tpred since \tanti since \treward since \tfake since \n", | |
| "0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n", | |
| "1 \t1.30492 1.28018 \t0.09375 0.07812 \t0.08594 0.07812 \t0.10938 0.10938 \t0.33714 0.34094 \n", | |
| "2 \t1.30358 1.30090 \t0.08333 0.06250 \t0.09896 0.12500 \t0.08333 0.03125 \t0.33515 0.33118 \n", | |
| "4 \t1.30400 1.30463 \t0.08438 0.08594 \t0.10000 0.10156 \t0.10000 0.12500 \t0.33447 0.33346 \n", | |
| "8 \t1.29777 1.28998 \t0.09722 0.11328 \t0.08507 0.06641 \t0.10764 0.11719 \t0.33946 0.34568 \n", | |
| "16 \t1.29687 1.29587 \t0.09743 0.09766 \t0.08548 0.08594 \t0.11489 0.12305 \t0.33790 0.33615 \n", | |
| "32 \t1.29696 1.29706 \t0.11222 0.12793 \t0.09612 0.10742 \t0.12311 0.13184 \t0.34567 0.35394 \n", | |
| "64 \t1.30000 1.30314 \t0.10481 0.09717 \t0.09832 0.10059 \t0.11034 0.09717 \t0.35347 0.36151 \n", | |
| "128 \t1.30487 1.30981 \t0.09448 0.08398 \t0.09278 0.08716 \t0.10514 0.09985 \t0.35846 0.36353 \n", | |
| "256 \t1.31660 1.32842 \t0.10171 0.10901 \t0.09168 0.09058 \t0.10445 0.10376 \t0.37122 0.38409 \n", | |
| "512 \t1.33261 1.34868 \t0.12189 0.14215 \t0.08784 0.08398 \t0.11364 0.12286 \t0.38679 0.40243 \n", | |
| "937 \t1.34934 1.36953 \t0.13471 0.15018 \t0.08136 0.07353 \t0.12328 0.13493 \t0.39942 0.41465 \n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "run_mnist_full_ci(seed=13, alglambda = lambda hilo: IKAlgorithm(hilo=hilo, sampler=SquareCB(gamma0=10, t0=10), lr=5e-2))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "8c587b93", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment