Last active
October 9, 2020 05:41
-
-
Save briandw/863d8a250c4a0c58d48b8f00d684447d to your computer and use it in GitHub Desktop.
Notebook of the Vision Transformer paper with MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Demo of using Vision Transformers on MNIST\n", | |
| "From the paper [AN IMAGE IS WORTH 16X16 WORDS](https://openreview.net/pdf?id=YicbFdNTTy)\n", | |
| "\n", | |
| "[Thanks to lucidrains](https://github.com/lucidrains/vit-pytorch)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "import torch.optim as optim\n", | |
| "from torchvision import datasets, transforms\n", | |
| "from torch.utils.data import DataLoader\n", | |
| "from torch.optim.lr_scheduler import StepLR\n", | |
| "\n", | |
| "from einops import rearrange\n", | |
| "import math\n", | |
| "from PIL.Image import Image\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import numpy as np\n", | |
| "from numpy.lib import stride_tricks" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "device(type='cuda')" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "device = torch.device(\"cuda\")\n", | |
| "device" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_set = datasets.MNIST('./data', train=True, download=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(33.318421449829934, 76.834538656172)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mean = 0\n", | |
| "std = 0\n", | |
| "for img, _ in train_set:\n", | |
| " a = np.array(img)\n", | |
| " mean += a.mean()\n", | |
| " std += a.std()\n", | |
| "\n", | |
| "mean/len(train_set), std/len(train_set)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.image.AxesImage at 0x7fca94b3d710>" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOX0lEQVR4nO3dbYxc5XnG8euKbUwxJvHGseMQFxzjFAg0Jl0ZkBFQoVCCIgGKCLGiiFBapwlOQutKUFoVWtHKrRIiSimSKS6m4iWQgPAHmsSyECRqcFmoAROHN+MS4+0aswIDIfZ6fffDjqsFdp5dZs68eO//T1rNzLnnzLk1cPmcmeeceRwRAjD5faDTDQBoD8IOJEHYgSQIO5AEYQeSmNrOjR3i6XGoZrRzk0Aqv9Fb2ht7PFatqbDbPkfS9ZKmSPrXiFhVev6hmqGTfVYzmwRQsDE21K01fBhve4qkGyV9TtLxkpbZPr7R1wPQWs18Zl8i6fmI2BoReyXdJem8atoCULVmwn6kpF+Nery9tuwdbC+33We7b0h7mtgcgGY0E/axvgR4z7m3EbE6InojoneapjexOQDNaCbs2yXNH/X445J2NNcOgFZpJuyPSlpke4HtQyR9SdK6atoCULWGh94iYp/tFZJ+rJGhtzUR8XRlnQGoVFPj7BHxgKQHKuoFQAtxuiyQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJNDWLK7qfp5b/E0/5yOyWbv+ZPz+6bm34sP3FdY9auLNYP+wbLtb/97pD6tYe7/1+cd1dw28V6yffs7JYP+bPHinWO6GpsNveJukNScOS9kVEbxVNAaheFXv234+IXRW8DoAW4jM7kESzYQ9JP7H9mO3lYz3B9nLbfbb7hrSnyc0BaFSzh/FLI2KH7TmS1tv+ZUQ8PPoJEbFa0mpJOsI90eT2ADSoqT17ROyo3e6UdJ+kJVU0BaB6DYfd9gzbMw/cl3S2pM1VNQagWs0cxs+VdJ/tA69zR0T8qJKuJpkpxy0q1mP6tGJ9xxkfKtbfPqX+mHDPB8vjxT/9dHm8uZP+49czi/V/+OdzivWNJ95Rt/bi0NvFdVcNfLZY/9hPD75PpA2HPSK2Svp0hb0AaCGG3oAkCDuQBGEHkiDsQBKEHUiCS1wrMHzmZ4r16269sVj/5LT6l2JOZkMxXKz/9Q1fLdanvlUe/jr1nhV1azNf3ldcd/qu8tDcYX0bi/VuxJ4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0C05/ZUaw/9pv5xfonpw1U2U6lVvafUqxvfbP8U9S3LvxB3drr+8vj5HP/6T+L9VY6+C5gHR97diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwhHtG1E8wj1xss9q2/a6xeAlpxbru88p/9zzlCcPL9af+MYN77unA67d9bvF+qNnlMfRh197vViPU+v/APG2bxVX1YJlT5SfgPfYGBu0OwbHnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMvvDxfrwq4PF+ot31B8rf/r0NcV1l/z9N4v1OTd27ppyvH9NjbPbXmN7p+3No5b12F5v+7na7awqGwZQvYkcxt8q6d2z3l8paUNELJK0ofYYQBcbN+wR8bCkdx9Hnidpbe3+WknnV9sWgKo1+gXd3Ijol6Ta7Zx6T7S93Haf7b4h7WlwcwCa1fJv4yNidUT0RkTvNE1v9eYA1NFo2Adsz5Ok2u3O6loC0AqNhn2dpItr9y+WdH817QBolXF/N972nZLOlDTb9nZJV0taJelu25dKeknSha1scrIb3vVqU+sP7W58fvdPffkXxforN00pv8D+8hzr6B7jhj0iltUpcXYMcBDhdFkgCcIOJEHYgSQIO5AEYQeSYMrmSeC4K56tW7vkxPKgyb8dtaFYP+PCy4r1md9/pFhH92DPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM4+CZSmTX7168cV131p3dvF+pXX3las/8UXLyjW478/WLc2/+9+XlxXbfyZ8wzYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEkzZnNzgH55arN9+9XeK9QVTD21425+6bUWxvujm/mJ939ZtDW97smpqymYAkwNhB5Ig7EAShB1IgrADSRB2IAnCDiTBODuKYuniYv2IVduL9Ts/8eOGt33sg39UrP/O39S/jl+Shp/b2vC2D1ZNjbPbXmN7p+3No5ZdY/tl25tqf+dW2TCA6k3kMP5WSeeMsfx7EbG49vdAtW0BqNq4YY+IhyUNtqEXAC3UzBd0K2w/WTvMn1XvSbaX2+6z3TekPU1sDkAzGg37TZIWSlosqV/Sd+s9MSJWR0RvRPRO0/QGNwegWQ2FPSIGImI4IvZLulnSkmrbAlC1hsJue96ohxdI2lzvuQC6w7jj7LbvlHSmpNmSBiRdXXu8WFJI2ibpaxFRvvhYjLNPRlPmzinWd1x0TN3axiuuL677gXH2RV9+8exi/fXTXi3WJ6PSOPu4k0RExLIxFt/SdFcA2orTZYEkCDuQBGEHkiDsQBKEHUiCS1zRMXdvL0/ZfJgPKdZ/HXuL9c9/8/L6r33fxuK6Byt+ShoAYQeyIOxAEoQdSIKwA0kQdiAJwg4kMe5Vb8ht/2mLi/UXLixP2XzC4m11a+ONo4/nhsGTivXD7u9r6vUnG/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+yTnHtPKNaf/VZ5rPvmpWuL9dMPLV9T3ow9MVSsPzK4oPwC+8f9dfNU2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsx8Epi44qlh/4ZKP1a1dc9FdxXW/cPiuhnqqwlUDvcX6Q9efUqzPWlv+3Xm807h7dtvzbT9oe4vtp21/u7a8x/Z628/Vbme1vl0AjZrIYfw+SSsj4jhJp0i6zPbxkq6UtCEiFknaUHsMoEuNG/aI6I+Ix2v335C0RdKRks6TdOBcyrWSzm9RjwAq8L6+oLN9tKSTJG2UNDci+qWRfxAkzamzznLbfbb7hrSnyXYBNGrCYbd9uKQfSro8InZPdL2IWB0RvRHRO03TG+kRQAUmFHbb0zQS9Nsj4t7a4gHb82r1eZJ2tqZFAFUYd+jNtiXdImlLRFw3qrRO0sWSVtVu729Jh5PA1KN/u1h//ffmFesX/e2PivU/+dC9xXorrewvD4/9/F/qD6/13PpfxXVn7WdorUoTGWdfKukrkp6yvam27CqNhPxu25dKeknShS3pEEAlxg17RPxM0piTu0s6q9p2ALQKp8sCSRB2IAnCDiRB2IEkCDuQBJe4TtDUeR+tWxtcM6O47tcXPFSsL5s50FBPVVjx8mnF+uM3LS7WZ/9gc7He8wZj5d2CPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJJFmnH3vH5R/tnjvnw4W61cd80Dd2tm/9VZDPVVlYPjturXT160srnvsX/2yWO95rTxOvr9YRTdhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaQZZ992fvnftWdPvKdl277xtYXF+vUPnV2se7jej/uOOPbaF+vWFg1sLK47XKxiMmHPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJOCLKT7DnS7pN0kc1cvny6oi43vY1kv5Y0iu1p14VEfUv+pZ0hHviZDPxK9AqG2ODdsfgmCdmTOSkmn2SVkbE47ZnSnrM9vpa7XsR8Z2qGgXQOhOZn71fUn/t/hu2t0g6stWNAajW+/rMbvtoSSdJOnAO5grbT9peY3tWnXWW2+6z3TekPc11C6BhEw677cMl/VDS5RGxW9JNkhZKWqyRPf93x1ovIlZHRG9E9E7T9OY7BtCQCYXd9jSNBP32iLhXkiJiICKGI2K/pJslLWldmwCaNW7YbVvSLZK2RMR1o5bPG/W0CySVp/ME0FET+TZ+qaSvSHrK9qbasqskLbO9WFJI2ibpay3oD0BFJvJt/M8kjTVuVxxTB9BdOIMOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQxLg/JV3pxuxXJP3PqEWzJe1qWwPvT7f21q19SfTWqCp7OyoiPjJWoa1hf8/G7b6I6O1YAwXd2lu39iXRW6Pa1RuH8UAShB1IotNhX93h7Zd0a2/d2pdEb41qS28d/cwOoH06vWcH0CaEHUiiI2G3fY7tZ2w/b/vKTvRQj+1ttp+yvcl2X4d7WWN7p+3No5b12F5v+7na7Zhz7HWot2tsv1x77zbZPrdDvc23/aDtLbaftv3t2vKOvneFvtryvrX9M7vtKZKelfRZSdslPSppWUT8oq2N1GF7m6TeiOj4CRi2T5f0pqTbIuKE2rJ/lDQYEatq/1DOiogruqS3ayS92elpvGuzFc0bPc24pPMlfVUdfO8KfX1RbXjfOrFnXyLp+YjYGhF7Jd0l6bwO9NH1IuJhSYPvWnyepLW1+2s18j9L29XprStERH9EPF67/4akA9OMd/S9K/TVFp0I+5GSfjXq8XZ113zvIeknth+zvbzTzYxhbkT0SyP/80ia0+F+3m3cabzb6V3TjHfNe9fI9OfN6kTYx5pKqpvG/5ZGxGckfU7SZbXDVUzMhKbxbpcxphnvCo1Of96sToR9u6T5ox5/XNKODvQxpojYUbvdKek+dd9U1AMHZtCt3e7scD//r5um8R5rmnF1wXvXyenPOxH2RyUtsr3A9iGSviRpXQf6eA/bM2pfnMj2DElnq/umol4n6eLa/Ysl3d/BXt6hW6bxrjfNuDr83nV8+vOIaPufpHM18o38C5L+shM91OnrE5KeqP093eneJN2pkcO6IY0cEV0q6cOSNkh6rnbb00W9/bukpyQ9qZFgzetQb6dp5KPhk5I21f7O7fR7V+irLe8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+zhHFo7nUhhwAAAABJRU5ErkJggg==\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "img = train_set[0][0]\n", | |
| "img_tensor = transforms.ToTensor()(img)\n", | |
| "plt.imshow(img_tensor.squeeze())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class SegmentTransform(object):\n", | |
| " \"\"\"Convert PIL to rehsaped Tensors.\"\"\"\n", | |
| " \n", | |
| " def __init__(self, patch_size = 4):\n", | |
| " \"\"\"path_size defaults to patch_size x patch_size patch sizes and N elements\"\"\"\n", | |
| " self.patch_size = patch_size\n", | |
| " \n", | |
| " def segment_array(self, a):\n", | |
| " if len(a.shape) != 2:\n", | |
| " raise Exception(\"array must be 2d\")\n", | |
| "\n", | |
| " side_length = a.shape[0]*a.shape[1]\n", | |
| " \n", | |
| " if (int(math.sqrt(side_length))**2) != side_length or side_length == 0:\n", | |
| " raise Exception(\"Must be a perfect square\")\n", | |
| "\n", | |
| " num_bytes = a.dtype.itemsize\n", | |
| " side_length = int(math.sqrt(side_length))\n", | |
| "\n", | |
| " new_shape = (side_length//self.patch_size, \n", | |
| " side_length//self.patch_size, \n", | |
| " self.patch_size, \n", | |
| " self.patch_size)\n", | |
| " \n", | |
| " strides = (num_bytes * self.patch_size * side_length, \n", | |
| " num_bytes * self.patch_size,\n", | |
| " num_bytes * side_length, \n", | |
| " num_bytes)\n", | |
| " \n", | |
| " return stride_tricks.as_strided(a, shape = new_shape, strides = strides)\n", | |
| "\n", | |
| " def __call__(self, img):\n", | |
| " a = np.array(img)\n", | |
| " a = self.segment_array(a)\n", | |
| " a = a.reshape(-1,(self.patch_size*self.patch_size))\n", | |
| " return a" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.image.AxesImage at 0x7fca94a44190>" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOX0lEQVR4nO3dbYxc5XnG8euKbUwxJvHGseMQFxzjFAg0Jl0ZkBFQoVCCIgGKCLGiiFBapwlOQutKUFoVWtHKrRIiSimSKS6m4iWQgPAHmsSyECRqcFmoAROHN+MS4+0aswIDIfZ6fffDjqsFdp5dZs68eO//T1rNzLnnzLk1cPmcmeeceRwRAjD5faDTDQBoD8IOJEHYgSQIO5AEYQeSmNrOjR3i6XGoZrRzk0Aqv9Fb2ht7PFatqbDbPkfS9ZKmSPrXiFhVev6hmqGTfVYzmwRQsDE21K01fBhve4qkGyV9TtLxkpbZPr7R1wPQWs18Zl8i6fmI2BoReyXdJem8atoCULVmwn6kpF+Nery9tuwdbC+33We7b0h7mtgcgGY0E/axvgR4z7m3EbE6InojoneapjexOQDNaCbs2yXNH/X445J2NNcOgFZpJuyPSlpke4HtQyR9SdK6atoCULWGh94iYp/tFZJ+rJGhtzUR8XRlnQGoVFPj7BHxgKQHKuoFQAtxuiyQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJNDWLK7qfp5b/E0/5yOyWbv+ZPz+6bm34sP3FdY9auLNYP+wbLtb/97pD6tYe7/1+cd1dw28V6yffs7JYP+bPHinWO6GpsNveJukNScOS9kVEbxVNAaheFXv234+IXRW8DoAW4jM7kESzYQ9JP7H9mO3lYz3B9nLbfbb7hrSnyc0BaFSzh/FLI2KH7TmS1tv+ZUQ8PPoJEbFa0mpJOsI90eT2ADSoqT17ROyo3e6UdJ+kJVU0BaB6DYfd9gzbMw/cl3S2pM1VNQagWs0cxs+VdJ/tA69zR0T8qJKuJpkpxy0q1mP6tGJ9xxkfKtbfPqX+mHDPB8vjxT/9dHm8uZP+49czi/V/+OdzivWNJ95Rt/bi0NvFdVcNfLZY/9hPD75PpA2HPSK2Svp0hb0AaCGG3oAkCDuQBGEHkiDsQBKEHUiCS1wrMHzmZ4r16269sVj/5LT6l2JOZkMxXKz/9Q1fLdanvlUe/jr1nhV1azNf3ldcd/qu8tDcYX0bi/VuxJ4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0C05/ZUaw/9pv5xfonpw1U2U6lVvafUqxvfbP8U9S3LvxB3drr+8vj5HP/6T+L9VY6+C5gHR97diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwhHtG1E8wj1xss9q2/a6xeAlpxbru88p/9zzlCcPL9af+MYN77unA67d9bvF+qNnlMfRh197vViPU+v/APG2bxVX1YJlT5SfgPfYGBu0OwbHnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMvvDxfrwq4PF+ot31B8rf/r0NcV1l/z9N4v1OTd27ppyvH9NjbPbXmN7p+3No5b12F5v+7na7awqGwZQvYkcxt8q6d2z3l8paUNELJK0ofYYQBcbN+wR8bCkdx9Hnidpbe3+WknnV9sWgKo1+gXd3Ijol6Ta7Zx6T7S93Haf7b4h7WlwcwCa1fJv4yNidUT0RkTvNE1v9eYA1NFo2Adsz5Ok2u3O6loC0AqNhn2dpItr9y+WdH817QBolXF/N972nZLOlDTb9nZJV0taJelu25dKeknSha1scrIb3vVqU+sP7W58fvdPffkXxforN00pv8D+8hzr6B7jhj0iltUpcXYMcBDhdFkgCcIOJEHYgSQIO5AEYQeSYMrmSeC4K56tW7vkxPKgyb8dtaFYP+PCy4r1md9/pFhH92DPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM4+CZSmTX7168cV131p3dvF+pXX3las/8UXLyjW478/WLc2/+9+XlxXbfyZ8wzYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEkzZnNzgH55arN9+9XeK9QVTD21425+6bUWxvujm/mJ939ZtDW97smpqymYAkwNhB5Ig7EAShB1IgrADSRB2IAnCDiTBODuKYuniYv2IVduL9Ts/8eOGt33sg39UrP/O39S/jl+Shp/b2vC2D1ZNjbPbXmN7p+3No5ZdY/tl25tqf+dW2TCA6k3kMP5WSeeMsfx7EbG49vdAtW0BqNq4YY+IhyUNtqEXAC3UzBd0K2w/WTvMn1XvSbaX2+6z3TekPU1sDkAzGg37TZIWSlosqV/Sd+s9MSJWR0RvRPRO0/QGNwegWQ2FPSIGImI4IvZLulnSkmrbAlC1hsJue96ohxdI2lzvuQC6w7jj7LbvlHSmpNmSBiRdXXu8WFJI2ibpaxFRvvhYjLNPRlPmzinWd1x0TN3axiuuL677gXH2RV9+8exi/fXTXi3WJ6PSOPu4k0RExLIxFt/SdFcA2orTZYEkCDuQBGEHkiDsQBKEHUiCS1zRMXdvL0/ZfJgPKdZ/HXuL9c9/8/L6r33fxuK6Byt+ShoAYQeyIOxAEoQdSIKwA0kQdiAJwg4kMe5Vb8ht/2mLi/UXLixP2XzC4m11a+ONo4/nhsGTivXD7u9r6vUnG/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+yTnHtPKNaf/VZ5rPvmpWuL9dMPLV9T3ow9MVSsPzK4oPwC+8f9dfNU2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsx8Epi44qlh/4ZKP1a1dc9FdxXW/cPiuhnqqwlUDvcX6Q9efUqzPWlv+3Xm807h7dtvzbT9oe4vtp21/u7a8x/Z628/Vbme1vl0AjZrIYfw+SSsj4jhJp0i6zPbxkq6UtCEiFknaUHsMoEuNG/aI6I+Ix2v335C0RdKRks6TdOBcyrWSzm9RjwAq8L6+oLN9tKSTJG2UNDci+qWRfxAkzamzznLbfbb7hrSnyXYBNGrCYbd9uKQfSro8InZPdL2IWB0RvRHRO03TG+kRQAUmFHbb0zQS9Nsj4t7a4gHb82r1eZJ2tqZFAFUYd+jNtiXdImlLRFw3qrRO0sWSVtVu729Jh5PA1KN/u1h//ffmFesX/e2PivU/+dC9xXorrewvD4/9/F/qD6/13PpfxXVn7WdorUoTGWdfKukrkp6yvam27CqNhPxu25dKeknShS3pEEAlxg17RPxM0piTu0s6q9p2ALQKp8sCSRB2IAnCDiRB2IEkCDuQBJe4TtDUeR+tWxtcM6O47tcXPFSsL5s50FBPVVjx8mnF+uM3LS7WZ/9gc7He8wZj5d2CPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJJFmnH3vH5R/tnjvnw4W61cd80Dd2tm/9VZDPVVlYPjturXT160srnvsX/2yWO95rTxOvr9YRTdhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaQZZ992fvnftWdPvKdl277xtYXF+vUPnV2se7jej/uOOPbaF+vWFg1sLK47XKxiMmHPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJOCLKT7DnS7pN0kc1cvny6oi43vY1kv5Y0iu1p14VEfUv+pZ0hHviZDPxK9AqG2ODdsfgmCdmTOSkmn2SVkbE47ZnSnrM9vpa7XsR8Z2qGgXQOhOZn71fUn/t/hu2t0g6stWNAajW+/rMbvtoSSdJOnAO5grbT9peY3tWnXWW2+6z3TekPc11C6BhEw677cMl/VDS5RGxW9JNkhZKWqyRPf93x1ovIlZHRG9E9E7T9OY7BtCQCYXd9jSNBP32iLhXkiJiICKGI2K/pJslLWldmwCaNW7YbVvSLZK2RMR1o5bPG/W0CySVp/ME0FET+TZ+qaSvSHrK9qbasqskLbO9WFJI2ibpay3oD0BFJvJt/M8kjTVuVxxTB9BdOIMOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQxLg/JV3pxuxXJP3PqEWzJe1qWwPvT7f21q19SfTWqCp7OyoiPjJWoa1hf8/G7b6I6O1YAwXd2lu39iXRW6Pa1RuH8UAShB1IotNhX93h7Zd0a2/d2pdEb41qS28d/cwOoH06vWcH0CaEHUiiI2G3fY7tZ2w/b/vKTvRQj+1ttp+yvcl2X4d7WWN7p+3No5b12F5v+7na7Zhz7HWot2tsv1x77zbZPrdDvc23/aDtLbaftv3t2vKOvneFvtryvrX9M7vtKZKelfRZSdslPSppWUT8oq2N1GF7m6TeiOj4CRi2T5f0pqTbIuKE2rJ/lDQYEatq/1DOiogruqS3ayS92elpvGuzFc0bPc24pPMlfVUdfO8KfX1RbXjfOrFnXyLp+YjYGhF7Jd0l6bwO9NH1IuJhSYPvWnyepLW1+2s18j9L29XprStERH9EPF67/4akA9OMd/S9K/TVFp0I+5GSfjXq8XZ113zvIeknth+zvbzTzYxhbkT0SyP/80ia0+F+3m3cabzb6V3TjHfNe9fI9OfN6kTYx5pKqpvG/5ZGxGckfU7SZbXDVUzMhKbxbpcxphnvCo1Of96sToR9u6T5ox5/XNKODvQxpojYUbvdKek+dd9U1AMHZtCt3e7scD//r5um8R5rmnF1wXvXyenPOxH2RyUtsr3A9iGSviRpXQf6eA/bM2pfnMj2DElnq/umol4n6eLa/Ysl3d/BXt6hW6bxrjfNuDr83nV8+vOIaPufpHM18o38C5L+shM91OnrE5KeqP093eneJN2pkcO6IY0cEV0q6cOSNkh6rnbb00W9/bukpyQ9qZFgzetQb6dp5KPhk5I21f7O7fR7V+irLe8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+zhHFo7nUhhwAAAABJRU5ErkJggg==\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "transform = SegmentTransform()\n", | |
| "reconstructed = torch.zeros((28,28), dtype=torch.float)\n", | |
| "for i, patch in enumerate(transform(img)):\n", | |
| " x = i % 7 * 4\n", | |
| " y = int(i / 7) * 4\n", | |
| " for k in range(4):\n", | |
| " for j in range(4):\n", | |
| " value = patch[j + k * 4]\n", | |
| " reconstructed[x + j][y + k] = value\n", | |
| " \n", | |
| "plt.imshow(reconstructed.T)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_transform = transforms.Compose(\n", | |
| " [transforms.RandomAffine(20, scale=(0.9, 1.1), translate=(0.1, 0.1)),\n", | |
| " SegmentTransform(), \n", | |
| " transforms.ToTensor(),\n", | |
| " transforms.Normalize((0.5), (0.5))])\n", | |
| "\n", | |
| "train = datasets.MNIST('./data', train=True, download=True, transform=train_transform)\n", | |
| "training_loader = DataLoader(train, batch_size=128, shuffle=True, num_workers=4)\n", | |
| "\n", | |
| "val_transform = transforms.Compose(\n", | |
| " [transforms.RandomAffine(0, translate=(0.1, 0.1)),\n", | |
| " SegmentTransform(),\n", | |
| " transforms.ToTensor(),\n", | |
| " transforms.Normalize((0.5), (0.5))])\n", | |
| "\n", | |
| "val = datasets.MNIST('./data', train=False, download=True, transform=val_transform)\n", | |
| "val_loader = DataLoader(val, batch_size=64, shuffle=True, num_workers=4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Residual(nn.Module):\n", | |
| " def __init__(self, fn):\n", | |
| " super().__init__()\n", | |
| " self.fn = fn\n", | |
| " def forward(self, x):\n", | |
| " return self.fn(x) + x\n", | |
| "\n", | |
| "class PreNorm(nn.Module):\n", | |
| " def __init__(self, dim, fn):\n", | |
| " super().__init__()\n", | |
| " self.norm = nn.LayerNorm(dim)\n", | |
| " self.fn = fn\n", | |
| " def forward(self, x):\n", | |
| " return self.fn(self.norm(x))\n", | |
| " \n", | |
| "class FeedForward(nn.Module):\n", | |
| " def __init__(self, dim, hidden_dim, out_dim):\n", | |
| " super().__init__()\n", | |
| " self.net = nn.Sequential(\n", | |
| " nn.Linear(dim, hidden_dim),\n", | |
| " nn.GELU(),\n", | |
| " nn.Linear(hidden_dim, out_dim)\n", | |
| " )\n", | |
| " def forward(self, x):\n", | |
| " return self.net(x)\n", | |
| "\n", | |
| "class Attention(nn.Module):\n", | |
| " def __init__(self, dim, heads = 8):\n", | |
| " super().__init__()\n", | |
| " self.heads = heads\n", | |
| " self.scale = dim ** -0.5\n", | |
| "\n", | |
| " self.to_qkv = nn.Linear(dim, dim * 3, bias = False)\n", | |
| " self.to_out = nn.Linear(dim, dim)\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " b, n, _, h = *x.shape, self.heads\n", | |
| " qkv = self.to_qkv(x)\n", | |
| " q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)\n", | |
| "\n", | |
| " dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale\n", | |
| " attn = dots.softmax(dim=-1)\n", | |
| "\n", | |
| " out = torch.einsum('bhij,bhjd->bhid', attn, v)\n", | |
| " out = rearrange(out, 'b h n d -> b n (h d)')\n", | |
| " out = self.to_out(out)\n", | |
| " return out\n", | |
| "\n", | |
| "class Transformer(nn.Module):\n", | |
| " def __init__(self, dim, depth, heads, mlp_dim):\n", | |
| " super().__init__()\n", | |
| " layers = []\n", | |
| " for _ in range(depth):\n", | |
| " layers.extend([\n", | |
| " Residual(PreNorm(dim, Attention(dim, heads = heads))),\n", | |
| " Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dim)))\n", | |
| " ])\n", | |
| " self.net = nn.Sequential(*layers)\n", | |
| " def forward(self, x):\n", | |
| " return self.net(x)\n", | |
| "\n", | |
| "class Net(nn.Module):\n", | |
| " def __init__(self, num_patches, patch_size, h_dim, mlp_dim, classes, t_depth=2, t_heads=4):\n", | |
| " super(Net, self).__init__()\n", | |
| " self.num_patches = num_patches\n", | |
| " self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, h_dim))\n", | |
| " self.image_projection = nn.Linear(patch_size, h_dim)\n", | |
| " self.cls_token = nn.Parameter(torch.randn(1, 1, h_dim))\n", | |
| " self.t1 = Transformer(h_dim, t_depth, t_heads, mlp_dim)\n", | |
| " self.fc = FeedForward(h_dim, mlp_dim, classes)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = self.image_projection(x)\n", | |
| " cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)\n", | |
| " x = torch.cat((cls_tokens, x), dim=1)\n", | |
| " x += self.pos_embedding\n", | |
| " x = self.t1(x)\n", | |
| " return self.fc(x[:, 0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(49, 16)" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "image_size = 28\n", | |
| "img = torch.tensor(np.arange(image_size*image_size).reshape(image_size, image_size), dtype=torch.float32)\n", | |
| "img = transform(img)\n", | |
| "num_patches, patch_size = img.shape\n", | |
| "num_patches, patch_size" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model = Net(num_patches, patch_size, h_dim=128, mlp_dim=64, classes=10)\n", | |
| "model = model.to(device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "epochs=10\n", | |
| "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)\n", | |
| "scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, \n", | |
| " max_lr=0.005, \n", | |
| " steps_per_epoch=len(training_loader),\n", | |
| " epochs=epochs)\n", | |
| "\n", | |
| "def train(model, device, training_loader, epoch):\n", | |
| " model.train()\n", | |
| " for batch_idx, (data, target) in enumerate(training_loader):\n", | |
| " data, target = data.to(device), target.to(device)\n", | |
| " data = data.squeeze(1)\n", | |
| " optimizer.zero_grad()\n", | |
| " output = model(data)\n", | |
| " loss = F.cross_entropy(output, target)\n", | |
| " loss.backward()\n", | |
| " optimizer.step()\n", | |
| " scheduler.step()\n", | |
| " \n", | |
| " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", | |
| " epoch, batch_idx * len(data), len(training_loader.dataset),\n", | |
| " 100. * batch_idx / len(training_loader), loss.item()))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test(model, device, test_loader):\n", | |
| " model.eval()\n", | |
| " test_loss = 0\n", | |
| " correct = 0\n", | |
| " with torch.no_grad():\n", | |
| " for data, target in test_loader:\n", | |
| " data, target = data.to(device), target.to(device)\n", | |
| " data = data.squeeze(1)\n", | |
| " output = model(data)\n", | |
| " test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss\n", | |
| " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", | |
| " correct += pred.eq(target.view_as(pred)).sum().item()\n", | |
| "\n", | |
| " test_loss /= len(test_loader.dataset)\n", | |
| "\n", | |
| " print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", | |
| " test_loss, correct, len(test_loader.dataset),\n", | |
| " 100. * correct / len(test_loader.dataset)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Train Epoch: 0 [44928/60000 (100%)]\tLoss: 0.667979\n", | |
| "Test set: Average loss: 0.6152, Accuracy: 7976/10000 (80%)\n", | |
| "\n", | |
| "Train Epoch: 1 [44928/60000 (100%)]\tLoss: 0.493464\n", | |
| "Test set: Average loss: 0.3714, Accuracy: 8803/10000 (88%)\n", | |
| "\n", | |
| "Train Epoch: 2 [44928/60000 (100%)]\tLoss: 0.285482\n", | |
| "Test set: Average loss: 0.2962, Accuracy: 9029/10000 (90%)\n", | |
| "\n", | |
| "Train Epoch: 3 [44928/60000 (100%)]\tLoss: 0.390123\n", | |
| "Test set: Average loss: 0.1847, Accuracy: 9429/10000 (94%)\n", | |
| "\n", | |
| "Train Epoch: 4 [44928/60000 (100%)]\tLoss: 0.296114\n", | |
| "Test set: Average loss: 0.1359, Accuracy: 9590/10000 (96%)\n", | |
| "\n", | |
| "Train Epoch: 5 [44928/60000 (100%)]\tLoss: 0.085885\n", | |
| "Test set: Average loss: 0.0971, Accuracy: 9702/10000 (97%)\n", | |
| "\n", | |
| "Train Epoch: 6 [44928/60000 (100%)]\tLoss: 0.075379\n", | |
| "Test set: Average loss: 0.0793, Accuracy: 9752/10000 (98%)\n", | |
| "\n", | |
| "Train Epoch: 7 [44928/60000 (100%)]\tLoss: 0.116215\n", | |
| "Test set: Average loss: 0.0561, Accuracy: 9822/10000 (98%)\n", | |
| "\n", | |
| "Train Epoch: 8 [44928/60000 (100%)]\tLoss: 0.108714\n", | |
| "Test set: Average loss: 0.0471, Accuracy: 9858/10000 (99%)\n", | |
| "\n", | |
| "Train Epoch: 9 [44928/60000 (100%)]\tLoss: 0.073521\n", | |
| "Test set: Average loss: 0.0504, Accuracy: 9827/10000 (98%)\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for i in range(epochs):\n", | |
| " train(model, device, training_loader, i)\n", | |
| " test(model, device, val_loader)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment