Created
May 21, 2020 19:13
-
-
Save bearpelican/d2833b3e65134e66b59904f0aef11666 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from fastai.vision import *" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "One-time download, uncomment the next cells to get the data." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#path = Config().data_path()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#! wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip -P {path}\n", | |
| "#! unzip -q -n {path}/horse2zebra.zip -d {path}\n", | |
| "#! rm {path}/horse2zebra.zip" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "data_path = Path('../../../../mnt/wamri/WAMRI-LevensonLab/datasets/')\n", | |
| "muse2he_path = data_path/'muse2he_urothelial_carcinoma'\n", | |
| "muse2he_path.ls()\n", | |
| "\n", | |
| "muse_path = muse2he_path/'trainA'\n", | |
| "he_path = muse2he_path/'trainB'\n", | |
| "\n", | |
| "torch.cuda.set_device(0) #set GPU id" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# path = Config().data_path()/'horse2zebra'\n", | |
| "# path.ls()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "See [this tutorial](https://docs.fast.ai/tutorial.itemlist.html) for a detailed walkthrough of how/why this custom `ItemList` was created." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ImageTuple(ItemBase):\n", | |
| " def __init__(self, img1, img2):\n", | |
| " self.img1,self.img2 = img1,img2\n", | |
| " self.obj,self.data = (img1,img2),[-1+2*img1.data,-1+2*img2.data]\n", | |
| " \n", | |
| " def apply_tfms(self, tfms, **kwargs):\n", | |
| " img1 = self.img1.apply_tfms(tfms, **kwargs)\n", | |
| " img2 = self.img2.apply_tfms(tfms, **kwargs)\n", | |
| " return ImageTuple(img1, img2)\n", | |
| " \n", | |
| " def to_one(self): return Image(0.5+torch.cat(self.data,2)/2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class TargetTupleList(ItemList):\n", | |
| " def reconstruct(self, t:Tensor): \n", | |
| " if len(t.size()) == 0: return t\n", | |
| " return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ImageTupleList(ImageList):\n", | |
| " _label_cls=TargetTupleList\n", | |
| " def __init__(self, items, itemsB=None, **kwargs):\n", | |
| " self.itemsB = itemsB\n", | |
| " super().__init__(items, **kwargs)\n", | |
| " \n", | |
| " def new(self, items, **kwargs):\n", | |
| " return super().new(items, itemsB=self.itemsB, **kwargs)\n", | |
| " \n", | |
| " def get(self, i):\n", | |
| " img1 = super().get(i)\n", | |
| " fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]\n", | |
| " return ImageTuple(img1, open_image(fn))\n", | |
| " \n", | |
| " def reconstruct(self, t:Tensor): \n", | |
| " return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))\n", | |
| " \n", | |
| " @classmethod\n", | |
| " def from_folders(cls, path, folderA, folderB, **kwargs):\n", | |
| " itemsB = ImageList.from_folder(path/folderB).items\n", | |
| " res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)\n", | |
| " res.path = path\n", | |
| " return res\n", | |
| " \n", | |
| " def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):\n", | |
| " \"Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method.\"\n", | |
| " rows = int(math.sqrt(len(xs)))\n", | |
| " fig, axs = plt.subplots(rows,rows,figsize=figsize)\n", | |
| " for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):\n", | |
| " xs[i].to_one().show(ax=ax, **kwargs)\n", | |
| " plt.tight_layout()\n", | |
| "\n", | |
| " def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):\n", | |
| " \"\"\"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.\n", | |
| " `kwargs` are passed to the show method.\"\"\"\n", | |
| " figsize = ifnone(figsize, (12,3*len(xs)))\n", | |
| " fig,axs = plt.subplots(len(xs), 2, figsize=figsize)\n", | |
| " fig.suptitle('Ground truth / Predictions', weight='bold', size=14)\n", | |
| " for i,(x,z) in enumerate(zip(xs,zs)):\n", | |
| " x.to_one().show(ax=axs[i,0], **kwargs)\n", | |
| " z.to_one().show(ax=axs[i,1], **kwargs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# src = ImageTupleList.from_folders(muse2he_path, 'trainA', 'trainB').split_none().label_empty()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "data = (ImageTupleList.from_folders(muse2he_path, 'trainA', 'trainB')\n", | |
| " .split_none()\n", | |
| " .label_empty()\n", | |
| " .transform(2*[[crop(size=256,row_pct=0,col_pct=0),flip_lr(p=0.5)]],size=512,resize_method=ResizeMethod.SQUISH)\n", | |
| " .databunch(bs=4,num_workers=2))\n", | |
| "data.valid_dl = data.train_dl # a hack for proper evaluation of loss and metrics at end of training" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "data.show_batch(rows=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([4, 3, 256, 256]), torch.Size([4, 3, 256, 256]))" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "xb, yb = data.one_batch()\n", | |
| "xb[0].shape, xb[1].shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_image = next(iter(data.train_dl))[0][0][0]\n", | |
| "plt.imshow(((train_image.permute(1,2,0)+1)/2*255).cpu().to(torch.int))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Models" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We use the models that were introduced in the [cycleGAN paper](https://arxiv.org/abs/1703.10593)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):\n", | |
| " return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),\n", | |
| " norm_layer(ch_out), nn.ReLU(True)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True, \n", | |
| " pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_, init_gain:int=0.02)->List[nn.Module]:\n", | |
| " layers = []\n", | |
| " if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))\n", | |
| " elif pad_mode == 'border': layers.append(nn.ReplicationPad2d(pad))\n", | |
| " p = pad if pad_mode == 'zeros' else 0\n", | |
| " conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)\n", | |
| " if init:\n", | |
| " if init == nn.init.normal_:\n", | |
| " init(conv.weight, 0.0, init_gain)\n", | |
| " else:\n", | |
| " init(conv.weight)\n", | |
| " if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n", | |
| " layers += [conv, norm_layer(ch_out)]\n", | |
| " if activ: layers.append(nn.ReLU(inplace=True))\n", | |
| " return layers" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ResnetBlock(nn.Module):\n", | |
| " def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):\n", | |
| " super().__init__()\n", | |
| " assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'\n", | |
| " norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", | |
| " layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)\n", | |
| " if dropout != 0: layers.append(nn.Dropout(dropout))\n", | |
| " layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)\n", | |
| " self.conv_block = nn.Sequential(*layers)\n", | |
| "\n", | |
| " def forward(self, x): return x + self.conv_block(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None, \n", | |
| " dropout:float=0., n_blocks:int=9, pad_mode:str='reflection')->nn.Module:\n", | |
| " norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", | |
| " bias = (norm_layer == nn.InstanceNorm2d)\n", | |
| " layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)\n", | |
| " for i in range(2):\n", | |
| " layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)\n", | |
| " n_ftrs *= 2\n", | |
| " layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]\n", | |
| " for i in range(2):\n", | |
| " layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)\n", | |
| " n_ftrs //= 2\n", | |
| " layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]\n", | |
| " return nn.Sequential(*layers)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "resnet_generator(3, 3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1, \n", | |
| " activ:bool=True, slope:float=0.2, init:Callable=nn.init.normal_, init_gain:int=0.02)->List[nn.Module]:\n", | |
| " conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)\n", | |
| " if init:\n", | |
| " if init == nn.init.normal_:\n", | |
| " init(conv.weight, 0.0, init_gain)\n", | |
| " else:\n", | |
| " init(conv.weight)\n", | |
| " if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n", | |
| " layers = [conv]\n", | |
| " if norm_layer is not None: layers.append(norm_layer(ch_out))\n", | |
| " if activ: layers.append(nn.LeakyReLU(slope, inplace=True))\n", | |
| " return layers" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def discriminator(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:\n", | |
| " norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", | |
| " bias = (norm_layer == nn.InstanceNorm2d)\n", | |
| " layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)\n", | |
| " for i in range(n_layers-1):\n", | |
| " new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs\n", | |
| " layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)\n", | |
| " n_ftrs = new_ftrs\n", | |
| " new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs\n", | |
| " layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)\n", | |
| " layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))\n", | |
| " if sigmoid: layers.append(nn.Sigmoid())\n", | |
| " return nn.Sequential(*layers)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Sequential(\n", | |
| " (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", | |
| " (1): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
| " (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", | |
| " (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", | |
| " (4): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
| " (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", | |
| " (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", | |
| " (7): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
| " (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))\n", | |
| " (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", | |
| " (10): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
| " (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "discriminator(3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We group two discriminators and two generators in a single model, then a `Callback` will take care of training them properly." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class CycleGAN(nn.Module):\n", | |
| " \n", | |
| " def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True, \n", | |
| " drop:float=0., norm_layer:nn.Module=None):\n", | |
| " super().__init__()\n", | |
| " self.D_A = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n", | |
| " self.D_B = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n", | |
| " self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n", | |
| " self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n", | |
| " #G_A: takes real input B and generates fake input A\n", | |
| " #G_B: takes real input A and generates fake input B\n", | |
| " #D_A: trained to make the difference between real input A and fake input A\n", | |
| " #D_B: trained to make the difference between real input B and fake input B\n", | |
| " \n", | |
| " def forward(self, real_A, real_B):\n", | |
| " fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)\n", | |
| " idt_A, idt_B = self.G_A(real_A), self.G_B(real_B) #Needed for the identity loss during training.\n", | |
| " if not self.training: return torch.cat([fake_A[:,None], fake_B[:,None], idt_A[:,None], idt_B[:,None]],1)\n", | |
| " return [fake_A, fake_B, idt_A, idt_B]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "`AdaptiveLoss` is a wrapper around a PyTorch loss function to compare an output of any size with a single number (0. or 1.). It will generate a target with the same shape as the output. A discriminator returns a feature map, and we want it to predict zeros (or ones) for each feature." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class AdaptiveLoss(nn.Module):\n", | |
| " def __init__(self, crit):\n", | |
| " super().__init__()\n", | |
| " self.crit = crit\n", | |
| " \n", | |
| " def forward(self, output, target:bool, **kwargs):\n", | |
| " targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())\n", | |
| " return self.crit(output, targ, **kwargs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "The main loss used to train the generators. It has three parts:\n", | |
| "- the classic GAN loss: they must make the critics believe their images are real\n", | |
| "- identity loss: if they are given an image from the set they are trying to imitate, they should return the same thing\n", | |
| "- cycle loss: if an image from A goes through the generator that imitates B then through the generator that imitates A, it should be the same as the initial image. Same for B and switching the generators" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class CycleGanLoss(nn.Module):\n", | |
| " \n", | |
| " def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):\n", | |
| " super().__init__()\n", | |
| " self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt\n", | |
| " self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)\n", | |
| " \n", | |
| " def set_training(self,training):\n", | |
| " self.training = training\n", | |
| "\n", | |
| " def set_input(self, input):\n", | |
| " self.real_A,self.real_B = input\n", | |
| "\n", | |
| " def forward(self, output, target):\n", | |
| " if self.training:\n", | |
| " fake_A, fake_B, idt_A, idt_B = output\n", | |
| " else:\n", | |
| " fake_A, fake_B, idt_A, idt_B = output[:,0,:,:,:], output[:,1,:,:,:], output[:,2,:,:,:], output[:,3,:,:,:]\n", | |
| "\n", | |
| " #Generators should return identity on the datasets they try to convert to\n", | |
| " self.id_loss = self.l_idt * (self.l_A * F.l1_loss(idt_A, self.real_A) + self.l_B * F.l1_loss(idt_B, self.real_B))\n", | |
| " #Generators are trained to trick the discriminators so the following should be ones\n", | |
| " self.gen_loss = self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)\n", | |
| " #Cycle loss\n", | |
| " self.cyc_loss = self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)\n", | |
| " self.cyc_loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)\n", | |
| " return self.id_loss+self.gen_loss+self.cyc_loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "The main callback to train a cycle GAN. The training loop will train the generators (so `learn.opt` is given those parameters) while the critics are trained by the callback during `on_batch_end`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class CycleGANTrainer(LearnerCallback):\n", | |
| " _order = -20 #Need to run before the Recorder\n", | |
| " \n", | |
| " def _set_trainable(self, D_A=False, D_B=False):\n", | |
| " gen = (not D_A) and (not D_B)\n", | |
| " requires_grad(self.learn.model.G_A, gen)\n", | |
| " requires_grad(self.learn.model.G_B, gen)\n", | |
| " requires_grad(self.learn.model.D_A, D_A)\n", | |
| " requires_grad(self.learn.model.D_B, D_B)\n", | |
| " if not gen:\n", | |
| " self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom\n", | |
| " self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta\n", | |
| " self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom\n", | |
| " self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta\n", | |
| " \n", | |
| " def on_train_begin(self, metrics_names, **kwargs):\n", | |
| " self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B\n", | |
| " self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B\n", | |
| " self.crit = self.learn.loss_func.crit\n", | |
| " if not getattr(self,'opt_G',None):\n", | |
| " self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])\n", | |
| " else: \n", | |
| " self.opt_G.lr,self.opt_G.wd = self.opt.lr,self.opt.wd\n", | |
| " self.opt_G.mom,self.opt_G.beta = self.opt.mom,self.opt.beta\n", | |
| " if not getattr(self,'opt_D_A',None):\n", | |
| " self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])\n", | |
| " if not getattr(self,'opt_D_B',None):\n", | |
| " self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])\n", | |
| " self.learn.opt.opt = self.opt_G.opt\n", | |
| " self._set_trainable()\n", | |
| " self.id_smter,self.gen_smter,self.cyc_smter = SmoothenValue(0.98),SmoothenValue(0.98),SmoothenValue(0.98)\n", | |
| " self.da_smter,self.db_smter = SmoothenValue(0.98),SmoothenValue(0.98)\n", | |
| " self.recorder.add_metric_names(['id_loss', 'gen_loss', 'cyc_loss', 'D_A_loss', 'D_B_loss'])\n", | |
| " \n", | |
| " def on_epoch_begin(self, **kwargs):\n", | |
| " torch.cuda.empty_cache()\n", | |
| " \n", | |
| " def on_batch_begin(self, last_input, **kwargs):\n", | |
| " self.training = self.learn.model.training\n", | |
| " self.learn.loss_func.set_training(self.training)\n", | |
| " self.learn.loss_func.set_input(last_input)\n", | |
| " \n", | |
| " \n", | |
| " def on_backward_begin(self, **kwargs):\n", | |
| " self.id_smter.add_value(self.loss_func.id_loss.detach().cpu())\n", | |
| " self.gen_smter.add_value(self.loss_func.gen_loss.detach().cpu())\n", | |
| " self.cyc_smter.add_value(self.loss_func.cyc_loss.detach().cpu())\n", | |
| " \n", | |
| " def on_batch_end(self, last_input, last_output, **kwargs):\n", | |
| " self.G_A.zero_grad(); self.G_B.zero_grad()\n", | |
| " fake_A, fake_B = last_output[0].detach(), last_output[1].detach()\n", | |
| " real_A, real_B = last_input\n", | |
| " self._set_trainable(D_A=True)\n", | |
| " self.D_A.zero_grad()\n", | |
| " loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))\n", | |
| " self.da_smter.add_value(loss_D_A.detach().cpu())\n", | |
| " if self.training:\n", | |
| " loss_D_A.backward()\n", | |
| " self.opt_D_A.step()\n", | |
| " self._set_trainable(D_B=True)\n", | |
| " self.D_B.zero_grad()\n", | |
| " loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))\n", | |
| " self.db_smter.add_value(loss_D_B.detach().cpu())\n", | |
| "\n", | |
| " if self.training:\n", | |
| " loss_D_B.backward()\n", | |
| " self.opt_D_B.step()\n", | |
| " self._set_trainable()\n", | |
| " \n", | |
| " def on_epoch_end(self, last_metrics, **kwargs):\n", | |
| " return add_metrics(last_metrics, [s.smooth for s in [self.id_smter,self.gen_smter,self.cyc_smter,\n", | |
| " self.da_smter,self.db_smter]])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Training" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Taken from https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py\n", | |
| "# Work of Zach Mueller and Mikhail Grankin\n", | |
| "from fastai.callback import *\n", | |
| "from fastai.callbacks import *\n", | |
| "def FlatAnnealScheduler(learn, lr:float=2e-4, n_epochs:int=100, n_epochs_decay:int=100, curve:str='linear'):\n", | |
| " tot_epochs = n_epochs + n_epochs_decay\n", | |
| " start_pct = n_epochs/tot_epochs\n", | |
| " n = len(learn.data.train_dl)\n", | |
| " anneal_start = int(n * tot_epochs * start_pct)\n", | |
| " batch_finish = ((n * tot_epochs) - anneal_start)\n", | |
| " if curve==\"cosine\": curve_type=annealing_cos\n", | |
| " elif curve==\"linear\": curve_type=annealing_linear\n", | |
| " elif curve==\"exponential\": curve_type=annealing_exp\n", | |
| " else: raise ValueError(f\"annealing type not supported {curve}\")\n", | |
| " phase0 = TrainingPhase(anneal_start).schedule_hp('lr', lr)\n", | |
| " phase1 = TrainingPhase(batch_finish).schedule_hp('lr', lr, anneal=curve_type)\n", | |
| " phases = [phase0, phase1]\n", | |
| " return GeneralScheduler(learn, phases)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def fit_fa(learn:Learner, n_epochs:int=100, n_epochs_decay:int=100, lr:float=2e-4, curve:str='linear',\n", | |
| " wd:float=None, callbacks:Optional[CallbackList]=None)->None:\n", | |
| " \"Fit a model with Flat Cosine Annealing\"\n", | |
| " max_lr = learn.lr_range(lr)\n", | |
| " callbacks = listify(callbacks)\n", | |
| " callbacks.append(FlatAnnealScheduler(learn, lr, n_epochs, n_epochs_decay, curve))\n", | |
| " learn.fit(n_epochs+n_epochs_decay, max_lr, wd=wd, callbacks=callbacks)\n", | |
| " \n", | |
| "Learner.fit_fa = fit_fa" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class AverageMetric(LearnerCallback):\n", | |
| " \"Wrap a `func` in a callback for metrics computation.\"\n", | |
| " def __init__(self, func):\n", | |
| " # If func has a __name__ use this one else it should be a partial\n", | |
| " name = func.__name__ if hasattr(func, '__name__') else func.func.__name__\n", | |
| " self.func, self.name = func, name\n", | |
| " self.world = num_distrib()\n", | |
| "\n", | |
| " def on_epoch_begin(self, **kwargs):\n", | |
| " \"Set the inner value to 0.\"\n", | |
| " self.val, self.count = 0.,0\n", | |
| "\n", | |
| " def on_batch_begin(self, last_input, **kwargs):\n", | |
| " self.last_input = last_input\n", | |
| " \n", | |
| " def on_batch_end(self, last_output, last_target, **kwargs):\n", | |
| " \"Update metric computation with `last_output` and `last_target`.\"\n", | |
| " if not is_listy(last_target): last_target=[last_target]\n", | |
| " self.count += first_el(last_target).size(0)\n", | |
| " val = self.func(self.last_input, last_output, *last_target)\n", | |
| " if self.world:\n", | |
| " val = val.clone()\n", | |
| " dist.all_reduce(val, op=dist.ReduceOp.SUM)\n", | |
| " val /= self.world\n", | |
| " self.val += first_el(last_target).size(0) * val.detach().cpu()\n", | |
| "\n", | |
| " def on_epoch_end(self, last_metrics, **kwargs):\n", | |
| " \"Set the final result in `last_metrics`.\"\n", | |
| " return add_metrics(last_metrics, self.val/self.count)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from metrics import *" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def ssim_fastai(xb,yb,_):\n", | |
| " real_A, real_B = xb\n", | |
| " fake_A, fake_B = yb[:,0,:,:,:], yb[:,1,:,:,:]\n", | |
| " real_A = (real_A/2 + 0.5)*255\n", | |
| " fake_B = (fake_B/2 + 0.5)*255\n", | |
| " return ssim(real_A,fake_B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def psnr_fastai(xb, yb, _):\n", | |
| " real_A, real_B = xb\n", | |
| " fake_A, fake_B = yb[:,0,:,:,:], yb[:,1,:,:,:]\n", | |
| " real_A = (real_A/2 + 0.5)*255\n", | |
| " fake_B = (fake_B/2 + 0.5)*255\n", | |
| " return psnr(real_A,fake_B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "cycle_gan = CycleGAN(3,3, gen_blocks=9)\n", | |
| "learn = Learner(data, cycle_gan, loss_func=CycleGanLoss(cycle_gan), opt_func=partial(optim.Adam, betas=(0.5,0.999)),\n", | |
| " callback_fns=[CycleGANTrainer],metrics=[AverageMetric(ssim_fastai),AverageMetric(psnr_fastai)])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#learn.lr_find()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#learn.recorder.plot()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: left;\">\n", | |
| " <th>epoch</th>\n", | |
| " <th>train_loss</th>\n", | |
| " <th>valid_loss</th>\n", | |
| " <th>ssim_fastai</th>\n", | |
| " <th>psnr_fastai</th>\n", | |
| " <th>id_loss</th>\n", | |
| " <th>gen_loss</th>\n", | |
| " <th>cyc_loss</th>\n", | |
| " <th>D_A_loss</th>\n", | |
| " <th>D_B_loss</th>\n", | |
| " <th>time</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <td>0</td>\n", | |
| " <td>6.561883</td>\n", | |
| " <td>5.135588</td>\n", | |
| " <td>0.473679</td>\n", | |
| " <td>9.387200</td>\n", | |
| " <td>1.884625</td>\n", | |
| " <td>0.796862</td>\n", | |
| " <td>3.880396</td>\n", | |
| " <td>0.324226</td>\n", | |
| " <td>0.321400</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1</td>\n", | |
| " <td>5.647283</td>\n", | |
| " <td>4.936441</td>\n", | |
| " <td>0.465910</td>\n", | |
| " <td>10.266727</td>\n", | |
| " <td>1.607140</td>\n", | |
| " <td>0.716910</td>\n", | |
| " <td>3.323231</td>\n", | |
| " <td>0.280169</td>\n", | |
| " <td>0.277287</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>2</td>\n", | |
| " <td>5.211876</td>\n", | |
| " <td>4.717415</td>\n", | |
| " <td>0.565417</td>\n", | |
| " <td>9.830503</td>\n", | |
| " <td>1.462108</td>\n", | |
| " <td>0.719012</td>\n", | |
| " <td>3.030756</td>\n", | |
| " <td>0.264041</td>\n", | |
| " <td>0.256112</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>3</td>\n", | |
| " <td>4.929071</td>\n", | |
| " <td>7.111959</td>\n", | |
| " <td>0.502945</td>\n", | |
| " <td>8.873150</td>\n", | |
| " <td>1.369080</td>\n", | |
| " <td>0.718374</td>\n", | |
| " <td>2.841614</td>\n", | |
| " <td>0.286366</td>\n", | |
| " <td>0.260715</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>4</td>\n", | |
| " <td>4.971952</td>\n", | |
| " <td>4.743098</td>\n", | |
| " <td>0.542775</td>\n", | |
| " <td>9.094391</td>\n", | |
| " <td>1.351116</td>\n", | |
| " <td>0.769224</td>\n", | |
| " <td>2.851610</td>\n", | |
| " <td>0.250340</td>\n", | |
| " <td>0.247486</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>5</td>\n", | |
| " <td>4.926872</td>\n", | |
| " <td>5.141806</td>\n", | |
| " <td>0.515856</td>\n", | |
| " <td>9.646178</td>\n", | |
| " <td>1.288682</td>\n", | |
| " <td>0.873950</td>\n", | |
| " <td>2.764238</td>\n", | |
| " <td>0.273015</td>\n", | |
| " <td>0.241806</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>6</td>\n", | |
| " <td>4.840660</td>\n", | |
| " <td>4.774198</td>\n", | |
| " <td>0.530930</td>\n", | |
| " <td>8.995791</td>\n", | |
| " <td>1.263278</td>\n", | |
| " <td>0.897851</td>\n", | |
| " <td>2.679529</td>\n", | |
| " <td>0.246634</td>\n", | |
| " <td>0.240409</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>7</td>\n", | |
| " <td>4.715111</td>\n", | |
| " <td>5.307607</td>\n", | |
| " <td>0.485308</td>\n", | |
| " <td>8.678684</td>\n", | |
| " <td>1.200815</td>\n", | |
| " <td>0.959975</td>\n", | |
| " <td>2.554320</td>\n", | |
| " <td>0.208429</td>\n", | |
| " <td>0.203911</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>8</td>\n", | |
| " <td>4.717269</td>\n", | |
| " <td>5.407540</td>\n", | |
| " <td>0.478808</td>\n", | |
| " <td>8.845898</td>\n", | |
| " <td>1.175990</td>\n", | |
| " <td>1.029106</td>\n", | |
| " <td>2.512172</td>\n", | |
| " <td>0.190207</td>\n", | |
| " <td>0.200359</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>9</td>\n", | |
| " <td>4.771541</td>\n", | |
| " <td>4.586983</td>\n", | |
| " <td>0.521542</td>\n", | |
| " <td>9.332252</td>\n", | |
| " <td>1.168054</td>\n", | |
| " <td>1.083328</td>\n", | |
| " <td>2.520158</td>\n", | |
| " <td>0.220892</td>\n", | |
| " <td>0.198271</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>10</td>\n", | |
| " <td>4.653002</td>\n", | |
| " <td>5.640585</td>\n", | |
| " <td>0.528441</td>\n", | |
| " <td>9.376346</td>\n", | |
| " <td>1.133995</td>\n", | |
| " <td>1.054805</td>\n", | |
| " <td>2.464202</td>\n", | |
| " <td>0.203103</td>\n", | |
| " <td>0.278971</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>11</td>\n", | |
| " <td>4.518406</td>\n", | |
| " <td>4.460074</td>\n", | |
| " <td>0.426281</td>\n", | |
| " <td>8.360112</td>\n", | |
| " <td>1.089192</td>\n", | |
| " <td>1.063930</td>\n", | |
| " <td>2.365285</td>\n", | |
| " <td>0.186630</td>\n", | |
| " <td>0.224311</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>12</td>\n", | |
| " <td>4.445597</td>\n", | |
| " <td>4.410108</td>\n", | |
| " <td>0.473550</td>\n", | |
| " <td>8.817196</td>\n", | |
| " <td>1.074073</td>\n", | |
| " <td>1.043954</td>\n", | |
| " <td>2.327569</td>\n", | |
| " <td>0.207671</td>\n", | |
| " <td>0.253871</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>13</td>\n", | |
| " <td>4.358193</td>\n", | |
| " <td>4.004363</td>\n", | |
| " <td>0.533533</td>\n", | |
| " <td>9.450353</td>\n", | |
| " <td>1.067293</td>\n", | |
| " <td>0.979354</td>\n", | |
| " <td>2.311545</td>\n", | |
| " <td>0.214134</td>\n", | |
| " <td>0.234769</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>14</td>\n", | |
| " <td>4.156375</td>\n", | |
| " <td>3.787990</td>\n", | |
| " <td>0.524728</td>\n", | |
| " <td>9.441799</td>\n", | |
| " <td>1.031719</td>\n", | |
| " <td>0.918310</td>\n", | |
| " <td>2.206345</td>\n", | |
| " <td>0.236098</td>\n", | |
| " <td>0.212539</td>\n", | |
| " <td>00:29</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>15</td>\n", | |
| " <td>4.053391</td>\n", | |
| " <td>3.995172</td>\n", | |
| " <td>0.527359</td>\n", | |
| " <td>10.342698</td>\n", | |
| " <td>1.005355</td>\n", | |
| " <td>0.885043</td>\n", | |
| " <td>2.162991</td>\n", | |
| " <td>0.240513</td>\n", | |
| " <td>0.207699</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>16</td>\n", | |
| " <td>4.067924</td>\n", | |
| " <td>3.625066</td>\n", | |
| " <td>0.502665</td>\n", | |
| " <td>9.238791</td>\n", | |
| " <td>1.003777</td>\n", | |
| " <td>0.911058</td>\n", | |
| " <td>2.153088</td>\n", | |
| " <td>0.246938</td>\n", | |
| " <td>0.191794</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>17</td>\n", | |
| " <td>4.037714</td>\n", | |
| " <td>3.711069</td>\n", | |
| " <td>0.464412</td>\n", | |
| " <td>7.298216</td>\n", | |
| " <td>1.000791</td>\n", | |
| " <td>0.896303</td>\n", | |
| " <td>2.140619</td>\n", | |
| " <td>0.249679</td>\n", | |
| " <td>0.187362</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>18</td>\n", | |
| " <td>3.857249</td>\n", | |
| " <td>4.345922</td>\n", | |
| " <td>0.544652</td>\n", | |
| " <td>9.388030</td>\n", | |
| " <td>0.948343</td>\n", | |
| " <td>0.891913</td>\n", | |
| " <td>2.016993</td>\n", | |
| " <td>0.257431</td>\n", | |
| " <td>0.198414</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>19</td>\n", | |
| " <td>3.762412</td>\n", | |
| " <td>4.109817</td>\n", | |
| " <td>0.490418</td>\n", | |
| " <td>9.542933</td>\n", | |
| " <td>0.926382</td>\n", | |
| " <td>0.875323</td>\n", | |
| " <td>1.960707</td>\n", | |
| " <td>0.261851</td>\n", | |
| " <td>0.210432</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>20</td>\n", | |
| " <td>3.702829</td>\n", | |
| " <td>3.526021</td>\n", | |
| " <td>0.464647</td>\n", | |
| " <td>8.307538</td>\n", | |
| " <td>0.898579</td>\n", | |
| " <td>0.908769</td>\n", | |
| " <td>1.895481</td>\n", | |
| " <td>0.261681</td>\n", | |
| " <td>0.217534</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>21</td>\n", | |
| " <td>3.622834</td>\n", | |
| " <td>3.706043</td>\n", | |
| " <td>0.502107</td>\n", | |
| " <td>10.158941</td>\n", | |
| " <td>0.875865</td>\n", | |
| " <td>0.904829</td>\n", | |
| " <td>1.842140</td>\n", | |
| " <td>0.240720</td>\n", | |
| " <td>0.208291</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>22</td>\n", | |
| " <td>3.538983</td>\n", | |
| " <td>3.435079</td>\n", | |
| " <td>0.546980</td>\n", | |
| " <td>9.465912</td>\n", | |
| " <td>0.851039</td>\n", | |
| " <td>0.919320</td>\n", | |
| " <td>1.768624</td>\n", | |
| " <td>0.258890</td>\n", | |
| " <td>0.202507</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>23</td>\n", | |
| " <td>3.513905</td>\n", | |
| " <td>3.234569</td>\n", | |
| " <td>0.508788</td>\n", | |
| " <td>8.967988</td>\n", | |
| " <td>0.841775</td>\n", | |
| " <td>0.934261</td>\n", | |
| " <td>1.737869</td>\n", | |
| " <td>0.246404</td>\n", | |
| " <td>0.231526</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>24</td>\n", | |
| " <td>3.515435</td>\n", | |
| " <td>3.423039</td>\n", | |
| " <td>0.495962</td>\n", | |
| " <td>8.955706</td>\n", | |
| " <td>0.829034</td>\n", | |
| " <td>0.963347</td>\n", | |
| " <td>1.723054</td>\n", | |
| " <td>0.217231</td>\n", | |
| " <td>0.240800</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>25</td>\n", | |
| " <td>3.495776</td>\n", | |
| " <td>3.450833</td>\n", | |
| " <td>0.489034</td>\n", | |
| " <td>9.838599</td>\n", | |
| " <td>0.812738</td>\n", | |
| " <td>1.014734</td>\n", | |
| " <td>1.668304</td>\n", | |
| " <td>0.194369</td>\n", | |
| " <td>0.248123</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>26</td>\n", | |
| " <td>3.449779</td>\n", | |
| " <td>3.365834</td>\n", | |
| " <td>0.487408</td>\n", | |
| " <td>9.713398</td>\n", | |
| " <td>0.795532</td>\n", | |
| " <td>1.049908</td>\n", | |
| " <td>1.604338</td>\n", | |
| " <td>0.181287</td>\n", | |
| " <td>0.248181</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>27</td>\n", | |
| " <td>3.393342</td>\n", | |
| " <td>3.238819</td>\n", | |
| " <td>0.534515</td>\n", | |
| " <td>9.728602</td>\n", | |
| " <td>0.766979</td>\n", | |
| " <td>1.090326</td>\n", | |
| " <td>1.536036</td>\n", | |
| " <td>0.172006</td>\n", | |
| " <td>0.235865</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>28</td>\n", | |
| " <td>3.330395</td>\n", | |
| " <td>3.262840</td>\n", | |
| " <td>0.526032</td>\n", | |
| " <td>9.696510</td>\n", | |
| " <td>0.745259</td>\n", | |
| " <td>1.104066</td>\n", | |
| " <td>1.481069</td>\n", | |
| " <td>0.159359</td>\n", | |
| " <td>0.235580</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>29</td>\n", | |
| " <td>3.290842</td>\n", | |
| " <td>3.225084</td>\n", | |
| " <td>0.528845</td>\n", | |
| " <td>9.593530</td>\n", | |
| " <td>0.723691</td>\n", | |
| " <td>1.135712</td>\n", | |
| " <td>1.431439</td>\n", | |
| " <td>0.151975</td>\n", | |
| " <td>0.231465</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "learn.fit_fa(lr=2e-4,n_epochs=15,n_epochs_decay=15)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn.save('30fit')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's look at some results using `Learner.show_results`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn.show_results(ds_type=DatasetType.Train, rows=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now let's go through all the images of the training set and find the ones that are the best converted (according to our critics) or the worst converted." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(100, 344)" | |
| ] | |
| }, | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(learn.data.train_ds.items),len(learn.data.train_ds.itemsB)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_batch(filenames, tfms, **kwargs):\n", | |
| " samples = [open_image(fn) for fn in filenames]\n", | |
| " for s in samples: s = s.apply_tfms(tfms, **kwargs)\n", | |
| " batch = torch.stack([s.data for s in samples], 0).cuda()\n", | |
| " return 2. * (batch - 0.5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "fnames = learn.data.train_ds.items[:8]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x = get_batch(fnames, get_transforms()[1], size=128)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn.model.eval()\n", | |
| "tfms = get_transforms()[1]\n", | |
| "bs = 16" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_losses(fnames, gen, crit, bs=16):\n", | |
| " losses_in,losses_out = [],[]\n", | |
| " with torch.no_grad():\n", | |
| " for i in progress_bar(range(0, len(fnames), bs)):\n", | |
| " xb = get_batch(fnames[i:i+bs], tfms, size=128)\n", | |
| " fakes = gen(xb)\n", | |
| " preds_in,preds_out = crit(xb),crit(fakes)\n", | |
| " loss_in = learn.loss_func.crit(preds_in, True,reduction='none')\n", | |
| " loss_out = learn.loss_func.crit(preds_out,True,reduction='none')\n", | |
| " losses_in.append(loss_in.view(loss_in.size(0),-1).mean(1))\n", | |
| " losses_out.append(loss_out.view(loss_out.size(0),-1).mean(1))\n", | |
| " return torch.cat(losses_in),torch.cat(losses_out)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <div>\n", | |
| " <style>\n", | |
| " /* Turns off some styling */\n", | |
| " progress {\n", | |
| " /* gets rid of default border in Firefox and Opera. */\n", | |
| " border: none;\n", | |
| " /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
| " background-size: auto;\n", | |
| " }\n", | |
| " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
| " background: #F44336;\n", | |
| " }\n", | |
| " </style>\n", | |
| " <progress value='7' class='' max='7', style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
| " 100.00% [7/7 00:03<00:00]\n", | |
| " </div>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "losses_A = get_losses(data.train_ds.x.items, learn.model.G_B, learn.model.D_B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <div>\n", | |
| " <style>\n", | |
| " /* Turns off some styling */\n", | |
| " progress {\n", | |
| " /* gets rid of default border in Firefox and Opera. */\n", | |
| " border: none;\n", | |
| " /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
| " background-size: auto;\n", | |
| " }\n", | |
| " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
| " background: #F44336;\n", | |
| " }\n", | |
| " </style>\n", | |
| " <progress value='22' class='' max='22', style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
| " 100.00% [22/22 00:13<00:00]\n", | |
| " </div>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "losses_B = get_losses(data.train_ds.x.itemsB, learn.model.G_A, learn.model.D_A)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def show_best(fnames, losses, gen, n=8):\n", | |
| " sort_idx = losses.argsort()\n", | |
| " _,axs = plt.subplots(n//2, 4, figsize=(12,2*n))\n", | |
| " xb = get_batch(fnames[sort_idx][:n], tfms, size=128)\n", | |
| " with torch.no_grad():\n", | |
| " fakes = gen(xb)\n", | |
| " xb,fakes = (1+xb)/2,(1+fakes)/2\n", | |
| " for i in range(n):\n", | |
| " axs.flatten()[2*i].imshow(xb[i].permute(1,2,0).cpu())\n", | |
| " axs.flatten()[2*i].axis('off')\n", | |
| " axs.flatten()[2*i+1].imshow(fakes[i].permute(1,2,0).cpu())\n", | |
| " axs.flatten()[2*i+1].set_title(losses[sort_idx][i].item())\n", | |
| " axs.flatten()[2*i+1].axis('off')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "show_best(data.train_ds.x.items, losses_A[1].cpu(), learn.model.G_B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "show_best(data.train_ds.x.itemsB, losses_B[1].cpu(), learn.model.G_A)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Save Image Predictions (Fake images)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from torch.utils.data import Dataset, DataLoader\n", | |
| "import torchvision\n", | |
| "import glob\n", | |
| "\n", | |
| "class FolderDataset(Dataset):\n", | |
| " def __init__(self, path,transforms=None):\n", | |
| " self.files = glob.glob(path+'/*')\n", | |
| " self.totensor = torchvision.transforms.ToTensor()\n", | |
| " if transforms:\n", | |
| " self.transform = torchvision.transforms.Compose(transforms)\n", | |
| " else:\n", | |
| " self.transform = lambda x: x\n", | |
| " \n", | |
| " def __len__(self):\n", | |
| " return len(self.files)\n", | |
| "\n", | |
| " def __getitem__(self, idx):\n", | |
| " image = PIL.Image.open(self.files[idx % len(self.files)])\n", | |
| " image = self.totensor(image)\n", | |
| " image = self.transform(image)\n", | |
| " return self.files[idx], image\n", | |
| "\n", | |
| "def load_dataset(test_path):\n", | |
| " dataset = FolderDataset(\n", | |
| " path=test_path,\n", | |
| " #transforms=[torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", | |
| " ) \n", | |
| " loader = torch.utils.data.DataLoader(\n", | |
| " dataset,\n", | |
| " batch_size=2,\n", | |
| " num_workers=4,\n", | |
| " shuffle=True\n", | |
| " )\n", | |
| " return loader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import tqdm\n", | |
| "def get_preds_cyclegan(learn,test_path,pred_path,suffix='png'):\n", | |
| " \n", | |
| " assert os.path.exists(test_path)\n", | |
| " \n", | |
| " if not os.path.exists(pred_path):\n", | |
| " os.mkdir(pred_path)\n", | |
| " \n", | |
| " model = learn.model.G_A\n", | |
| " \n", | |
| " test_dl = load_dataset(test_path)\n", | |
| " \n", | |
| " for i, xb in tqdm.tqdm(enumerate(test_dl),total=len(test_dl)):\n", | |
| " fn, im = xb\n", | |
| " preds = (learn.model.G_B(im.cuda())/2 + 0.5)\n", | |
| " for i in range(len(fn)):\n", | |
| " new_fn = os.path.join(pred_path,'.'.join([os.path.basename(fn[i]).split('.')[0]+'_fakeB',suffix])) \n", | |
| " torchvision.utils.save_image(preds[i],new_fn)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 50/50 [00:12<00:00, 3.86it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "get_preds_cyclegan(learn,str(muse2he_path/'testA'),'./preds')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment