Skip to content

Instantly share code, notes, and snippets.

@martinsbruveris
Created July 4, 2021 11:44
Show Gist options
  • Select an option

  • Save martinsbruveris/e5556c6255de34e27f952bda22289ef9 to your computer and use it in GitHub Desktop.

Select an option

Save martinsbruveris/e5556c6255de34e27f952bda22289ef9 to your computer and use it in GitHub Desktop.
Computing FLOPS of vision transformer models
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "567e140c",
"metadata": {},
"outputs": [],
"source": [
"import humanize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "66d0b5e6",
"metadata": {},
"outputs": [],
"source": [
"def matmul(m, n, k):\n",
" \"\"\"Multiply an (m, n)-matrix with an (n, k)-matrix\"\"\"\n",
" return m * (2 * n - 1) * k"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "06f1e164",
"metadata": {},
"outputs": [],
"source": [
"def softmax(n, d):\n",
" \"\"\"Row-wise softmax of (n, d)-matrix.\"\"\"\n",
" # Assume exp(x) is one operation\n",
" return n * (3 * d - 1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6808c97f",
"metadata": {},
"outputs": [],
"source": [
"def layer_norm_flops(n, d):\n",
" \"\"\"Layer normalization. \n",
" See paper: https://arxiv.org/pdf/1607.06450.pdf\n",
" \"\"\"\n",
" flops = 0\n",
" flops += n * d - 1 # mean\n",
" flops += 3 * n * d - 1 # variance\n",
" flops += 2 * n * d # normalize\n",
" return flops"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3adee209",
"metadata": {},
"outputs": [],
"source": [
"def single_head_attention_flops(seq_length, input_dim, output_dim):\n",
" \"\"\"Single head QKV attention.\n",
" \n",
" We assume queries, keys and values all have ouputput_dim.\n",
" \"\"\"\n",
" flops = 0\n",
" flops += 3 * matmul(seq_length, input_dim, output_dim) # Q, K and V\n",
" flops += matmul(seq_length, output_dim, seq_length) # Q*K^T\n",
" flops += seq_length ** 2 # Divide by sqrt(output_dim)\n",
" flops += softmax(seq_length, seq_length)\n",
" flops += matmul(seq_length, seq_length, output_dim) # softmax(...) * V\n",
" return flops"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5fd6b081",
"metadata": {},
"outputs": [],
"source": [
"def attention_flops(seq_length, latent_dim, heads):\n",
" \"\"\"Multi-head attention with a dense layer at the end.\"\"\"\n",
" flops = 0\n",
" sa_flops = single_head_attention_flops(\n",
" seq_length, latent_dim, latent_dim // heads\n",
" )\n",
" flops += heads * sa_flops\n",
" # There is a dense layer after the attention heads\n",
" flops += matmul(seq_length, latent_dim, latent_dim)\n",
" return flops"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cf9de9d0",
"metadata": {},
"outputs": [],
"source": [
"def mlp_flops(seq_length, latent_dim, mlp_dim):\n",
" \"\"\"2-layer network with nonlinearity between the layers.\"\"\"\n",
" flops = 0\n",
" flops += matmul(seq_length, latent_dim, mlp_dim) # First dense layer\n",
" flops += seq_length * mlp_dim # Activation (assumed to be one operation)\n",
" flops += matmul(seq_length, mlp_dim, latent_dim) # Second dense layer\n",
" return flops"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e33f32f6",
"metadata": {},
"outputs": [],
"source": [
"def encoder_flops(seq_length, latent_dim, heads, mlp_dim):\n",
" \"\"\"One encoder block as in Fig. 1 (right) in [1].\n",
" \n",
" [1] An Image Is Worth 16x16 Words. \n",
" https://arxiv.org/pdf/2010.11929.pdf\n",
" \"\"\"\n",
" flops = 0\n",
" flops += layer_norm_flops(n=seq_length, d=latent_dim)\n",
" flops += attention_flops(seq_length, latent_dim, heads)\n",
" flops += seq_length * latent_dim # Residual connection\n",
" flops += layer_norm_flops(n=seq_length, d=latent_dim)\n",
" flops += mlp_flops(seq_length, latent_dim, mlp_dim)\n",
" flops += seq_length * latent_dim # Residual connection\n",
" return flops"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5a9fde01",
"metadata": {},
"outputs": [],
"source": [
"def count_transformer_flops(\n",
" image_size, patch_size, layers, latent_dim, heads, mlp_dim\n",
"):\n",
" \"\"\"Counts FLOPS of a transformer model.\"\"\"\n",
" if image_size % patch_size != 0:\n",
" raise ValueError(\"image_size must be multiple of patch_size\")\n",
" if latent_dim % heads != 0:\n",
" raise ValueError(\"latent_dim must be multiple of heads.\")\n",
" \n",
" seq_length = (image_size // patch_size) ** 2\n",
" input_dim = 3 * patch_size ** 2\n",
" \n",
" flops = 0\n",
" # Project input dimension to latent dimension\n",
" flops += matmul(seq_length, input_dim, latent_dim)\n",
" seq_length += 1 # Add class embedding to sequence\n",
" flops += seq_length * latent_dim # Add positional embedding\n",
" flops += layers * encoder_flops(seq_length, latent_dim, heads, mlp_dim)\n",
" return flops"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "33f05f09",
"metadata": {},
"outputs": [],
"source": [
"# ViT-B/16 model\n",
"vit_b = count_transformer_flops(\n",
" image_size=224,\n",
" patch_size=16,\n",
" layers=12,\n",
" latent_dim=768,\n",
" heads=12,\n",
" mlp_dim=3072\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cf7dc3d0",
"metadata": {},
"outputs": [],
"source": [
"# ViT-L/16 model\n",
"vit_l = count_transformer_flops(\n",
" image_size=224,\n",
" patch_size=16,\n",
" layers=24,\n",
" latent_dim=1024,\n",
" heads=16,\n",
" mlp_dim=4096\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6f0c5f70",
"metadata": {},
"outputs": [],
"source": [
"# ViT-H/14 model\n",
"vit_h = count_transformer_flops(\n",
" image_size=224,\n",
" patch_size=14,\n",
" layers=32,\n",
" latent_dim=1280,\n",
" heads=16,\n",
" mlp_dim=5120\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3f74fae2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'35.2 billion'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"humanize.intword(vit_b)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9b71ecc8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'123.2 billion'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"humanize.intword(vit_l)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "26f40e5e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'334.8 billion'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"humanize.intword(vit_h)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment