Created
December 4, 2025 10:26
-
-
Save AhmedCoolProjects/c0247e0418e9227ef5787830e479614b to your computer and use it in GitHub Desktop.
GAT.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyPuoS8Bb6gKor9j0rujVqXu", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/AhmedCoolProjects/c0247e0418e9227ef5787830e479614b/gat.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "id": "x-mWllw-2F3c" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class GATLayer(nn.Module):\n", | |
| " def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2, concat=True):\n", | |
| " \"\"\"\n", | |
| " Args:\n", | |
| " in_features: Input feature dimension\n", | |
| " out_features: Output feature dimension\n", | |
| " dropout: Dropout probability\n", | |
| " alpha: LeakyReLU negative slope\n", | |
| " concat: Whether to concatenate (True) or average (False) activation\n", | |
| " (Concatenation is usually for hidden layers, averaging for output)\n", | |
| " \"\"\"\n", | |
| " super(GATLayer, self).__init__()\n", | |
| " self.in_features = in_features\n", | |
| " self.out_features = out_features\n", | |
| " self.dropout = dropout\n", | |
| " self.alpha = alpha\n", | |
| " self.concat = concat\n", | |
| "\n", | |
| " # W: Learnable weight matrix (F x F')\n", | |
| " self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))\n", | |
| " nn.init.xavier_uniform_(self.W.data, gain=1.414)\n", | |
| "\n", | |
| " # a: Learnable attention vector (2F' x 1)\n", | |
| " self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))\n", | |
| " nn.init.xavier_uniform_(self.a.data, gain=1.414)\n", | |
| "\n", | |
| " self.leakyrelu = nn.LeakyReLU(self.alpha)\n", | |
| "\n", | |
| " def forward(self, h, adj):\n", | |
| " \"\"\"\n", | |
| " Args:\n", | |
| " h: Input node features (N x in_features)\n", | |
| " adj: Adjacency matrix (N x N)\n", | |
| " \"\"\"\n", | |
| " # 1. Linear Transformation\n", | |
| " # h: (N, in_features), W: (in_features, out_features) -> Wh: (N, out_features)\n", | |
| " Wh = torch.mm(h, self.W)\n", | |
| " N = Wh.size()[0]\n", | |
| "\n", | |
| " # 2. Attention Mechanism\n", | |
| " # We need to compute a^T [Wh_i || Wh_j] for all pairs (i, j).\n", | |
| " # A clever way to vectorize this:\n", | |
| " # a_input = [Wh_i || Wh_j]\n", | |
| " # But we can split a into a1 and a2 (both size out_features x 1)\n", | |
| " # a^T [Wh_i || Wh_j] = a1^T Wh_i + a2^T Wh_j\n", | |
| "\n", | |
| " a1 = self.a[:self.out_features, :] # (out_features, 1)\n", | |
| " a2 = self.a[self.out_features:, :] # (out_features, 1)\n", | |
| "\n", | |
| " # (N, out) x (out, 1) -> (N, 1)\n", | |
| " e1 = torch.matmul(Wh, a1)\n", | |
| " e2 = torch.matmul(Wh, a2)\n", | |
| "\n", | |
| " # Broadcast add to get (N, N) matrix where [i, j] is e1[i] + e2[j]\n", | |
| " # e: (N, N)\n", | |
| " e = e1 + e2.T\n", | |
| "\n", | |
| " e = self.leakyrelu(e)\n", | |
| "\n", | |
| " # 3. Masked Attention (use adjacency to only attend to neighbors)\n", | |
| " # We assume adj is 1 for edges and 0 otherwise.\n", | |
| " # We set non-neighbors to -infinity so softmax makes them 0.\n", | |
| " zero_vec = -9e15 * torch.ones_like(e)\n", | |
| " attention = torch.where(adj > 0, e, zero_vec)\n", | |
| "\n", | |
| " # 4. Softmax Normalization\n", | |
| " attention = F.softmax(attention, dim=1)\n", | |
| " attention = F.dropout(attention, self.dropout, training=self.training)\n", | |
| "\n", | |
| " # 5. Aggregation\n", | |
| " # h_prime = attention * Wh\n", | |
| " # (N, N) x (N, out) -> (N, out)\n", | |
| " h_prime = torch.matmul(attention, Wh)\n", | |
| "\n", | |
| " if self.concat:\n", | |
| " return F.elu(h_prime)\n", | |
| " else:\n", | |
| " return h_prime" | |
| ], | |
| "metadata": { | |
| "id": "ga63lVPy2GyG" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class GAT(nn.Module):\n", | |
| " def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):\n", | |
| " super(GAT, self).__init__()\n", | |
| " self.dropout = dropout\n", | |
| "\n", | |
| " # Multi-head attention for the first layer\n", | |
| " self.attentions = [\n", | |
| " GATLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)\n", | |
| " for _ in range(nheads)\n", | |
| " ]\n", | |
| "\n", | |
| " # Register them as sub-modules so they show up in model.parameters()\n", | |
| " for i, attention in enumerate(self.attentions):\n", | |
| " self.add_module('attention_{}'.format(i), attention)\n", | |
| "\n", | |
| " # Output layer (single head or averaged multi-head, here we do simple single head for classification)\n", | |
| " self.out_att = GATLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)\n", | |
| "\n", | |
| " def forward(self, x, adj):\n", | |
| " x = F.dropout(x, self.dropout, training=self.training)\n", | |
| "\n", | |
| " # Concatenate outputs of all heads\n", | |
| " x = torch.cat([att(x, adj) for att in self.attentions], dim=1)\n", | |
| "\n", | |
| " x = F.dropout(x, self.dropout, training=self.training)\n", | |
| "\n", | |
| " # Output layer\n", | |
| " x = F.elu(self.out_att(x, adj))\n", | |
| "\n", | |
| " # Log softmax for classification\n", | |
| " return F.log_softmax(x, dim=1)" | |
| ], | |
| "metadata": { | |
| "id": "-uwGqFRw2Jbx" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Dummy data\n", | |
| "N = 5 # Number of nodes\n", | |
| "F_in = 10 # Input features\n", | |
| "F_out = 2 # Number of classes\n", | |
| "\n", | |
| "x = torch.randn(N, F_in)\n", | |
| "\n", | |
| "# Random adjacency matrix (binary)\n", | |
| "adj = torch.randint(0, 2, (N, N)).float()\n", | |
| "# Add self-loops\n", | |
| "adj = adj + torch.eye(N)\n", | |
| "adj[adj > 1] = 1\n", | |
| "\n", | |
| "model = GAT(nfeat=F_in, nhid=8, nclass=F_out, dropout=0.6, alpha=0.2, nheads=2)\n", | |
| "\n", | |
| "output = model(x, adj)\n", | |
| "print(\"Output shape:\", output.shape)\n", | |
| "print(\"Output probabilities:\\n\", torch.exp(output))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "BIcNqxK02Mc_", | |
| "outputId": "e2fc8864-34d7-4051-d866-a42e15431089" | |
| }, | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Output shape: torch.Size([5, 2])\n", | |
| "Output probabilities:\n", | |
| " tensor([[4.9513e-01, 5.0487e-01],\n", | |
| " [5.4965e-01, 4.5035e-01],\n", | |
| " [5.4527e-01, 4.5473e-01],\n", | |
| " [1.0000e+00, 3.0683e-13],\n", | |
| " [5.5136e-01, 4.4864e-01]], grad_fn=<ExpBackward0>)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [], | |
| "metadata": { | |
| "id": "433A_x642P1K" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment