Skip to content

Instantly share code, notes, and snippets.

@briandw
Last active October 9, 2020 05:41
Show Gist options
  • Select an option

  • Save briandw/863d8a250c4a0c58d48b8f00d684447d to your computer and use it in GitHub Desktop.

Select an option

Save briandw/863d8a250c4a0c58d48b8f00d684447d to your computer and use it in GitHub Desktop.
Notebook of the Vision Transformer paper with MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Demo of using Vision Transformers on MNIST\n",
"From the paper [AN IMAGE IS WORTH 16X16 WORDS](https://openreview.net/pdf?id=YicbFdNTTy)\n",
"\n",
"[Thanks to lucidrains](https://github.com/lucidrains/vit-pytorch)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms\n",
"from torch.utils.data import DataLoader\n",
"from torch.optim.lr_scheduler import StepLR\n",
"\n",
"from einops import rearrange\n",
"import math\n",
"from PIL.Image import Image\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from numpy.lib import stride_tricks"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda\")\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_set = datasets.MNIST('./data', train=True, download=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(33.318421449829934, 76.834538656172)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean = 0\n",
"std = 0\n",
"for img, _ in train_set:\n",
" a = np.array(img)\n",
" mean += a.mean()\n",
" std += a.std()\n",
"\n",
"mean/len(train_set), std/len(train_set)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fca94b3d710>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOX0lEQVR4nO3dbYxc5XnG8euKbUwxJvHGseMQFxzjFAg0Jl0ZkBFQoVCCIgGKCLGiiFBapwlOQutKUFoVWtHKrRIiSimSKS6m4iWQgPAHmsSyECRqcFmoAROHN+MS4+0aswIDIfZ6fffDjqsFdp5dZs68eO//T1rNzLnnzLk1cPmcmeeceRwRAjD5faDTDQBoD8IOJEHYgSQIO5AEYQeSmNrOjR3i6XGoZrRzk0Aqv9Fb2ht7PFatqbDbPkfS9ZKmSPrXiFhVev6hmqGTfVYzmwRQsDE21K01fBhve4qkGyV9TtLxkpbZPr7R1wPQWs18Zl8i6fmI2BoReyXdJem8atoCULVmwn6kpF+Nery9tuwdbC+33We7b0h7mtgcgGY0E/axvgR4z7m3EbE6InojoneapjexOQDNaCbs2yXNH/X445J2NNcOgFZpJuyPSlpke4HtQyR9SdK6atoCULWGh94iYp/tFZJ+rJGhtzUR8XRlnQGoVFPj7BHxgKQHKuoFQAtxuiyQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJNDWLK7qfp5b/E0/5yOyWbv+ZPz+6bm34sP3FdY9auLNYP+wbLtb/97pD6tYe7/1+cd1dw28V6yffs7JYP+bPHinWO6GpsNveJukNScOS9kVEbxVNAaheFXv234+IXRW8DoAW4jM7kESzYQ9JP7H9mO3lYz3B9nLbfbb7hrSnyc0BaFSzh/FLI2KH7TmS1tv+ZUQ8PPoJEbFa0mpJOsI90eT2ADSoqT17ROyo3e6UdJ+kJVU0BaB6DYfd9gzbMw/cl3S2pM1VNQagWs0cxs+VdJ/tA69zR0T8qJKuJpkpxy0q1mP6tGJ9xxkfKtbfPqX+mHDPB8vjxT/9dHm8uZP+49czi/V/+OdzivWNJ95Rt/bi0NvFdVcNfLZY/9hPD75PpA2HPSK2Svp0hb0AaCGG3oAkCDuQBGEHkiDsQBKEHUiCS1wrMHzmZ4r16269sVj/5LT6l2JOZkMxXKz/9Q1fLdanvlUe/jr1nhV1azNf3ldcd/qu8tDcYX0bi/VuxJ4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0C05/ZUaw/9pv5xfonpw1U2U6lVvafUqxvfbP8U9S3LvxB3drr+8vj5HP/6T+L9VY6+C5gHR97diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwhHtG1E8wj1xss9q2/a6xeAlpxbru88p/9zzlCcPL9af+MYN77unA67d9bvF+qNnlMfRh197vViPU+v/APG2bxVX1YJlT5SfgPfYGBu0OwbHnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMvvDxfrwq4PF+ot31B8rf/r0NcV1l/z9N4v1OTd27ppyvH9NjbPbXmN7p+3No5b12F5v+7na7awqGwZQvYkcxt8q6d2z3l8paUNELJK0ofYYQBcbN+wR8bCkdx9Hnidpbe3+WknnV9sWgKo1+gXd3Ijol6Ta7Zx6T7S93Haf7b4h7WlwcwCa1fJv4yNidUT0RkTvNE1v9eYA1NFo2Adsz5Ok2u3O6loC0AqNhn2dpItr9y+WdH817QBolXF/N972nZLOlDTb9nZJV0taJelu25dKeknSha1scrIb3vVqU+sP7W58fvdPffkXxforN00pv8D+8hzr6B7jhj0iltUpcXYMcBDhdFkgCcIOJEHYgSQIO5AEYQeSYMrmSeC4K56tW7vkxPKgyb8dtaFYP+PCy4r1md9/pFhH92DPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM4+CZSmTX7168cV131p3dvF+pXX3las/8UXLyjW478/WLc2/+9+XlxXbfyZ8wzYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEkzZnNzgH55arN9+9XeK9QVTD21425+6bUWxvujm/mJ939ZtDW97smpqymYAkwNhB5Ig7EAShB1IgrADSRB2IAnCDiTBODuKYuniYv2IVduL9Ts/8eOGt33sg39UrP/O39S/jl+Shp/b2vC2D1ZNjbPbXmN7p+3No5ZdY/tl25tqf+dW2TCA6k3kMP5WSeeMsfx7EbG49vdAtW0BqNq4YY+IhyUNtqEXAC3UzBd0K2w/WTvMn1XvSbaX2+6z3TekPU1sDkAzGg37TZIWSlosqV/Sd+s9MSJWR0RvRPRO0/QGNwegWQ2FPSIGImI4IvZLulnSkmrbAlC1hsJue96ohxdI2lzvuQC6w7jj7LbvlHSmpNmSBiRdXXu8WFJI2ibpaxFRvvhYjLNPRlPmzinWd1x0TN3axiuuL677gXH2RV9+8exi/fXTXi3WJ6PSOPu4k0RExLIxFt/SdFcA2orTZYEkCDuQBGEHkiDsQBKEHUiCS1zRMXdvL0/ZfJgPKdZ/HXuL9c9/8/L6r33fxuK6Byt+ShoAYQeyIOxAEoQdSIKwA0kQdiAJwg4kMe5Vb8ht/2mLi/UXLixP2XzC4m11a+ONo4/nhsGTivXD7u9r6vUnG/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+yTnHtPKNaf/VZ5rPvmpWuL9dMPLV9T3ow9MVSsPzK4oPwC+8f9dfNU2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsx8Epi44qlh/4ZKP1a1dc9FdxXW/cPiuhnqqwlUDvcX6Q9efUqzPWlv+3Xm807h7dtvzbT9oe4vtp21/u7a8x/Z628/Vbme1vl0AjZrIYfw+SSsj4jhJp0i6zPbxkq6UtCEiFknaUHsMoEuNG/aI6I+Ix2v335C0RdKRks6TdOBcyrWSzm9RjwAq8L6+oLN9tKSTJG2UNDci+qWRfxAkzamzznLbfbb7hrSnyXYBNGrCYbd9uKQfSro8InZPdL2IWB0RvRHRO03TG+kRQAUmFHbb0zQS9Nsj4t7a4gHb82r1eZJ2tqZFAFUYd+jNtiXdImlLRFw3qrRO0sWSVtVu729Jh5PA1KN/u1h//ffmFesX/e2PivU/+dC9xXorrewvD4/9/F/qD6/13PpfxXVn7WdorUoTGWdfKukrkp6yvam27CqNhPxu25dKeknShS3pEEAlxg17RPxM0piTu0s6q9p2ALQKp8sCSRB2IAnCDiRB2IEkCDuQBJe4TtDUeR+tWxtcM6O47tcXPFSsL5s50FBPVVjx8mnF+uM3LS7WZ/9gc7He8wZj5d2CPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJJFmnH3vH5R/tnjvnw4W61cd80Dd2tm/9VZDPVVlYPjturXT160srnvsX/2yWO95rTxOvr9YRTdhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaQZZ992fvnftWdPvKdl277xtYXF+vUPnV2se7jej/uOOPbaF+vWFg1sLK47XKxiMmHPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJOCLKT7DnS7pN0kc1cvny6oi43vY1kv5Y0iu1p14VEfUv+pZ0hHviZDPxK9AqG2ODdsfgmCdmTOSkmn2SVkbE47ZnSnrM9vpa7XsR8Z2qGgXQOhOZn71fUn/t/hu2t0g6stWNAajW+/rMbvtoSSdJOnAO5grbT9peY3tWnXWW2+6z3TekPc11C6BhEw677cMl/VDS5RGxW9JNkhZKWqyRPf93x1ovIlZHRG9E9E7T9OY7BtCQCYXd9jSNBP32iLhXkiJiICKGI2K/pJslLWldmwCaNW7YbVvSLZK2RMR1o5bPG/W0CySVp/ME0FET+TZ+qaSvSHrK9qbasqskLbO9WFJI2ibpay3oD0BFJvJt/M8kjTVuVxxTB9BdOIMOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQxLg/JV3pxuxXJP3PqEWzJe1qWwPvT7f21q19SfTWqCp7OyoiPjJWoa1hf8/G7b6I6O1YAwXd2lu39iXRW6Pa1RuH8UAShB1IotNhX93h7Zd0a2/d2pdEb41qS28d/cwOoH06vWcH0CaEHUiiI2G3fY7tZ2w/b/vKTvRQj+1ttp+yvcl2X4d7WWN7p+3No5b12F5v+7na7Zhz7HWot2tsv1x77zbZPrdDvc23/aDtLbaftv3t2vKOvneFvtryvrX9M7vtKZKelfRZSdslPSppWUT8oq2N1GF7m6TeiOj4CRi2T5f0pqTbIuKE2rJ/lDQYEatq/1DOiogruqS3ayS92elpvGuzFc0bPc24pPMlfVUdfO8KfX1RbXjfOrFnXyLp+YjYGhF7Jd0l6bwO9NH1IuJhSYPvWnyepLW1+2s18j9L29XprStERH9EPF67/4akA9OMd/S9K/TVFp0I+5GSfjXq8XZ113zvIeknth+zvbzTzYxhbkT0SyP/80ia0+F+3m3cabzb6V3TjHfNe9fI9OfN6kTYx5pKqpvG/5ZGxGckfU7SZbXDVUzMhKbxbpcxphnvCo1Of96sToR9u6T5ox5/XNKODvQxpojYUbvdKek+dd9U1AMHZtCt3e7scD//r5um8R5rmnF1wXvXyenPOxH2RyUtsr3A9iGSviRpXQf6eA/bM2pfnMj2DElnq/umol4n6eLa/Ysl3d/BXt6hW6bxrjfNuDr83nV8+vOIaPufpHM18o38C5L+shM91OnrE5KeqP093eneJN2pkcO6IY0cEV0q6cOSNkh6rnbb00W9/bukpyQ9qZFgzetQb6dp5KPhk5I21f7O7fR7V+irLe8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+zhHFo7nUhhwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"img = train_set[0][0]\n",
"img_tensor = transforms.ToTensor()(img)\n",
"plt.imshow(img_tensor.squeeze())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class SegmentTransform(object):\n",
" \"\"\"Convert PIL to rehsaped Tensors.\"\"\"\n",
" \n",
" def __init__(self, patch_size = 4):\n",
" \"\"\"path_size defaults to patch_size x patch_size patch sizes and N elements\"\"\"\n",
" self.patch_size = patch_size\n",
" \n",
" def segment_array(self, a):\n",
" if len(a.shape) != 2:\n",
" raise Exception(\"array must be 2d\")\n",
"\n",
" side_length = a.shape[0]*a.shape[1]\n",
" \n",
" if (int(math.sqrt(side_length))**2) != side_length or side_length == 0:\n",
" raise Exception(\"Must be a perfect square\")\n",
"\n",
" num_bytes = a.dtype.itemsize\n",
" side_length = int(math.sqrt(side_length))\n",
"\n",
" new_shape = (side_length//self.patch_size, \n",
" side_length//self.patch_size, \n",
" self.patch_size, \n",
" self.patch_size)\n",
" \n",
" strides = (num_bytes * self.patch_size * side_length, \n",
" num_bytes * self.patch_size,\n",
" num_bytes * side_length, \n",
" num_bytes)\n",
" \n",
" return stride_tricks.as_strided(a, shape = new_shape, strides = strides)\n",
"\n",
" def __call__(self, img):\n",
" a = np.array(img)\n",
" a = self.segment_array(a)\n",
" a = a.reshape(-1,(self.patch_size*self.patch_size))\n",
" return a"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fca94a44190>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOX0lEQVR4nO3dbYxc5XnG8euKbUwxJvHGseMQFxzjFAg0Jl0ZkBFQoVCCIgGKCLGiiFBapwlOQutKUFoVWtHKrRIiSimSKS6m4iWQgPAHmsSyECRqcFmoAROHN+MS4+0aswIDIfZ6fffDjqsFdp5dZs68eO//T1rNzLnnzLk1cPmcmeeceRwRAjD5faDTDQBoD8IOJEHYgSQIO5AEYQeSmNrOjR3i6XGoZrRzk0Aqv9Fb2ht7PFatqbDbPkfS9ZKmSPrXiFhVev6hmqGTfVYzmwRQsDE21K01fBhve4qkGyV9TtLxkpbZPr7R1wPQWs18Zl8i6fmI2BoReyXdJem8atoCULVmwn6kpF+Nery9tuwdbC+33We7b0h7mtgcgGY0E/axvgR4z7m3EbE6InojoneapjexOQDNaCbs2yXNH/X445J2NNcOgFZpJuyPSlpke4HtQyR9SdK6atoCULWGh94iYp/tFZJ+rJGhtzUR8XRlnQGoVFPj7BHxgKQHKuoFQAtxuiyQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJNDWLK7qfp5b/E0/5yOyWbv+ZPz+6bm34sP3FdY9auLNYP+wbLtb/97pD6tYe7/1+cd1dw28V6yffs7JYP+bPHinWO6GpsNveJukNScOS9kVEbxVNAaheFXv234+IXRW8DoAW4jM7kESzYQ9JP7H9mO3lYz3B9nLbfbb7hrSnyc0BaFSzh/FLI2KH7TmS1tv+ZUQ8PPoJEbFa0mpJOsI90eT2ADSoqT17ROyo3e6UdJ+kJVU0BaB6DYfd9gzbMw/cl3S2pM1VNQagWs0cxs+VdJ/tA69zR0T8qJKuJpkpxy0q1mP6tGJ9xxkfKtbfPqX+mHDPB8vjxT/9dHm8uZP+49czi/V/+OdzivWNJ95Rt/bi0NvFdVcNfLZY/9hPD75PpA2HPSK2Svp0hb0AaCGG3oAkCDuQBGEHkiDsQBKEHUiCS1wrMHzmZ4r16269sVj/5LT6l2JOZkMxXKz/9Q1fLdanvlUe/jr1nhV1azNf3ldcd/qu8tDcYX0bi/VuxJ4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0C05/ZUaw/9pv5xfonpw1U2U6lVvafUqxvfbP8U9S3LvxB3drr+8vj5HP/6T+L9VY6+C5gHR97diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwhHtG1E8wj1xss9q2/a6xeAlpxbru88p/9zzlCcPL9af+MYN77unA67d9bvF+qNnlMfRh197vViPU+v/APG2bxVX1YJlT5SfgPfYGBu0OwbHnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMvvDxfrwq4PF+ot31B8rf/r0NcV1l/z9N4v1OTd27ppyvH9NjbPbXmN7p+3No5b12F5v+7na7awqGwZQvYkcxt8q6d2z3l8paUNELJK0ofYYQBcbN+wR8bCkdx9Hnidpbe3+WknnV9sWgKo1+gXd3Ijol6Ta7Zx6T7S93Haf7b4h7WlwcwCa1fJv4yNidUT0RkTvNE1v9eYA1NFo2Adsz5Ok2u3O6loC0AqNhn2dpItr9y+WdH817QBolXF/N972nZLOlDTb9nZJV0taJelu25dKeknSha1scrIb3vVqU+sP7W58fvdPffkXxforN00pv8D+8hzr6B7jhj0iltUpcXYMcBDhdFkgCcIOJEHYgSQIO5AEYQeSYMrmSeC4K56tW7vkxPKgyb8dtaFYP+PCy4r1md9/pFhH92DPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM4+CZSmTX7168cV131p3dvF+pXX3las/8UXLyjW478/WLc2/+9+XlxXbfyZ8wzYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEkzZnNzgH55arN9+9XeK9QVTD21425+6bUWxvujm/mJ939ZtDW97smpqymYAkwNhB5Ig7EAShB1IgrADSRB2IAnCDiTBODuKYuniYv2IVduL9Ts/8eOGt33sg39UrP/O39S/jl+Shp/b2vC2D1ZNjbPbXmN7p+3No5ZdY/tl25tqf+dW2TCA6k3kMP5WSeeMsfx7EbG49vdAtW0BqNq4YY+IhyUNtqEXAC3UzBd0K2w/WTvMn1XvSbaX2+6z3TekPU1sDkAzGg37TZIWSlosqV/Sd+s9MSJWR0RvRPRO0/QGNwegWQ2FPSIGImI4IvZLulnSkmrbAlC1hsJue96ohxdI2lzvuQC6w7jj7LbvlHSmpNmSBiRdXXu8WFJI2ibpaxFRvvhYjLNPRlPmzinWd1x0TN3axiuuL677gXH2RV9+8exi/fXTXi3WJ6PSOPu4k0RExLIxFt/SdFcA2orTZYEkCDuQBGEHkiDsQBKEHUiCS1zRMXdvL0/ZfJgPKdZ/HXuL9c9/8/L6r33fxuK6Byt+ShoAYQeyIOxAEoQdSIKwA0kQdiAJwg4kMe5Vb8ht/2mLi/UXLixP2XzC4m11a+ONo4/nhsGTivXD7u9r6vUnG/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+yTnHtPKNaf/VZ5rPvmpWuL9dMPLV9T3ow9MVSsPzK4oPwC+8f9dfNU2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsx8Epi44qlh/4ZKP1a1dc9FdxXW/cPiuhnqqwlUDvcX6Q9efUqzPWlv+3Xm807h7dtvzbT9oe4vtp21/u7a8x/Z628/Vbme1vl0AjZrIYfw+SSsj4jhJp0i6zPbxkq6UtCEiFknaUHsMoEuNG/aI6I+Ix2v335C0RdKRks6TdOBcyrWSzm9RjwAq8L6+oLN9tKSTJG2UNDci+qWRfxAkzamzznLbfbb7hrSnyXYBNGrCYbd9uKQfSro8InZPdL2IWB0RvRHRO03TG+kRQAUmFHbb0zQS9Nsj4t7a4gHb82r1eZJ2tqZFAFUYd+jNtiXdImlLRFw3qrRO0sWSVtVu729Jh5PA1KN/u1h//ffmFesX/e2PivU/+dC9xXorrewvD4/9/F/qD6/13PpfxXVn7WdorUoTGWdfKukrkp6yvam27CqNhPxu25dKeknShS3pEEAlxg17RPxM0piTu0s6q9p2ALQKp8sCSRB2IAnCDiRB2IEkCDuQBJe4TtDUeR+tWxtcM6O47tcXPFSsL5s50FBPVVjx8mnF+uM3LS7WZ/9gc7He8wZj5d2CPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJJFmnH3vH5R/tnjvnw4W61cd80Dd2tm/9VZDPVVlYPjturXT160srnvsX/2yWO95rTxOvr9YRTdhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaQZZ992fvnftWdPvKdl277xtYXF+vUPnV2se7jej/uOOPbaF+vWFg1sLK47XKxiMmHPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJOCLKT7DnS7pN0kc1cvny6oi43vY1kv5Y0iu1p14VEfUv+pZ0hHviZDPxK9AqG2ODdsfgmCdmTOSkmn2SVkbE47ZnSnrM9vpa7XsR8Z2qGgXQOhOZn71fUn/t/hu2t0g6stWNAajW+/rMbvtoSSdJOnAO5grbT9peY3tWnXWW2+6z3TekPc11C6BhEw677cMl/VDS5RGxW9JNkhZKWqyRPf93x1ovIlZHRG9E9E7T9OY7BtCQCYXd9jSNBP32iLhXkiJiICKGI2K/pJslLWldmwCaNW7YbVvSLZK2RMR1o5bPG/W0CySVp/ME0FET+TZ+qaSvSHrK9qbasqskLbO9WFJI2ibpay3oD0BFJvJt/M8kjTVuVxxTB9BdOIMOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQxLg/JV3pxuxXJP3PqEWzJe1qWwPvT7f21q19SfTWqCp7OyoiPjJWoa1hf8/G7b6I6O1YAwXd2lu39iXRW6Pa1RuH8UAShB1IotNhX93h7Zd0a2/d2pdEb41qS28d/cwOoH06vWcH0CaEHUiiI2G3fY7tZ2w/b/vKTvRQj+1ttp+yvcl2X4d7WWN7p+3No5b12F5v+7na7Zhz7HWot2tsv1x77zbZPrdDvc23/aDtLbaftv3t2vKOvneFvtryvrX9M7vtKZKelfRZSdslPSppWUT8oq2N1GF7m6TeiOj4CRi2T5f0pqTbIuKE2rJ/lDQYEatq/1DOiogruqS3ayS92elpvGuzFc0bPc24pPMlfVUdfO8KfX1RbXjfOrFnXyLp+YjYGhF7Jd0l6bwO9NH1IuJhSYPvWnyepLW1+2s18j9L29XprStERH9EPF67/4akA9OMd/S9K/TVFp0I+5GSfjXq8XZ113zvIeknth+zvbzTzYxhbkT0SyP/80ia0+F+3m3cabzb6V3TjHfNe9fI9OfN6kTYx5pKqpvG/5ZGxGckfU7SZbXDVUzMhKbxbpcxphnvCo1Of96sToR9u6T5ox5/XNKODvQxpojYUbvdKek+dd9U1AMHZtCt3e7scD//r5um8R5rmnF1wXvXyenPOxH2RyUtsr3A9iGSviRpXQf6eA/bM2pfnMj2DElnq/umol4n6eLa/Ysl3d/BXt6hW6bxrjfNuDr83nV8+vOIaPufpHM18o38C5L+shM91OnrE5KeqP093eneJN2pkcO6IY0cEV0q6cOSNkh6rnbb00W9/bukpyQ9qZFgzetQb6dp5KPhk5I21f7O7fR7V+irLe8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+zhHFo7nUhhwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"transform = SegmentTransform()\n",
"reconstructed = torch.zeros((28,28), dtype=torch.float)\n",
"for i, patch in enumerate(transform(img)):\n",
" x = i % 7 * 4\n",
" y = int(i / 7) * 4\n",
" for k in range(4):\n",
" for j in range(4):\n",
" value = patch[j + k * 4]\n",
" reconstructed[x + j][y + k] = value\n",
" \n",
"plt.imshow(reconstructed.T)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"train_transform = transforms.Compose(\n",
" [transforms.RandomAffine(20, scale=(0.9, 1.1), translate=(0.1, 0.1)),\n",
" SegmentTransform(), \n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5), (0.5))])\n",
"\n",
"train = datasets.MNIST('./data', train=True, download=True, transform=train_transform)\n",
"training_loader = DataLoader(train, batch_size=128, shuffle=True, num_workers=4)\n",
"\n",
"val_transform = transforms.Compose(\n",
" [transforms.RandomAffine(0, translate=(0.1, 0.1)),\n",
" SegmentTransform(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5), (0.5))])\n",
"\n",
"val = datasets.MNIST('./data', train=False, download=True, transform=val_transform)\n",
"val_loader = DataLoader(val, batch_size=64, shuffle=True, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class Residual(nn.Module):\n",
" def __init__(self, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
" def forward(self, x):\n",
" return self.fn(x) + x\n",
"\n",
"class PreNorm(nn.Module):\n",
" def __init__(self, dim, fn):\n",
" super().__init__()\n",
" self.norm = nn.LayerNorm(dim)\n",
" self.fn = fn\n",
" def forward(self, x):\n",
" return self.fn(self.norm(x))\n",
" \n",
"class FeedForward(nn.Module):\n",
" def __init__(self, dim, hidden_dim, out_dim):\n",
" super().__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Linear(dim, hidden_dim),\n",
" nn.GELU(),\n",
" nn.Linear(hidden_dim, out_dim)\n",
" )\n",
" def forward(self, x):\n",
" return self.net(x)\n",
"\n",
"class Attention(nn.Module):\n",
" def __init__(self, dim, heads = 8):\n",
" super().__init__()\n",
" self.heads = heads\n",
" self.scale = dim ** -0.5\n",
"\n",
" self.to_qkv = nn.Linear(dim, dim * 3, bias = False)\n",
" self.to_out = nn.Linear(dim, dim)\n",
" \n",
" def forward(self, x):\n",
" b, n, _, h = *x.shape, self.heads\n",
" qkv = self.to_qkv(x)\n",
" q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)\n",
"\n",
" dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale\n",
" attn = dots.softmax(dim=-1)\n",
"\n",
" out = torch.einsum('bhij,bhjd->bhid', attn, v)\n",
" out = rearrange(out, 'b h n d -> b n (h d)')\n",
" out = self.to_out(out)\n",
" return out\n",
"\n",
"class Transformer(nn.Module):\n",
" def __init__(self, dim, depth, heads, mlp_dim):\n",
" super().__init__()\n",
" layers = []\n",
" for _ in range(depth):\n",
" layers.extend([\n",
" Residual(PreNorm(dim, Attention(dim, heads = heads))),\n",
" Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dim)))\n",
" ])\n",
" self.net = nn.Sequential(*layers)\n",
" def forward(self, x):\n",
" return self.net(x)\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self, num_patches, patch_size, h_dim, mlp_dim, classes, t_depth=2, t_heads=4):\n",
" super(Net, self).__init__()\n",
" self.num_patches = num_patches\n",
" self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, h_dim))\n",
" self.image_projection = nn.Linear(patch_size, h_dim)\n",
" self.cls_token = nn.Parameter(torch.randn(1, 1, h_dim))\n",
" self.t1 = Transformer(h_dim, t_depth, t_heads, mlp_dim)\n",
" self.fc = FeedForward(h_dim, mlp_dim, classes)\n",
"\n",
" def forward(self, x):\n",
" x = self.image_projection(x)\n",
" cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)\n",
" x = torch.cat((cls_tokens, x), dim=1)\n",
" x += self.pos_embedding\n",
" x = self.t1(x)\n",
" return self.fc(x[:, 0])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(49, 16)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image_size = 28\n",
"img = torch.tensor(np.arange(image_size*image_size).reshape(image_size, image_size), dtype=torch.float32)\n",
"img = transform(img)\n",
"num_patches, patch_size = img.shape\n",
"num_patches, patch_size"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"model = Net(num_patches, patch_size, h_dim=128, mlp_dim=64, classes=10)\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"epochs=10\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)\n",
"scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, \n",
" max_lr=0.005, \n",
" steps_per_epoch=len(training_loader),\n",
" epochs=epochs)\n",
"\n",
"def train(model, device, training_loader, epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(training_loader):\n",
" data, target = data.to(device), target.to(device)\n",
" data = data.squeeze(1)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.cross_entropy(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" scheduler.step()\n",
" \n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(training_loader.dataset),\n",
" 100. * batch_idx / len(training_loader), loss.item()))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def test(model, device, test_loader):\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" data, target = data.to(device), target.to(device)\n",
" data = data.squeeze(1)\n",
" output = model(data)\n",
" test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss\n",
" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" test_loss /= len(test_loader.dataset)\n",
"\n",
" print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(test_loader.dataset),\n",
" 100. * correct / len(test_loader.dataset)))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 0 [44928/60000 (100%)]\tLoss: 0.667979\n",
"Test set: Average loss: 0.6152, Accuracy: 7976/10000 (80%)\n",
"\n",
"Train Epoch: 1 [44928/60000 (100%)]\tLoss: 0.493464\n",
"Test set: Average loss: 0.3714, Accuracy: 8803/10000 (88%)\n",
"\n",
"Train Epoch: 2 [44928/60000 (100%)]\tLoss: 0.285482\n",
"Test set: Average loss: 0.2962, Accuracy: 9029/10000 (90%)\n",
"\n",
"Train Epoch: 3 [44928/60000 (100%)]\tLoss: 0.390123\n",
"Test set: Average loss: 0.1847, Accuracy: 9429/10000 (94%)\n",
"\n",
"Train Epoch: 4 [44928/60000 (100%)]\tLoss: 0.296114\n",
"Test set: Average loss: 0.1359, Accuracy: 9590/10000 (96%)\n",
"\n",
"Train Epoch: 5 [44928/60000 (100%)]\tLoss: 0.085885\n",
"Test set: Average loss: 0.0971, Accuracy: 9702/10000 (97%)\n",
"\n",
"Train Epoch: 6 [44928/60000 (100%)]\tLoss: 0.075379\n",
"Test set: Average loss: 0.0793, Accuracy: 9752/10000 (98%)\n",
"\n",
"Train Epoch: 7 [44928/60000 (100%)]\tLoss: 0.116215\n",
"Test set: Average loss: 0.0561, Accuracy: 9822/10000 (98%)\n",
"\n",
"Train Epoch: 8 [44928/60000 (100%)]\tLoss: 0.108714\n",
"Test set: Average loss: 0.0471, Accuracy: 9858/10000 (99%)\n",
"\n",
"Train Epoch: 9 [44928/60000 (100%)]\tLoss: 0.073521\n",
"Test set: Average loss: 0.0504, Accuracy: 9827/10000 (98%)\n",
"\n"
]
}
],
"source": [
"for i in range(epochs):\n",
" train(model, device, training_loader, i)\n",
" test(model, device, val_loader)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment