Created
November 3, 2025 23:29
-
-
Save wrmedford/c4dd0dedd2f64ee250e1562d3059245f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ethos_router.py | |
| # A compact, self-contained PyTorch router with sinusoidal ID codes, | |
| # top-b axis selection, and additive beam search. | |
| # | |
| # Shapes (conventions used below): | |
| # B = batch | |
| # D = model/embedding dim | |
| # H = heads | |
| # A = routing axes per head | |
| # r = axis query dim (must be even for [cos,sin] codes) | |
| # S = size of the discrete ID space per axis | |
| # b = per-axis candidates kept during streaming top-b | |
| # K = final beams per (B,H) | |
| # | |
| # Public surface: | |
| # - sinusoidal_codes_from_ids(ids, r) | |
| # - router_score_topb(...) / router_score_topb_from_queries(...) | |
| # - beam_search_additive(vals_all, idxs_all, top_k, beam_width) | |
| # - BeamRouter(d_model, ...) | |
| # | |
| # Notes: | |
| # * Streaming over S with tiles keeps memory bounded. | |
| # * Autograd is implemented to propagate gradients to x/W or to q_axes. | |
| # * All math is FP32; works on CPU/CUDA. IDs are stored as int32 to save memory. | |
| from __future__ import annotations | |
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| __all__ = [ | |
| "sinusoidal_codes_from_ids", | |
| "router_score_topb", | |
| "router_score_topb_from_queries", | |
| "beam_search_additive", | |
| "BeamRouter", | |
| ] | |
| # ----------------------------- | |
| # Sinusoidal ID codebook (ETHOS) | |
| # ----------------------------- | |
| def sinusoidal_codes_from_ids(ids: torch.Tensor, r: int) -> torch.Tensor: | |
| """ | |
| Build [cos, sin]-concatenated sinusoidal codes for integer IDs. | |
| Args: | |
| ids: integer tensor [...], each in [0, S) | |
| r: code dimension (must be even) | |
| Returns: | |
| codes: float32 tensor [..., r] | |
| """ | |
| if r % 2 != 0: | |
| raise ValueError(f"axis_qdim r must be even, got r={r}") | |
| device = ids.device | |
| r_half = r // 2 | |
| # Log-spaced frequencies in (10^0 .. 10^-1). Any smooth spectrum works. | |
| freq = torch.logspace(0.0, -1.0, steps=r_half, base=10.0, device=device, dtype=torch.float32) | |
| angle = ids.to(torch.float32).unsqueeze(-1) * freq.unsqueeze(0) * (2.0 * math.pi) | |
| return torch.cat([torch.cos(angle), torch.sin(angle)], dim=-1).to(torch.float32) | |
| # ============================================================ | |
| # 1) RouterScoreTopB: stream over S, keep per-axis top-b | |
| # (computes queries from x and W, provides grads to both) | |
| # ============================================================ | |
| class _RouterScoreTopB(torch.autograd.Function): | |
| @staticmethod | |
| def forward( | |
| ctx, | |
| x_tile: torch.Tensor, # [Bt, D] fp32 | |
| query_w: torch.Tensor, # [H, A, r, D] fp32 | |
| axis_size: int, # S | |
| top_b: int, # b | |
| s_tile: int # tile size over S | |
| ): | |
| """ | |
| Returns: | |
| vals_all: [Bt, H, A, b] (fp32) | |
| idxs_all: [Bt, H, A, b] (int32) | |
| """ | |
| if s_tile <= 0: | |
| raise ValueError("s_tile must be positive") | |
| Bt, D = x_tile.shape | |
| H, A, r, Dw = query_w.shape | |
| if D != Dw: | |
| raise ValueError(f"Input dim D={D} must match query_w[..., D]={Dw}") | |
| S = int(axis_size) | |
| b = min(int(top_b), S) | |
| device = x_tile.device | |
| # q_axes: [Bt,H,A,r] (no normalization here) | |
| q_axes = torch.einsum("bd,hard->bhar", x_tile, query_w) | |
| # Collapse heads/axes for efficient streaming | |
| Q = q_axes.permute(1, 2, 0, 3).contiguous().reshape(H * A, Bt, r) | |
| top_vals = torch.full((H * A, Bt, b), float("-inf"), device=device, dtype=torch.float32) | |
| top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32) | |
| for s0 in range(0, S, s_tile): | |
| s1 = min(S, s0 + s_tile) | |
| s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32) # [s_t] | |
| K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t] | |
| logits = torch.matmul(Q, K_rt) # [H*A, Bt, s_t] | |
| merged_vals = torch.cat([top_vals, logits], dim=-1) # [H*A,Bt,b+s_t] | |
| merged_idx = torch.cat( | |
| [ | |
| top_idx, | |
| torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0), | |
| ], | |
| dim=-1, | |
| ) | |
| keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1) # [H*A,Bt,b] | |
| keep_idx = merged_idx.gather(-1, keep_pos) | |
| top_vals, top_idx = keep_vals, keep_idx | |
| vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| # Save for backward to propagate through q = einsum(x, W) | |
| ctx.save_for_backward(x_tile, query_w, idxs_all) | |
| ctx.meta = (H, A, r, D) | |
| return vals_all, idxs_all | |
| @staticmethod | |
| def backward(ctx, g_vals_all: torch.Tensor, _g_idxs_all: Optional[torch.Tensor] = None): | |
| x_tile, query_w, idxs_all = ctx.saved_tensors | |
| H, A, r, D = ctx.meta | |
| if g_vals_all is None: | |
| return torch.zeros_like(x_tile), torch.zeros_like(query_w), None, None, None | |
| # Gather codes and accumulate gradient wrt q_axes | |
| codes = sinusoidal_codes_from_ids(idxs_all, r) # [Bt,H,A,b,r] | |
| g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3) # [Bt,H,A,r] | |
| # q = einsum('bd,hard->bhar') | |
| g_x = torch.einsum("bhar,hard->bd", g_q, query_w) # [Bt,D] | |
| g_w = torch.einsum("bhar,bd->hard", g_q, x_tile) # [H,A,r,D] | |
| return g_x, g_w, None, None, None | |
| def router_score_topb( | |
| x_tile: torch.Tensor, | |
| query_w: torch.Tensor, | |
| axis_size: int, | |
| top_b: int, | |
| s_tile: int, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """See _RouterScoreTopB.forward for return shapes.""" | |
| return _RouterScoreTopB.apply(x_tile, query_w, axis_size, top_b, s_tile) | |
| # ============================================================ | |
| # 1b) From precomputed (normalized) queries; grads wrt q only | |
| # ============================================================ | |
| class _RouterScoreTopBFromQueries(torch.autograd.Function): | |
| @staticmethod | |
| def forward( | |
| ctx, | |
| q_axes: torch.Tensor, # [Bt,H,A,r] fp32 (already normalized if desired) | |
| axis_size: int, # S | |
| top_b: int, # b | |
| s_tile: int # tile size over S | |
| ): | |
| """ | |
| Returns: | |
| vals_all: [Bt, H, A, b] (fp32) | |
| idxs_all: [Bt, H, A, b] (int32) | |
| """ | |
| if s_tile <= 0: | |
| raise ValueError("s_tile must be positive") | |
| Bt, H, A, r = q_axes.shape | |
| S = int(axis_size) | |
| b = min(int(top_b), S) | |
| device = q_axes.device | |
| Q = q_axes.permute(1, 2, 0, 3).contiguous().reshape(H * A, Bt, r) | |
| top_vals = torch.full((H * A, Bt, b), float("-inf"), device=device, dtype=torch.float32) | |
| top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32) | |
| for s0 in range(0, S, s_tile): | |
| s1 = min(S, s0 + s_tile) | |
| s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32) | |
| K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t] | |
| logits = torch.matmul(Q, K_rt) # [H*A,Bt,s_t] | |
| merged_vals = torch.cat([top_vals, logits], dim=-1) | |
| merged_idx = torch.cat( | |
| [ | |
| top_idx, | |
| torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0), | |
| ], | |
| dim=-1, | |
| ) | |
| keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1) | |
| keep_idx = merged_idx.gather(-1, keep_pos) | |
| top_vals, top_idx = keep_vals, keep_idx | |
| vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| # Only need indices to backprop to q_axes | |
| ctx.save_for_backward(idxs_all) | |
| ctx.meta = (r,) | |
| return vals_all, idxs_all | |
| @staticmethod | |
| def backward(ctx, g_vals_all: torch.Tensor, _g_idxs_all: Optional[torch.Tensor] = None): | |
| (idxs_all,) = ctx.saved_tensors | |
| (r,) = ctx.meta | |
| if g_vals_all is None: | |
| # shape-safe zero with correct dtype/device | |
| return torch.zeros_like(idxs_all, dtype=torch.float32)[..., :0], None, None, None | |
| codes = sinusoidal_codes_from_ids(idxs_all, r) # [Bt,H,A,b,r] | |
| g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3) # [Bt,H,A,r] | |
| return g_q, None, None, None | |
| def router_score_topb_from_queries( | |
| q_axes: torch.Tensor, | |
| axis_size: int, | |
| top_b: int, | |
| s_tile: int, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """See _RouterScoreTopBFromQueries.forward for return shapes.""" | |
| return _RouterScoreTopBFromQueries.apply(q_axes, axis_size, top_b, s_tile) | |
| # ============================================================ | |
| # 2) Additive beam search over axes; grads -> per-axis b scores | |
| # ============================================================ | |
| class _BeamSearchAdditive(torch.autograd.Function): | |
| @staticmethod | |
| def forward( | |
| ctx, | |
| vals_all: torch.Tensor, # [B,H,A,b] fp32 (per-axis top-b scores) | |
| idxs_all: torch.Tensor, # [B,H,A,b] int32 (matching IDs per axis) | |
| top_k: int, | |
| beam_width: int, | |
| ): | |
| device = vals_all.device | |
| B, H, A, b = vals_all.shape | |
| Bwidth = int(beam_width) | |
| # Running beam state | |
| beam_scores = torch.full((B, H, Bwidth), float("-inf"), device=device, dtype=torch.float32) | |
| beam_idx = torch.empty(B, H, Bwidth, A, device=device, dtype=torch.int32) | |
| # Initialize from first axis | |
| v0, i0 = vals_all[:, :, 0, :], idxs_all[:, :, 0, :] | |
| keep0 = min(Bwidth, b) | |
| base_vals, base_pos = torch.topk(v0, k=keep0, dim=-1) # [B,H,keep0] | |
| beam_scores[:, :, :keep0] = base_vals | |
| beam_idx[:, :, :keep0, 0] = i0.gather(-1, base_pos) | |
| cur = keep0 | |
| # Combine axes additively | |
| for a in range(1, A): | |
| va = vals_all[:, :, a, :] | |
| ia = idxs_all[:, :, a, :] | |
| cur_scores = beam_scores[:, :, :cur].unsqueeze(-1) + va.unsqueeze(-2) # [B,H,cur,b] | |
| flat = cur_scores.reshape(B, H, -1) | |
| keep = min(Bwidth, flat.size(-1)) | |
| kv, kf = torch.topk(flat, k=keep, dim=-1) # kv: new beam scores, kf: flat indices | |
| parent, choose = kf // b, kf % b # map back to (parent-beam, choice in b) | |
| if a > 0: | |
| # Propagate parent history | |
| parent_exp = parent.unsqueeze(-1).expand(-1, -1, -1, a) # [B,H,keep,a] | |
| beam_idx[:, :, :keep, :a] = torch.gather(beam_idx[:, :, :cur, :a], 2, parent_exp) | |
| beam_idx[:, :, :keep, a] = ia.gather(-1, choose) | |
| beam_scores[:, :, :keep] = kv | |
| cur = keep | |
| K = min(int(top_k), cur) | |
| final_vals, final_pos = torch.topk(beam_scores[:, :, :cur], k=K, dim=-1) # [B,H,K] | |
| coords = torch.gather( | |
| beam_idx[:, :, :cur, :], | |
| 2, | |
| final_pos.unsqueeze(-1).expand(-1, -1, -1, A), | |
| ).to(torch.int16) # [B,H,K,A] (small type; IDs typically < 65536) | |
| # Map back which b-slot each coord came from so we can scatter gradients | |
| idxs_exp = idxs_all.unsqueeze(2) # [B,H,1,A,b] | |
| coords_exp = coords.to(torch.int32).unsqueeze(-1) # [B,H,K,A,1] | |
| eq = (idxs_exp == coords_exp) # [B,H,K,A,b] | |
| choose_idx = torch.argmax(eq.to(torch.int32), dim=-1) # [B,H,K,A] (which 'b' index) | |
| scores = F.softmax(final_vals, dim=-1) # [B,H,K] | |
| ctx.save_for_backward(scores, choose_idx) | |
| ctx.meta = (B, H, A, b, K) | |
| return scores, coords | |
| @staticmethod | |
| def backward(ctx, g_scores: torch.Tensor, _g_coords: Optional[torch.Tensor] = None): | |
| scores, choose_idx = ctx.saved_tensors | |
| B, H, A, b, K = ctx.meta | |
| if g_scores is None: | |
| return torch.zeros((B, H, A, b), dtype=torch.float32, device=scores.device), None, None, None | |
| # Softmax backward on top-K beam logits | |
| g = g_scores | |
| s = scores | |
| dot = (g * s).sum(dim=-1, keepdim=True) | |
| g_z = (g - dot) * s # [B,H,K] | |
| # Scatter back into per-axis top-b slots using choose_idx | |
| g_vals_all = torch.zeros((B, H, A, b), dtype=torch.float32, device=scores.device) | |
| for a in range(A): | |
| tgt = g_vals_all[:, :, a, :] | |
| idx = choose_idx[:, :, :, a] # Long by default | |
| tgt.scatter_add_(dim=-1, index=idx, src=g_z) | |
| return g_vals_all, None, None, None | |
| def beam_search_additive( | |
| vals_all: torch.Tensor, | |
| idxs_all: torch.Tensor, | |
| top_k: int, | |
| beam_width: int, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| vals_all: [B,H,A,b] per-axis top-b scores (float32) | |
| idxs_all: [B,H,A,b] per-axis ids matching vals_all (int32) | |
| top_k: beams to return | |
| beam_width: internal beam budget | |
| Returns: | |
| scores: [B,H,K] softmax over the K beams per (B,H) | |
| coords: [B,H,K,A] int16, chosen ID per axis for each beam | |
| """ | |
| return _BeamSearchAdditive.apply(vals_all, idxs_all, top_k, beam_width) | |
| # ----------------------------- | |
| # High-level Router nn.Module | |
| # ----------------------------- | |
| class BeamRouter(nn.Module): | |
| """ | |
| Projects x:[B,D] -> per-axis queries q:[B,H,A,r], streams over S to keep top-b | |
| candidates per axis, then performs additive beam search across axes. | |
| Forward: | |
| scores:[B,H,K], coords:[B,H,K,A] | |
| Args: | |
| d_model: input dimension D | |
| num_heads: H heads | |
| num_axes: A axes per head | |
| axis_qdim: r (must be even for sinusoidal id codes) | |
| axis_size: S ids per axis | |
| top_k: K beams per (B,H) | |
| beam_axis_b: per-axis candidates b to keep while streaming S | |
| beam_width: internal beam width for combining axes | |
| s_tile: tile size while streaming S (memory vs. speed tradeoff) | |
| Notes: | |
| * Applies RMSNorm over r to stabilize query directions. | |
| * Includes a learnable per-(H,A) temperature (logit_scale). | |
| """ | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int = 4, | |
| num_axes: int = 6, | |
| axis_qdim: int = 32, | |
| axis_size: int = 512, | |
| top_k: int = 8, | |
| beam_axis_b: int = 32, | |
| beam_width: int = 64, | |
| s_tile: int = 512, | |
| ): | |
| super().__init__() | |
| self.D = int(d_model) | |
| self.H = int(num_heads) | |
| self.A = int(num_axes) | |
| self.r = int(axis_qdim) | |
| self.S = int(axis_size) | |
| self.K = int(top_k) | |
| self.b = int(beam_axis_b) | |
| self.bw = int(beam_width) | |
| self.s_tile = int(s_tile) | |
| # Query projection: [H,A,r,D] | |
| self.query_w = nn.Parameter(torch.empty(self.H, self.A, self.r, self.D, dtype=torch.float32)) | |
| with torch.no_grad(): | |
| self.query_w.normal_(0.0, (1.0 / self.D) ** 0.5) | |
| # Per-axis RMSNorm over r (direction-preserving) | |
| self.q_norm = nn.RMSNorm(self.r, eps=1e-6, elementwise_affine=False) | |
| # Optional per-(H,A) temperature to re-calibrate cross-axis scales | |
| self.logit_scale = nn.Parameter(torch.zeros(self.H, self.A)) # exp() ~ 1.0 | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| x: [B,D] float32 | |
| Returns: | |
| scores: [B,H,K] float32 | |
| coords: [B,H,K,A] int16 | |
| """ | |
| # Form queries and normalize per (B,H,A) over r | |
| q_axes = torch.einsum("bd,hard->bhar", x, self.query_w) # [B,H,A,r] | |
| q_axes = self.q_norm(q_axes) | |
| # Route from precomputed queries | |
| vals, idxs = router_score_topb_from_queries(q_axes, axis_size=self.S, top_b=self.b, s_tile=self.s_tile) | |
| # Optional temperature | |
| vals = vals * self.logit_scale.exp()[None, :, :, None] | |
| scores, coords = beam_search_additive(vals, idxs, top_k=self.K, beam_width=self.bw) | |
| return scores, coords |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment