Created
October 29, 2025 13:38
-
-
Save wrmedford/ef452a86bae0c7dd1201b5e4e265729a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ethos_gen_mlp.py | |
| # ETHOS (TF32/FP32) for flattened CIFAR-10: | |
| # Router (pure-ID, custom autograd) → W1/W2 → FiLM → vector gate → dictionaries → mean heads → logits | |
| from __future__ import annotations | |
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Use TF32 paths on Ampere+ (still FP32 tensors) | |
| try: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| except Exception: | |
| pass | |
| try: | |
| torch.backends.cudnn.allow_tf32 = True | |
| except Exception: | |
| pass | |
| try: | |
| torch.set_float32_matmul_precision("high") # enables TF32 where applicable | |
| except Exception: | |
| pass | |
| def gelu_tanh(x: torch.Tensor) -> torch.Tensor: | |
| return F.gelu(x, approximate="tanh") | |
| # ========================= | |
| # Sinusoidal codes (ETHOS) | |
| # ========================= | |
| def sinusoidal_codes_from_ids(ids: torch.Tensor, r: int) -> torch.Tensor: | |
| """ | |
| ids: [...], integer in [0, S) | |
| returns: [..., r] FP32 with [cos, sin] concatenation | |
| """ | |
| assert r % 2 == 0, "axis_qdim r must be even" | |
| device = ids.device | |
| r_half = r // 2 | |
| # Log-spaced frequencies in (10^0 .. 10^-1); any smooth spectrum is fine | |
| freq = torch.logspace(0.0, -1.0, steps=r_half, base=10.0, device=device, dtype=torch.float32) | |
| angle = ids.to(dtype=torch.float32).unsqueeze(-1) * freq.unsqueeze(0) * (2.0 * math.pi) | |
| cos = torch.cos(angle) | |
| sin = torch.sin(angle) | |
| return torch.cat([cos, sin], dim=-1).to(torch.float32) | |
| # ============================================================ | |
| # 1) RouterScoreTopB (legacy): stream S, keep per-axis b | |
| # Computes q_axes internally (no normalization here) | |
| # ============================================================ | |
| class RouterScoreTopBFunction(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, | |
| x_tile: torch.Tensor, # [Bt, D] fp32 | |
| query_w: torch.Tensor, # [H, A, r, D] fp32 | |
| axis_size: int, # S | |
| top_b: int, # b | |
| s_tile: int): # chunk over S | |
| """ | |
| Returns: | |
| vals_all: [Bt, H, A, b] (fp32) | |
| idxs_all: [Bt, H, A, b] (int32) | |
| """ | |
| Bt, D = x_tile.shape | |
| H, A, r, Dw = query_w.shape | |
| assert D == Dw | |
| S = int(axis_size) | |
| b = min(int(top_b), S) | |
| device = x_tile.device | |
| # q_axes: [Bt,H,A,r] fp32 (NO normalization here) | |
| q_axes = torch.einsum('bd,hard->bhar', x_tile, query_w) | |
| # group: [H*A, Bt, r] | |
| Q = q_axes.permute(1, 2, 0, 3).contiguous().view(H * A, Bt, r) | |
| top_vals = torch.full((H * A, Bt, b), float('-inf'), device=device, dtype=torch.float32) | |
| top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32) | |
| for s0 in range(0, S, s_tile): | |
| s1 = min(S, s0 + s_tile) | |
| s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32) # [s_t] | |
| K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t] fp32 | |
| logits_tile = torch.matmul(Q, K_rt) # [H*A, Bt, s_t] fp32 | |
| merged_vals = torch.cat([top_vals, logits_tile], dim=-1) # [H*A, Bt, b + s_t] | |
| merged_idx = torch.cat([ | |
| top_idx, | |
| torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0) | |
| ], dim=-1) # [H*A, Bt, b + s_t] | |
| keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1) # [H*A, Bt, b] | |
| keep_idx = merged_idx.gather(-1, keep_pos) | |
| top_vals, top_idx = keep_vals, keep_idx | |
| vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| # Save tensors that are needed to backprop through q = einsum(x, W) | |
| ctx.save_for_backward(x_tile, query_w, idxs_all) | |
| ctx.meta = (H, A, r, D, b, S, s_tile) | |
| return vals_all, idxs_all | |
| @staticmethod | |
| def backward(ctx, g_vals_all: torch.Tensor, g_idxs_all: Optional[torch.Tensor] = None): | |
| x_tile, query_w, idxs_all = ctx.saved_tensors | |
| H, A, r, D, b, S, s_tile = ctx.meta | |
| if g_vals_all is None: | |
| return (torch.zeros_like(x_tile), torch.zeros_like(query_w), None, None, None) | |
| # Gather code vectors for selected indices: [Bt,H,A,b,r] | |
| codes = sinusoidal_codes_from_ids(idxs_all, r) # fp32 | |
| # grad wrt q_axes: sum_b g_vals[..., b] * code[..., b, :] | |
| g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3) # [Bt,H,A,r] fp32 | |
| # q_axes = einsum('bd,hard->bhar') | |
| g_x = torch.einsum('nhar,hard->nd', g_q, query_w) # [Bt,D] | |
| g_w = torch.einsum('nhar,nd->hard', g_q, x_tile) # [H,A,r,D] | |
| return g_x, g_w, None, None, None | |
| def router_score_topb(x_tile: torch.Tensor, | |
| query_w: torch.Tensor, | |
| axis_size: int, | |
| top_b: int, | |
| s_tile: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return RouterScoreTopBFunction.apply(x_tile, query_w, axis_size, top_b, s_tile) | |
| # ============================================================ | |
| # 1b) RouterScoreTopBFromQueries: same as above but takes | |
| # precomputed (and normalized) queries, returns grad wrt q | |
| # ============================================================ | |
| class RouterScoreTopBFromQueriesFunction(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, | |
| q_axes: torch.Tensor, # [Bt,H,A,r] fp32 (already normalized) | |
| axis_size: int, # S | |
| top_b: int, # b | |
| s_tile: int): # chunk over S | |
| """ | |
| Returns: | |
| vals_all: [Bt, H, A, b] (fp32) | |
| idxs_all: [Bt, H, A, b] (int32) | |
| """ | |
| Bt, H, A, r = q_axes.shape | |
| S = int(axis_size) | |
| b = min(int(top_b), S) | |
| device = q_axes.device | |
| # group: [H*A, Bt, r] | |
| Q = q_axes.permute(1, 2, 0, 3).contiguous().view(H * A, Bt, r) | |
| top_vals = torch.full((H * A, Bt, b), float('-inf'), device=device, dtype=torch.float32) | |
| top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32) | |
| for s0 in range(0, S, s_tile): | |
| s1 = min(S, s0 + s_tile) | |
| s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32) # [s_t] | |
| K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t] fp32 | |
| logits_tile = torch.matmul(Q, K_rt) # [H*A, Bt, s_t] fp32 | |
| merged_vals = torch.cat([top_vals, logits_tile], dim=-1) # [H*A, Bt, b + s_t] | |
| merged_idx = torch.cat([ | |
| top_idx, | |
| torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0) | |
| ], dim=-1) # [H*A, Bt, b + s_t] | |
| keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1) # [H*A, Bt, b] | |
| keep_idx = merged_idx.gather(-1, keep_pos) | |
| top_vals, top_idx = keep_vals, keep_idx | |
| vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b] | |
| # Save only what's needed to backprop to q_axes | |
| ctx.save_for_backward(idxs_all) | |
| ctx.meta = (r,) | |
| return vals_all, idxs_all | |
| @staticmethod | |
| def backward(ctx, g_vals_all: torch.Tensor, g_idxs_all: Optional[torch.Tensor] = None): | |
| (idxs_all,) = ctx.saved_tensors | |
| (r,) = ctx.meta | |
| if g_vals_all is None: | |
| return torch.zeros_like(idxs_all, dtype=torch.float32)[..., :0], None, None, None # shape-safe zero | |
| # codes: [Bt,H,A,b,r] | |
| codes = sinusoidal_codes_from_ids(idxs_all, r) | |
| # gradient wrt q_axes: [Bt,H,A,r] | |
| g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3) | |
| return g_q, None, None, None | |
| def router_score_topb_from_queries(q_axes: torch.Tensor, | |
| axis_size: int, | |
| top_b: int, | |
| s_tile: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return RouterScoreTopBFromQueriesFunction.apply(q_axes, axis_size, top_b, s_tile) | |
| # ============================================================ | |
| # 2) BeamSearchAdditive: additive combine; grads to per-axis b | |
| # ============================================================ | |
| class BeamSearchAdditiveFunction(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, | |
| vals_all: torch.Tensor, # [Bt,H,A,b] fp32 | |
| idxs_all: torch.Tensor, # [Bt,H,A,b] int32 | |
| top_k: int, | |
| beam_width: int): | |
| device = vals_all.device | |
| Bt, H, A, b = vals_all.shape | |
| Bwidth = int(beam_width) | |
| beam_scores = torch.full((Bt, H, Bwidth), float('-inf'), device=device, dtype=torch.float32) | |
| beam_idx = torch.empty(Bt, H, Bwidth, A, device=device, dtype=torch.int32) | |
| v0 = vals_all[:, :, 0, :] | |
| i0 = idxs_all[:, :, 0, :] | |
| keep0 = min(Bwidth, b) | |
| base_vals, base_pos = torch.topk(v0, k=keep0, dim=-1) | |
| beam_scores[:, :, :keep0] = base_vals | |
| beam_idx[:, :, :keep0, 0] = i0.gather(-1, base_pos) | |
| cur = keep0 | |
| for a in range(1, A): | |
| va = vals_all[:, :, a, :] | |
| ia = idxs_all[:, :, a, :] | |
| cur_scores = beam_scores[:, :, :cur].unsqueeze(-1) + va.unsqueeze(-2) # [Bt,H,cur,b] | |
| flat = cur_scores.view(Bt, H, -1) | |
| keep = min(Bwidth, flat.size(-1)) | |
| kv, kf = torch.topk(flat, k=keep, dim=-1) | |
| parent, choose = kf // b, kf % b | |
| parent_exp = parent.unsqueeze(-1).expand(-1, -1, -1, a) | |
| if a > 0: | |
| beam_idx[:, :, :keep, :a] = torch.gather(beam_idx[:, :, :cur, :a], 2, parent_exp) | |
| beam_idx[:, :, :keep, a] = ia.gather(-1, choose) | |
| beam_scores[:, :, :keep] = kv | |
| cur = keep | |
| K = min(int(top_k), cur) | |
| final_vals, final_pos = torch.topk(beam_scores[:, :, :cur], k=K, dim=-1) # [Bt,H,K] | |
| coords = torch.gather(beam_idx[:, :, :cur, :], | |
| 2, | |
| final_pos.unsqueeze(-1).expand(-1, -1, -1, A)).to(torch.int16) # [Bt,H,K,A] | |
| # Map back: which b-slot each coord came from | |
| idxs_exp = idxs_all.unsqueeze(2) # [Bt,H,1,A,b] | |
| coords_exp = coords.to(torch.int32).unsqueeze(-1) # [Bt,H,K,A,1] | |
| eq = (idxs_exp == coords_exp) # [Bt,H,K,A,b] | |
| choose_idx = torch.argmax(eq.to(torch.int32), dim=-1) # [Bt,H,K,A] | |
| scores = F.softmax(final_vals, dim=-1) # [Bt,H,K] fp32 | |
| ctx.save_for_backward(scores, choose_idx) | |
| ctx.meta = (Bt, H, A, b, K) | |
| return scores, coords | |
| @staticmethod | |
| def backward(ctx, g_scores: torch.Tensor, g_coords: Optional[torch.Tensor] = None): | |
| scores, choose_idx = ctx.saved_tensors | |
| Bt, H, A, b, K = ctx.meta | |
| if g_scores is None: | |
| return torch.zeros((Bt, H, A, b), dtype=torch.float32, device=scores.device), None, None, None | |
| # Softmax backward: g_z = (g - <g,s>) * s | |
| g = g_scores | |
| s = scores | |
| dot = (g * s).sum(dim=-1, keepdim=True) | |
| g_z = (g - dot) * s # [Bt,H,K] | |
| # Scatter back to per-axis b slots | |
| g_vals_all = torch.zeros((Bt, H, A, b), dtype=torch.float32, device=scores.device) | |
| for a in range(A): | |
| tgt = g_vals_all[:, :, a, :] | |
| idx = choose_idx[:, :, :, a] | |
| tgt.scatter_add_(dim=-1, index=idx, src=g_z) | |
| return g_vals_all, None, None, None | |
| def beam_search_additive(vals_all: torch.Tensor, | |
| idxs_all: torch.Tensor, | |
| top_k: int, | |
| beam_width: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return BeamSearchAdditiveFunction.apply(vals_all, idxs_all, top_k, beam_width) | |
| # ============================================================ | |
| # ETHOS module (FP32/TF32) — logits directly from dictionaries | |
| # ============================================================ | |
| class BeamRouter(nn.Module): | |
| def __init__(self, d_model, num_heads=4, num_axes=6, axis_qdim=32, | |
| axis_size=512, top_k=8, beam_axis_b=32, beam_width=64, s_tile=512): | |
| super().__init__() | |
| self.D, self.H, self.A, self.r = int(d_model), int(num_heads), int(num_axes), int(axis_qdim) | |
| self.S, self.K, self.bw, self.b = int(axis_size), int(top_k), int(beam_width), int(beam_axis_b) | |
| self.s_tile = int(s_tile) | |
| self.query_w = nn.Parameter(torch.empty(self.H, self.A, self.r, self.D, dtype=torch.float32)) | |
| with torch.no_grad(): | |
| self.query_w.normal_(0.0, (1.0 / self.D) ** 0.5) | |
| # Per-axis RMSNorm over r (direction-preserving) | |
| self.q_norm = nn.RMSNorm(self.r, eps=1e-6, elementwise_affine=False) | |
| # Optional: learned per-(H,A) temperature to re-calibrate cross-axis scale | |
| self.logit_scale = nn.Parameter(torch.zeros(self.H, self.A)) # exp() ~ 1.0 | |
| def forward(self, x: torch.Tensor): | |
| # Form queries and normalize per (B,H,A) over r | |
| q_axes = torch.einsum('bd,hard->bhar', x, self.query_w) # [B,H,A,r] | |
| q_axes = self.q_norm(q_axes) # [B,H,A,r] | |
| # Route using from-queries op (clean backward through RMSNorm + einsum) | |
| vals, idxs = router_score_topb_from_queries(q_axes, axis_size=self.S, top_b=self.b, s_tile=self.s_tile) | |
| # Optional temperature | |
| vals = vals * self.logit_scale.exp()[None, :, :, None] | |
| scores, coords = beam_search_additive(vals, idxs, top_k=self.K, beam_width=self.bw) | |
| return scores, coords | |
| class EthosGeneratedMLP(nn.Module): | |
| """ | |
| ETHOS path over flattened x:[B,D] (FP32/TF32): | |
| router -> W1_blocks/W2 -> (alpha,beta) -> dict(B_in/B_out) | |
| + FiLM (gamma,beta) over M -> vector gating -> mean heads -> logits:[B,C] | |
| No separate classifier; logits come from B_out directly. | |
| """ | |
| def __init__(self, d_in, n_classes, | |
| H=4, A=6, r=32, S=512, K=8, b=32, beam_width=64, | |
| Hh=128, M=32, film_eps=0.10, film_tanh_gamma=True, per_head_dict=True, s_tile=512): | |
| super().__init__() | |
| self.D, self.C = int(d_in), int(n_classes) | |
| self.film_eps = float(film_eps) | |
| self.film_tanh_gamma = bool(film_tanh_gamma) | |
| self.per_head_dict = bool(per_head_dict) | |
| # Router (pure-ID) | |
| self.router = BeamRouter(d_model=d_in, num_heads=H, num_axes=A, axis_qdim=r, | |
| axis_size=S, top_k=K, beam_axis_b=b, beam_width=beam_width, s_tile=s_tile) | |
| # Hypernet + dicts | |
| self.H, self.A, self.r, self.Hh, self.M = H, self.router.A, r, int(Hh), int(M) | |
| self.W1_blocks = nn.Parameter(torch.empty(H, self.A, r, self.Hh, dtype=torch.float32)) # [H,A,r,Hh] | |
| self.W2 = nn.Parameter(torch.empty(H, self.Hh, 2*self.M, dtype=torch.float32)) # [H,Hh,2M] | |
| with torch.no_grad(): | |
| nn.init.xavier_uniform_(self.W1_blocks.view(H*self.A, r, self.Hh)) | |
| nn.init.xavier_uniform_(self.W2) | |
| if self.per_head_dict: | |
| self.B_in = nn.Parameter(torch.empty(H, self.M, self.D, dtype=torch.float32)) # [H,M,D] | |
| self.B_out = nn.Parameter(torch.empty(H, self.M, self.C, dtype=torch.float32)) # [H,M,C] <-- logits direct | |
| with torch.no_grad(): | |
| nn.init.xavier_uniform_(self.B_in) | |
| nn.init.xavier_uniform_(self.B_out) | |
| self.B_in_shared = None | |
| self.B_out_shared = None | |
| else: | |
| self.B_in_shared = nn.Parameter(torch.empty(self.M, self.D, dtype=torch.float32)) # [M,D] | |
| self.B_out_shared = nn.Parameter(torch.empty(self.M, self.C, dtype=torch.float32)) # [M,C] <-- logits direct | |
| with torch.no_grad(): | |
| nn.init.xavier_uniform_(self.B_in_shared) | |
| nn.init.xavier_uniform_(self.B_out_shared) | |
| self.B_in = None | |
| self.B_out = None | |
| # FiLM parameters (zeros init) | |
| self.gamma_w = nn.Parameter(torch.zeros(H, self.M, self.Hh, dtype=torch.float32)) # [H,M,Hh] | |
| self.gamma_b = nn.Parameter(torch.zeros(H, self.M, dtype=torch.float32)) # [H,M] | |
| self.beta_w = nn.Parameter(torch.zeros(H, self.M, self.Hh, dtype=torch.float32)) # [H,M,Hh] | |
| self.beta_b = nn.Parameter(torch.zeros(H, self.M, dtype=torch.float32)) # [H,M] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: # x:[B,D] fp32 | |
| B, D = x.shape | |
| H, A, r, Hh, S, M, C = self.H, self.A, self.r, self.Hh, self.router.S, self.M, self.C | |
| # 1) Router | |
| scores, coords = self.router(x) # [B,H,K], [B,H,K,A] | |
| K = scores.size(2) | |
| # 2) Hyper features from selected coords (FP32/TF32) | |
| ids_flat = coords.reshape(-1).long() | |
| feats = sinusoidal_codes_from_ids(ids_flat, r).view(B, H, K, A, r) # [B,H,K,A,r] | |
| z_pre = torch.einsum('bhkar,harq->bhkq', feats, self.W1_blocks) # [B,H,K,Hh] | |
| g = gelu_tanh(z_pre) # [B,H,K,Hh] | |
| AB = torch.einsum('bhkq,hqm->bhkm', g, self.W2) # [B,H,K,2M] | |
| alpha = AB[..., :M] # [B,H,K,M] | |
| beta = AB[..., M:] # [B,H,K,M] | |
| # 3) Dictionaries: project x, then FiLM on M-dim | |
| if self.per_head_dict: | |
| h_proj = torch.einsum('bd,hmd->bhm', x, self.B_in) # [B,H,M] | |
| else: | |
| h_proj = torch.einsum('bd,md->bm', x, self.B_in_shared) # [B,M] | |
| h_proj = h_proj.unsqueeze(1).expand(-1, H, -1).contiguous() # [B,H,M] | |
| w = scores.unsqueeze(-1) # [B,H,K,1] | |
| gamma_k = torch.einsum('bhkq,hmq->bhkm', g, self.gamma_w) + self.gamma_b.unsqueeze(0).unsqueeze(2) | |
| if self.film_tanh_gamma: | |
| gamma_k = torch.tanh(gamma_k) | |
| beta_k = torch.einsum('bhkq,hmq->bhkm', g, self.beta_w) + self.beta_b.unsqueeze(0).unsqueeze(2) | |
| gamma_bar = (w * gamma_k).sum(dim=2) # [B,H,M] | |
| beta_bar = (w * beta_k).sum(dim=2) # [B,H,M] | |
| h_mod = (1.0 + self.film_eps * gamma_bar) * h_proj + self.film_eps * beta_bar # [B,H,M] | |
| # 4) Vector gate combine over beams | |
| act_M = gelu_tanh(alpha * h_mod.unsqueeze(2)) # [B,H,K,M] | |
| m = (w * act_M * beta).sum(dim=2) # [B,H,M] | |
| # 5) To logits via B_out, average heads ---> No separate classifier | |
| if self.per_head_dict: | |
| logits_h = torch.einsum('bhm,hmc->bhc', m, self.B_out) # [B,H,C] | |
| else: | |
| logits_h = torch.einsum('bhm,mc->bhc', m, self.B_out_shared) # [B,H,C] | |
| logits = logits_h.mean(dim=1) # [B,C] | |
| return logits | |
| # (Optional) param counter for logging | |
| def ethos_param_count(model: EthosGeneratedMLP) -> int: | |
| P = 0 | |
| P += model.router.query_w.numel() | |
| P += model.W1_blocks.numel() | |
| P += model.W2.numel() | |
| if model.per_head_dict: | |
| P += model.B_in.numel() + model.B_out.numel() | |
| else: | |
| P += model.B_in_shared.numel() + model.B_out_shared.numel() | |
| P += model.gamma_w.numel() + model.gamma_b.numel() | |
| P += model.beta_w.numel() + model.beta_b.numel() | |
| return P |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Fixed arch | |
| use_mha: [true] | |
| mha_heads: [6] | |
| mha_embed_dim: [384] | |
| mha_depth: [6] | |
| mha_patch: [4] | |
| mha_dropout: [0.1] | |
| H: [1] | |
| A: [8] | |
| r: [64] | |
| S: [2048] | |
| b: [64] | |
| beam: [32] | |
| K: [16] | |
| Hh: [1024] | |
| M: [512] | |
| share_dict: [false] | |
| classes: [10] | |
| # Baseline | |
| mlp_depth: [1] | |
| mlp_width: [0] | |
| mlp_norm: ["ln"] | |
| mlp_dropout: [0.0] | |
| # Training HPs to sweep | |
| epochs: [200] | |
| batch_size: [2048] | |
| lr: [0.0006, 0.0003] | |
| lr_ethos: [null] | |
| lr_mlp: [null] | |
| lr_router: [0.00144, 0.00108, 0.00072, 0.00036] # ETHOS-only | |
| lr_hypernet: [0.0006, 0.00048, 0.00036, 0.00024, 0.00012] # ETHOS-only | |
| weight_decay: [0.0005] | |
| warmup_epochs: [10] | |
| min_lr: [0.000001] | |
| grad_clip: [100.0, 1.0] | |
| seed: [42] | |
| workers: [8] | |
| data: ["./data"] | |
| # EMA + logging | |
| metric_ema_decay: [0.9] | |
| metric_ema_bias_correct: [true] | |
| sweep_name: ["flat_vs_ethos"] | |
| log_root: ["logs"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # train_flat_ethos_vs_mlp_sweep_cached.py | |
| from __future__ import annotations | |
| import argparse, math, os, random, sys, itertools, json, time, gc, hashlib | |
| from dataclasses import dataclass, asdict, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision import datasets, transforms | |
| # Requires PyYAML (pip install pyyaml) | |
| import yaml | |
| from ethos import EthosGeneratedMLP | |
| # Enable TF32 globally | |
| try: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| pass | |
| # ---------------------------- | |
| # Reproducibility helpers | |
| # ---------------------------- | |
| def set_seed(seed: int): | |
| random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| # ---------------------------- | |
| # Vision preprocessor (ViT-style) | |
| # ---------------------------- | |
| class PatchEmbed(nn.Module): | |
| """ | |
| Patchify image via a Conv2d: kernel=stride=patch_size, produces tokens. | |
| Input: [B,3,H,W] | |
| Output: [B, N, E] where N=(H/ps)*(W/ps), E=embed_dim | |
| """ | |
| def __init__(self, img_size: int = 32, patch_size: int = 4, in_chans: int = 3, embed_dim: int = 384): | |
| super().__init__() | |
| assert img_size % patch_size == 0, "img_size must be divisible by patch_size" | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.grid = img_size // patch_size | |
| self.num_patches = self.grid * self.grid | |
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.proj(x) # [B,E,H/ps,W/ps] | |
| x = x.flatten(2).transpose(1, 2) # [B,N,E] | |
| return x | |
| class TransformerBlock(nn.Module): | |
| """ | |
| Pre-LN Transformer encoder block. | |
| """ | |
| def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.1): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) | |
| self.drop1 = nn.Dropout(dropout) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| hidden_dim = int(embed_dim * mlp_ratio) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(embed_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, embed_dim), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| xa = self.norm1(x) | |
| a, _ = self.attn(xa, xa, xa) | |
| x = x + self.drop1(a) | |
| xm = self.norm2(x) | |
| x = x + self.mlp(xm) | |
| return x | |
| class VisionPreprocessor(nn.Module): | |
| """ | |
| ViT-style preprocessor that outputs a single vector [B, E] via mean-pooling tokens. | |
| This is a proper vision attention mechanism (tokens > 1). | |
| """ | |
| def __init__(self, img_size: int = 32, patch_size: int = 4, embed_dim: int = 384, | |
| depth: int = 2, num_heads: int = 6, dropout: float = 0.1): | |
| super().__init__() | |
| self.patch = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.patch.num_patches, embed_dim)) | |
| self.pos_drop = nn.Dropout(dropout) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4.0, dropout=dropout) | |
| for _ in range(depth) | |
| ]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| t = self.patch(x) | |
| t = t + self.pos_embed | |
| t = self.pos_drop(t) | |
| for blk in self.blocks: | |
| t = blk(t) | |
| t = self.norm(t) | |
| z = t.mean(dim=1) | |
| return z # [B, E] | |
| # ---------------------------- | |
| # CIFAR-10 transforms | |
| # ---------------------------- | |
| CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) | |
| CIFAR10_STD = (0.2470, 0.2435, 0.2616) | |
| def make_loaders(batch_size: int, num_workers: int = 8, data_root: str = "./data"): | |
| train_tf = transforms.Compose([ | |
| transforms.RandomCrop(32, padding=4), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD), | |
| ]) | |
| test_tf = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD), | |
| ]) | |
| train_set = datasets.CIFAR10(data_root, train=True, transform=train_tf, download=True) | |
| test_set = datasets.CIFAR10(data_root, train=False, transform=test_tf, download=True) | |
| train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, | |
| pin_memory=True, drop_last=True) | |
| test_loader = DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=num_workers, | |
| pin_memory=True, drop_last=False) | |
| return train_loader, test_loader | |
| # ---------------------------- | |
| # Iso-parameter MLP baseline (variable depth) | |
| # ---------------------------- | |
| class IsoParamMLP(nn.Module): | |
| """ | |
| Depth-L MLP with optional norm/dropout. Hidden width solved to match ETHOS core params unless --mlp_width is provided. | |
| """ | |
| def __init__(self, d_in: int, num_classes: int, width: int, depth: int, | |
| norm: str = "ln", dropout: float = 0.0): | |
| super().__init__() | |
| assert depth >= 1, "mlp_depth must be >= 1" | |
| self.depth = depth | |
| self.norm_type = norm | |
| self.dropout_p = float(dropout) | |
| layers = [] | |
| in_dim = d_in | |
| for _ in range(depth): | |
| layers.append(nn.Linear(in_dim, width, bias=True)) | |
| if norm == "ln": | |
| layers.append(nn.LayerNorm(width)) | |
| elif norm == "bn": | |
| layers.append(nn.BatchNorm1d(width)) | |
| elif norm == "none": | |
| pass | |
| else: | |
| raise ValueError(f"Unknown norm: {norm}") | |
| layers.append(nn.GELU()) | |
| if self.dropout_p > 0: | |
| layers.append(nn.Dropout(self.dropout_p)) | |
| in_dim = width | |
| self.hidden = nn.Sequential(*layers) | |
| self.fc_out = nn.Linear(width, num_classes, bias=True) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| if m is self.fc_out: | |
| nn.init.xavier_uniform_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| else: | |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): | |
| nn.init.ones_(m.weight); nn.init.zeros_(m.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.view(x.size(0), -1) | |
| x = self.hidden(x) | |
| return self.fc_out(x) | |
| def mlp_param_count(d_in: int, num_classes: int, width: int, depth: int, norm: str) -> int: | |
| linear = d_in * width + width # first | |
| if depth >= 2: | |
| linear += (depth - 1) * (width * width + width) | |
| linear += width * num_classes + num_classes # out | |
| norm_params_per_layer = 0 | |
| if norm in ("ln", "bn"): | |
| norm_params_per_layer = 2 * width | |
| norm_total = depth * norm_params_per_layer | |
| return linear + norm_total | |
| def solve_width_for_budget(d_in: int, num_classes: int, depth: int, norm: str, target_params: int) -> int: | |
| lo, hi = 1, max(2, int(math.sqrt(max(1, target_params)))) * 4 | |
| best_w, best_diff = 1, float("inf") | |
| while lo <= hi: | |
| mid = (lo + hi) // 2 | |
| p = mlp_param_count(d_in, num_classes, mid, depth, norm) | |
| diff = abs(p - target_params) | |
| if p <= target_params and diff < best_diff: | |
| best_w, best_diff = mid, diff | |
| if p <= target_params: | |
| lo = mid + 1 | |
| else: | |
| hi = mid - 1 | |
| return max(1, best_w) | |
| # ---------------------------- | |
| # Optim / Schedules | |
| # ---------------------------- | |
| class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): | |
| def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr=None, min_lr=1e-6, last_epoch=-1): | |
| self.warmup_epochs = warmup_epochs | |
| self.total_epochs = total_epochs | |
| self.min_lr = min_lr | |
| self.use_group_base_lr = (base_lr is None) | |
| self.base_lr = base_lr | |
| super().__init__(optimizer, last_epoch) | |
| def get_lr(self): | |
| epoch = self.last_epoch + 1 | |
| if epoch < self.warmup_epochs: | |
| scale = epoch / max(1, self.warmup_epochs) | |
| if self.use_group_base_lr: | |
| return [scale * base_lr for base_lr in self.base_lrs] | |
| else: | |
| return [scale * self.base_lr for _ in self.base_lrs] | |
| t = (epoch - self.warmup_epochs) / max(1, self.total_epochs - self.warmup_epochs) | |
| cos = 0.5 * (1 + math.cos(math.pi * t)) | |
| if self.use_group_base_lr: | |
| return [self.min_lr + (base_lr - self.min_lr) * cos for base_lr in self.base_lrs] | |
| else: | |
| lr = self.min_lr + (self.base_lr - self.min_lr) * cos | |
| return [lr for _ in self.base_lrs] | |
| # ---------------------------- | |
| # Metric EMA helper | |
| # ---------------------------- | |
| class ScalarEMA: | |
| """ | |
| EMA for scalar metrics (e.g., accuracy in %). | |
| If bias_correct=True, EMA_1 equals the first value. | |
| """ | |
| def __init__(self, decay: float = 0.9, bias_correct: bool = True): | |
| self.decay = float(decay) | |
| self.bias_correct = bool(bias_correct) | |
| self.t = 0 | |
| self.m = 0.0 | |
| def update(self, x: float) -> float: | |
| x = float(x) | |
| self.t += 1 | |
| self.m = self.decay * self.m + (1.0 - self.decay) * x | |
| if self.bias_correct: | |
| denom = 1.0 - (self.decay ** self.t) | |
| return self.m / max(1e-12, denom) | |
| return self.m | |
| def value(self) -> float | None: | |
| if self.t == 0: | |
| return None | |
| if self.bias_correct: | |
| denom = 1.0 - (self.decay ** self.t) | |
| return self.m / max(1e-12, denom) | |
| return self.m | |
| # ---------------------------- | |
| # Train / Eval loops (shared) | |
| # ---------------------------- | |
| def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> float: | |
| model.eval() | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for xb, yb in loader: | |
| xb = xb.to(device, non_blocking=True) | |
| yb = yb.to(device, non_blocking=True) | |
| logits = model(xb) | |
| pred = logits.argmax(dim=1) | |
| correct += (pred == yb).sum().item() | |
| total += yb.numel() | |
| return 100.0 * correct / max(1, total) | |
| def train_one_epoch(model: nn.Module, loader: DataLoader, device: torch.device, | |
| optimizer: torch.optim.Optimizer, criterion: nn.Module, | |
| grad_clip: float): | |
| model.train() | |
| for xb, yb in loader: | |
| xb = xb.to(device, non_blocking=True) | |
| yb = yb.to(device, non_blocking=True) | |
| logits = model(xb) | |
| loss = criterion(logits, yb) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| if grad_clip is not None and grad_clip > 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| # ---------------------------- | |
| # Args + parsing | |
| # ---------------------------- | |
| @dataclass | |
| class Args: | |
| # Core training | |
| epochs: int = 100 | |
| batch_size: int = 512 | |
| lr: float = 3e-4 | |
| lr_ethos: float | None = None | |
| lr_mlp: float | None = None | |
| lr_router: float | None = None | |
| lr_hypernet: float | None = None | |
| weight_decay: float = 5e-4 | |
| warmup_epochs: int = 5 | |
| min_lr: float = 1e-6 | |
| grad_clip: float = 1.0 | |
| seed: int = 42 | |
| workers: int = 8 | |
| data: str = "./data" | |
| # ETHOS knobs | |
| H: int = 1 | |
| A: int = 16 | |
| r: int = 64 | |
| S: int = 256 | |
| b: int = 64 | |
| beam: int = 128 | |
| K: int = 16 | |
| Hh: int = 256 | |
| M: int = 64 | |
| share_dict: bool = False | |
| classes: int = 10 | |
| # Baseline controls | |
| mlp_depth: int = 1 | |
| mlp_width: int = 0 | |
| mlp_norm: str = "ln" | |
| mlp_dropout: float = 0.0 | |
| # Vision preprocessor controls | |
| use_mha: bool = False | |
| mha_heads: int = 6 | |
| mha_embed_dim: int = 384 | |
| mha_depth: int = 2 | |
| mha_patch: int = 4 | |
| mha_dropout: float = 0.1 | |
| # Metric EMA | |
| metric_ema_decay: float = 0.9 | |
| metric_ema_bias_correct: bool = True | |
| # Sweep / logging | |
| config: str | None = None # path to YAML sweep file | |
| sweep_name: str = "sweep" | |
| log_root: str = "logs" | |
| # Baseline cache controls | |
| baseline_cache_dir: str | None = None # default resolves to <log_root>/cache | |
| baseline_cache_ignore_seed: bool = False | |
| baseline_cache_refresh: bool = False | |
| # internal (not exposed to YAML by default) | |
| _trial_id: str = field(default="", repr=False) | |
| def parse_args() -> Args: | |
| p = argparse.ArgumentParser() | |
| # Core training | |
| p.add_argument("--epochs", type=int, default=100) | |
| p.add_argument("--batch_size", type=int, default=512) | |
| p.add_argument("--lr", type=float, default=3e-4) | |
| p.add_argument("--lr_ethos", type=float, default=None) | |
| p.add_argument("--lr_mlp", type=float, default=None) | |
| p.add_argument("--lr_router", type=float, default=None) | |
| p.add_argument("--lr_hypernet", type=float, default=None) | |
| p.add_argument("--weight_decay", type=float, default=5e-4) | |
| p.add_argument("--warmup_epochs", type=int, default=100) # keep previous CLI default | |
| p.add_argument("--min_lr", type=float, default=1e-6) | |
| p.add_argument("--grad_clip", type=float, default=100.0) | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument("--workers", type=int, default=8) | |
| p.add_argument("--data", type=str, default="./data") | |
| # ETHOS | |
| p.add_argument("--H", type=int, default=1) | |
| p.add_argument("--A", type=int, default=16) | |
| p.add_argument("--r", type=int, default=64) | |
| p.add_argument("--S", type=int, default=256) | |
| p.add_argument("--b", type=int, default=64) | |
| p.add_argument("--beam", type=int, default=128) | |
| p.add_argument("--K", type=int, default=16) | |
| p.add_argument("--Hh", type=int, default=256) | |
| p.add_argument("--M", type=int, default=64) | |
| p.add_argument("--share_dict", action="store_true") | |
| p.add_argument("--classes", type=int, default=10) | |
| # Baseline | |
| p.add_argument("--mlp_depth", type=int, default=1) | |
| p.add_argument("--mlp_width", type=int, default=0) | |
| p.add_argument("--mlp_norm", type=str, default="ln", choices=["ln", "bn", "none"]) | |
| p.add_argument("--mlp_dropout", type=float, default=0.0) | |
| # Vision preprocessor | |
| p.add_argument("--use_mha", action="store_true") | |
| p.add_argument("--mha_heads", type=int, default=6) | |
| p.add_argument("--mha_embed_dim", type=int, default=384) | |
| p.add_argument("--mha_depth", type=int, default=2) | |
| p.add_argument("--mha_patch", type=int, default=4) | |
| p.add_argument("--mha_dropout", type=float, default=0.1) | |
| # Metric EMA | |
| p.add_argument("--metric_ema_decay", type=float, default=0.9) | |
| p.add_argument("--metric_ema_bias_correct", action="store_true") | |
| # Sweep + logging | |
| p.add_argument("--config", type=str, default=None, help="YAML file with lists for each key to grid-search") | |
| p.add_argument("--sweep_name", type=str, default="sweep") | |
| p.add_argument("--log_root", type=str, default="logs") | |
| # Baseline cache controls | |
| p.add_argument("--baseline_cache_dir", type=str, default=None) | |
| p.add_argument("--baseline_cache_ignore_seed", action="store_true") | |
| p.add_argument("--baseline_cache_refresh", action="store_true") | |
| a = p.parse_args() | |
| if "metric_ema_bias_correct" not in vars(a) or vars(a)["metric_ema_bias_correct"] is False: | |
| setattr(a, "metric_ema_bias_correct", True) | |
| return Args(**vars(a)) | |
| # ---------------------------- | |
| # Baseline cache helpers | |
| # ---------------------------- | |
| def mlp_cache_key(args: Args, ethos_core_params: int, d_in_core: int): | |
| lr_mlp_eff = args.lr_mlp if args.lr_mlp is not None else args.lr | |
| sig = { | |
| # arch | |
| "d_in": d_in_core, | |
| "classes": args.classes, | |
| "use_mha": args.use_mha, | |
| "mha_heads": args.mha_heads, | |
| "mha_embed_dim": args.mha_embed_dim, | |
| "mha_depth": args.mha_depth, | |
| "mha_patch": args.mha_patch, | |
| "mha_dropout": args.mha_dropout, | |
| "mlp_depth": args.mlp_depth, | |
| "mlp_norm": args.mlp_norm, | |
| "mlp_dropout": args.mlp_dropout, | |
| # width: literal or auto@budget | |
| "mlp_width": args.mlp_width if args.mlp_width > 0 else f"auto@{ethos_core_params}", | |
| # train hparams | |
| "epochs": args.epochs, | |
| "batch_size": args.batch_size, | |
| "lr_mlp": lr_mlp_eff, | |
| "weight_decay": args.weight_decay, | |
| "warmup_epochs": args.warmup_epochs, | |
| "min_lr": args.min_lr, | |
| "grad_clip": args.grad_clip, | |
| } | |
| if not args.baseline_cache_ignore_seed: | |
| sig["seed"] = args.seed | |
| key = hashlib.sha1(json.dumps(sig, sort_keys=True).encode()).hexdigest()[:16] | |
| return key, sig | |
| def _cartesian_from_yaml(yaml_path: str) -> Tuple[List[str], List[Tuple[Any, ...]]]: | |
| """ | |
| Load YAML where each key maps to a list of candidate values. | |
| Returns (keys, list_of_value_tuples) | |
| """ | |
| with open(yaml_path, "r") as f: | |
| cfg = yaml.safe_load(f) or {} | |
| keys = list(cfg.keys()) | |
| value_lists = [] | |
| for k in keys: | |
| v = cfg[k] | |
| if isinstance(v, list): | |
| value_lists.append(v) | |
| else: | |
| value_lists.append([v]) | |
| combos = list(itertools.product(*value_lists)) | |
| return keys, combos | |
| def _apply_overrides(base: Args, overrides: Dict[str, Any]) -> Args: | |
| d = asdict(base) | |
| for k, v in overrides.items(): | |
| if k not in d: | |
| raise KeyError(f"Unknown config key in sweep: {k}") | |
| d[k] = v | |
| d["_trial_id"] = "" | |
| return Args(**d) | |
| def _append_summary_line(summary_path: Path, record: Dict[str, Any]): | |
| summary_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(summary_path, "a") as f: | |
| f.write(json.dumps(record) + "\n") | |
| # ---------------------------- | |
| # Trial runner (single) | |
| # ---------------------------- | |
| def run_single_trial(args: Args, trial_log_path: Path) -> Dict[str, Any]: | |
| """ | |
| Runs one training trial with the provided Args. | |
| Logs to both stdout and trial_log_path, returns a summary dict. | |
| """ | |
| trial_log_path.parent.mkdir(parents=True, exist_ok=True) | |
| lf = open(trial_log_path, "w", buffering=1) | |
| def log(msg: str): | |
| msg = str(msg) | |
| print(msg) | |
| lf.write(msg + "\n") | |
| lf.flush() | |
| # Header | |
| log("=" * 80) | |
| log(f"[TRIAL] {args._trial_id} | {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| log(f"[CONFIG] {json.dumps({k: v for k, v in asdict(args).items() if not k.startswith('_')}, default=str)}") | |
| log("=" * 80) | |
| # Seed + device | |
| set_seed(args.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if device.type == "cuda": | |
| gpu_name = torch.cuda.get_device_name(0) | |
| log(f"Using GPU: {gpu_name}") | |
| else: | |
| log("Using CPU") | |
| # Data | |
| train_loader, test_loader = make_loaders(args.batch_size, args.workers, data_root=args.data) | |
| # Preprocessor / core dims | |
| if args.use_mha: | |
| D_in_core = int(args.mha_embed_dim) | |
| vision_e = VisionPreprocessor(img_size=32, patch_size=args.mha_patch, | |
| embed_dim=args.mha_embed_dim, depth=args.mha_depth, | |
| num_heads=args.mha_heads, dropout=args.mha_dropout) | |
| # For the baseline we may or may not instantiate, so keep a separate preproc when needed | |
| pre_params_e = sum(p.numel() for p in vision_e.parameters()) | |
| log(f"[Vision Pre] Enabled (ViT-style)") | |
| log(f" ETHOS-pre params: {pre_params_e:,} (patch={args.mha_patch}, depth={args.mha_depth}, heads={args.mha_heads}, dim={args.mha_embed_dim})") | |
| else: | |
| D_in_core = 3 * 32 * 32 | |
| pre_params_e = 0 | |
| vision_e = None | |
| C = args.classes | |
| # ETHOS core (exclude preproc from budget) | |
| ethos_core = EthosGeneratedMLP( | |
| d_in=D_in_core, n_classes=C, | |
| H=args.H, A=args.A, r=args.r, S=args.S, K=args.K, b=args.b, beam_width=args.beam, | |
| Hh=args.Hh, M=args.M, film_eps=0.1, film_tanh_gamma=True, per_head_dict=not args.share_dict, s_tile=512 | |
| ) | |
| # Compose full ETHOS model | |
| if args.use_mha: | |
| ethos = nn.Sequential(vision_e, ethos_core).to(device) | |
| else: | |
| ethos = nn.Sequential(nn.Flatten(), ethos_core).to(device) | |
| # Iso MLP width matching ETHOS-core param budget (preproc excluded) | |
| ethos_core_params = sum(p.numel() for p in ethos_core.parameters()) | |
| # ---------------- Cache setup for baseline ---------------- | |
| cache_root = Path(args.baseline_cache_dir) if args.baseline_cache_dir else (Path(args.log_root) / "cache") | |
| cache_root.mkdir(parents=True, exist_ok=True) | |
| cache_key, cache_sig = mlp_cache_key(args, ethos_core_params, D_in_core) | |
| cache_path = cache_root / f"mlp_{cache_key}.summary.json" | |
| baseline_cached = cache_path.exists() and (not args.baseline_cache_refresh) | |
| best_m_cached = None | |
| mlp_params_total_cached = None | |
| if baseline_cached: | |
| try: | |
| with open(cache_path, "r") as f: | |
| cached = json.load(f) | |
| best_m_cached = float(cached["results"]["best_mlp"]) | |
| mlp_params_total_cached = int(cached["params"]["mlp_total"]) if "params" in cached and "mlp_total" in cached["params"] else None | |
| log(f"[CACHE] Iso-MLP hit key={cache_key} | best={best_m_cached:.2f}% | params={mlp_params_total_cached if mlp_params_total_cached is not None else 'n/a'}") | |
| except Exception as e: | |
| log(f"[CACHE] Failed to read cache ({e}); retraining baseline.") | |
| baseline_cached = False | |
| # Determine learning rates | |
| lr_ethos = args.lr_ethos if args.lr_ethos is not None else args.lr | |
| lr_mlp = args.lr_mlp if args.lr_mlp is not None else args.lr | |
| lr_router = args.lr_router if args.lr_router is not None else lr_ethos | |
| lr_hypernet = args.lr_hypernet if args.lr_hypernet is not None else lr_ethos | |
| # Optimizers / schedulers | |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| # Param groups for ETHOS | |
| router_params, hypernet_params, other_params = [], [], [] | |
| for name, param in ethos.named_parameters(): | |
| if args.use_mha and name.startswith('0.'): | |
| other_params.append(param) # vision preproc | |
| elif 'router.query_w' in name: | |
| router_params.append(param) | |
| elif any(key in name for key in ['W1_blocks', 'W2', 'B_in', 'B_out']): | |
| hypernet_params.append(param) | |
| elif any(key in name for key in ['gamma_w', 'gamma_b', 'beta_w', 'beta_b']): | |
| other_params.append(param) | |
| else: | |
| other_params.append(param) | |
| ethos_param_groups = [] | |
| if router_params: ethos_param_groups.append({'params': router_params, 'lr': lr_router}) | |
| if hypernet_params: ethos_param_groups.append({'params': hypernet_params, 'lr': lr_hypernet}) | |
| if other_params: ethos_param_groups.append({'params': other_params, 'lr': lr_ethos}) | |
| opt_e = torch.optim.AdamW(ethos_param_groups, lr=lr_ethos, weight_decay=args.weight_decay) | |
| sched_e = WarmupCosineLR(opt_e, warmup_epochs=args.warmup_epochs, total_epochs=args.epochs, | |
| base_lr=None, min_lr=args.min_lr) | |
| # Baseline build only if no cache | |
| iso_mlp = None | |
| opt_m = None | |
| sched_m = None | |
| vision_m = None | |
| pre_params_m = 0 | |
| if not baseline_cached: | |
| if args.use_mha: | |
| vision_m = VisionPreprocessor(img_size=32, patch_size=args.mha_patch, | |
| embed_dim=args.mha_embed_dim, depth=args.mha_depth, | |
| num_heads=args.mha_heads, dropout=args.mha_dropout) | |
| pre_params_m = sum(p.numel() for p in vision_m.parameters()) | |
| # Solve width (if needed) & build | |
| if args.mlp_width > 0: | |
| iso_width = int(args.mlp_width) | |
| else: | |
| iso_width = solve_width_for_budget(D_in_core, C, depth=args.mlp_depth, | |
| norm=args.mlp_norm, target_params=ethos_core_params) | |
| iso_mlp_core = IsoParamMLP(d_in=D_in_core, num_classes=C, width=iso_width, depth=args.mlp_depth, | |
| norm=args.mlp_norm, dropout=args.mlp_dropout) | |
| if args.use_mha: | |
| iso_mlp = nn.Sequential(vision_m, iso_mlp_core).to(device) | |
| else: | |
| iso_mlp = nn.Sequential(nn.Flatten(), iso_mlp_core).to(device) | |
| opt_m = torch.optim.AdamW(iso_mlp.parameters(), lr=lr_mlp, weight_decay=args.weight_decay) | |
| sched_m = WarmupCosineLR(opt_m, warmup_epochs=args.warmup_epochs, total_epochs=args.epochs, | |
| base_lr=lr_mlp, min_lr=args.min_lr) | |
| ethos_params_total = sum(p.numel() for p in ethos.parameters()) | |
| mlp_params_total = (sum(p.numel() for p in iso_mlp.parameters()) if iso_mlp is not None else mlp_params_total_cached) | |
| log(f"[Budget] ETHOS total params: {ethos_params_total:,} (core: {ethos_core_params:,}, preproc: {pre_params_e:,})") | |
| log(f"[Model] ETHOS total: {ethos_params_total:,} | Iso-MLP total: {mlp_params_total if mlp_params_total is not None else 'n/a'}") | |
| log(f"[LR] ETHOS lr: {lr_ethos:.6f} | MLP lr: {lr_mlp:.6f}") | |
| log(f"[LR] Router lr: {lr_router:.6f} | Hypernet lr: {lr_hypernet:.6f}") | |
| log(f"[Param Groups] Router: {sum(p.numel() for p in router_params):,} | Hypernet: {sum(p.numel() for p in hypernet_params):,} | Other: {sum(p.numel() for p in other_params):,}") | |
| ema_acc_ethos = ScalarEMA(decay=args.metric_ema_decay, bias_correct=args.metric_ema_bias_correct) | |
| ema_acc_mlp = ScalarEMA(decay=args.metric_ema_decay, bias_correct=args.metric_ema_bias_correct) | |
| best_e = 0.0 | |
| best_m = (0.0 if not baseline_cached else float(best_m_cached)) | |
| for epoch in range(1, args.epochs + 1): | |
| # ETHOS train | |
| train_one_epoch(ethos, train_loader, device, opt_e, criterion, grad_clip=args.grad_clip) | |
| sched_e.step() | |
| acc_e = evaluate(ethos, test_loader, device) | |
| acc_e_ema = ema_acc_ethos.update(acc_e) | |
| best_e = max(best_e, acc_e) | |
| if baseline_cached: | |
| log( | |
| f"epoch {epoch}/{args.epochs} | " | |
| f"ETHOS acc: {acc_e:5.2f}% (EMA: {acc_e_ema:5.2f}%) | " | |
| f"Iso-MLP: [cached] | " | |
| f"Delta(best raw): {best_e - best_m:5.2f}" | |
| ) | |
| else: | |
| # Baseline train | |
| train_one_epoch(iso_mlp, train_loader, device, opt_m, criterion, grad_clip=args.grad_clip) | |
| sched_m.step() | |
| acc_m = evaluate(iso_mlp, test_loader, device) | |
| acc_m_ema = ema_acc_mlp.update(acc_m) | |
| best_m = max(best_m, acc_m) | |
| log( | |
| f"epoch {epoch}/{args.epochs} | " | |
| f"ETHOS acc: {acc_e:5.2f}% (EMA: {acc_e_ema:5.2f}%) | " | |
| f"Iso-MLP acc: {acc_m:5.2f}% (EMA: {acc_m_ema:5.2f}%) | " | |
| f"Delta(best raw): {best_e - best_m:5.2f}" | |
| ) | |
| log(f"[FINAL] ETHOS: best {best_e:.2f}% | Iso-MLP: {'cached ' if baseline_cached else ''}best {best_m:.2f}%") | |
| lf.close() | |
| # write cache if trained baseline | |
| if not baseline_cached: | |
| try: | |
| summary_for_cache = { | |
| "key": cache_key, | |
| "sig": cache_sig, | |
| "params": {"mlp_total": int(mlp_params_total) if mlp_params_total is not None else None}, | |
| "results": {"best_mlp": float(best_m)}, | |
| "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| } | |
| with open(cache_path, "w") as f: | |
| json.dump(summary_for_cache, f) | |
| except Exception as e: | |
| print(f"[CACHE] Failed to write cache: {e}", file=sys.stderr) | |
| # free mem between trials | |
| try: | |
| del ethos, ethos_core | |
| if not baseline_cached: | |
| del iso_mlp, iso_mlp_core | |
| if args.use_mha: | |
| del vision_e | |
| if not baseline_cached and vision_m is not None: | |
| del vision_m | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| except Exception: | |
| pass | |
| return { | |
| "trial_id": args._trial_id, | |
| "log_file": str(trial_log_path), | |
| "config": {k: v for k, v in asdict(args).items() if not k.startswith('_')}, | |
| "params": { | |
| "ethos_total": int(ethos_params_total), | |
| "ethos_core": int(ethos_core_params), | |
| "preproc": int(pre_params_e), | |
| "mlp_total": (int(mlp_params_total) if mlp_params_total is not None else None), | |
| }, | |
| "results": { | |
| "best_ethos": float(best_e), | |
| "best_mlp": float(best_m), | |
| "delta": float(best_e - best_m), | |
| }, | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| "status": "ok", | |
| "baseline_cache": { | |
| "used": bool(baseline_cached), | |
| "key": cache_key, | |
| "path": str(cache_path), | |
| } | |
| } | |
| # ---------------------------- | |
| # Main | |
| # ---------------------------- | |
| def main(): | |
| args = parse_args() | |
| log_root = Path(args.log_root) | |
| log_root.mkdir(parents=True, exist_ok=True) | |
| summary_path = log_root / "trials.log" | |
| # Sweep mode: YAML provided | |
| if args.config is not None: | |
| keys, combos = _cartesian_from_yaml(args.config) | |
| total = len(combos) | |
| start_stamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| print(f"[SWEEP] {args.sweep_name} | {total} trial(s) from {args.config}") | |
| print(f"[SWEEP] Logs in: {str(log_root.resolve())}") | |
| for idx, values in enumerate(combos, start=1): | |
| overrides = dict(zip(keys, values)) | |
| trial_args = _apply_overrides(args, overrides) | |
| trial_args._trial_id = f"{args.sweep_name}_{start_stamp}_{idx:03d}" | |
| trial_log_path = log_root / f"{trial_args._trial_id}.log" | |
| try: | |
| summary = run_single_trial(trial_args, trial_log_path) | |
| except Exception as e: | |
| summary = { | |
| "trial_id": trial_args._trial_id, | |
| "log_file": str(trial_log_path), | |
| "config": {k: v for k, v in asdict(trial_args).items() if not k.startswith('_')}, | |
| "status": "error", | |
| "error": repr(e), | |
| "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| } | |
| print(f"[ERROR] Trial {trial_args._trial_id} failed: {repr(e)}", file=sys.stderr) | |
| _append_summary_line(summary_path, summary) | |
| print(f"[SWEEP DONE] Wrote collated results to {summary_path}") | |
| else: | |
| # Single-run mode (no YAML) | |
| start_stamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| args._trial_id = f"{args.sweep_name}_{start_stamp}_single" | |
| trial_log_path = log_root / f"{args._trial_id}.log" | |
| try: | |
| summary = run_single_trial(args, trial_log_path) | |
| except Exception as e: | |
| summary = { | |
| "trial_id": args._trial_id, | |
| "log_file": str(trial_log_path), | |
| "config": {k: v for k, v in asdict(args).items() if not k.startswith('_')}, | |
| "status": "error", | |
| "error": repr(e), | |
| "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| } | |
| print(f"[ERROR] Trial failed: {repr(e)}", file=sys.stderr) | |
| _append_summary_line(summary_path, summary) | |
| print(f"[DONE] Wrote collated result to {summary_path}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment