Created
September 18, 2025 07:58
-
-
Save peterroelants/19a3e8988f2c989982ba9105a902ad8a to your computer and use it in GitHub Desktop.
Illustration of the difference between a 2x2 conv and a patchify + 1x1 conv
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "6ed99221", | |
| "metadata": {}, | |
| "source": [ | |
| "# Illustration of the difference between a 2x2 conv and a patchify + 1x1 conv" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "0c6ab949", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import einops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "aae9b1f3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def patchify_3d(\n", | |
| " x: torch.Tensor, # [B, C, T, H, W]\n", | |
| " patch_size_time: int,\n", | |
| " patch_size_height: int,\n", | |
| " patch_size_width: int,\n", | |
| ") -> torch.Tensor: # [B, C, T*patch_size_time, H*patch_size_height, W*patch_size_width]\n", | |
| " \"\"\"\n", | |
| " Patchify input tensor (space-to-depth)\n", | |
| "\n", | |
| " This is similar to torch.nn.functional.pixel_shuffle, but for 3D tensors.\n", | |
| " \"\"\"\n", | |
| " return einops.rearrange(\n", | |
| " x,\n", | |
| " \"b c (t pst) (h psh) (w psw) -> b (c pst psh psw) t h w\",\n", | |
| " pst=patch_size_time,\n", | |
| " psh=patch_size_height,\n", | |
| " psw=patch_size_width,\n", | |
| " )\n", | |
| "\n", | |
| "\n", | |
| "def unpatchify_3d(\n", | |
| " x: torch.Tensor, # [B, C, T*patch_size_time, H*patch_size_height, W*patch_size_width]\n", | |
| " patch_size_time: int,\n", | |
| " patch_size_height: int,\n", | |
| " patch_size_width: int,\n", | |
| ") -> torch.Tensor: # [B, C, T, H, W]\n", | |
| " return einops.rearrange(\n", | |
| " x,\n", | |
| " \"b (c pst psh psw) t h w -> b c (t pst) (h psh) (w psw)\",\n", | |
| " pst=patch_size_time,\n", | |
| " psh=patch_size_height,\n", | |
| " psw=patch_size_width,\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "fbd437c7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "x.shape=torch.Size([4, 48, 3, 10, 12])\n", | |
| "y.shape=torch.Size([4, 128, 3, 5, 6])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "nb_channels_out = 128\n", | |
| "\n", | |
| "batch_size = 4\n", | |
| "nb_channels_in = 48\n", | |
| "time_steps = 3\n", | |
| "height = 10\n", | |
| "width = 12\n", | |
| "\n", | |
| "x = torch.randn(batch_size, nb_channels_in, time_steps, height, width)\n", | |
| "print(f\"{x.shape=}\")\n", | |
| "\n", | |
| "y = torch.empty(batch_size, nb_channels_out, time_steps, height // 2, width // 2)\n", | |
| "print(f\"{y.shape=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "6167cf13", | |
| "metadata": {}, | |
| "source": [ | |
| "## Downsample" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "b4c59ef6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "x_down.shape=torch.Size([4, 128, 3, 5, 6])\n", | |
| "conv_down.weight.data.shape=torch.Size([128, 48, 1, 2, 2])\n", | |
| "Total parameters in conv_down: 24,576\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# 2x2 conv downsample\n", | |
| "conv_down = torch.nn.Conv3d(\n", | |
| " in_channels=nb_channels_in,\n", | |
| " out_channels=nb_channels_out,\n", | |
| " kernel_size=(1, 2, 2),\n", | |
| " padding=(0, 0, 0),\n", | |
| " stride=(1, 2, 2),\n", | |
| " bias=False,\n", | |
| ")\n", | |
| "\n", | |
| "x_down = conv_down(x)\n", | |
| "print(f\"{x_down.shape=}\")\n", | |
| "assert x_down.shape == y.shape\n", | |
| "\n", | |
| "print(f\"{conv_down.weight.data.shape=}\")\n", | |
| "total_params = sum(p.numel() for p in conv_down.parameters())\n", | |
| "print(f\"Total parameters in conv_down: {total_params:,}\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "b3959860", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "x_patched.shape=torch.Size([4, 192, 3, 5, 6])\n", | |
| "x_patched_down.shape=torch.Size([4, 128, 3, 5, 6])\n", | |
| "conv_1x1_down.weight.data.shape=torch.Size([128, 192, 1, 1, 1])\n", | |
| "Total parameters in conv_1x1: 24,576\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Patchify and 1x1 conv\n", | |
| "x_patched = patchify_3d(x, patch_size_time=1, patch_size_height=2, patch_size_width=2)\n", | |
| "print(f\"{x_patched.shape=}\")\n", | |
| "\n", | |
| "\n", | |
| "conv_1x1_down = torch.nn.Conv3d(\n", | |
| " in_channels=nb_channels_in*4,\n", | |
| " out_channels=nb_channels_out,\n", | |
| " kernel_size=(1, 1, 1),\n", | |
| " padding=(0, 0, 0),\n", | |
| " stride=(1, 1, 1),\n", | |
| " bias=False,\n", | |
| ")\n", | |
| "# Set identity weights\n", | |
| "conv_1x1_down.weight.data = einops.rearrange(\n", | |
| " conv_down.weight.data, \n", | |
| " 'o i t (h ph) (w pw) -> o (i ph pw) t h w',\n", | |
| " ph=2,\n", | |
| " pw=2,\n", | |
| ")\n", | |
| "\n", | |
| "x_patched_down = conv_1x1_down(x_patched)\n", | |
| "print(f\"{x_patched_down.shape=}\")\n", | |
| "assert x_patched_down.shape == y.shape\n", | |
| "\n", | |
| "print(f\"{conv_1x1_down.weight.data.shape=}\")\n", | |
| "total_params_1x1 = sum(p.numel() for p in conv_1x1_down.parameters())\n", | |
| "print(f\"Total parameters in conv_1x1: {total_params_1x1:,}\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "fe93c965", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.abs(x_patched_down - x_down).max()=tensor(1.0729e-06, grad_fn=<MaxBackward1>)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "assert x_down.shape == x_patched_down.shape == y.shape\n", | |
| "print(f\"{torch.abs(x_patched_down - x_down).max()=}\")\n", | |
| "assert torch.allclose(x_down, x_patched_down, atol=1e-6)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0f21265c", | |
| "metadata": {}, | |
| "source": [ | |
| "## Upsample" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "756acfee", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "y_up.shape=torch.Size([4, 48, 3, 10, 12])\n", | |
| "conv_up.weight.data.shape=torch.Size([128, 48, 1, 2, 2])\n", | |
| "Total parameters in conv_up: 24,576\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# 2x2 conv upsample\n", | |
| "conv_up = torch.nn.ConvTranspose3d(\n", | |
| " in_channels=nb_channels_out,\n", | |
| " out_channels=nb_channels_in,\n", | |
| " kernel_size=(1, 2, 2),\n", | |
| " padding=(0, 0, 0),\n", | |
| " stride=(1, 2, 2),\n", | |
| " bias=False,\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "y_up = conv_up(y)\n", | |
| "print(f\"{y_up.shape=}\")\n", | |
| "assert y_up.shape == x.shape\n", | |
| "print(f\"{conv_up.weight.data.shape=}\")\n", | |
| "total_params = sum(p.numel() for p in conv_up.parameters())\n", | |
| "print(f\"Total parameters in conv_up: {total_params:,}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "dbcf12dc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "y_proj.shape=torch.Size([4, 192, 3, 5, 6])\n", | |
| "y_unpatched.shape=torch.Size([4, 48, 3, 10, 12])\n", | |
| "conv1x1_up.weight.data.shape=torch.Size([192, 128, 1, 1, 1])\n", | |
| "Total parameters in conv1x1_up: 24,576\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# 1x1 projection and unpatchify\n", | |
| "conv1x1_up = torch.nn.Conv3d(\n", | |
| " in_channels=nb_channels_out,\n", | |
| " out_channels=nb_channels_in * 4,\n", | |
| " kernel_size=(1, 1, 1),\n", | |
| " padding=(0, 0, 0),\n", | |
| " stride=(1, 1, 1),\n", | |
| " bias=False,\n", | |
| ")\n", | |
| "# Set identity weights\n", | |
| "conv1x1_up.weight.data = einops.rearrange(\n", | |
| " conv_up.weight.data, \n", | |
| " 'o i t (h ph) (w pw) -> (i ph pw) o t h w',\n", | |
| " ph=2,\n", | |
| " pw=2,\n", | |
| ")\n", | |
| "\n", | |
| "y_proj = conv1x1_up(y)\n", | |
| "print(f\"{y_proj.shape=}\")\n", | |
| "\n", | |
| "y_unpatched = unpatchify_3d(\n", | |
| " y_proj,\n", | |
| " patch_size_time=1,\n", | |
| " patch_size_height=2,\n", | |
| " patch_size_width=2,\n", | |
| ")\n", | |
| "print(f\"{y_unpatched.shape=}\")\n", | |
| "assert y_unpatched.shape == x.shape\n", | |
| "\n", | |
| "print(f\"{conv1x1_up.weight.data.shape=}\")\n", | |
| "total_params_unpatch = sum(p.numel() for p in conv1x1_up.parameters())\n", | |
| "print(f\"Total parameters in conv1x1_up: {total_params_unpatch:,}\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "0f51c0af", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.abs(y_up - y_unpatched).max()=tensor(0., grad_fn=<MaxBackward1>)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "assert y_up.shape == y_unpatched.shape == x.shape\n", | |
| "print(f\"{torch.abs(y_up - y_unpatched).max()=}\")\n", | |
| "assert torch.allclose(y_up, y_unpatched, atol=1e-6)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "anam-audio-to-latent", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.11" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment