Skip to content

Instantly share code, notes, and snippets.

@wrmedford
Created November 3, 2025 23:29
Show Gist options
  • Select an option

  • Save wrmedford/c4dd0dedd2f64ee250e1562d3059245f to your computer and use it in GitHub Desktop.

Select an option

Save wrmedford/c4dd0dedd2f64ee250e1562d3059245f to your computer and use it in GitHub Desktop.
# ethos_router.py
# A compact, self-contained PyTorch router with sinusoidal ID codes,
# top-b axis selection, and additive beam search.
#
# Shapes (conventions used below):
# B = batch
# D = model/embedding dim
# H = heads
# A = routing axes per head
# r = axis query dim (must be even for [cos,sin] codes)
# S = size of the discrete ID space per axis
# b = per-axis candidates kept during streaming top-b
# K = final beams per (B,H)
#
# Public surface:
# - sinusoidal_codes_from_ids(ids, r)
# - router_score_topb(...) / router_score_topb_from_queries(...)
# - beam_search_additive(vals_all, idxs_all, top_k, beam_width)
# - BeamRouter(d_model, ...)
#
# Notes:
# * Streaming over S with tiles keeps memory bounded.
# * Autograd is implemented to propagate gradients to x/W or to q_axes.
# * All math is FP32; works on CPU/CUDA. IDs are stored as int32 to save memory.
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"sinusoidal_codes_from_ids",
"router_score_topb",
"router_score_topb_from_queries",
"beam_search_additive",
"BeamRouter",
]
# -----------------------------
# Sinusoidal ID codebook (ETHOS)
# -----------------------------
def sinusoidal_codes_from_ids(ids: torch.Tensor, r: int) -> torch.Tensor:
"""
Build [cos, sin]-concatenated sinusoidal codes for integer IDs.
Args:
ids: integer tensor [...], each in [0, S)
r: code dimension (must be even)
Returns:
codes: float32 tensor [..., r]
"""
if r % 2 != 0:
raise ValueError(f"axis_qdim r must be even, got r={r}")
device = ids.device
r_half = r // 2
# Log-spaced frequencies in (10^0 .. 10^-1). Any smooth spectrum works.
freq = torch.logspace(0.0, -1.0, steps=r_half, base=10.0, device=device, dtype=torch.float32)
angle = ids.to(torch.float32).unsqueeze(-1) * freq.unsqueeze(0) * (2.0 * math.pi)
return torch.cat([torch.cos(angle), torch.sin(angle)], dim=-1).to(torch.float32)
# ============================================================
# 1) RouterScoreTopB: stream over S, keep per-axis top-b
# (computes queries from x and W, provides grads to both)
# ============================================================
class _RouterScoreTopB(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x_tile: torch.Tensor, # [Bt, D] fp32
query_w: torch.Tensor, # [H, A, r, D] fp32
axis_size: int, # S
top_b: int, # b
s_tile: int # tile size over S
):
"""
Returns:
vals_all: [Bt, H, A, b] (fp32)
idxs_all: [Bt, H, A, b] (int32)
"""
if s_tile <= 0:
raise ValueError("s_tile must be positive")
Bt, D = x_tile.shape
H, A, r, Dw = query_w.shape
if D != Dw:
raise ValueError(f"Input dim D={D} must match query_w[..., D]={Dw}")
S = int(axis_size)
b = min(int(top_b), S)
device = x_tile.device
# q_axes: [Bt,H,A,r] (no normalization here)
q_axes = torch.einsum("bd,hard->bhar", x_tile, query_w)
# Collapse heads/axes for efficient streaming
Q = q_axes.permute(1, 2, 0, 3).contiguous().reshape(H * A, Bt, r)
top_vals = torch.full((H * A, Bt, b), float("-inf"), device=device, dtype=torch.float32)
top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32)
for s0 in range(0, S, s_tile):
s1 = min(S, s0 + s_tile)
s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32) # [s_t]
K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t]
logits = torch.matmul(Q, K_rt) # [H*A, Bt, s_t]
merged_vals = torch.cat([top_vals, logits], dim=-1) # [H*A,Bt,b+s_t]
merged_idx = torch.cat(
[
top_idx,
torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0),
],
dim=-1,
)
keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1) # [H*A,Bt,b]
keep_idx = merged_idx.gather(-1, keep_pos)
top_vals, top_idx = keep_vals, keep_idx
vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
# Save for backward to propagate through q = einsum(x, W)
ctx.save_for_backward(x_tile, query_w, idxs_all)
ctx.meta = (H, A, r, D)
return vals_all, idxs_all
@staticmethod
def backward(ctx, g_vals_all: torch.Tensor, _g_idxs_all: Optional[torch.Tensor] = None):
x_tile, query_w, idxs_all = ctx.saved_tensors
H, A, r, D = ctx.meta
if g_vals_all is None:
return torch.zeros_like(x_tile), torch.zeros_like(query_w), None, None, None
# Gather codes and accumulate gradient wrt q_axes
codes = sinusoidal_codes_from_ids(idxs_all, r) # [Bt,H,A,b,r]
g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3) # [Bt,H,A,r]
# q = einsum('bd,hard->bhar')
g_x = torch.einsum("bhar,hard->bd", g_q, query_w) # [Bt,D]
g_w = torch.einsum("bhar,bd->hard", g_q, x_tile) # [H,A,r,D]
return g_x, g_w, None, None, None
def router_score_topb(
x_tile: torch.Tensor,
query_w: torch.Tensor,
axis_size: int,
top_b: int,
s_tile: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""See _RouterScoreTopB.forward for return shapes."""
return _RouterScoreTopB.apply(x_tile, query_w, axis_size, top_b, s_tile)
# ============================================================
# 1b) From precomputed (normalized) queries; grads wrt q only
# ============================================================
class _RouterScoreTopBFromQueries(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q_axes: torch.Tensor, # [Bt,H,A,r] fp32 (already normalized if desired)
axis_size: int, # S
top_b: int, # b
s_tile: int # tile size over S
):
"""
Returns:
vals_all: [Bt, H, A, b] (fp32)
idxs_all: [Bt, H, A, b] (int32)
"""
if s_tile <= 0:
raise ValueError("s_tile must be positive")
Bt, H, A, r = q_axes.shape
S = int(axis_size)
b = min(int(top_b), S)
device = q_axes.device
Q = q_axes.permute(1, 2, 0, 3).contiguous().reshape(H * A, Bt, r)
top_vals = torch.full((H * A, Bt, b), float("-inf"), device=device, dtype=torch.float32)
top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32)
for s0 in range(0, S, s_tile):
s1 = min(S, s0 + s_tile)
s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32)
K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t]
logits = torch.matmul(Q, K_rt) # [H*A,Bt,s_t]
merged_vals = torch.cat([top_vals, logits], dim=-1)
merged_idx = torch.cat(
[
top_idx,
torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0),
],
dim=-1,
)
keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1)
keep_idx = merged_idx.gather(-1, keep_pos)
top_vals, top_idx = keep_vals, keep_idx
vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
# Only need indices to backprop to q_axes
ctx.save_for_backward(idxs_all)
ctx.meta = (r,)
return vals_all, idxs_all
@staticmethod
def backward(ctx, g_vals_all: torch.Tensor, _g_idxs_all: Optional[torch.Tensor] = None):
(idxs_all,) = ctx.saved_tensors
(r,) = ctx.meta
if g_vals_all is None:
# shape-safe zero with correct dtype/device
return torch.zeros_like(idxs_all, dtype=torch.float32)[..., :0], None, None, None
codes = sinusoidal_codes_from_ids(idxs_all, r) # [Bt,H,A,b,r]
g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3) # [Bt,H,A,r]
return g_q, None, None, None
def router_score_topb_from_queries(
q_axes: torch.Tensor,
axis_size: int,
top_b: int,
s_tile: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""See _RouterScoreTopBFromQueries.forward for return shapes."""
return _RouterScoreTopBFromQueries.apply(q_axes, axis_size, top_b, s_tile)
# ============================================================
# 2) Additive beam search over axes; grads -> per-axis b scores
# ============================================================
class _BeamSearchAdditive(torch.autograd.Function):
@staticmethod
def forward(
ctx,
vals_all: torch.Tensor, # [B,H,A,b] fp32 (per-axis top-b scores)
idxs_all: torch.Tensor, # [B,H,A,b] int32 (matching IDs per axis)
top_k: int,
beam_width: int,
):
device = vals_all.device
B, H, A, b = vals_all.shape
Bwidth = int(beam_width)
# Running beam state
beam_scores = torch.full((B, H, Bwidth), float("-inf"), device=device, dtype=torch.float32)
beam_idx = torch.empty(B, H, Bwidth, A, device=device, dtype=torch.int32)
# Initialize from first axis
v0, i0 = vals_all[:, :, 0, :], idxs_all[:, :, 0, :]
keep0 = min(Bwidth, b)
base_vals, base_pos = torch.topk(v0, k=keep0, dim=-1) # [B,H,keep0]
beam_scores[:, :, :keep0] = base_vals
beam_idx[:, :, :keep0, 0] = i0.gather(-1, base_pos)
cur = keep0
# Combine axes additively
for a in range(1, A):
va = vals_all[:, :, a, :]
ia = idxs_all[:, :, a, :]
cur_scores = beam_scores[:, :, :cur].unsqueeze(-1) + va.unsqueeze(-2) # [B,H,cur,b]
flat = cur_scores.reshape(B, H, -1)
keep = min(Bwidth, flat.size(-1))
kv, kf = torch.topk(flat, k=keep, dim=-1) # kv: new beam scores, kf: flat indices
parent, choose = kf // b, kf % b # map back to (parent-beam, choice in b)
if a > 0:
# Propagate parent history
parent_exp = parent.unsqueeze(-1).expand(-1, -1, -1, a) # [B,H,keep,a]
beam_idx[:, :, :keep, :a] = torch.gather(beam_idx[:, :, :cur, :a], 2, parent_exp)
beam_idx[:, :, :keep, a] = ia.gather(-1, choose)
beam_scores[:, :, :keep] = kv
cur = keep
K = min(int(top_k), cur)
final_vals, final_pos = torch.topk(beam_scores[:, :, :cur], k=K, dim=-1) # [B,H,K]
coords = torch.gather(
beam_idx[:, :, :cur, :],
2,
final_pos.unsqueeze(-1).expand(-1, -1, -1, A),
).to(torch.int16) # [B,H,K,A] (small type; IDs typically < 65536)
# Map back which b-slot each coord came from so we can scatter gradients
idxs_exp = idxs_all.unsqueeze(2) # [B,H,1,A,b]
coords_exp = coords.to(torch.int32).unsqueeze(-1) # [B,H,K,A,1]
eq = (idxs_exp == coords_exp) # [B,H,K,A,b]
choose_idx = torch.argmax(eq.to(torch.int32), dim=-1) # [B,H,K,A] (which 'b' index)
scores = F.softmax(final_vals, dim=-1) # [B,H,K]
ctx.save_for_backward(scores, choose_idx)
ctx.meta = (B, H, A, b, K)
return scores, coords
@staticmethod
def backward(ctx, g_scores: torch.Tensor, _g_coords: Optional[torch.Tensor] = None):
scores, choose_idx = ctx.saved_tensors
B, H, A, b, K = ctx.meta
if g_scores is None:
return torch.zeros((B, H, A, b), dtype=torch.float32, device=scores.device), None, None, None
# Softmax backward on top-K beam logits
g = g_scores
s = scores
dot = (g * s).sum(dim=-1, keepdim=True)
g_z = (g - dot) * s # [B,H,K]
# Scatter back into per-axis top-b slots using choose_idx
g_vals_all = torch.zeros((B, H, A, b), dtype=torch.float32, device=scores.device)
for a in range(A):
tgt = g_vals_all[:, :, a, :]
idx = choose_idx[:, :, :, a] # Long by default
tgt.scatter_add_(dim=-1, index=idx, src=g_z)
return g_vals_all, None, None, None
def beam_search_additive(
vals_all: torch.Tensor,
idxs_all: torch.Tensor,
top_k: int,
beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
vals_all: [B,H,A,b] per-axis top-b scores (float32)
idxs_all: [B,H,A,b] per-axis ids matching vals_all (int32)
top_k: beams to return
beam_width: internal beam budget
Returns:
scores: [B,H,K] softmax over the K beams per (B,H)
coords: [B,H,K,A] int16, chosen ID per axis for each beam
"""
return _BeamSearchAdditive.apply(vals_all, idxs_all, top_k, beam_width)
# -----------------------------
# High-level Router nn.Module
# -----------------------------
class BeamRouter(nn.Module):
"""
Projects x:[B,D] -> per-axis queries q:[B,H,A,r], streams over S to keep top-b
candidates per axis, then performs additive beam search across axes.
Forward:
scores:[B,H,K], coords:[B,H,K,A]
Args:
d_model: input dimension D
num_heads: H heads
num_axes: A axes per head
axis_qdim: r (must be even for sinusoidal id codes)
axis_size: S ids per axis
top_k: K beams per (B,H)
beam_axis_b: per-axis candidates b to keep while streaming S
beam_width: internal beam width for combining axes
s_tile: tile size while streaming S (memory vs. speed tradeoff)
Notes:
* Applies RMSNorm over r to stabilize query directions.
* Includes a learnable per-(H,A) temperature (logit_scale).
"""
def __init__(
self,
d_model: int,
num_heads: int = 4,
num_axes: int = 6,
axis_qdim: int = 32,
axis_size: int = 512,
top_k: int = 8,
beam_axis_b: int = 32,
beam_width: int = 64,
s_tile: int = 512,
):
super().__init__()
self.D = int(d_model)
self.H = int(num_heads)
self.A = int(num_axes)
self.r = int(axis_qdim)
self.S = int(axis_size)
self.K = int(top_k)
self.b = int(beam_axis_b)
self.bw = int(beam_width)
self.s_tile = int(s_tile)
# Query projection: [H,A,r,D]
self.query_w = nn.Parameter(torch.empty(self.H, self.A, self.r, self.D, dtype=torch.float32))
with torch.no_grad():
self.query_w.normal_(0.0, (1.0 / self.D) ** 0.5)
# Per-axis RMSNorm over r (direction-preserving)
self.q_norm = nn.RMSNorm(self.r, eps=1e-6, elementwise_affine=False)
# Optional per-(H,A) temperature to re-calibrate cross-axis scales
self.logit_scale = nn.Parameter(torch.zeros(self.H, self.A)) # exp() ~ 1.0
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: [B,D] float32
Returns:
scores: [B,H,K] float32
coords: [B,H,K,A] int16
"""
# Form queries and normalize per (B,H,A) over r
q_axes = torch.einsum("bd,hard->bhar", x, self.query_w) # [B,H,A,r]
q_axes = self.q_norm(q_axes)
# Route from precomputed queries
vals, idxs = router_score_topb_from_queries(q_axes, axis_size=self.S, top_b=self.b, s_tile=self.s_tile)
# Optional temperature
vals = vals * self.logit_scale.exp()[None, :, :, None]
scores, coords = beam_search_additive(vals, idxs, top_k=self.K, beam_width=self.bw)
return scores, coords
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment