An implementation of Efficientnet in PyTorch
Pytorch
B0:
- Animals 138:
84%
B1:
- Animals 138:
85%
An implementation of Efficientnet in PyTorch
Pytorch
B0:
84%B1:
85%| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "bd75bfa0-8ff3-4a56-a222-0a73068bb4ff", | |
| "metadata": {}, | |
| "source": [ | |
| "# EfficientNet in PyTorch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "0e55ee58-7dcc-485b-91ec-5329b248bf06", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "from torch import nn, optim\n", | |
| "import math\n", | |
| "import os\n", | |
| "from torchinfo import summary" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "09d41e73-3720-4f39-916b-7602e40fedaf", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def conv_block(in_channels, out_channels, kernel_size=3, \n", | |
| " stride=1, padding=0, groups=1,\n", | |
| " bias=False, bn=True, act = True):\n", | |
| " layers = [\n", | |
| " nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \n", | |
| " padding=padding, groups=groups, bias=bias),\n", | |
| " nn.BatchNorm2d(out_channels) if bn else nn.Identity(),\n", | |
| " nn.SiLU() if act else nn.Identity()\n", | |
| " ]\n", | |
| " return nn.Sequential(*layers)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "bf59c00f-f320-45be-86f8-6023c2d1f749", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class SEBlock(nn.Module):\n", | |
| " def __init__(self, c, r=24):\n", | |
| " super(SEBlock, self).__init__()\n", | |
| " self.squeeze = nn.AdaptiveMaxPool2d(1)\n", | |
| " self.excitation = nn.Sequential(\n", | |
| " nn.Conv2d(c, c // r, kernel_size=1),\n", | |
| " nn.SiLU(),\n", | |
| " nn.Conv2d(c // r, c, kernel_size=1),\n", | |
| " nn.Sigmoid()\n", | |
| " )\n", | |
| " def forward(self, x):\n", | |
| " s = self.squeeze(x)\n", | |
| " e = self.excitation(s)\n", | |
| " return x * e" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "df21a513-e747-4d9f-a2ae-4378ba6e3c91", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class MBConv(nn.Module):\n", | |
| " def __init__(self, n_in, n_out, expansion, kernel_size=3, stride=1, r=24, dropout=0.1):\n", | |
| " super(MBConv, self).__init__()\n", | |
| " self.skip_connection = (n_in == n_out) and (stride == 1)\n", | |
| " padding = (kernel_size-1)//2\n", | |
| " expanded = expansion*n_in\n", | |
| " \n", | |
| " self.expand_pw = nn.Identity() if expansion == 1 else conv_block(n_in, expanded, kernel_size=1)\n", | |
| " self.depthwise = conv_block(expanded, expanded, kernel_size=kernel_size, \n", | |
| " stride=stride, padding=padding, groups=expanded)\n", | |
| " self.se = SEBlock(expanded, r=r)\n", | |
| " self.reduce_pw = conv_block(expanded, n_out, kernel_size=1, act=False)\n", | |
| " self.dropout = nn.Dropout(dropout)\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " residual = x\n", | |
| " x = self.expand_pw(x)\n", | |
| " x = self.depthwise(x)\n", | |
| " x = self.se(x)\n", | |
| " x = self.reduce_pw(x)\n", | |
| " if self.skip_connection:\n", | |
| " x = self.dropout(x)\n", | |
| " x = x + residual\n", | |
| " return x" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "777c11c8-e042-4b52-ba2a-a57146e225d1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def mbconv1(n_in, n_out, kernel_size=3, stride=1, r=24, dropout=0.1):\n", | |
| " return MBConv(n_in, n_out, 1, kernel_size=kernel_size, stride=stride, r=r, dropout=dropout)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "115dc875-517d-4cf5-9441-0344f144b69d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def mbconv6(n_in, n_out, kernel_size=3, stride=1, r=24, dropout=0.1):\n", | |
| " return MBConv(n_in, n_out, 6, kernel_size=kernel_size, stride=stride, r=r, dropout=dropout)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "90de185f-a60f-4bd3-a22f-4bb090d22538", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def create_stage(n_in, n_out, num_layers, layer=mbconv6, \n", | |
| " kernel_size=3, stride=1, r=24, ps=0):\n", | |
| " layers = [layer(n_in, n_out, kernel_size=kernel_size,\n", | |
| " stride=stride, r=r, dropout=ps)]\n", | |
| " layers += [layer(n_out, n_out, kernel_size=kernel_size,\n", | |
| " r=r, dropout=ps) for _ in range(num_layers-1)]\n", | |
| " return nn.Sequential(*layers)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "08539591-434f-4457-824e-8840a97a6ccf", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def scale_width(w, w_factor):\n", | |
| " w *= w_factor\n", | |
| " new_w = (int(w+4) // 8) * 8\n", | |
| " new_w = max(8, new_w)\n", | |
| " if new_w < 0.9*w:\n", | |
| " new_w += 8\n", | |
| " return int(new_w)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d9ad5a73-ed3d-4cfa-849a-d292782cd6c4", | |
| "metadata": {}, | |
| "source": [ | |
| "EfficientNet Base structure\n", | |
| "\n", | |
| "| Stage (i) | Layer | Resolution | Channels | Layers |\n", | |
| "|-----------|-----------|------------|----------|--------|\n", | |
| "| 1 | `mbconv1` | 224 x 224 | 32 | 1 |\n", | |
| "| 2 | `mbconv6` | 112 x 112 | 16 | 1 |\n", | |
| "| 3 | `mbconv6` | 112 x 112 | 24 | 2 |\n", | |
| "| 4 | `mbconv6` | 56 x 56 | 40 | 2 |\n", | |
| "| 5 | `mbconv6` | 28 x 28 | 80 | 3 |\n", | |
| "| 6 | `mbconv6` | 14 x 14 | 112 | 3 |\n", | |
| "| 7 | `mbconv6` | 14 x 14 | 192 | 4 |\n", | |
| "| 8 | `mbconv6` | 7 x 7 | 320 | 1 |\n", | |
| "| 9 | `mbconv6` | 7 x 7 | 1080 | 1 |" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "83b2baa2-4923-4a05-8f33-9429214baa3f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "### Obtained from Paper ###\n", | |
| "base_widths = [(32, 16), (16, 24), (24, 40),\n", | |
| " (40, 80), (80, 112), (112, 192),\n", | |
| " (192, 320), (320, 1280)]\n", | |
| "base_depths = [1, 2, 2, 3, 3, 4, 1]\n", | |
| "kernel_sizes = [3, 3, 5, 3, 5, 5, 3]\n", | |
| "strides = [1, 2, 2, 2, 1, 2, 1]\n", | |
| "ps = [0, 0.029, 0.057, 0.086, 0.114, 0.143, 0.171]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "22099a0a-2579-4268-ac29-52db221b5366", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_gen(w_factor=1, d_factor=1):\n", | |
| " scaled_widths = [(scale_width(w[0], w_factor), scale_width(w[1], w_factor)) \n", | |
| " for w in base_widths]\n", | |
| " scaled_depths = [math.ceil(d_factor*d) for d in base_depths]\n", | |
| " return scaled_widths, scaled_depths" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "id": "08cdce48-2c2e-4bc7-bbd0-c8ab6bc02113", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class EfficientNet(nn.Module):\n", | |
| " def __init__(self, w_factor=1, d_factor=1, n_classes=1000):\n", | |
| " super(EfficientNet, self).__init__()\n", | |
| " scaled_widths, scaled_depths = efficientnet_gen(w_factor=w_factor, d_factor=d_factor)\n", | |
| " \n", | |
| " self.conv1 = conv_block(3, scaled_widths[0][0], stride=2, padding=1)\n", | |
| " stages = [\n", | |
| " create_stage(*scaled_widths[i], scaled_depths[i], layer= mbconv1 if i==0 else mbconv6, \n", | |
| " kernel_size=kernel_sizes[i], stride=strides[i], r= 4 if i==0 else 24, ps=ps[i]) for i in range(7)\n", | |
| " ]\n", | |
| " self.stages = nn.Sequential(*stages)\n", | |
| " self.pre = conv_block(*scaled_widths[-1], kernel_size=1)\n", | |
| " self.pool_flatten = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten())\n", | |
| " self.head = nn.Sequential(\n", | |
| " nn.Linear(scaled_widths[-1][1], n_classes)\n", | |
| " )\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " x = self.conv1(x)\n", | |
| " x = self.stages(x)\n", | |
| " x = self.pre(x)\n", | |
| " x = self.pool_flatten(x)\n", | |
| " x = self.head(x)\n", | |
| " return x" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "id": "9fe3126b-d63a-4fdb-b081-698f07092cc1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b0(n_classes=1000):\n", | |
| " return EfficientNet(n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "335c91f7-60a7-4397-b2ea-ba9755e0a259", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b1(n_classes=1000):\n", | |
| " return EfficientNet(1, 1.1, n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "id": "a9ae8295-66e0-4788-861e-af81cf33e099", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b2(n_classes=1000):\n", | |
| " return EfficientNet(1.1, 1.2, n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "227832fc-48b7-4c75-ae40-86213f008c91", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b3(n_classes=1000):\n", | |
| " return EfficientNet(1.2, 1.4, n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "cdfaff84-9f4a-4ba6-94f5-aa07ab81d22f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b4(n_classes=1000):\n", | |
| " return EfficientNet(1.4, 1.8, n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "id": "05634592-a1c0-4692-88da-fba8ac2ea017", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b5(n_classes=1000):\n", | |
| " return EfficientNet(1.6, 2.2, n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "id": "d6a66f21-9dea-4310-afc9-4c28b118ef52", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b6(n_classes=1000):\n", | |
| " return EfficientNet(1.8, 2.6, n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "id": "24462c7e-994e-494e-baba-d619055f816c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def efficientnet_b7(n_classes=1000):\n", | |
| " return EfficientNet(2, 3.1, n_classes=n_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 60, | |
| "id": "ce69f255-c00b-4dc3-a044-37f1cc1be178", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "b0 = efficientnet_b0()\n", | |
| "b1 = efficientnet_b1()\n", | |
| "b2 = efficientnet_b2()\n", | |
| "b3 = efficientnet_b3()\n", | |
| "b4 = efficientnet_b4()\n", | |
| "b5 = efficientnet_b5()\n", | |
| "b6 = efficientnet_b6()\n", | |
| "b7 = efficientnet_b7()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 61, | |
| "id": "15858249-5a11-4e41-afe5-bd071b6c1417", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 1000])" | |
| ] | |
| }, | |
| "execution_count": 61, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "inp = torch.randn(1, 3, 224, 224)\n", | |
| "b0(inp).shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 62, | |
| "id": "43caafea-ef4e-4dd2-bff4-e0268b84a68d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def print_size_of_model(model):\n", | |
| " torch.save(model.state_dict(), \"temp.p\")\n", | |
| " print('Size (MB):', os.path.getsize(\"temp.p\")/1e6)\n", | |
| " os.remove('temp.p')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 63, | |
| "id": "a79c046d-882f-439b-9b60-9c0997baebdf", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Size (MB): 21.446577\n", | |
| "Size (MB): 31.600841\n", | |
| "Size (MB): 36.885449\n", | |
| "Size (MB): 49.479621\n", | |
| "Size (MB): 78.111933\n", | |
| "Size (MB): 122.546261\n", | |
| "Size (MB): 173.400525\n", | |
| "Size (MB): 267.054441\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print_size_of_model(b0)\n", | |
| "print_size_of_model(b1)\n", | |
| "print_size_of_model(b2)\n", | |
| "print_size_of_model(b3)\n", | |
| "print_size_of_model(b4)\n", | |
| "print_size_of_model(b5)\n", | |
| "print_size_of_model(b6)\n", | |
| "print_size_of_model(b7)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "id": "c3259af2-9648-4788-849e-d9878b6d90a5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def fmat(n):\n", | |
| " return \"{:.2f}M\".format(n / 1_000_000)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "id": "32a21dc3-3bd4-48d6-902f-b0fd22fbf261", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def params(model, f=True):\n", | |
| " s = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | |
| " return fmat(s) if f else s" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 53, | |
| "id": "994369d9-3be0-4896-838a-b5ff7b94c0a2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "('5.29M', '7.79M', '9.11M', '12.23M', '19.34M', '30.39M', '43.04M', '66.35M')" | |
| ] | |
| }, | |
| "execution_count": 53, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "params(b0),params(b1), params(b2), params(b3), params(b4), params(b5), params(b6), params(b7)\n", | |
| "# roughly equivalent to the params mentioned in paper \n", | |
| "# (5.3M, 7.8M, 9.2M, 12M, 19M, 30M, 43M, 66M) <- param sizes in the paper" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 64, | |
| "id": "d8f7b707-8cc1-43f7-be6a-30bd865bc6eb", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "====================================================================================================\n", | |
| "Layer (type:depth-idx) Output Shape Param #\n", | |
| "====================================================================================================\n", | |
| "EfficientNet -- --\n", | |
| "├─Sequential: 1-1 [1, 32, 112, 112] --\n", | |
| "│ └─Conv2d: 2-1 [1, 32, 112, 112] 864\n", | |
| "│ └─BatchNorm2d: 2-2 [1, 32, 112, 112] 64\n", | |
| "│ └─SiLU: 2-3 [1, 32, 112, 112] --\n", | |
| "├─Sequential: 1-2 [1, 320, 7, 7] --\n", | |
| "│ └─Sequential: 2-4 [1, 16, 112, 112] --\n", | |
| "│ │ └─MBConv: 3-1 [1, 16, 112, 112] 1,448\n", | |
| "│ └─Sequential: 2-5 [1, 24, 56, 56] --\n", | |
| "│ │ └─MBConv: 3-2 [1, 24, 56, 56] 6,004\n", | |
| "│ │ └─MBConv: 3-3 [1, 24, 56, 56] 10,710\n", | |
| "│ └─Sequential: 2-6 [1, 40, 28, 28] --\n", | |
| "│ │ └─MBConv: 3-4 [1, 40, 28, 28] 15,350\n", | |
| "│ │ └─MBConv: 3-5 [1, 40, 28, 28] 31,290\n", | |
| "│ └─Sequential: 2-7 [1, 80, 14, 14] --\n", | |
| "│ │ └─MBConv: 3-6 [1, 80, 14, 14] 37,130\n", | |
| "│ │ └─MBConv: 3-7 [1, 80, 14, 14] 102,900\n", | |
| "│ │ └─MBConv: 3-8 [1, 80, 14, 14] 102,900\n", | |
| "│ └─Sequential: 2-8 [1, 112, 14, 14] --\n", | |
| "│ │ └─MBConv: 3-9 [1, 112, 14, 14] 126,004\n", | |
| "│ │ └─MBConv: 3-10 [1, 112, 14, 14] 208,572\n", | |
| "│ │ └─MBConv: 3-11 [1, 112, 14, 14] 208,572\n", | |
| "│ └─Sequential: 2-9 [1, 192, 7, 7] --\n", | |
| "│ │ └─MBConv: 3-12 [1, 192, 7, 7] 262,492\n", | |
| "│ │ └─MBConv: 3-13 [1, 192, 7, 7] 587,952\n", | |
| "│ │ └─MBConv: 3-14 [1, 192, 7, 7] 587,952\n", | |
| "│ │ └─MBConv: 3-15 [1, 192, 7, 7] 587,952\n", | |
| "│ └─Sequential: 2-10 [1, 320, 7, 7] --\n", | |
| "│ │ └─MBConv: 3-16 [1, 320, 7, 7] 717,232\n", | |
| "├─Sequential: 1-3 [1, 1280, 7, 7] --\n", | |
| "│ └─Conv2d: 2-11 [1, 1280, 7, 7] 409,600\n", | |
| "│ └─BatchNorm2d: 2-12 [1, 1280, 7, 7] 2,560\n", | |
| "│ └─SiLU: 2-13 [1, 1280, 7, 7] --\n", | |
| "├─Sequential: 1-4 [1, 1280] --\n", | |
| "│ └─AdaptiveAvgPool2d: 2-14 [1, 1280, 1, 1] --\n", | |
| "│ └─Flatten: 2-15 [1, 1280] --\n", | |
| "├─Sequential: 1-5 [1, 1000] --\n", | |
| "│ └─Linear: 2-16 [1, 1000] 1,281,000\n", | |
| "====================================================================================================\n", | |
| "Total params: 5,288,548\n", | |
| "Trainable params: 5,288,548\n", | |
| "Non-trainable params: 0\n", | |
| "Total mult-adds (M): 385.87\n", | |
| "====================================================================================================\n", | |
| "Input size (MB): 0.60\n", | |
| "Forward/backward pass size (MB): 107.89\n", | |
| "Params size (MB): 21.15\n", | |
| "Estimated Total Size (MB): 129.64\n", | |
| "====================================================================================================" | |
| ] | |
| }, | |
| "execution_count": 64, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "summary(b0, (1, 3, 224, 224)) # pick a model." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e6fc8d9c-a52b-4b89-9f19-8344b70c418b", | |
| "metadata": {}, | |
| "source": [ | |
| "End of notebook" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |