Created
July 4, 2021 11:44
-
-
Save martinsbruveris/e5556c6255de34e27f952bda22289ef9 to your computer and use it in GitHub Desktop.
Computing FLOPS of vision transformer models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "567e140c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import humanize" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "66d0b5e6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def matmul(m, n, k):\n", | |
| " \"\"\"Multiply an (m, n)-matrix with an (n, k)-matrix\"\"\"\n", | |
| " return m * (2 * n - 1) * k" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "06f1e164", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def softmax(n, d):\n", | |
| " \"\"\"Row-wise softmax of (n, d)-matrix.\"\"\"\n", | |
| " # Assume exp(x) is one operation\n", | |
| " return n * (3 * d - 1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "6808c97f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def layer_norm_flops(n, d):\n", | |
| " \"\"\"Layer normalization. \n", | |
| " See paper: https://arxiv.org/pdf/1607.06450.pdf\n", | |
| " \"\"\"\n", | |
| " flops = 0\n", | |
| " flops += n * d - 1 # mean\n", | |
| " flops += 3 * n * d - 1 # variance\n", | |
| " flops += 2 * n * d # normalize\n", | |
| " return flops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "3adee209", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def single_head_attention_flops(seq_length, input_dim, output_dim):\n", | |
| " \"\"\"Single head QKV attention.\n", | |
| " \n", | |
| " We assume queries, keys and values all have ouputput_dim.\n", | |
| " \"\"\"\n", | |
| " flops = 0\n", | |
| " flops += 3 * matmul(seq_length, input_dim, output_dim) # Q, K and V\n", | |
| " flops += matmul(seq_length, output_dim, seq_length) # Q*K^T\n", | |
| " flops += seq_length ** 2 # Divide by sqrt(output_dim)\n", | |
| " flops += softmax(seq_length, seq_length)\n", | |
| " flops += matmul(seq_length, seq_length, output_dim) # softmax(...) * V\n", | |
| " return flops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "5fd6b081", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def attention_flops(seq_length, latent_dim, heads):\n", | |
| " \"\"\"Multi-head attention with a dense layer at the end.\"\"\"\n", | |
| " flops = 0\n", | |
| " sa_flops = single_head_attention_flops(\n", | |
| " seq_length, latent_dim, latent_dim // heads\n", | |
| " )\n", | |
| " flops += heads * sa_flops\n", | |
| " # There is a dense layer after the attention heads\n", | |
| " flops += matmul(seq_length, latent_dim, latent_dim)\n", | |
| " return flops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "cf9de9d0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def mlp_flops(seq_length, latent_dim, mlp_dim):\n", | |
| " \"\"\"2-layer network with nonlinearity between the layers.\"\"\"\n", | |
| " flops = 0\n", | |
| " flops += matmul(seq_length, latent_dim, mlp_dim) # First dense layer\n", | |
| " flops += seq_length * mlp_dim # Activation (assumed to be one operation)\n", | |
| " flops += matmul(seq_length, mlp_dim, latent_dim) # Second dense layer\n", | |
| " return flops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "e33f32f6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def encoder_flops(seq_length, latent_dim, heads, mlp_dim):\n", | |
| " \"\"\"One encoder block as in Fig. 1 (right) in [1].\n", | |
| " \n", | |
| " [1] An Image Is Worth 16x16 Words. \n", | |
| " https://arxiv.org/pdf/2010.11929.pdf\n", | |
| " \"\"\"\n", | |
| " flops = 0\n", | |
| " flops += layer_norm_flops(n=seq_length, d=latent_dim)\n", | |
| " flops += attention_flops(seq_length, latent_dim, heads)\n", | |
| " flops += seq_length * latent_dim # Residual connection\n", | |
| " flops += layer_norm_flops(n=seq_length, d=latent_dim)\n", | |
| " flops += mlp_flops(seq_length, latent_dim, mlp_dim)\n", | |
| " flops += seq_length * latent_dim # Residual connection\n", | |
| " return flops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "5a9fde01", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def count_transformer_flops(\n", | |
| " image_size, patch_size, layers, latent_dim, heads, mlp_dim\n", | |
| "):\n", | |
| " \"\"\"Counts FLOPS of a transformer model.\"\"\"\n", | |
| " if image_size % patch_size != 0:\n", | |
| " raise ValueError(\"image_size must be multiple of patch_size\")\n", | |
| " if latent_dim % heads != 0:\n", | |
| " raise ValueError(\"latent_dim must be multiple of heads.\")\n", | |
| " \n", | |
| " seq_length = (image_size // patch_size) ** 2\n", | |
| " input_dim = 3 * patch_size ** 2\n", | |
| " \n", | |
| " flops = 0\n", | |
| " # Project input dimension to latent dimension\n", | |
| " flops += matmul(seq_length, input_dim, latent_dim)\n", | |
| " seq_length += 1 # Add class embedding to sequence\n", | |
| " flops += seq_length * latent_dim # Add positional embedding\n", | |
| " flops += layers * encoder_flops(seq_length, latent_dim, heads, mlp_dim)\n", | |
| " return flops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "33f05f09", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# ViT-B/16 model\n", | |
| "vit_b = count_transformer_flops(\n", | |
| " image_size=224,\n", | |
| " patch_size=16,\n", | |
| " layers=12,\n", | |
| " latent_dim=768,\n", | |
| " heads=12,\n", | |
| " mlp_dim=3072\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "cf7dc3d0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# ViT-L/16 model\n", | |
| "vit_l = count_transformer_flops(\n", | |
| " image_size=224,\n", | |
| " patch_size=16,\n", | |
| " layers=24,\n", | |
| " latent_dim=1024,\n", | |
| " heads=16,\n", | |
| " mlp_dim=4096\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "6f0c5f70", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# ViT-H/14 model\n", | |
| "vit_h = count_transformer_flops(\n", | |
| " image_size=224,\n", | |
| " patch_size=14,\n", | |
| " layers=32,\n", | |
| " latent_dim=1280,\n", | |
| " heads=16,\n", | |
| " mlp_dim=5120\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "3f74fae2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'35.2 billion'" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "humanize.intword(vit_b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "9b71ecc8", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'123.2 billion'" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "humanize.intword(vit_l)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "26f40e5e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'334.8 billion'" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "humanize.intword(vit_h)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment