Skip to content

Instantly share code, notes, and snippets.

@AhmedCoolProjects
Created December 4, 2025 10:28
Show Gist options
  • Select an option

  • Save AhmedCoolProjects/6ce9ae11dac1f46bed689d3a956c5856 to your computer and use it in GitHub Desktop.

Select an option

Save AhmedCoolProjects/6ce9ae11dac1f46bed689d3a956c5856 to your computer and use it in GitHub Desktop.
MAE.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOrg5d2Ud7O5FFAVXw214gS",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/AhmedCoolProjects/6ce9ae11dac1f46bed689d3a956c5856/mae.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "GTj7DGOc5s7x"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"\n",
"class MAE(nn.Module):\n",
" def __init__(self,\n",
" img_size=224,\n",
" patch_size=16,\n",
" in_chans=3,\n",
" embed_dim=1024,\n",
" depth=24,\n",
" num_heads=16,\n",
" decoder_embed_dim=512,\n",
" decoder_depth=8,\n",
" decoder_num_heads=16,\n",
" mlp_ratio=4.,\n",
" mask_ratio=0.75):\n",
" super().__init__()\n",
"\n",
" self.img_size = img_size\n",
" self.patch_size = patch_size\n",
" self.grid_size = img_size // patch_size\n",
" self.num_patches = self.grid_size ** 2\n",
" self.patch_dim = patch_size * patch_size * in_chans\n",
" self.mask_ratio = mask_ratio\n",
"\n",
" # --------------------------------------------------------------------------\n",
" # MAE Encoder\n",
" # --------------------------------------------------------------------------\n",
" # Patch Embedding\n",
" self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n",
"\n",
" # Class token (optional in MAE, but often used) and Positional Embedding\n",
" self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n",
" self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))\n",
"\n",
" # Transformer Encoder Blocks\n",
" encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,\n",
" dim_feedforward=int(embed_dim*mlp_ratio),\n",
" activation='gelu', batch_first=True)\n",
" self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)\n",
" self.encoder_norm = nn.LayerNorm(embed_dim)\n",
"\n",
" # --------------------------------------------------------------------------\n",
" # MAE Decoder\n",
" # --------------------------------------------------------------------------\n",
" # Project encoder embedding to decoder dimension\n",
" self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)\n",
"\n",
" # Mask token\n",
" self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))\n",
"\n",
" # Decoder Positional Embedding\n",
" self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim))\n",
"\n",
" # Transformer Decoder Blocks\n",
" decoder_layer = nn.TransformerEncoderLayer(d_model=decoder_embed_dim, nhead=decoder_num_heads,\n",
" dim_feedforward=int(decoder_embed_dim*mlp_ratio),\n",
" activation='gelu', batch_first=True)\n",
" self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=decoder_depth)\n",
" self.decoder_norm = nn.LayerNorm(decoder_embed_dim)\n",
"\n",
" # Prediction Head\n",
" self.decoder_pred = nn.Linear(decoder_embed_dim, self.patch_dim, bias=True)\n",
"\n",
" self.initialize_weights()\n",
"\n",
" def initialize_weights(self):\n",
" # Initialize (sin-cos) pos_embed and other weights...\n",
" # For brevity, using simple normal initialization here\n",
" nn.init.xavier_uniform_(self.patch_embed.weight)\n",
" nn.init.normal_(self.cls_token, std=.02)\n",
" nn.init.normal_(self.mask_token, std=.02)\n",
" nn.init.normal_(self.pos_embed, std=.02)\n",
" nn.init.normal_(self.decoder_pos_embed, std=.02)\n",
"\n",
" def patchify(self, imgs):\n",
" \"\"\"\n",
" imgs: (N, 3, H, W)\n",
" x: (N, L, patch_size**2 *3)\n",
" \"\"\"\n",
" p = self.patch_size\n",
" assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0\n",
"\n",
" h = w = imgs.shape[2] // p\n",
" x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))\n",
" x = torch.einsum('nchpwq->nhwpqc', x)\n",
" x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))\n",
" return x\n",
"\n",
" def random_masking(self, x, mask_ratio):\n",
" \"\"\"\n",
" Perform per-sample random masking by per-sample shuffling.\n",
" x: [N, L, D], sequence\n",
" \"\"\"\n",
" N, L, D = x.shape # batch, length, dim\n",
" len_keep = int(L * (1 - mask_ratio))\n",
"\n",
" noise = torch.rand(N, L, device=x.device) # noise in [0, 1]\n",
"\n",
" # sort noise for each sample\n",
" ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove\n",
" ids_restore = torch.argsort(ids_shuffle, dim=1)\n",
"\n",
" # keep the first subset\n",
" ids_keep = ids_shuffle[:, :len_keep]\n",
" x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))\n",
"\n",
" # generate the binary mask: 0 is keep, 1 is remove\n",
" mask = torch.ones([N, L], device=x.device)\n",
" mask[:, :len_keep] = 0\n",
" # unshuffle to get the binary mask\n",
" mask = torch.gather(mask, dim=1, index=ids_restore)\n",
"\n",
" return x_masked, mask, ids_restore\n",
"\n",
" def forward_encoder(self, x):\n",
" # embed patches\n",
" x = self.patch_embed(x) # [N, C, H, W] -> [N, Embed, H/P, W/P]\n",
" x = x.flatten(2).transpose(1, 2) # [N, Embed, L] -> [N, L, Embed]\n",
"\n",
" # add pos embed w/o cls token\n",
" x = x + self.pos_embed[:, 1:, :]\n",
"\n",
" # masking: length -> length * mask_ratio\n",
" x, mask, ids_restore = self.random_masking(x, self.mask_ratio)\n",
"\n",
" # append cls token\n",
" cls_token = self.cls_token + self.pos_embed[:, :1, :]\n",
" cls_tokens = cls_token.expand(x.shape[0], -1, -1)\n",
" x = torch.cat((cls_tokens, x), dim=1)\n",
"\n",
" # apply Transformer blocks\n",
" x = self.encoder(x)\n",
" x = self.encoder_norm(x)\n",
"\n",
" return x, mask, ids_restore\n",
"\n",
" def forward_decoder(self, x, ids_restore):\n",
" # embed tokens\n",
" x = self.decoder_embed(x)\n",
"\n",
" # append mask tokens to sequence\n",
" mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)\n",
" x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token\n",
" x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle\n",
" x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token\n",
"\n",
" # add pos embed\n",
" x = x + self.decoder_pos_embed\n",
"\n",
" # apply Transformer blocks\n",
" x = self.decoder(x)\n",
" x = self.decoder_norm(x)\n",
"\n",
" # predictor projection\n",
" x = self.decoder_pred(x)\n",
"\n",
" # remove cls token\n",
" x = x[:, 1:, :]\n",
"\n",
" return x\n",
"\n",
" def forward_loss(self, imgs, pred, mask):\n",
" \"\"\"\n",
" imgs: [N, 3, H, W]\n",
" pred: [N, L, p*p*3]\n",
" mask: [N, L], 0 is keep, 1 is remove,\n",
" \"\"\"\n",
" target = self.patchify(imgs)\n",
"\n",
" # Mean Squared Error\n",
" loss = (pred - target) ** 2\n",
" loss = loss.mean(dim=-1) # [N, L], mean loss per patch\n",
"\n",
" # Apply mask: calculate loss only on masked patches\n",
" loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches\n",
" return loss\n",
"\n",
" def forward(self, imgs):\n",
" latent, mask, ids_restore = self.forward_encoder(imgs)\n",
" pred = self.forward_decoder(latent, ids_restore)\n",
" loss = self.forward_loss(imgs, pred, mask)\n",
" return loss, pred, mask\n"
]
},
{
"cell_type": "code",
"source": [
"model = MAE(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,\n",
" decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16)\n",
"\n",
"img = torch.randn(2, 3, 224, 224)\n",
"loss, pred, mask = model(img)\n",
"print(f\"Loss: {loss.item()}\")\n",
"print(f\"Prediction shape: {pred.shape}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HoE52xWT5tjN",
"outputId": "d8be68c0-a036-4276-a62c-150146fc8bc2"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loss: 1.3435649871826172\n",
"Prediction shape: torch.Size([2, 196, 768])\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "2lz7acjo5vk7"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment