Implementation of SRGAN in PyTorch. Superresolves to 4x
Based on this paper: https://arxiv.org/pdf/1609.04802.pdf
Model is to be trained for 150 epochs.
Implementation of SRGAN in PyTorch. Superresolves to 4x
Based on this paper: https://arxiv.org/pdf/1609.04802.pdf
Model is to be trained for 150 epochs.
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.1" | |
| }, | |
| "colab": { | |
| "name": "srgan.ipynb", | |
| "provenance": [], | |
| "collapsed_sections": [], | |
| "toc_visible": true, | |
| "include_colab_link": true | |
| }, | |
| "accelerator": "GPU", | |
| "widgets": { | |
| "application/vnd.jupyter.widget-state+json": { | |
| "5b14de92c3bb43f5b769eb8b8918d393": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HBoxModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_view_name": "HBoxView", | |
| "_dom_classes": [], | |
| "_model_name": "HBoxModel", | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_view_count": null, | |
| "_view_module_version": "1.5.0", | |
| "box_style": "", | |
| "layout": "IPY_MODEL_92a437f3d4d74fab94d4dfb34cad8cfc", | |
| "_model_module": "@jupyter-widgets/controls", | |
| "children": [ | |
| "IPY_MODEL_6047442ef1464494a1775e852d98de6e", | |
| "IPY_MODEL_35baa95ba9434b2f94fbd8335df48a41" | |
| ] | |
| } | |
| }, | |
| "92a437f3d4d74fab94d4dfb34cad8cfc": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_view_name": "LayoutView", | |
| "grid_template_rows": null, | |
| "right": null, | |
| "justify_content": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "overflow": null, | |
| "_model_module_version": "1.2.0", | |
| "_view_count": null, | |
| "flex_flow": null, | |
| "width": null, | |
| "min_width": null, | |
| "border": null, | |
| "align_items": null, | |
| "bottom": null, | |
| "_model_module": "@jupyter-widgets/base", | |
| "top": null, | |
| "grid_column": null, | |
| "overflow_y": null, | |
| "overflow_x": null, | |
| "grid_auto_flow": null, | |
| "grid_area": null, | |
| "grid_template_columns": null, | |
| "flex": null, | |
| "_model_name": "LayoutModel", | |
| "justify_items": null, | |
| "grid_row": null, | |
| "max_height": null, | |
| "align_content": null, | |
| "visibility": null, | |
| "align_self": null, | |
| "height": null, | |
| "min_height": null, | |
| "padding": null, | |
| "grid_auto_rows": null, | |
| "grid_gap": null, | |
| "max_width": null, | |
| "order": null, | |
| "_view_module_version": "1.2.0", | |
| "grid_template_areas": null, | |
| "object_position": null, | |
| "object_fit": null, | |
| "grid_auto_columns": null, | |
| "margin": null, | |
| "display": null, | |
| "left": null | |
| } | |
| }, | |
| "6047442ef1464494a1775e852d98de6e": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "FloatProgressModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_view_name": "ProgressView", | |
| "style": "IPY_MODEL_3b3fcd76032c4f4997e819e68f7ddef0", | |
| "_dom_classes": [], | |
| "description": "100%", | |
| "_model_name": "FloatProgressModel", | |
| "bar_style": "success", | |
| "max": 553433881, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "value": 553433881, | |
| "_view_count": null, | |
| "_view_module_version": "1.5.0", | |
| "orientation": "horizontal", | |
| "min": 0, | |
| "description_tooltip": null, | |
| "_model_module": "@jupyter-widgets/controls", | |
| "layout": "IPY_MODEL_b201fa430ceb42e3bc438dc673856622" | |
| } | |
| }, | |
| "35baa95ba9434b2f94fbd8335df48a41": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_view_name": "HTMLView", | |
| "style": "IPY_MODEL_39c14f33da7a4d85bfcb187c0cacd68a", | |
| "_dom_classes": [], | |
| "description": "", | |
| "_model_name": "HTMLModel", | |
| "placeholder": "", | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "value": " 528M/528M [00:09<00:00, 55.6MB/s]", | |
| "_view_count": null, | |
| "_view_module_version": "1.5.0", | |
| "description_tooltip": null, | |
| "_model_module": "@jupyter-widgets/controls", | |
| "layout": "IPY_MODEL_638747b422554adb8c64e087f52f4fd6" | |
| } | |
| }, | |
| "3b3fcd76032c4f4997e819e68f7ddef0": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "ProgressStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_view_name": "StyleView", | |
| "_model_name": "ProgressStyleModel", | |
| "description_width": "initial", | |
| "_view_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.5.0", | |
| "_view_count": null, | |
| "_view_module_version": "1.2.0", | |
| "bar_color": null, | |
| "_model_module": "@jupyter-widgets/controls" | |
| } | |
| }, | |
| "b201fa430ceb42e3bc438dc673856622": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_view_name": "LayoutView", | |
| "grid_template_rows": null, | |
| "right": null, | |
| "justify_content": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "overflow": null, | |
| "_model_module_version": "1.2.0", | |
| "_view_count": null, | |
| "flex_flow": null, | |
| "width": null, | |
| "min_width": null, | |
| "border": null, | |
| "align_items": null, | |
| "bottom": null, | |
| "_model_module": "@jupyter-widgets/base", | |
| "top": null, | |
| "grid_column": null, | |
| "overflow_y": null, | |
| "overflow_x": null, | |
| "grid_auto_flow": null, | |
| "grid_area": null, | |
| "grid_template_columns": null, | |
| "flex": null, | |
| "_model_name": "LayoutModel", | |
| "justify_items": null, | |
| "grid_row": null, | |
| "max_height": null, | |
| "align_content": null, | |
| "visibility": null, | |
| "align_self": null, | |
| "height": null, | |
| "min_height": null, | |
| "padding": null, | |
| "grid_auto_rows": null, | |
| "grid_gap": null, | |
| "max_width": null, | |
| "order": null, | |
| "_view_module_version": "1.2.0", | |
| "grid_template_areas": null, | |
| "object_position": null, | |
| "object_fit": null, | |
| "grid_auto_columns": null, | |
| "margin": null, | |
| "display": null, | |
| "left": null | |
| } | |
| }, | |
| "39c14f33da7a4d85bfcb187c0cacd68a": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_view_name": "StyleView", | |
| "_model_name": "DescriptionStyleModel", | |
| "description_width": "", | |
| "_view_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.5.0", | |
| "_view_count": null, | |
| "_view_module_version": "1.2.0", | |
| "_model_module": "@jupyter-widgets/controls" | |
| } | |
| }, | |
| "638747b422554adb8c64e087f52f4fd6": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_view_name": "LayoutView", | |
| "grid_template_rows": null, | |
| "right": null, | |
| "justify_content": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "overflow": null, | |
| "_model_module_version": "1.2.0", | |
| "_view_count": null, | |
| "flex_flow": null, | |
| "width": null, | |
| "min_width": null, | |
| "border": null, | |
| "align_items": null, | |
| "bottom": null, | |
| "_model_module": "@jupyter-widgets/base", | |
| "top": null, | |
| "grid_column": null, | |
| "overflow_y": null, | |
| "overflow_x": null, | |
| "grid_auto_flow": null, | |
| "grid_area": null, | |
| "grid_template_columns": null, | |
| "flex": null, | |
| "_model_name": "LayoutModel", | |
| "justify_items": null, | |
| "grid_row": null, | |
| "max_height": null, | |
| "align_content": null, | |
| "visibility": null, | |
| "align_self": null, | |
| "height": null, | |
| "min_height": null, | |
| "padding": null, | |
| "grid_auto_rows": null, | |
| "grid_gap": null, | |
| "max_width": null, | |
| "order": null, | |
| "_view_module_version": "1.2.0", | |
| "grid_template_areas": null, | |
| "object_position": null, | |
| "object_fit": null, | |
| "grid_auto_columns": null, | |
| "margin": null, | |
| "display": null, | |
| "left": null | |
| } | |
| } | |
| } | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/SharanSMenon/7afd37c9cac76a736fd1a592966608c0/srgan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "wtSTQigkkUSR" | |
| }, | |
| "source": [ | |
| "# Super resolution with SRGAN" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "OLfJgUAekUST" | |
| }, | |
| "source": [ | |
| "**Dataset Link**: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "nggOyQPRke9t" | |
| }, | |
| "source": [ | |
| "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip\n", | |
| "!unzip DIV2K_train_HR.zip" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "ll7Qe9DGkUST" | |
| }, | |
| "source": [ | |
| "## Imports" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "kEbzONOhkUST" | |
| }, | |
| "source": [ | |
| "import torch\n", | |
| "import math\n", | |
| "from os import listdir\n", | |
| "import numpy as np\n", | |
| "from torch.autograd import Variable" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "qI1niZaykUST" | |
| }, | |
| "source": [ | |
| "from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "2av_Pl2SkUSU" | |
| }, | |
| "source": [ | |
| "from torch.utils.data import DataLoader, Dataset" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "RaYLriLqkUSU" | |
| }, | |
| "source": [ | |
| "from srgandata import TrainDatasetFromFolder, ValDatasetFromFolder, TestDatasetFromFolder, display_transform" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "UJ0FVu-NkUSU", | |
| "outputId": "29ac2361-0f74-4c20-bcde-f58487383694" | |
| }, | |
| "source": [ | |
| "torch.autograd.set_detect_anomaly(True)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f71f6ead908>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "4OQpaNjXkUSU" | |
| }, | |
| "source": [ | |
| "## Dataset" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "C8i4bP25kUSU" | |
| }, | |
| "source": [ | |
| "UPSCALE_FACTOR = 4\n", | |
| "CROP_SIZE = 88" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "3OkxFCSYkUSU" | |
| }, | |
| "source": [ | |
| "mean = np.array([0.485, 0.456, 0.406])\n", | |
| "std = np.array([0.229, 0.224, 0.225])" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "F8DP8DVokUSV" | |
| }, | |
| "source": [ | |
| "Makes a low resokution image" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "cmejgK8xkUSV" | |
| }, | |
| "source": [ | |
| "train_set = TrainDatasetFromFolder('DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)\n", | |
| "# val_set = ValDatasetFromFolder('DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)\n", | |
| "train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)\n", | |
| "# val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "-VSfRAQWkUSV" | |
| }, | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "12sK5ob7kUSV" | |
| }, | |
| "source": [ | |
| "## Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ENEfy134kUSV" | |
| }, | |
| "source": [ | |
| "from torch import nn, optim" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "xGPGQ_BpkUSV" | |
| }, | |
| "source": [ | |
| "class ResidualBlock(nn.Module):\n", | |
| " def __init__(self, channels):\n", | |
| " super(ResidualBlock, self).__init__()\n", | |
| " self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)\n", | |
| " self.bn1 = nn.BatchNorm2d(channels)\n", | |
| " self.prelu = nn.PReLU()\n", | |
| " self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)\n", | |
| " self.bn2 = nn.BatchNorm2d(channels)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " residual = self.conv1(x)\n", | |
| " residual = self.bn1(residual)\n", | |
| " residual = self.prelu(residual)\n", | |
| " residual = self.conv2(residual)\n", | |
| " residual = self.bn2(residual)\n", | |
| "\n", | |
| " return x + residual" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "6ceKahB2kUSV" | |
| }, | |
| "source": [ | |
| "class UpsampleBLock(nn.Module):\n", | |
| " def __init__(self, in_channels, up_scale):\n", | |
| " super(UpsampleBLock, self).__init__()\n", | |
| " self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)\n", | |
| " self.pixel_shuffle = nn.PixelShuffle(up_scale)\n", | |
| " self.prelu = nn.PReLU()\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = self.conv(x)\n", | |
| " x = self.pixel_shuffle(x)\n", | |
| " x = self.prelu(x)\n", | |
| " return x" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "E9sHfbpRkUSW" | |
| }, | |
| "source": [ | |
| "class Generator(nn.Module):\n", | |
| " def __init__(self, scale_factor):\n", | |
| " upsample_block_num = int(math.log(scale_factor, 2))\n", | |
| "\n", | |
| " super(Generator, self).__init__()\n", | |
| " self.block1 = nn.Sequential(\n", | |
| " nn.Conv2d(3, 64, kernel_size=9, padding=4),\n", | |
| " nn.PReLU()\n", | |
| " )\n", | |
| " self.block2 = ResidualBlock(64)\n", | |
| " self.block3 = ResidualBlock(64)\n", | |
| " self.block4 = ResidualBlock(64)\n", | |
| " self.block5 = ResidualBlock(64)\n", | |
| " self.block6 = ResidualBlock(64)\n", | |
| " self.block7 = nn.Sequential(\n", | |
| " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", | |
| " nn.BatchNorm2d(64)\n", | |
| " )\n", | |
| " block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]\n", | |
| " block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))\n", | |
| " self.block8 = nn.Sequential(*block8)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " block1 = self.block1(x)\n", | |
| " block2 = self.block2(block1)\n", | |
| " block3 = self.block3(block2)\n", | |
| " block4 = self.block4(block3)\n", | |
| " block5 = self.block5(block4)\n", | |
| " block6 = self.block6(block5)\n", | |
| " block7 = self.block7(block6)\n", | |
| " block8 = self.block8(block1 + block7)\n", | |
| "\n", | |
| " return (torch.tanh(block8) + 1) / 2" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "5v9c_NYNkUSW" | |
| }, | |
| "source": [ | |
| "class Discriminator(nn.Module):\n", | |
| " def __init__(self):\n", | |
| " super(Discriminator, self).__init__()\n", | |
| " self.net = nn.Sequential(\n", | |
| " nn.Conv2d(3, 64, kernel_size=3, padding=1),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n", | |
| " nn.BatchNorm2d(64),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.Conv2d(64, 128, kernel_size=3, padding=1),\n", | |
| " nn.BatchNorm2d(128),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),\n", | |
| " nn.BatchNorm2d(128),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.Conv2d(128, 256, kernel_size=3, padding=1),\n", | |
| " nn.BatchNorm2d(256),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),\n", | |
| " nn.BatchNorm2d(256),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.Conv2d(256, 512, kernel_size=3, padding=1),\n", | |
| " nn.BatchNorm2d(512),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),\n", | |
| " nn.BatchNorm2d(512),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| "\n", | |
| " nn.AdaptiveAvgPool2d(1),\n", | |
| " nn.Conv2d(512, 1024, kernel_size=1),\n", | |
| " nn.LeakyReLU(0.2),\n", | |
| " nn.Conv2d(1024, 1, kernel_size=1)\n", | |
| " )\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " batch_size = x.size(0)\n", | |
| " return torch.sigmoid(self.net(x).view(batch_size))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "UbJRjjxBkUSW" | |
| }, | |
| "source": [ | |
| "from torchvision.models.vgg import vgg16" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "PBErzouckUSW" | |
| }, | |
| "source": [ | |
| "### Loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "1dBQrQBgkUSW" | |
| }, | |
| "source": [ | |
| "class GeneratorLoss(nn.Module):\n", | |
| " def __init__(self):\n", | |
| " super(GeneratorLoss, self).__init__()\n", | |
| " vgg = vgg16(pretrained=True)\n", | |
| " loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()\n", | |
| " for param in loss_network.parameters():\n", | |
| " param.requires_grad = False\n", | |
| " self.loss_network = loss_network\n", | |
| " self.mse_loss = nn.MSELoss()\n", | |
| " self.tv_loss = TVLoss()\n", | |
| "\n", | |
| " def forward(self, out_labels, out_images, target_images):\n", | |
| " # Adversarial Loss\n", | |
| " adversarial_loss = torch.mean(1 - out_labels)\n", | |
| " # Perception Loss\n", | |
| " perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))\n", | |
| " # Image Loss\n", | |
| " image_loss = self.mse_loss(out_images, target_images)\n", | |
| " # TV Loss\n", | |
| " tv_loss = self.tv_loss(out_images)\n", | |
| " return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss\n" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "uGPc4qnWkUSW" | |
| }, | |
| "source": [ | |
| "class TVLoss(nn.Module):\n", | |
| " def __init__(self, tv_loss_weight=1):\n", | |
| " super(TVLoss, self).__init__()\n", | |
| " self.tv_loss_weight = tv_loss_weight\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " batch_size = x.size()[0]\n", | |
| " h_x = x.size()[2]\n", | |
| " w_x = x.size()[3]\n", | |
| " count_h = self.tensor_size(x[:, :, 1:, :])\n", | |
| " count_w = self.tensor_size(x[:, :, :, 1:])\n", | |
| " h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()\n", | |
| " w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()\n", | |
| " return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size\n", | |
| "\n", | |
| " @staticmethod\n", | |
| " def tensor_size(t):\n", | |
| " return t.size()[1] * t.size()[2] * t.size()[3]" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "duZ8UvogmJ8X" | |
| }, | |
| "source": [ | |
| "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "9QPi_X9_kUSW" | |
| }, | |
| "source": [ | |
| "netG = Generator(UPSCALE_FACTOR)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "6sQrfCZHkUSW" | |
| }, | |
| "source": [ | |
| "netD = Discriminator()" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "PiiqnrvkkUSX", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 103, | |
| "referenced_widgets": [ | |
| "5b14de92c3bb43f5b769eb8b8918d393", | |
| "92a437f3d4d74fab94d4dfb34cad8cfc", | |
| "6047442ef1464494a1775e852d98de6e", | |
| "35baa95ba9434b2f94fbd8335df48a41", | |
| "3b3fcd76032c4f4997e819e68f7ddef0", | |
| "b201fa430ceb42e3bc438dc673856622", | |
| "39c14f33da7a4d85bfcb187c0cacd68a", | |
| "638747b422554adb8c64e087f52f4fd6" | |
| ] | |
| }, | |
| "outputId": "596a3dca-fa3f-4329-b29c-85d632f9792b" | |
| }, | |
| "source": [ | |
| "generator_criterion = GeneratorLoss()" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Downloading: \"https://download.pytorch.org/models/vgg16-397923af.pth\" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth\n" | |
| ], | |
| "name": "stderr" | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "5b14de92c3bb43f5b769eb8b8918d393", | |
| "version_minor": 0, | |
| "version_major": 2 | |
| }, | |
| "text/plain": [ | |
| "HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| } | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "aQaIjG8UnunL" | |
| }, | |
| "source": [ | |
| "generator_criterion = generator_criterion.to(device)\n", | |
| "netG = netG.to(device)\n", | |
| "netD = netD.to(device)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "YRPW-NqlkUSX" | |
| }, | |
| "source": [ | |
| "optimizerG = optim.Adam(netG.parameters(), lr=0.0002)\n", | |
| "optimizerD = optim.Adam(netD.parameters(), lr=0.0002)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "236B74VqOLgq" | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "yULcXdFokUSX" | |
| }, | |
| "source": [ | |
| "results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "9Esuf4_xkUSX" | |
| }, | |
| "source": [ | |
| "## Train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "rg9EVhMGkUSX" | |
| }, | |
| "source": [ | |
| "from tqdm import tqdm\n", | |
| "import os" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "lDi9oyr-kUSX" | |
| }, | |
| "source": [ | |
| "N_EPOCHS = 150" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "dJSFZRzykUSX", | |
| "outputId": "dc79719e-2477-4621-d365-c6d7c5fb99c2" | |
| }, | |
| "source": [ | |
| "for epoch in range(1, N_EPOCHS + 1):\n", | |
| " train_bar = tqdm(train_loader)\n", | |
| " running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}\n", | |
| "\n", | |
| " netG.train()\n", | |
| " netD.train()\n", | |
| " for data, target in train_bar:\n", | |
| " g_update_first = True\n", | |
| " batch_size = data.size(0)\n", | |
| " running_results['batch_sizes'] += batch_size\n", | |
| " \n", | |
| " real_img = Variable(target)\n", | |
| " if torch.cuda.is_available():\n", | |
| " real_img = real_img.cuda()\n", | |
| " z = Variable(data)\n", | |
| " if torch.cuda.is_available():\n", | |
| " z = z.cuda()\n", | |
| " \n", | |
| " ############################\n", | |
| " # (1) Update D network: maximize D(x)-1-D(G(z))\n", | |
| " ###########################\n", | |
| " fake_img = netG(z)\n", | |
| "\n", | |
| " netD.zero_grad()\n", | |
| " real_out = netD(real_img).mean()\n", | |
| " fake_out = netD(fake_img).mean()\n", | |
| " d_loss = 1 - real_out + fake_out\n", | |
| " d_loss.backward(retain_graph=True)\n", | |
| " optimizerD.step()\n", | |
| "\n", | |
| " ############################\n", | |
| " # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss\n", | |
| " ###########################\n", | |
| " ###### Was causing Runtime Error ######\n", | |
| " fake_img = netG(z)\n", | |
| " fake_out = netD(fake_img).mean()\n", | |
| " #######################################\n", | |
| " netG.zero_grad()\n", | |
| " g_loss = generator_criterion(fake_out, fake_img, real_img)\n", | |
| " g_loss.backward()\n", | |
| "\n", | |
| " fake_img = netG(z)\n", | |
| " fake_out = netD(fake_img).mean()\n", | |
| "\n", | |
| " optimizerG.step()\n", | |
| "\n", | |
| " # loss for current batch before optimization \n", | |
| " running_results['g_loss'] += g_loss.item() * batch_size\n", | |
| " running_results['d_loss'] += d_loss.item() * batch_size\n", | |
| " running_results['d_score'] += real_out.item() * batch_size\n", | |
| " running_results['g_score'] += fake_out.item() * batch_size\n", | |
| "\n", | |
| " train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (\n", | |
| " epoch, N_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],\n", | |
| " running_results['g_loss'] / running_results['batch_sizes'],\n", | |
| " running_results['d_score'] / running_results['batch_sizes'],\n", | |
| " running_results['g_score'] / running_results['batch_sizes']))\n", | |
| "\n", | |
| " netG.eval()\n", | |
| " out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'\n", | |
| " if not os.path.exists(out_path):\n", | |
| " os.makedirs(out_path)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "[1/150] Loss_D: 0.8393 Loss_G: 0.0452 D(x): 0.5709 D(G(z)): 0.3487: 100%|██████████| 13/13 [01:01<00:00, 4.72s/it]\n", | |
| "[2/150] Loss_D: 0.7786 Loss_G: 0.0193 D(x): 0.5395 D(G(z)): 0.3129: 100%|██████████| 13/13 [01:00<00:00, 4.66s/it]\n", | |
| "[3/150] Loss_D: 0.3920 Loss_G: 0.0157 D(x): 0.7983 D(G(z)): 0.1449: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
| "[4/150] Loss_D: 0.1291 Loss_G: 0.0145 D(x): 0.9329 D(G(z)): 0.0532: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
| "[5/150] Loss_D: 0.0426 Loss_G: 0.0136 D(x): 0.9778 D(G(z)): 0.0177: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
| "[6/150] Loss_D: 0.0277 Loss_G: 0.0131 D(x): 0.9834 D(G(z)): 0.0386: 100%|██████████| 13/13 [01:00<00:00, 4.64s/it]\n", | |
| "[7/150] Loss_D: 0.8722 Loss_G: 0.0119 D(x): 0.5780 D(G(z)): 0.3579: 100%|██████████| 13/13 [00:59<00:00, 4.59s/it]\n", | |
| "[8/150] Loss_D: 0.7058 Loss_G: 0.0104 D(x): 0.5374 D(G(z)): 0.2106: 100%|██████████| 13/13 [00:59<00:00, 4.60s/it]\n", | |
| "[9/150] Loss_D: 0.4493 Loss_G: 0.0108 D(x): 0.7283 D(G(z)): 0.1293: 100%|██████████| 13/13 [00:59<00:00, 4.57s/it]\n", | |
| "[10/150] Loss_D: 0.1850 Loss_G: 0.0102 D(x): 0.9002 D(G(z)): 0.0621: 100%|██████████| 13/13 [00:59<00:00, 4.58s/it]\n", | |
| " 0%| | 0/13 [00:00<?, ?it/s]" | |
| ], | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "YTPIWjk2NC5P" | |
| }, | |
| "source": [ | |
| "from torchvision.transforms import ToTensor, ToPILImage" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "O9yt2ExXLxjO" | |
| }, | |
| "source": [ | |
| "from PIL import Image" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Qz8w5AtxcuRi" | |
| }, | |
| "source": [ | |
| "torch.save(netG.state_dict(), \"super_res_gen.pth\")" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "4CIfaGsmQVwg", | |
| "outputId": "eb6874a0-3589-400f-c571-9cc9f2d32d28" | |
| }, | |
| "source": [ | |
| "netG.load_state_dict(torch.load(\"super_res_gen.pth\")) # If you already have a pretrained weights file for this model." | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<All keys matched successfully>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 18 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "8LI5BqU5NIWu", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 165 | |
| }, | |
| "outputId": "97eef7b0-b8f0-48d5-b532-c4afc52a4b7d" | |
| }, | |
| "source": [ | |
| "# LOAD your OWN image here. This cell wont work unless you upload your own image to Colab\n", | |
| "image = Image.open(\"table2.jpg\")" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "error", | |
| "ename": "NameError", | |
| "evalue": "ignored", | |
| "traceback": [ | |
| "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
| "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | |
| "\u001b[0;32m<ipython-input-1-369897577819>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mimage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"table2.jpg\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
| "\u001b[0;31mNameError\u001b[0m: name 'Image' is not defined" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "TQjCZ7CBkUSX", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "f4b8339f-09b8-4c82-8886-4af5ef4530c8" | |
| }, | |
| "source": [ | |
| "image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:1: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n", | |
| " \"\"\"Entry point for launching an IPython kernel.\n" | |
| ], | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "VpCwx3FXRXYO" | |
| }, | |
| "source": [ | |
| "netG = netG.to(torch.device(\"cuda\"))\n", | |
| "image = image.to(torch.device(\"cuda\"))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "MehEf581NQTH" | |
| }, | |
| "source": [ | |
| "out = netG(image)\n", | |
| "out_img = ToPILImage()(out[0].data.cpu())" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "wrPdGj02NVRi" | |
| }, | |
| "source": [ | |
| "out_img.save(\"table2-superres-4x.jpg\")" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Add this to beginning of file