Created
January 8, 2022 20:46
-
-
Save SharanSMenon/53bc826c286aafb1f097ae4511f98ec6 to your computer and use it in GitHub Desktop.
An implementation of Retinanet.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "a2c3541f-8c5f-490d-8ada-2d9264a72074", | |
| "metadata": {}, | |
| "source": [ | |
| "# RetinaNet Implementation in PyTorch\n", | |
| "\n", | |
| "Implementation of the following paper: [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "cc7a980f-fddb-458f-b48e-fd0f12de3e58", | |
| "metadata": {}, | |
| "source": [ | |
| "## Imports" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "8f124e21-1f07-4bfd-89e0-fdff09aa4e0d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import math\n", | |
| "import copy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "6e523509-e024-4df9-82eb-c64df1d60110", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "from torch import nn, optim\n", | |
| "from torch.nn import functional as F\n", | |
| "from torch.utils.data import DataLoader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "c3bfdb2b-b2a2-4570-93c2-9a9a89a5615e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torchvision\n", | |
| "from torchvision import transforms, datasets\n", | |
| "from torchvision.transforms import functional as FT\n", | |
| "from torchvision.transforms import transforms as T" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "feb9f59e-1e07-4810-9408-58b23cf27c18", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from PIL import Image\n", | |
| "import os\n", | |
| "import cv2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "f13fa1a6-d707-4f58-9202-9f65fa67a41f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from tqdm.notebook import tqdm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "8318d1c0-75da-44ab-9f01-521d7b0fd738", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "('1.9.0', '0.10.0')" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.__version__, torchvision.__version__" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "bd6b3cf8-65e3-436d-bf20-c553dd592f32", | |
| "metadata": {}, | |
| "source": [ | |
| "## Transforms and Utilities" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "df785a0e-d026-4e2d-b2d6-f95b81db6757", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Compose:\n", | |
| " def __init__(self, transforms):\n", | |
| " self.transforms = transforms\n", | |
| "\n", | |
| " def __call__(self, image, target):\n", | |
| " for t in self.transforms:\n", | |
| " image, target = t(image, target)\n", | |
| " return image, target" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "871e66f6-bc56-4095-90f0-2a5f603fd29e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Normalizer(object):\n", | |
| "\n", | |
| " def __init__(self):\n", | |
| " self.mean = [0.485, 0.456, 0.406]\n", | |
| " self.std = [0.229, 0.224, 0.225]\n", | |
| " self.normalize = T.Compose([T.Normalize(mean=self.mean, std=self.std)])\n", | |
| "\n", | |
| " def __call__(self, image, target):\n", | |
| "\n", | |
| " return self.normalize(image), target" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "1d953822-6c79-4511-bd1c-19a84633f703", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Resize(object):\n", | |
| " def __init__(self, size=400):\n", | |
| " self.size = size\n", | |
| " def __call__(self, img, target):\n", | |
| " size = self.size\n", | |
| " boxes = [t['bbox'] for t in target]\n", | |
| " w, h = img.size\n", | |
| " if isinstance(size, int):\n", | |
| " size_min = min(w,h)\n", | |
| " size_max = max(w,h)\n", | |
| " sw = sh = float(size) / size_min\n", | |
| " if sw * size_max > 800:\n", | |
| " sw = sh = float(800) / size_max\n", | |
| " ow = int(w * sw + 0.5)\n", | |
| " oh = int(h * sh + 0.5)\n", | |
| " else:\n", | |
| " ow, oh = size\n", | |
| " sw = float(ow) / w\n", | |
| " sh = float(oh) / h\n", | |
| " boxes = (torch.FloatTensor(boxes)*torch.Tensor([sw,sh,sw,sh])).tolist()\n", | |
| " for t in range(len(target)):\n", | |
| " target[t]['bbox'] = boxes[t]\n", | |
| " return img.resize((ow,oh), Image.BILINEAR), target" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "b3a9d7e4-9fee-4ed0-a205-892721488fb8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ToTensor(nn.Module):\n", | |
| " def forward(\n", | |
| " self, image, target = None\n", | |
| " ):\n", | |
| " image = FT.pil_to_tensor(image)\n", | |
| " image = FT.convert_image_dtype(image)\n", | |
| " return image, target" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "246f31f3-e29a-4b23-9f14-7c15468f94e8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class PILToTensor(nn.Module):\n", | |
| " def forward(\n", | |
| " self, image, target = None\n", | |
| " ):\n", | |
| " image = FT.pil_to_tensor(image)\n", | |
| " return image, target" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0a2ecfe0-6bda-4ff0-8de9-d8f5cfacb5ba", | |
| "metadata": {}, | |
| "source": [ | |
| "## Dataset" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "bb23610c-6a83-4236-9cae-0cfbad97d9de", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#### COLAB LOADER ####\n", | |
| "# !curl -L \"https://public.roboflow.com/ds/L6PD1uTSPF?key=Gq3tCeIqHA\" > roboflow.zip; unzip roboflow.zip; rm roboflow.zip\n", | |
| "# Use for colab only #\n", | |
| "######################" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "57619344-33c5-40fc-8d34-ef7529487d00", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from pycocotools.coco import COCO" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "c7bab354-9b31-497d-834c-1ebf5d0ecf9f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dataset_path = \"/Volumes/Samsung_T5/Documents/MachineLearning/machine_learning_notebooks/pytorch/aquarium-dataset/Aquarium Combined/\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "96370ec5-bc2d-475c-9b94-8a228cf3c982", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loading annotations into memory...\n", | |
| "Done (t=0.02s)\n", | |
| "creating index...\n", | |
| "index created!\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "8" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "coc = COCO(os.path.join(dataset_path, \"train\", \"_annotations.coco.json\"))\n", | |
| "categories = coc.cats\n", | |
| "n_classes = len(categories.keys())\n", | |
| "n_classes" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "12a71e96-e1cf-41ea-b8d0-5ce1b14d3ff7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def xyxy_2_xywh(boxes):\n", | |
| " a = torch.FloatTensor(boxes[:,:2])\n", | |
| " b = torch.FloatTensor(boxes[:,2:])\n", | |
| " return torch.cat([(a+b)/2,b-a+1], 1)\n", | |
| " \n", | |
| "def xywh_2_xyxy(boxes):\n", | |
| " a = torch.FloatTensor(boxes[:,:2])\n", | |
| " b = torch.FloatTensor(boxes[:,2:])\n", | |
| " return torch.cat([a-b/2,a+b/2], 1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "826dfb96-824a-44e5-9268-249c0756e7ae", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def box_nms(bboxes, scores, threshold=0.5, mode='union'):\n", | |
| " \n", | |
| " x1 = bboxes[:,0]\n", | |
| " y1 = bboxes[:,1]\n", | |
| " x2 = bboxes[:,2]\n", | |
| " y2 = bboxes[:,3]\n", | |
| "\n", | |
| " areas = (x2-x1+1) * (y2-y1+1)\n", | |
| " _, order = scores.sort(0, descending=True)\n", | |
| "\n", | |
| " keep = []\n", | |
| " while order.numel() > 0:\n", | |
| " if order.numel() == 1:\n", | |
| " keep.append(order.item())\n", | |
| " break\n", | |
| " \n", | |
| " i = order[0]\n", | |
| " keep.append(i)\n", | |
| "\n", | |
| " xx1 = x1[order[1:]].clamp(min=x1[i])\n", | |
| " yy1 = y1[order[1:]].clamp(min=y1[i])\n", | |
| " xx2 = x2[order[1:]].clamp(max=x2[i])\n", | |
| " yy2 = y2[order[1:]].clamp(max=y2[i])\n", | |
| "\n", | |
| " w = (xx2-xx1+1).clamp(min=0)\n", | |
| " h = (yy2-yy1+1).clamp(min=0)\n", | |
| " inter = w*h\n", | |
| "\n", | |
| " if mode == 'union':\n", | |
| " ovr = inter / (areas[i] + areas[order[1:]] - inter)\n", | |
| " elif mode == 'min':\n", | |
| " ovr = inter / areas[order[1:]].clamp(max=areas[i])\n", | |
| " else:\n", | |
| " raise TypeError('Unknown nms mode: %s.' % mode)\n", | |
| "\n", | |
| " ids = (ovr<=threshold).nonzero().squeeze()\n", | |
| " if ids.numel() == 0:\n", | |
| " break\n", | |
| " order = order[ids+1]\n", | |
| " return torch.LongTensor(keep)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "88ac1227-a417-403a-8cac-6b04faadc851", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def iou(box1, box2, order=\"xyxy\"):\n", | |
| " if order == \"xywh\":\n", | |
| " box1 = xywh_2_xyxy(box1)\n", | |
| " box2 = xywh_2_xyxy(box2)\n", | |
| " N = box1.size(0)\n", | |
| " M = box2.size(0)\n", | |
| "\n", | |
| " lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2]\n", | |
| " rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2]\n", | |
| "\n", | |
| " wh = (rb-lt+1).clamp(min=0) # [N,M,2]\n", | |
| " inter = wh[:,:,0] * wh[:,:,1] # [N,M]\n", | |
| "\n", | |
| " area1 = (box1[:,2]-box1[:,0]+1) * (box1[:,3]-box1[:,1]+1) # [N,]\n", | |
| " area2 = (box2[:,2]-box2[:,0]+1) * (box2[:,3]-box2[:,1]+1) # [M,]\n", | |
| " iou = inter / (area1[:,None] + area2 - inter)\n", | |
| " return iou" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "f313d08d-9c3b-400b-9b8b-28cbd5ccdec9", | |
| "metadata": {}, | |
| "source": [ | |
| "### Anchor Boxes\n", | |
| "\n", | |
| "\"*Anchor boxes have areas of $32^2$ to $512^2$ on pyramid levels $P_3$ to $P_7$.*\" (Page 4, Focal Loss for Dense Object Detection)\n", | |
| "\n", | |
| "- Aspect ratios: $\\{1:2, 1:1, 2:1\\}$, translates to `[0.5, 1, 2]` in python\n", | |
| "- Scales: $\\{2^0, 2^{1/3}, 2^{2/3}\\}$\n", | |
| "\n", | |
| "There should be a total of $A=9$ anchors per level" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "359ef414-00f1-487d-9269-6d2316b10e80", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class AnchorBox():\n", | |
| " \"\"\"\n", | |
| " Generate anchor boxes for level 3 to level 8\n", | |
| " \"\"\"\n", | |
| " def __init__(self):\n", | |
| " self.ratios = [0.5, 1, 2]\n", | |
| " self.scales = [1, 2**(1/3), 2**(2/3)]\n", | |
| " \n", | |
| " self.A = len(self.ratios) * len(self.scales) # number of anchors (from paper)\n", | |
| " self.areas = [x**2 for x in [32, 64, 128, 256, 512]] # P3, P4, P5, P6, P7\n", | |
| " self.strides = [2 ** i for i in range(3, 8)] # Each layer's feature map is 2^l smaller than the input\n", | |
| " self.anchor_dims = self._anchor_dims()\n", | |
| " ## for feature map sizes\n", | |
| " \n", | |
| " def _meshgrid(self, x, y, row_major=True):\n", | |
| " a = torch.arange(0,x)\n", | |
| " b = torch.arange(0,y)\n", | |
| " xx = a.repeat(y).view(-1,1)\n", | |
| " yy = b.view(-1,1).repeat(1,x).view(-1,1)\n", | |
| " return torch.cat([xx,yy],1) if row_major else torch.cat([yy,xx],1)\n", | |
| " \n", | |
| " def _anchor_dims(self):\n", | |
| " anchor_dims = []\n", | |
| " for area in self.areas:\n", | |
| " for ratio in self.ratios:\n", | |
| " anchor_height = math.sqrt(area / ratio)\n", | |
| " anchor_width = area / anchor_height\n", | |
| " \n", | |
| " for scale in self.scales:\n", | |
| " anchor_width = anchor_width * scale\n", | |
| " anchor_height = anchor_height * scale\n", | |
| " anchor_dims.append([anchor_width, anchor_height])\n", | |
| " return torch.FloatTensor(anchor_dims).view(len(self.areas), -1, 2)\n", | |
| " \n", | |
| " def generate_anchor_boxes(self, input_size):\n", | |
| " \"\"\"\n", | |
| " Generates Anchor Boxes\n", | |
| " \n", | |
| " input_size: torch.Tensor: (w, h)\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " num_feature_maps = len(self.areas)\n", | |
| " feature_map_sizes = [(input_size / stride).ceil() for stride in self.strides] # calculating feature map sizes of p3 to p7\n", | |
| " boxes = []\n", | |
| " for i in range(num_feature_maps):\n", | |
| " fm_size = feature_map_sizes[i]\n", | |
| " grid_size = input_size / fm_size\n", | |
| " fm_w, fm_h = int(fm_size[0]), int(fm_size[1])\n", | |
| " xy = self._meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2]\n", | |
| " xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2)\n", | |
| " wh = self.anchor_dims[i].view(1,1,9,2).expand(fm_h,fm_w,9,2)\n", | |
| " box = torch.cat([xy,wh], 3) # [x,y,w,h]\n", | |
| " boxes.append(box.view(-1,4))\n", | |
| " return torch.cat(boxes, 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "2d71e06b-3ce1-41f1-afae-034a62bc98b6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Encoder:\n", | |
| " def __init__(self):\n", | |
| " self.anchor_box = AnchorBox()\n", | |
| " def encode(self, boxes, labels, input_size):\n", | |
| " input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \\\n", | |
| " else torch.Tensor(input_size)\n", | |
| " anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n", | |
| "# boxes = xyxy_2_xywh(boxes)\n", | |
| " boxes = torch.FloatTensor(boxes)\n", | |
| " \n", | |
| " ious = iou(anchor_boxes, boxes, order=\"xywh\")\n", | |
| " max_ious, max_ids = ious.max(1)\n", | |
| " boxes = boxes[max_ids]\n", | |
| " \n", | |
| " loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:]\n", | |
| " loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:])\n", | |
| " loc_targets = torch.cat([loc_xy,loc_wh], 1)\n", | |
| " cls_targets = 1 + labels[max_ids]\n", | |
| "\n", | |
| " cls_targets[max_ious<0.5] = 0\n", | |
| " ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5]\n", | |
| " cls_targets[ignore] = -1 # for now just mark ignored to -1\n", | |
| " return loc_targets, cls_targets\n", | |
| " \n", | |
| " def decode(self, loc_preds, cls_preds, input_size):\n", | |
| " input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) else torch.Tensor(input_size)\n", | |
| " \n", | |
| " CLS_THRESH = 0.5\n", | |
| " NMS_THRESH = 0.5\n", | |
| " anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n", | |
| "\n", | |
| " loc_xy = loc_preds[:,:2]\n", | |
| " loc_wh = loc_preds[:,2:]\n", | |
| "\n", | |
| " xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2]\n", | |
| " wh = loc_wh.exp() * anchor_boxes[:,2:]\n", | |
| " boxes = torch.cat([xy-wh/2, xy+wh/2], 1)\n", | |
| "\n", | |
| " score, labels = cls_preds.sigmoid().max(1)\n", | |
| " ids = score > CLS_THRESH\n", | |
| " ids = ids.nonzero().squeeze()\n", | |
| " keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH)\n", | |
| " return boxes[ids][keep], labels[ids][keep]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 53, | |
| "id": "91eb2fc1-4fe1-4558-8742-619b89286361", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class AquariumDetection(datasets.VisionDataset):\n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " root: str,\n", | |
| " split = \"train\",\n", | |
| " transform= None,\n", | |
| " target_transform = None,\n", | |
| " transforms = None,\n", | |
| " ) -> None:\n", | |
| " super().__init__(root, transforms, transform, target_transform)\n", | |
| " self.split = split\n", | |
| " self.coco = COCO(os.path.join(root, split, \"_annotations.coco.json\"))\n", | |
| " self.ids = list(sorted(self.coco.imgs.keys()))\n", | |
| " self.ids = [id for id in self.ids if (len(self._load_target(id)) > 0)]\n", | |
| "\n", | |
| " def _load_image(self, id: int) -> Image.Image:\n", | |
| " path = self.coco.loadImgs(id)[0][\"file_name\"]\n", | |
| " img = Image.open(os.path.join(self.root, self.split, path)).convert(\"RGB\")\n", | |
| " return img\n", | |
| "\n", | |
| " def _load_target(self, id: int):\n", | |
| " return self.coco.loadAnns(self.coco.getAnnIds(id))\n", | |
| "\n", | |
| " def __getitem__(self, index: int):\n", | |
| " id = self.ids[index]\n", | |
| " image = self._load_image(id)\n", | |
| " target = copy.deepcopy(self._load_target(id))\n", | |
| "\n", | |
| " if self.transforms is not None:\n", | |
| " image, target = self.transforms(image, target)\n", | |
| " \n", | |
| " annot = [t[\"bbox\"] + [t[\"category_id\"]] for t in target]\n", | |
| "\n", | |
| " return image, annot\n", | |
| "\n", | |
| "\n", | |
| " def __len__(self) -> int:\n", | |
| " return len(self.ids)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 54, | |
| "id": "3701d805-c03a-410c-a51a-0f7305d1cf35", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def collate_fn(batch):\n", | |
| " \"\"\"\n", | |
| " The images in the dataset will be of different sizes. This function takes the images and pads them. Then we encode the images.\n", | |
| " \"\"\"\n", | |
| " imgs = [x[0] for x in batch]\n", | |
| " annots = np.array([x[1] for x in batch], dtype=object)\n", | |
| "\n", | |
| " widths = [int(s.shape[1]) for s in imgs]\n", | |
| " heights = [int(s.shape[2]) for s in imgs]\n", | |
| " batch_size = len(imgs)\n", | |
| "\n", | |
| " max_width = np.array(widths).max()\n", | |
| " max_height = np.array(heights).max()\n", | |
| "\n", | |
| " padded_imgs = torch.zeros(batch_size, max_width, max_height, 3)\n", | |
| "\n", | |
| " for i in range(batch_size):\n", | |
| " img = imgs[i]\n", | |
| " padded_imgs[i, :int(img.shape[1]), :int(img.shape[2]), :] = img.permute(1, 2, 0)\n", | |
| " padded_imgs = padded_imgs.permute(0, 3, 1, 2)\n", | |
| " \n", | |
| " ## Encode ##\n", | |
| " encoder = Encoder()\n", | |
| " loc_targets = []\n", | |
| " cls_targets = []\n", | |
| " for i in range(len(imgs)):\n", | |
| " annot = annots[i]\n", | |
| " boxes = np.array(annot)[:, 0:4]\n", | |
| " labels = np.array(annot)[:, 4]\n", | |
| " image = padded_imgs[i]\n", | |
| " loc_target, cls_target = encoder.encode(boxes, labels, (image.shape[1], image.shape[2]))\n", | |
| " loc_targets.append(torch.FloatTensor(loc_target))\n", | |
| " cls_targets.append(torch.FloatTensor(cls_target))\n", | |
| " return {'img': padded_imgs, 'loc_targets': torch.stack(loc_targets), 'cls_targets': torch.stack(cls_targets)}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 55, | |
| "id": "351a5ec0-56e9-4c56-a1b0-7adb6543e511", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_transform(train):\n", | |
| " transforms = []\n", | |
| " transforms.append(Resize(size=300))\n", | |
| " transforms.append(ToTensor())\n", | |
| " transforms.append(Normalizer())\n", | |
| " return Compose(transforms)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 56, | |
| "id": "c693c3ab-0176-4c06-b2a7-309d24d8bb31", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loading annotations into memory...\n", | |
| "Done (t=0.02s)\n", | |
| "creating index...\n", | |
| "index created!\n", | |
| "loading annotations into memory...\n", | |
| "Done (t=0.00s)\n", | |
| "creating index...\n", | |
| "index created!\n", | |
| "loading annotations into memory...\n", | |
| "Done (t=0.00s)\n", | |
| "creating index...\n", | |
| "index created!\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "train_dataset = AquariumDetection(root=dataset_path, transforms=get_transform(True))\n", | |
| "val_dataset = AquariumDetection(root=dataset_path, split=\"valid\", transforms=get_transform(False))\n", | |
| "test_dataset = AquariumDetection(root=dataset_path, split=\"test\", transforms=get_transform(False))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "6a890ca3-b3a6-4f6a-a4f3-6fdf3a518cb4", | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_fn)\n", | |
| "val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)\n", | |
| "test_loader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 60, | |
| "id": "a5e86435-ff29-4cbe-bc47-078dc15c8de7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(56, 16, 8)" | |
| ] | |
| }, | |
| "execution_count": 60, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(train_loader), len(val_loader), len(test_loader)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 59, | |
| "id": "d9332820-f53f-4a3a-b6cc-aa8c0df89596", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "for i in range(len(train_dataset)):\n", | |
| " _ = train_dataset[i]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d4cf6818-a90d-435c-b88c-8c6b07f03248", | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "source": [ | |
| "## Retinanet Implementation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 167, | |
| "id": "ad44cca1-4e59-4058-9ddf-e9743650fb58", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n", | |
| " \"\"\"3x3 convolution with padding\"\"\"\n", | |
| " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", | |
| " padding=dilation, groups=groups, bias=False, dilation=dilation)\n", | |
| "\n", | |
| "\n", | |
| "def conv1x1(in_planes, out_planes, stride=1):\n", | |
| " \"\"\"1x1 convolution\"\"\"\n", | |
| " return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 168, | |
| "id": "cb82dba8-cc96-49f0-aba1-fbf83775b8e8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Bottleneck(nn.Module):\n", | |
| " expansion = 4\n", | |
| " def __init__(self, inplanes, planes, stride=1, groups=1,\n", | |
| " base_width=64, dilation=1):\n", | |
| " super(Bottleneck, self).__init__()\n", | |
| " norm_layer = nn.BatchNorm2d\n", | |
| " width = int(planes * (base_width / 64.)) * groups\n", | |
| " self.conv1 = conv1x1(inplanes, width)\n", | |
| " self.bn1 = norm_layer(width)\n", | |
| " self.conv2 = conv3x3(width, width, stride, groups, dilation)\n", | |
| " self.bn2 = norm_layer(width)\n", | |
| " self.conv3 = conv1x1(width, planes * self.expansion)\n", | |
| " self.bn3 = norm_layer(planes * self.expansion)\n", | |
| " self.relu = nn.ReLU(inplace=True)\n", | |
| " \n", | |
| " self.downsample = nn.Sequential()\n", | |
| " if stride != 1 or inplanes != self.expansion*planes:\n", | |
| " self.downsample = nn.Sequential(\n", | |
| " nn.Conv2d(inplanes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", | |
| " nn.BatchNorm2d(self.expansion*planes)\n", | |
| " )\n", | |
| " \n", | |
| " self.stride = stride\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " identity = x\n", | |
| "\n", | |
| " out = self.conv1(x)\n", | |
| " out = self.bn1(out)\n", | |
| " out = self.relu(out)\n", | |
| "\n", | |
| " out = self.conv2(out)\n", | |
| " out = self.bn2(out)\n", | |
| " out = self.relu(out)\n", | |
| "\n", | |
| " out = self.conv3(out)\n", | |
| " out = self.bn3(out)\n", | |
| "\n", | |
| " identity = self.downsample(x)\n", | |
| "\n", | |
| " out += identity\n", | |
| " out = self.relu(out)\n", | |
| "\n", | |
| " return out\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 169, | |
| "id": "5ca298d4-9375-41c1-b66f-8d60640b5581", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class FPN(nn.Module):\n", | |
| " def __init__(self, block, num_blocks):\n", | |
| " super(FPN, self).__init__()\n", | |
| " self.in_planes = 64\n", | |
| " \n", | |
| " self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n", | |
| " self.bn1 = nn.BatchNorm2d(64)\n", | |
| " \n", | |
| " self.conv2 = self._make_layer(block, 64, num_blocks=num_blocks[0], stride=1)\n", | |
| " self.conv3 = self._make_layer(block, 128, num_blocks=num_blocks[1], stride=2)\n", | |
| " self.conv4 = self._make_layer(block, 256, num_blocks=num_blocks[2], stride=2)\n", | |
| " self.conv5 = self._make_layer(block, 512, num_blocks=num_blocks[3], stride=2)\n", | |
| " \n", | |
| " self.conv6 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, padding=1)\n", | |
| " self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)\n", | |
| " \n", | |
| " ## lateral layers ##\n", | |
| " self.lat1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)\n", | |
| " self.lat2 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)\n", | |
| " self.lat3 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)\n", | |
| " \n", | |
| " ## top-down layers ##\n", | |
| " self.topdown1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n", | |
| " self.topdown2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n", | |
| " \n", | |
| " self.relu = nn.ReLU()\n", | |
| " \n", | |
| " def _upsample_and_add(self, x, y):\n", | |
| " _,_,H,W = y.size()\n", | |
| " return F.upsample(x, size=(H,W), mode='bilinear') + y\n", | |
| " \n", | |
| " def _make_layer(self, block, planes, num_blocks, stride):\n", | |
| " strides = [stride] + [1]*(num_blocks-1)\n", | |
| " layers = []\n", | |
| " for stride in strides:\n", | |
| " layers.append(block(self.in_planes, planes, stride))\n", | |
| " self.in_planes = planes * block.expansion\n", | |
| " return nn.Sequential(*layers)\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " #bottom up\n", | |
| " c1 = self.relu(self.bn1(self.conv1(x)))\n", | |
| " c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)\n", | |
| " c2 = self.conv2(c1)\n", | |
| " c3 = self.conv3(c2)\n", | |
| " c4 = self.conv4(c3)\n", | |
| " c5 = self.conv5(c4)\n", | |
| " p6 = self.conv6(c5)\n", | |
| " p7 = self.conv7(p6)\n", | |
| " p5 = self.lat1(c5)\n", | |
| " p4 = self._upsample_and_add(p5, self.lat2(c4))\n", | |
| " p4 = self.topdown1(p4)\n", | |
| " p3 = self._upsample_and_add(p4, self.lat3(c3))\n", | |
| " p3 = self.topdown2(p3)\n", | |
| " return p3, p4, p5, p6, p7" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 170, | |
| "id": "5d228fae-537c-4ac6-a3b5-1e838112eb3f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ClassificationHead(nn.Module):\n", | |
| " def __init__(self, n_classes=8):\n", | |
| " super(ClassificationHead, self).__init__()\n", | |
| " self.n_anchors = 9\n", | |
| " self.n_classes = n_classes\n", | |
| " \n", | |
| " self.convnet = nn.Sequential(*[\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, self.n_classes*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n", | |
| " ])\n", | |
| " def forward(self, x):\n", | |
| " return self.convnet(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 171, | |
| "id": "4c57cc78-6615-45a3-93ce-242157536e0f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class RegressionHead(nn.Module):\n", | |
| " def __init__(self):\n", | |
| " super(RegressionHead, self).__init__()\n", | |
| " self.n_anchors = 9\n", | |
| " self.convnet = nn.Sequential(*[\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
| " nn.ReLU(True),\n", | |
| " nn.Conv2d(256, 4*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n", | |
| " ])\n", | |
| " def forward(self, x):\n", | |
| " return self.convnet(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 172, | |
| "id": "86f363cc-7c90-4be0-8522-89cc96d087a2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class RetinaNet(nn.Module):\n", | |
| " def __init__(self, n_classes=8):\n", | |
| " super(RetinaNet, self).__init__()\n", | |
| " \n", | |
| " self.fpn = FPN(Bottleneck, [3, 4, 6, 3])\n", | |
| " \n", | |
| " self.num_classes = n_classes\n", | |
| " \n", | |
| " self.classification_head = ClassificationHead(n_classes = self.num_classes) # class head\n", | |
| " self.regression_head = RegressionHead() # loc head\n", | |
| " def forward(self, x):\n", | |
| " feature_maps = self.fpn(x) #p3, p4, p5, p6, p7\n", | |
| " \n", | |
| " loc_preds = []\n", | |
| " cls_preds = []\n", | |
| " \n", | |
| " for fmap in feature_maps:\n", | |
| " loc_pred = self.regression_head(fmap)\n", | |
| " cls_pred = self.classification_head(fmap)\n", | |
| " \n", | |
| " loc_pred = loc_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,4) \n", | |
| " cls_pred = cls_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,self.num_classes) \n", | |
| " \n", | |
| " loc_preds.append(loc_pred)\n", | |
| " cls_preds.append(cls_pred)\n", | |
| " \n", | |
| " return torch.cat(loc_preds, 1), torch.cat(cls_preds, 1)\n", | |
| " \n", | |
| " def freeze_bn(self):\n", | |
| " '''Freeze BatchNorm layers.'''\n", | |
| " for layer in self.modules():\n", | |
| " if isinstance(layer, nn.BatchNorm2d):\n", | |
| " layer.eval()\n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 173, | |
| "id": "767203b7-fbdb-4708-a942-42022b019bc8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "net = RetinaNet()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 176, | |
| "id": "233b4698-e07b-4430-9c88-e3b1f1106e75", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([8, 30231, 4]), torch.Size([8, 30231, 8]))" | |
| ] | |
| }, | |
| "execution_count": 176, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "batch = next(iter(train_loader))\n", | |
| "loc_preds, cls_preds = net(batch['img'])\n", | |
| "loc_preds.shape, cls_preds.shape # The 2nd number should be the same as the number of anchors per image." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 179, | |
| "id": "6abb61c1-03e4-4530-b56b-e401965c5917", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "encoder = Encoder()\n", | |
| "# _ = encoder.decode(loc_preds[0], cls_preds[0], tuple(batch['img'].shape[2:]))\n", | |
| "## Ensure this cell just runs ##" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "414de16f-fee8-4d7e-87c3-2f9484911786", | |
| "metadata": {}, | |
| "source": [ | |
| "## Focal Loss\n", | |
| "\n", | |
| "An extension of Cross Entropy\n", | |
| "$$\n", | |
| "FL(p_t) = -\\alpha(1-p_t)^{\\gamma}log(p_t)\n", | |
| "$$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 180, | |
| "id": "ab44b1d4-fd07-4334-b84e-adb1bd4b9ee0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def one_hot_embedding(labels, num_classes):\n", | |
| " y = torch.eye(num_classes)\n", | |
| " return y[labels]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 222, | |
| "id": "fc16aa9b-f88e-461b-bdb6-3a39202a4330", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class FocalLoss(nn.Module):\n", | |
| " def __init__(self, n_classes = 8):\n", | |
| " super().__init__()\n", | |
| " self.n_classes = n_classes\n", | |
| " \n", | |
| " def focal_loss(self, x, y):\n", | |
| " alpha = -0.25\n", | |
| " gamma = 2 # Paper recommended values\n", | |
| " \n", | |
| " t = one_hot_embedding(y.cpu(), 1 + self.n_classes)\n", | |
| " t = t[:,1:]\n", | |
| " if torch.cuda.is_available():\n", | |
| " t = t.cuda()\n", | |
| " \n", | |
| " xt = x*(2*t-1)\n", | |
| " pt = (2*xt+1).sigmoid()\n", | |
| " \n", | |
| " w = alpha*t + (1-alpha)*(1-t)\n", | |
| " loss = -w*pt.log() / 2\n", | |
| " return loss.sum()\n", | |
| " \n", | |
| " def forward(self, loc_preds, loc_true, cls_preds, cls_true):\n", | |
| " batch_size, num_boxes = cls_true.size()\n", | |
| " pos = cls_true > 0\n", | |
| " num_pos = pos.long().sum()\n", | |
| " \n", | |
| " ## Loc loss\n", | |
| " mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4]\n", | |
| " masked_loc_preds = loc_preds[mask].view(-1,4) # [#pos,4]\n", | |
| " masked_loc_true = loc_true[mask].view(-1,4) # [#pos,4]\n", | |
| " loc_loss = F.smooth_l1_loss(masked_loc_preds, masked_loc_true, size_average=False)\n", | |
| " ## cls loss\n", | |
| " pos_neg = cls_true > -1 # exclude ignored anchors\n", | |
| " mask = pos_neg.unsqueeze(2).expand_as(cls_preds)\n", | |
| " masked_cls_preds = cls_preds[mask].view(-1,self.n_classes)\n", | |
| " cls_loss = self.focal_loss(masked_cls_preds, cls_true[pos_neg])\n", | |
| " \n", | |
| " loss = (loc_loss + cls_loss) / num_pos\n", | |
| " return loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "90160ab7-8e15-4306-9280-798d43165794", | |
| "metadata": {}, | |
| "source": [ | |
| "## Initialization" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 223, | |
| "id": "2e63b8db-31f2-48c0-9f18-8594e3f82d6e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "criterion = FocalLoss()\n", | |
| "optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 224, | |
| "id": "d6b6b57c-68ac-40d9-91f3-a15e7e9f1af7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "device(type='cpu')" | |
| ] | |
| }, | |
| "execution_count": 224, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
| "device" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 225, | |
| "id": "c8b82edb-389f-471d-a05a-c2ac3890acb8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "net = net.to(device)\n", | |
| "criterion = criterion.to(device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b801f336-8c55-4217-8e67-221970c55f00", | |
| "metadata": {}, | |
| "source": [ | |
| "## Training" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 234, | |
| "id": "9497a504-1e21-4c3f-8339-7ad1afc086f8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def train(epoch):\n", | |
| " net.train()\n", | |
| " net.freeze_bn()\n", | |
| " train_loss = 0\n", | |
| " for batch in tqdm(train_loader):\n", | |
| " imgs = batch['img'].to(device)\n", | |
| " loc_targets = batch['loc_targets'].to(device)\n", | |
| " cls_targets = batch['cls_targets'].to(device)\n", | |
| " cls_targets = cls_targets.long()\n", | |
| " \n", | |
| " optimizer.zero_grad()\n", | |
| " loc_pred, cls_pred = net(imgs)\n", | |
| " loss = criterion(loc_pred, loc_targets, cls_pred, cls_targets)\n", | |
| " \n", | |
| " loss.backward()\n", | |
| " optimizer.step()\n", | |
| " \n", | |
| " train_loss += loss.item()\n", | |
| " print('train_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(len(train_loader))))\n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 235, | |
| "id": "c5bc39c8-431e-4fe1-ae38-03d484e7f952", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test(epoch, loader):\n", | |
| " with torch.no_grad():\n", | |
| " net.eval()\n", | |
| " test_loss = 0\n", | |
| " for batch in tqdm(loader):\n", | |
| " imgs = batch['img'].to(device)\n", | |
| " loc_targets = batch['loc_targets'].to(device)\n", | |
| " cls_targets = batch['cls_targets'].to(device)\n", | |
| "\n", | |
| " loc_pred, cls_pred = net(imgs)\n", | |
| " loss = criterion(loc_pred, loc_targets, cls_pred, cls_targets)\n", | |
| " test_loss += loss[0]\n", | |
| " print('test_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(len(train_loader))))\n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 236, | |
| "id": "503f99ad-fed4-48af-ae08-7a430b9c8ae5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "EPOCHS = 50" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 237, | |
| "id": "c7534024-d9a2-430a-8ddc-a6db0ad0bd57", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "EPOCH 1\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "ec194265c36b440383e074ffb997acce", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/56 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
| "[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
| "[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
| "[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
| "[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
| "[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n" | |
| ] | |
| }, | |
| { | |
| "ename": "KeyboardInterrupt", | |
| "evalue": "", | |
| "output_type": "error", | |
| "traceback": [ | |
| "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
| "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
| "\u001b[0;32m/var/folders/vr/x7p4fznn1dv39r_83dmyjkjm0000gn/T/ipykernel_42021/107800544.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"EPOCH {epoch}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m##\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m##\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
| "\u001b[0;32m/var/folders/vr/x7p4fznn1dv39r_83dmyjkjm0000gn/T/ipykernel_42021/3530921501.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc_targets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls_targets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
| "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m inputs=inputs)\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
| "\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n", | |
| "\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for epoch in range(1, EPOCHS + 1):\n", | |
| " print(f\"EPOCH {epoch}\")\n", | |
| " ##\n", | |
| " train(epoch)\n", | |
| " test(epoch, val_loader)\n", | |
| " ##" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "79d32e91-d705-4711-8a71-79d94cd5f620", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "f7206910-c8cc-4909-84c3-76a7f3ee023e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e33c98cd-3787-45d6-86eb-cf7a65aaf882", | |
| "metadata": {}, | |
| "source": [ | |
| "## Testing" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "e2d3a81a-43d3-4da1-8f7c-db98adeff2d5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "8efa5f7e-a5c9-4942-bde2-7e0897f2e036", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "69670f59-7b7a-46a2-abf7-80180d8f994e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "afd407d5-6d0a-46bc-b769-7b7e604accfb", | |
| "metadata": {}, | |
| "source": [ | |
| "## Saving Model Weights" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "5a8ef18e-aa6b-4cdb-8ef5-c91671b76e9e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "torch.save(model.state_dict(), \"retinanet.pt\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "90ff717d-7d5c-413d-9ae8-c57b935f26f9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.4" | |
| }, | |
| "toc-autonumbering": true | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment