Created
December 4, 2025 10:28
-
-
Save AhmedCoolProjects/6ce9ae11dac1f46bed689d3a956c5856 to your computer and use it in GitHub Desktop.
MAE.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyOrg5d2Ud7O5FFAVXw214gS", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/AhmedCoolProjects/6ce9ae11dac1f46bed689d3a956c5856/mae.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "id": "GTj7DGOc5s7x" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "class MAE(nn.Module):\n", | |
| " def __init__(self,\n", | |
| " img_size=224,\n", | |
| " patch_size=16,\n", | |
| " in_chans=3,\n", | |
| " embed_dim=1024,\n", | |
| " depth=24,\n", | |
| " num_heads=16,\n", | |
| " decoder_embed_dim=512,\n", | |
| " decoder_depth=8,\n", | |
| " decoder_num_heads=16,\n", | |
| " mlp_ratio=4.,\n", | |
| " mask_ratio=0.75):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.img_size = img_size\n", | |
| " self.patch_size = patch_size\n", | |
| " self.grid_size = img_size // patch_size\n", | |
| " self.num_patches = self.grid_size ** 2\n", | |
| " self.patch_dim = patch_size * patch_size * in_chans\n", | |
| " self.mask_ratio = mask_ratio\n", | |
| "\n", | |
| " # --------------------------------------------------------------------------\n", | |
| " # MAE Encoder\n", | |
| " # --------------------------------------------------------------------------\n", | |
| " # Patch Embedding\n", | |
| " self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n", | |
| "\n", | |
| " # Class token (optional in MAE, but often used) and Positional Embedding\n", | |
| " self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n", | |
| " self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))\n", | |
| "\n", | |
| " # Transformer Encoder Blocks\n", | |
| " encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,\n", | |
| " dim_feedforward=int(embed_dim*mlp_ratio),\n", | |
| " activation='gelu', batch_first=True)\n", | |
| " self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)\n", | |
| " self.encoder_norm = nn.LayerNorm(embed_dim)\n", | |
| "\n", | |
| " # --------------------------------------------------------------------------\n", | |
| " # MAE Decoder\n", | |
| " # --------------------------------------------------------------------------\n", | |
| " # Project encoder embedding to decoder dimension\n", | |
| " self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)\n", | |
| "\n", | |
| " # Mask token\n", | |
| " self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))\n", | |
| "\n", | |
| " # Decoder Positional Embedding\n", | |
| " self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim))\n", | |
| "\n", | |
| " # Transformer Decoder Blocks\n", | |
| " decoder_layer = nn.TransformerEncoderLayer(d_model=decoder_embed_dim, nhead=decoder_num_heads,\n", | |
| " dim_feedforward=int(decoder_embed_dim*mlp_ratio),\n", | |
| " activation='gelu', batch_first=True)\n", | |
| " self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=decoder_depth)\n", | |
| " self.decoder_norm = nn.LayerNorm(decoder_embed_dim)\n", | |
| "\n", | |
| " # Prediction Head\n", | |
| " self.decoder_pred = nn.Linear(decoder_embed_dim, self.patch_dim, bias=True)\n", | |
| "\n", | |
| " self.initialize_weights()\n", | |
| "\n", | |
| " def initialize_weights(self):\n", | |
| " # Initialize (sin-cos) pos_embed and other weights...\n", | |
| " # For brevity, using simple normal initialization here\n", | |
| " nn.init.xavier_uniform_(self.patch_embed.weight)\n", | |
| " nn.init.normal_(self.cls_token, std=.02)\n", | |
| " nn.init.normal_(self.mask_token, std=.02)\n", | |
| " nn.init.normal_(self.pos_embed, std=.02)\n", | |
| " nn.init.normal_(self.decoder_pos_embed, std=.02)\n", | |
| "\n", | |
| " def patchify(self, imgs):\n", | |
| " \"\"\"\n", | |
| " imgs: (N, 3, H, W)\n", | |
| " x: (N, L, patch_size**2 *3)\n", | |
| " \"\"\"\n", | |
| " p = self.patch_size\n", | |
| " assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0\n", | |
| "\n", | |
| " h = w = imgs.shape[2] // p\n", | |
| " x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))\n", | |
| " x = torch.einsum('nchpwq->nhwpqc', x)\n", | |
| " x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))\n", | |
| " return x\n", | |
| "\n", | |
| " def random_masking(self, x, mask_ratio):\n", | |
| " \"\"\"\n", | |
| " Perform per-sample random masking by per-sample shuffling.\n", | |
| " x: [N, L, D], sequence\n", | |
| " \"\"\"\n", | |
| " N, L, D = x.shape # batch, length, dim\n", | |
| " len_keep = int(L * (1 - mask_ratio))\n", | |
| "\n", | |
| " noise = torch.rand(N, L, device=x.device) # noise in [0, 1]\n", | |
| "\n", | |
| " # sort noise for each sample\n", | |
| " ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove\n", | |
| " ids_restore = torch.argsort(ids_shuffle, dim=1)\n", | |
| "\n", | |
| " # keep the first subset\n", | |
| " ids_keep = ids_shuffle[:, :len_keep]\n", | |
| " x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))\n", | |
| "\n", | |
| " # generate the binary mask: 0 is keep, 1 is remove\n", | |
| " mask = torch.ones([N, L], device=x.device)\n", | |
| " mask[:, :len_keep] = 0\n", | |
| " # unshuffle to get the binary mask\n", | |
| " mask = torch.gather(mask, dim=1, index=ids_restore)\n", | |
| "\n", | |
| " return x_masked, mask, ids_restore\n", | |
| "\n", | |
| " def forward_encoder(self, x):\n", | |
| " # embed patches\n", | |
| " x = self.patch_embed(x) # [N, C, H, W] -> [N, Embed, H/P, W/P]\n", | |
| " x = x.flatten(2).transpose(1, 2) # [N, Embed, L] -> [N, L, Embed]\n", | |
| "\n", | |
| " # add pos embed w/o cls token\n", | |
| " x = x + self.pos_embed[:, 1:, :]\n", | |
| "\n", | |
| " # masking: length -> length * mask_ratio\n", | |
| " x, mask, ids_restore = self.random_masking(x, self.mask_ratio)\n", | |
| "\n", | |
| " # append cls token\n", | |
| " cls_token = self.cls_token + self.pos_embed[:, :1, :]\n", | |
| " cls_tokens = cls_token.expand(x.shape[0], -1, -1)\n", | |
| " x = torch.cat((cls_tokens, x), dim=1)\n", | |
| "\n", | |
| " # apply Transformer blocks\n", | |
| " x = self.encoder(x)\n", | |
| " x = self.encoder_norm(x)\n", | |
| "\n", | |
| " return x, mask, ids_restore\n", | |
| "\n", | |
| " def forward_decoder(self, x, ids_restore):\n", | |
| " # embed tokens\n", | |
| " x = self.decoder_embed(x)\n", | |
| "\n", | |
| " # append mask tokens to sequence\n", | |
| " mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)\n", | |
| " x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token\n", | |
| " x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle\n", | |
| " x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token\n", | |
| "\n", | |
| " # add pos embed\n", | |
| " x = x + self.decoder_pos_embed\n", | |
| "\n", | |
| " # apply Transformer blocks\n", | |
| " x = self.decoder(x)\n", | |
| " x = self.decoder_norm(x)\n", | |
| "\n", | |
| " # predictor projection\n", | |
| " x = self.decoder_pred(x)\n", | |
| "\n", | |
| " # remove cls token\n", | |
| " x = x[:, 1:, :]\n", | |
| "\n", | |
| " return x\n", | |
| "\n", | |
| " def forward_loss(self, imgs, pred, mask):\n", | |
| " \"\"\"\n", | |
| " imgs: [N, 3, H, W]\n", | |
| " pred: [N, L, p*p*3]\n", | |
| " mask: [N, L], 0 is keep, 1 is remove,\n", | |
| " \"\"\"\n", | |
| " target = self.patchify(imgs)\n", | |
| "\n", | |
| " # Mean Squared Error\n", | |
| " loss = (pred - target) ** 2\n", | |
| " loss = loss.mean(dim=-1) # [N, L], mean loss per patch\n", | |
| "\n", | |
| " # Apply mask: calculate loss only on masked patches\n", | |
| " loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches\n", | |
| " return loss\n", | |
| "\n", | |
| " def forward(self, imgs):\n", | |
| " latent, mask, ids_restore = self.forward_encoder(imgs)\n", | |
| " pred = self.forward_decoder(latent, ids_restore)\n", | |
| " loss = self.forward_loss(imgs, pred, mask)\n", | |
| " return loss, pred, mask\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model = MAE(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,\n", | |
| " decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16)\n", | |
| "\n", | |
| "img = torch.randn(2, 3, 224, 224)\n", | |
| "loss, pred, mask = model(img)\n", | |
| "print(f\"Loss: {loss.item()}\")\n", | |
| "print(f\"Prediction shape: {pred.shape}\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "HoE52xWT5tjN", | |
| "outputId": "d8be68c0-a036-4276-a62c-150146fc8bc2" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Loss: 1.3435649871826172\n", | |
| "Prediction shape: torch.Size([2, 196, 768])\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [], | |
| "metadata": { | |
| "id": "2lz7acjo5vk7" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment