Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Last active December 8, 2025 11:42
Show Gist options
  • Select an option

  • Save Algomancer/6692c81d7fc3a80223f46be4a750af57 to your computer and use it in GitHub Desktop.

Select an option

Save Algomancer/6692c81d7fc3a80223f46be4a750af57 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
def norm(x):
return F.rms_norm(x, (x.shape[-1],))
def spectral_init(module):
if hasattr(module, 'weight'):
nn.init.orthogonal_(module.weight)
def zero_init(module):
if hasattr(module, 'weight'):
nn.init.zeros_(module.weight)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
class TopologicalLinearAttention(nn.Module):
def __init__(self, dim, num_heads=12, num_bins=16, head_dim=32):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.num_bins = num_bins
self.head_dim = head_dim
# Height projection (one height per head)
self.to_heights = nn.Linear(dim, num_heads, bias=False)
# Value projection
self.to_v = nn.Linear(dim, num_heads * head_dim, bias=False)
# Output projection
self.to_out = nn.Linear(num_heads * head_dim, dim, bias=False)
# Learnable scale for soft binning
self.scale = nn.Parameter(torch.tensor(5.0))
# Learnable interpolation between PDF and CDF
self.alpha = nn.Parameter(torch.tensor(0.5))
# Bin centers
self.register_buffer('bin_centers', torch.linspace(-1, 1, num_bins))
# Init
spectral_init(self.to_heights)
spectral_init(self.to_v)
zero_init(self.to_out)
def forward(self, x):
B, L, D = x.shape
H = self.num_heads
K = self.num_bins
Hd = self.head_dim
x_norm = norm(x)
# The coordinate
heights = torch.tanh(self.to_heights(x_norm))
v = self.to_v(x_norm).view(B, L, H, Hd)
# Soft Binning (Discretization of the filtration)
dist = (heights.unsqueeze(-1) - self.bin_centers).pow(2)
bin_weights = F.softmax(-self.scale.abs() * dist, dim=-1)
# scan
bin_w = bin_weights.permute(0, 2, 1, 3).reshape(B * H, L, K)
v_r = v.permute(0, 2, 1, 3).reshape(B * H, L, Hd)
# 3. State Accumulation (The "Filtration")
v_binned = bin_w.unsqueeze(-1) * v_r.unsqueeze(-2)
# (Track how much weight is in each bin)
z_binned = bin_w.unsqueeze(-1)
# Causal Integration
v_state = v_binned.cumsum(dim=1)
z_state = z_binned.cumsum(dim=1)
# Topological Readout
num_pdf = torch.einsum('blk,blkd->bld', bin_w, v_state)
den_pdf = torch.einsum('blk,blk->bl', bin_w, z_state.squeeze(-1)).unsqueeze(-1)
out_pdf = num_pdf / (den_pdf + 1e-6)
# CDF / ECT ---
v_cdf_state = v_state.cumsum(dim=2)
z_cdf_state = z_state.cumsum(dim=2)
num_cdf = torch.einsum('blk,blkd->bld', bin_w, v_cdf_state)
den_cdf = torch.einsum('blk,blk->bl', bin_w, z_cdf_state.squeeze(-1)).unsqueeze(-1)
out_cdf = num_cdf / (den_cdf + 1e-6)
alpha = torch.sigmoid(self.alpha)
out = alpha * out_pdf + (1 - alpha) * out_cdf
out = out.view(B, H, L, Hd).permute(0, 2, 1, 3).reshape(B, L, H * Hd)
return self.to_out(out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment