Last active
December 8, 2025 11:42
-
-
Save Algomancer/6692c81d7fc3a80223f46be4a750af57 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def norm(x): | |
| return F.rms_norm(x, (x.shape[-1],)) | |
| def spectral_init(module): | |
| if hasattr(module, 'weight'): | |
| nn.init.orthogonal_(module.weight) | |
| def zero_init(module): | |
| if hasattr(module, 'weight'): | |
| nn.init.zeros_(module.weight) | |
| if hasattr(module, 'bias') and module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| class TopologicalLinearAttention(nn.Module): | |
| def __init__(self, dim, num_heads=12, num_bins=16, head_dim=32): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.num_bins = num_bins | |
| self.head_dim = head_dim | |
| # Height projection (one height per head) | |
| self.to_heights = nn.Linear(dim, num_heads, bias=False) | |
| # Value projection | |
| self.to_v = nn.Linear(dim, num_heads * head_dim, bias=False) | |
| # Output projection | |
| self.to_out = nn.Linear(num_heads * head_dim, dim, bias=False) | |
| # Learnable scale for soft binning | |
| self.scale = nn.Parameter(torch.tensor(5.0)) | |
| # Learnable interpolation between PDF and CDF | |
| self.alpha = nn.Parameter(torch.tensor(0.5)) | |
| # Bin centers | |
| self.register_buffer('bin_centers', torch.linspace(-1, 1, num_bins)) | |
| # Init | |
| spectral_init(self.to_heights) | |
| spectral_init(self.to_v) | |
| zero_init(self.to_out) | |
| def forward(self, x): | |
| B, L, D = x.shape | |
| H = self.num_heads | |
| K = self.num_bins | |
| Hd = self.head_dim | |
| x_norm = norm(x) | |
| # The coordinate | |
| heights = torch.tanh(self.to_heights(x_norm)) | |
| v = self.to_v(x_norm).view(B, L, H, Hd) | |
| # Soft Binning (Discretization of the filtration) | |
| dist = (heights.unsqueeze(-1) - self.bin_centers).pow(2) | |
| bin_weights = F.softmax(-self.scale.abs() * dist, dim=-1) | |
| # scan | |
| bin_w = bin_weights.permute(0, 2, 1, 3).reshape(B * H, L, K) | |
| v_r = v.permute(0, 2, 1, 3).reshape(B * H, L, Hd) | |
| # 3. State Accumulation (The "Filtration") | |
| v_binned = bin_w.unsqueeze(-1) * v_r.unsqueeze(-2) | |
| # (Track how much weight is in each bin) | |
| z_binned = bin_w.unsqueeze(-1) | |
| # Causal Integration | |
| v_state = v_binned.cumsum(dim=1) | |
| z_state = z_binned.cumsum(dim=1) | |
| # Topological Readout | |
| num_pdf = torch.einsum('blk,blkd->bld', bin_w, v_state) | |
| den_pdf = torch.einsum('blk,blk->bl', bin_w, z_state.squeeze(-1)).unsqueeze(-1) | |
| out_pdf = num_pdf / (den_pdf + 1e-6) | |
| # CDF / ECT --- | |
| v_cdf_state = v_state.cumsum(dim=2) | |
| z_cdf_state = z_state.cumsum(dim=2) | |
| num_cdf = torch.einsum('blk,blkd->bld', bin_w, v_cdf_state) | |
| den_cdf = torch.einsum('blk,blk->bl', bin_w, z_cdf_state.squeeze(-1)).unsqueeze(-1) | |
| out_cdf = num_cdf / (den_cdf + 1e-6) | |
| alpha = torch.sigmoid(self.alpha) | |
| out = alpha * out_pdf + (1 - alpha) * out_cdf | |
| out = out.view(B, H, L, Hd).permute(0, 2, 1, 3).reshape(B, L, H * Hd) | |
| return self.to_out(out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment