Skip to content

Instantly share code, notes, and snippets.

@AhmedCoolProjects
Created December 4, 2025 10:26
Show Gist options
  • Select an option

  • Save AhmedCoolProjects/c0247e0418e9227ef5787830e479614b to your computer and use it in GitHub Desktop.

Select an option

Save AhmedCoolProjects/c0247e0418e9227ef5787830e479614b to your computer and use it in GitHub Desktop.
GAT.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPuoS8Bb6gKor9j0rujVqXu",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/AhmedCoolProjects/c0247e0418e9227ef5787830e479614b/gat.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "x-mWllw-2F3c"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"source": [
"class GATLayer(nn.Module):\n",
" def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2, concat=True):\n",
" \"\"\"\n",
" Args:\n",
" in_features: Input feature dimension\n",
" out_features: Output feature dimension\n",
" dropout: Dropout probability\n",
" alpha: LeakyReLU negative slope\n",
" concat: Whether to concatenate (True) or average (False) activation\n",
" (Concatenation is usually for hidden layers, averaging for output)\n",
" \"\"\"\n",
" super(GATLayer, self).__init__()\n",
" self.in_features = in_features\n",
" self.out_features = out_features\n",
" self.dropout = dropout\n",
" self.alpha = alpha\n",
" self.concat = concat\n",
"\n",
" # W: Learnable weight matrix (F x F')\n",
" self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))\n",
" nn.init.xavier_uniform_(self.W.data, gain=1.414)\n",
"\n",
" # a: Learnable attention vector (2F' x 1)\n",
" self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))\n",
" nn.init.xavier_uniform_(self.a.data, gain=1.414)\n",
"\n",
" self.leakyrelu = nn.LeakyReLU(self.alpha)\n",
"\n",
" def forward(self, h, adj):\n",
" \"\"\"\n",
" Args:\n",
" h: Input node features (N x in_features)\n",
" adj: Adjacency matrix (N x N)\n",
" \"\"\"\n",
" # 1. Linear Transformation\n",
" # h: (N, in_features), W: (in_features, out_features) -> Wh: (N, out_features)\n",
" Wh = torch.mm(h, self.W)\n",
" N = Wh.size()[0]\n",
"\n",
" # 2. Attention Mechanism\n",
" # We need to compute a^T [Wh_i || Wh_j] for all pairs (i, j).\n",
" # A clever way to vectorize this:\n",
" # a_input = [Wh_i || Wh_j]\n",
" # But we can split a into a1 and a2 (both size out_features x 1)\n",
" # a^T [Wh_i || Wh_j] = a1^T Wh_i + a2^T Wh_j\n",
"\n",
" a1 = self.a[:self.out_features, :] # (out_features, 1)\n",
" a2 = self.a[self.out_features:, :] # (out_features, 1)\n",
"\n",
" # (N, out) x (out, 1) -> (N, 1)\n",
" e1 = torch.matmul(Wh, a1)\n",
" e2 = torch.matmul(Wh, a2)\n",
"\n",
" # Broadcast add to get (N, N) matrix where [i, j] is e1[i] + e2[j]\n",
" # e: (N, N)\n",
" e = e1 + e2.T\n",
"\n",
" e = self.leakyrelu(e)\n",
"\n",
" # 3. Masked Attention (use adjacency to only attend to neighbors)\n",
" # We assume adj is 1 for edges and 0 otherwise.\n",
" # We set non-neighbors to -infinity so softmax makes them 0.\n",
" zero_vec = -9e15 * torch.ones_like(e)\n",
" attention = torch.where(adj > 0, e, zero_vec)\n",
"\n",
" # 4. Softmax Normalization\n",
" attention = F.softmax(attention, dim=1)\n",
" attention = F.dropout(attention, self.dropout, training=self.training)\n",
"\n",
" # 5. Aggregation\n",
" # h_prime = attention * Wh\n",
" # (N, N) x (N, out) -> (N, out)\n",
" h_prime = torch.matmul(attention, Wh)\n",
"\n",
" if self.concat:\n",
" return F.elu(h_prime)\n",
" else:\n",
" return h_prime"
],
"metadata": {
"id": "ga63lVPy2GyG"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class GAT(nn.Module):\n",
" def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):\n",
" super(GAT, self).__init__()\n",
" self.dropout = dropout\n",
"\n",
" # Multi-head attention for the first layer\n",
" self.attentions = [\n",
" GATLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)\n",
" for _ in range(nheads)\n",
" ]\n",
"\n",
" # Register them as sub-modules so they show up in model.parameters()\n",
" for i, attention in enumerate(self.attentions):\n",
" self.add_module('attention_{}'.format(i), attention)\n",
"\n",
" # Output layer (single head or averaged multi-head, here we do simple single head for classification)\n",
" self.out_att = GATLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)\n",
"\n",
" def forward(self, x, adj):\n",
" x = F.dropout(x, self.dropout, training=self.training)\n",
"\n",
" # Concatenate outputs of all heads\n",
" x = torch.cat([att(x, adj) for att in self.attentions], dim=1)\n",
"\n",
" x = F.dropout(x, self.dropout, training=self.training)\n",
"\n",
" # Output layer\n",
" x = F.elu(self.out_att(x, adj))\n",
"\n",
" # Log softmax for classification\n",
" return F.log_softmax(x, dim=1)"
],
"metadata": {
"id": "-uwGqFRw2Jbx"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Dummy data\n",
"N = 5 # Number of nodes\n",
"F_in = 10 # Input features\n",
"F_out = 2 # Number of classes\n",
"\n",
"x = torch.randn(N, F_in)\n",
"\n",
"# Random adjacency matrix (binary)\n",
"adj = torch.randint(0, 2, (N, N)).float()\n",
"# Add self-loops\n",
"adj = adj + torch.eye(N)\n",
"adj[adj > 1] = 1\n",
"\n",
"model = GAT(nfeat=F_in, nhid=8, nclass=F_out, dropout=0.6, alpha=0.2, nheads=2)\n",
"\n",
"output = model(x, adj)\n",
"print(\"Output shape:\", output.shape)\n",
"print(\"Output probabilities:\\n\", torch.exp(output))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BIcNqxK02Mc_",
"outputId": "e2fc8864-34d7-4051-d866-a42e15431089"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Output shape: torch.Size([5, 2])\n",
"Output probabilities:\n",
" tensor([[4.9513e-01, 5.0487e-01],\n",
" [5.4965e-01, 4.5035e-01],\n",
" [5.4527e-01, 4.5473e-01],\n",
" [1.0000e+00, 3.0683e-13],\n",
" [5.5136e-01, 4.4864e-01]], grad_fn=<ExpBackward0>)\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "433A_x642P1K"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment