Skip to content

Instantly share code, notes, and snippets.

@wrmedford
Created October 29, 2025 13:38
Show Gist options
  • Select an option

  • Save wrmedford/ef452a86bae0c7dd1201b5e4e265729a to your computer and use it in GitHub Desktop.

Select an option

Save wrmedford/ef452a86bae0c7dd1201b5e4e265729a to your computer and use it in GitHub Desktop.
# ethos_gen_mlp.py
# ETHOS (TF32/FP32) for flattened CIFAR-10:
# Router (pure-ID, custom autograd) → W1/W2 → FiLM → vector gate → dictionaries → mean heads → logits
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# Use TF32 paths on Ampere+ (still FP32 tensors)
try:
torch.backends.cuda.matmul.allow_tf32 = True
except Exception:
pass
try:
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
try:
torch.set_float32_matmul_precision("high") # enables TF32 where applicable
except Exception:
pass
def gelu_tanh(x: torch.Tensor) -> torch.Tensor:
return F.gelu(x, approximate="tanh")
# =========================
# Sinusoidal codes (ETHOS)
# =========================
def sinusoidal_codes_from_ids(ids: torch.Tensor, r: int) -> torch.Tensor:
"""
ids: [...], integer in [0, S)
returns: [..., r] FP32 with [cos, sin] concatenation
"""
assert r % 2 == 0, "axis_qdim r must be even"
device = ids.device
r_half = r // 2
# Log-spaced frequencies in (10^0 .. 10^-1); any smooth spectrum is fine
freq = torch.logspace(0.0, -1.0, steps=r_half, base=10.0, device=device, dtype=torch.float32)
angle = ids.to(dtype=torch.float32).unsqueeze(-1) * freq.unsqueeze(0) * (2.0 * math.pi)
cos = torch.cos(angle)
sin = torch.sin(angle)
return torch.cat([cos, sin], dim=-1).to(torch.float32)
# ============================================================
# 1) RouterScoreTopB (legacy): stream S, keep per-axis b
# Computes q_axes internally (no normalization here)
# ============================================================
class RouterScoreTopBFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
x_tile: torch.Tensor, # [Bt, D] fp32
query_w: torch.Tensor, # [H, A, r, D] fp32
axis_size: int, # S
top_b: int, # b
s_tile: int): # chunk over S
"""
Returns:
vals_all: [Bt, H, A, b] (fp32)
idxs_all: [Bt, H, A, b] (int32)
"""
Bt, D = x_tile.shape
H, A, r, Dw = query_w.shape
assert D == Dw
S = int(axis_size)
b = min(int(top_b), S)
device = x_tile.device
# q_axes: [Bt,H,A,r] fp32 (NO normalization here)
q_axes = torch.einsum('bd,hard->bhar', x_tile, query_w)
# group: [H*A, Bt, r]
Q = q_axes.permute(1, 2, 0, 3).contiguous().view(H * A, Bt, r)
top_vals = torch.full((H * A, Bt, b), float('-inf'), device=device, dtype=torch.float32)
top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32)
for s0 in range(0, S, s_tile):
s1 = min(S, s0 + s_tile)
s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32) # [s_t]
K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t] fp32
logits_tile = torch.matmul(Q, K_rt) # [H*A, Bt, s_t] fp32
merged_vals = torch.cat([top_vals, logits_tile], dim=-1) # [H*A, Bt, b + s_t]
merged_idx = torch.cat([
top_idx,
torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0)
], dim=-1) # [H*A, Bt, b + s_t]
keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1) # [H*A, Bt, b]
keep_idx = merged_idx.gather(-1, keep_pos)
top_vals, top_idx = keep_vals, keep_idx
vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
# Save tensors that are needed to backprop through q = einsum(x, W)
ctx.save_for_backward(x_tile, query_w, idxs_all)
ctx.meta = (H, A, r, D, b, S, s_tile)
return vals_all, idxs_all
@staticmethod
def backward(ctx, g_vals_all: torch.Tensor, g_idxs_all: Optional[torch.Tensor] = None):
x_tile, query_w, idxs_all = ctx.saved_tensors
H, A, r, D, b, S, s_tile = ctx.meta
if g_vals_all is None:
return (torch.zeros_like(x_tile), torch.zeros_like(query_w), None, None, None)
# Gather code vectors for selected indices: [Bt,H,A,b,r]
codes = sinusoidal_codes_from_ids(idxs_all, r) # fp32
# grad wrt q_axes: sum_b g_vals[..., b] * code[..., b, :]
g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3) # [Bt,H,A,r] fp32
# q_axes = einsum('bd,hard->bhar')
g_x = torch.einsum('nhar,hard->nd', g_q, query_w) # [Bt,D]
g_w = torch.einsum('nhar,nd->hard', g_q, x_tile) # [H,A,r,D]
return g_x, g_w, None, None, None
def router_score_topb(x_tile: torch.Tensor,
query_w: torch.Tensor,
axis_size: int,
top_b: int,
s_tile: int) -> Tuple[torch.Tensor, torch.Tensor]:
return RouterScoreTopBFunction.apply(x_tile, query_w, axis_size, top_b, s_tile)
# ============================================================
# 1b) RouterScoreTopBFromQueries: same as above but takes
# precomputed (and normalized) queries, returns grad wrt q
# ============================================================
class RouterScoreTopBFromQueriesFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q_axes: torch.Tensor, # [Bt,H,A,r] fp32 (already normalized)
axis_size: int, # S
top_b: int, # b
s_tile: int): # chunk over S
"""
Returns:
vals_all: [Bt, H, A, b] (fp32)
idxs_all: [Bt, H, A, b] (int32)
"""
Bt, H, A, r = q_axes.shape
S = int(axis_size)
b = min(int(top_b), S)
device = q_axes.device
# group: [H*A, Bt, r]
Q = q_axes.permute(1, 2, 0, 3).contiguous().view(H * A, Bt, r)
top_vals = torch.full((H * A, Bt, b), float('-inf'), device=device, dtype=torch.float32)
top_idx = torch.empty((H * A, Bt, b), device=device, dtype=torch.int32)
for s0 in range(0, S, s_tile):
s1 = min(S, s0 + s_tile)
s_ids = torch.arange(s0, s1, device=device, dtype=torch.int32) # [s_t]
K_rt = sinusoidal_codes_from_ids(s_ids, r).transpose(0, 1).contiguous() # [r, s_t] fp32
logits_tile = torch.matmul(Q, K_rt) # [H*A, Bt, s_t] fp32
merged_vals = torch.cat([top_vals, logits_tile], dim=-1) # [H*A, Bt, b + s_t]
merged_idx = torch.cat([
top_idx,
torch.arange(s0, s1, device=device, dtype=torch.int32)[None, None, :].expand(H * A, Bt, s1 - s0)
], dim=-1) # [H*A, Bt, b + s_t]
keep_vals, keep_pos = torch.topk(merged_vals, k=b, dim=-1) # [H*A, Bt, b]
keep_idx = merged_idx.gather(-1, keep_pos)
top_vals, top_idx = keep_vals, keep_idx
vals_all = top_vals.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
idxs_all = top_idx.view(H, A, Bt, b).permute(2, 0, 1, 3).contiguous() # [Bt,H,A,b]
# Save only what's needed to backprop to q_axes
ctx.save_for_backward(idxs_all)
ctx.meta = (r,)
return vals_all, idxs_all
@staticmethod
def backward(ctx, g_vals_all: torch.Tensor, g_idxs_all: Optional[torch.Tensor] = None):
(idxs_all,) = ctx.saved_tensors
(r,) = ctx.meta
if g_vals_all is None:
return torch.zeros_like(idxs_all, dtype=torch.float32)[..., :0], None, None, None # shape-safe zero
# codes: [Bt,H,A,b,r]
codes = sinusoidal_codes_from_ids(idxs_all, r)
# gradient wrt q_axes: [Bt,H,A,r]
g_q = (g_vals_all.unsqueeze(-1) * codes).sum(dim=3)
return g_q, None, None, None
def router_score_topb_from_queries(q_axes: torch.Tensor,
axis_size: int,
top_b: int,
s_tile: int) -> Tuple[torch.Tensor, torch.Tensor]:
return RouterScoreTopBFromQueriesFunction.apply(q_axes, axis_size, top_b, s_tile)
# ============================================================
# 2) BeamSearchAdditive: additive combine; grads to per-axis b
# ============================================================
class BeamSearchAdditiveFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
vals_all: torch.Tensor, # [Bt,H,A,b] fp32
idxs_all: torch.Tensor, # [Bt,H,A,b] int32
top_k: int,
beam_width: int):
device = vals_all.device
Bt, H, A, b = vals_all.shape
Bwidth = int(beam_width)
beam_scores = torch.full((Bt, H, Bwidth), float('-inf'), device=device, dtype=torch.float32)
beam_idx = torch.empty(Bt, H, Bwidth, A, device=device, dtype=torch.int32)
v0 = vals_all[:, :, 0, :]
i0 = idxs_all[:, :, 0, :]
keep0 = min(Bwidth, b)
base_vals, base_pos = torch.topk(v0, k=keep0, dim=-1)
beam_scores[:, :, :keep0] = base_vals
beam_idx[:, :, :keep0, 0] = i0.gather(-1, base_pos)
cur = keep0
for a in range(1, A):
va = vals_all[:, :, a, :]
ia = idxs_all[:, :, a, :]
cur_scores = beam_scores[:, :, :cur].unsqueeze(-1) + va.unsqueeze(-2) # [Bt,H,cur,b]
flat = cur_scores.view(Bt, H, -1)
keep = min(Bwidth, flat.size(-1))
kv, kf = torch.topk(flat, k=keep, dim=-1)
parent, choose = kf // b, kf % b
parent_exp = parent.unsqueeze(-1).expand(-1, -1, -1, a)
if a > 0:
beam_idx[:, :, :keep, :a] = torch.gather(beam_idx[:, :, :cur, :a], 2, parent_exp)
beam_idx[:, :, :keep, a] = ia.gather(-1, choose)
beam_scores[:, :, :keep] = kv
cur = keep
K = min(int(top_k), cur)
final_vals, final_pos = torch.topk(beam_scores[:, :, :cur], k=K, dim=-1) # [Bt,H,K]
coords = torch.gather(beam_idx[:, :, :cur, :],
2,
final_pos.unsqueeze(-1).expand(-1, -1, -1, A)).to(torch.int16) # [Bt,H,K,A]
# Map back: which b-slot each coord came from
idxs_exp = idxs_all.unsqueeze(2) # [Bt,H,1,A,b]
coords_exp = coords.to(torch.int32).unsqueeze(-1) # [Bt,H,K,A,1]
eq = (idxs_exp == coords_exp) # [Bt,H,K,A,b]
choose_idx = torch.argmax(eq.to(torch.int32), dim=-1) # [Bt,H,K,A]
scores = F.softmax(final_vals, dim=-1) # [Bt,H,K] fp32
ctx.save_for_backward(scores, choose_idx)
ctx.meta = (Bt, H, A, b, K)
return scores, coords
@staticmethod
def backward(ctx, g_scores: torch.Tensor, g_coords: Optional[torch.Tensor] = None):
scores, choose_idx = ctx.saved_tensors
Bt, H, A, b, K = ctx.meta
if g_scores is None:
return torch.zeros((Bt, H, A, b), dtype=torch.float32, device=scores.device), None, None, None
# Softmax backward: g_z = (g - <g,s>) * s
g = g_scores
s = scores
dot = (g * s).sum(dim=-1, keepdim=True)
g_z = (g - dot) * s # [Bt,H,K]
# Scatter back to per-axis b slots
g_vals_all = torch.zeros((Bt, H, A, b), dtype=torch.float32, device=scores.device)
for a in range(A):
tgt = g_vals_all[:, :, a, :]
idx = choose_idx[:, :, :, a]
tgt.scatter_add_(dim=-1, index=idx, src=g_z)
return g_vals_all, None, None, None
def beam_search_additive(vals_all: torch.Tensor,
idxs_all: torch.Tensor,
top_k: int,
beam_width: int) -> Tuple[torch.Tensor, torch.Tensor]:
return BeamSearchAdditiveFunction.apply(vals_all, idxs_all, top_k, beam_width)
# ============================================================
# ETHOS module (FP32/TF32) — logits directly from dictionaries
# ============================================================
class BeamRouter(nn.Module):
def __init__(self, d_model, num_heads=4, num_axes=6, axis_qdim=32,
axis_size=512, top_k=8, beam_axis_b=32, beam_width=64, s_tile=512):
super().__init__()
self.D, self.H, self.A, self.r = int(d_model), int(num_heads), int(num_axes), int(axis_qdim)
self.S, self.K, self.bw, self.b = int(axis_size), int(top_k), int(beam_width), int(beam_axis_b)
self.s_tile = int(s_tile)
self.query_w = nn.Parameter(torch.empty(self.H, self.A, self.r, self.D, dtype=torch.float32))
with torch.no_grad():
self.query_w.normal_(0.0, (1.0 / self.D) ** 0.5)
# Per-axis RMSNorm over r (direction-preserving)
self.q_norm = nn.RMSNorm(self.r, eps=1e-6, elementwise_affine=False)
# Optional: learned per-(H,A) temperature to re-calibrate cross-axis scale
self.logit_scale = nn.Parameter(torch.zeros(self.H, self.A)) # exp() ~ 1.0
def forward(self, x: torch.Tensor):
# Form queries and normalize per (B,H,A) over r
q_axes = torch.einsum('bd,hard->bhar', x, self.query_w) # [B,H,A,r]
q_axes = self.q_norm(q_axes) # [B,H,A,r]
# Route using from-queries op (clean backward through RMSNorm + einsum)
vals, idxs = router_score_topb_from_queries(q_axes, axis_size=self.S, top_b=self.b, s_tile=self.s_tile)
# Optional temperature
vals = vals * self.logit_scale.exp()[None, :, :, None]
scores, coords = beam_search_additive(vals, idxs, top_k=self.K, beam_width=self.bw)
return scores, coords
class EthosGeneratedMLP(nn.Module):
"""
ETHOS path over flattened x:[B,D] (FP32/TF32):
router -> W1_blocks/W2 -> (alpha,beta) -> dict(B_in/B_out)
+ FiLM (gamma,beta) over M -> vector gating -> mean heads -> logits:[B,C]
No separate classifier; logits come from B_out directly.
"""
def __init__(self, d_in, n_classes,
H=4, A=6, r=32, S=512, K=8, b=32, beam_width=64,
Hh=128, M=32, film_eps=0.10, film_tanh_gamma=True, per_head_dict=True, s_tile=512):
super().__init__()
self.D, self.C = int(d_in), int(n_classes)
self.film_eps = float(film_eps)
self.film_tanh_gamma = bool(film_tanh_gamma)
self.per_head_dict = bool(per_head_dict)
# Router (pure-ID)
self.router = BeamRouter(d_model=d_in, num_heads=H, num_axes=A, axis_qdim=r,
axis_size=S, top_k=K, beam_axis_b=b, beam_width=beam_width, s_tile=s_tile)
# Hypernet + dicts
self.H, self.A, self.r, self.Hh, self.M = H, self.router.A, r, int(Hh), int(M)
self.W1_blocks = nn.Parameter(torch.empty(H, self.A, r, self.Hh, dtype=torch.float32)) # [H,A,r,Hh]
self.W2 = nn.Parameter(torch.empty(H, self.Hh, 2*self.M, dtype=torch.float32)) # [H,Hh,2M]
with torch.no_grad():
nn.init.xavier_uniform_(self.W1_blocks.view(H*self.A, r, self.Hh))
nn.init.xavier_uniform_(self.W2)
if self.per_head_dict:
self.B_in = nn.Parameter(torch.empty(H, self.M, self.D, dtype=torch.float32)) # [H,M,D]
self.B_out = nn.Parameter(torch.empty(H, self.M, self.C, dtype=torch.float32)) # [H,M,C] <-- logits direct
with torch.no_grad():
nn.init.xavier_uniform_(self.B_in)
nn.init.xavier_uniform_(self.B_out)
self.B_in_shared = None
self.B_out_shared = None
else:
self.B_in_shared = nn.Parameter(torch.empty(self.M, self.D, dtype=torch.float32)) # [M,D]
self.B_out_shared = nn.Parameter(torch.empty(self.M, self.C, dtype=torch.float32)) # [M,C] <-- logits direct
with torch.no_grad():
nn.init.xavier_uniform_(self.B_in_shared)
nn.init.xavier_uniform_(self.B_out_shared)
self.B_in = None
self.B_out = None
# FiLM parameters (zeros init)
self.gamma_w = nn.Parameter(torch.zeros(H, self.M, self.Hh, dtype=torch.float32)) # [H,M,Hh]
self.gamma_b = nn.Parameter(torch.zeros(H, self.M, dtype=torch.float32)) # [H,M]
self.beta_w = nn.Parameter(torch.zeros(H, self.M, self.Hh, dtype=torch.float32)) # [H,M,Hh]
self.beta_b = nn.Parameter(torch.zeros(H, self.M, dtype=torch.float32)) # [H,M]
def forward(self, x: torch.Tensor) -> torch.Tensor: # x:[B,D] fp32
B, D = x.shape
H, A, r, Hh, S, M, C = self.H, self.A, self.r, self.Hh, self.router.S, self.M, self.C
# 1) Router
scores, coords = self.router(x) # [B,H,K], [B,H,K,A]
K = scores.size(2)
# 2) Hyper features from selected coords (FP32/TF32)
ids_flat = coords.reshape(-1).long()
feats = sinusoidal_codes_from_ids(ids_flat, r).view(B, H, K, A, r) # [B,H,K,A,r]
z_pre = torch.einsum('bhkar,harq->bhkq', feats, self.W1_blocks) # [B,H,K,Hh]
g = gelu_tanh(z_pre) # [B,H,K,Hh]
AB = torch.einsum('bhkq,hqm->bhkm', g, self.W2) # [B,H,K,2M]
alpha = AB[..., :M] # [B,H,K,M]
beta = AB[..., M:] # [B,H,K,M]
# 3) Dictionaries: project x, then FiLM on M-dim
if self.per_head_dict:
h_proj = torch.einsum('bd,hmd->bhm', x, self.B_in) # [B,H,M]
else:
h_proj = torch.einsum('bd,md->bm', x, self.B_in_shared) # [B,M]
h_proj = h_proj.unsqueeze(1).expand(-1, H, -1).contiguous() # [B,H,M]
w = scores.unsqueeze(-1) # [B,H,K,1]
gamma_k = torch.einsum('bhkq,hmq->bhkm', g, self.gamma_w) + self.gamma_b.unsqueeze(0).unsqueeze(2)
if self.film_tanh_gamma:
gamma_k = torch.tanh(gamma_k)
beta_k = torch.einsum('bhkq,hmq->bhkm', g, self.beta_w) + self.beta_b.unsqueeze(0).unsqueeze(2)
gamma_bar = (w * gamma_k).sum(dim=2) # [B,H,M]
beta_bar = (w * beta_k).sum(dim=2) # [B,H,M]
h_mod = (1.0 + self.film_eps * gamma_bar) * h_proj + self.film_eps * beta_bar # [B,H,M]
# 4) Vector gate combine over beams
act_M = gelu_tanh(alpha * h_mod.unsqueeze(2)) # [B,H,K,M]
m = (w * act_M * beta).sum(dim=2) # [B,H,M]
# 5) To logits via B_out, average heads ---> No separate classifier
if self.per_head_dict:
logits_h = torch.einsum('bhm,hmc->bhc', m, self.B_out) # [B,H,C]
else:
logits_h = torch.einsum('bhm,mc->bhc', m, self.B_out_shared) # [B,H,C]
logits = logits_h.mean(dim=1) # [B,C]
return logits
# (Optional) param counter for logging
def ethos_param_count(model: EthosGeneratedMLP) -> int:
P = 0
P += model.router.query_w.numel()
P += model.W1_blocks.numel()
P += model.W2.numel()
if model.per_head_dict:
P += model.B_in.numel() + model.B_out.numel()
else:
P += model.B_in_shared.numel() + model.B_out_shared.numel()
P += model.gamma_w.numel() + model.gamma_b.numel()
P += model.beta_w.numel() + model.beta_b.numel()
return P
# Fixed arch
use_mha: [true]
mha_heads: [6]
mha_embed_dim: [384]
mha_depth: [6]
mha_patch: [4]
mha_dropout: [0.1]
H: [1]
A: [8]
r: [64]
S: [2048]
b: [64]
beam: [32]
K: [16]
Hh: [1024]
M: [512]
share_dict: [false]
classes: [10]
# Baseline
mlp_depth: [1]
mlp_width: [0]
mlp_norm: ["ln"]
mlp_dropout: [0.0]
# Training HPs to sweep
epochs: [200]
batch_size: [2048]
lr: [0.0006, 0.0003]
lr_ethos: [null]
lr_mlp: [null]
lr_router: [0.00144, 0.00108, 0.00072, 0.00036] # ETHOS-only
lr_hypernet: [0.0006, 0.00048, 0.00036, 0.00024, 0.00012] # ETHOS-only
weight_decay: [0.0005]
warmup_epochs: [10]
min_lr: [0.000001]
grad_clip: [100.0, 1.0]
seed: [42]
workers: [8]
data: ["./data"]
# EMA + logging
metric_ema_decay: [0.9]
metric_ema_bias_correct: [true]
sweep_name: ["flat_vs_ethos"]
log_root: ["logs"]
# train_flat_ethos_vs_mlp_sweep_cached.py
from __future__ import annotations
import argparse, math, os, random, sys, itertools, json, time, gc, hashlib
from dataclasses import dataclass, asdict, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Requires PyYAML (pip install pyyaml)
import yaml
from ethos import EthosGeneratedMLP
# Enable TF32 globally
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
except Exception:
pass
# ----------------------------
# Reproducibility helpers
# ----------------------------
def set_seed(seed: int):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# ----------------------------
# Vision preprocessor (ViT-style)
# ----------------------------
class PatchEmbed(nn.Module):
"""
Patchify image via a Conv2d: kernel=stride=patch_size, produces tokens.
Input: [B,3,H,W]
Output: [B, N, E] where N=(H/ps)*(W/ps), E=embed_dim
"""
def __init__(self, img_size: int = 32, patch_size: int = 4, in_chans: int = 3, embed_dim: int = 384):
super().__init__()
assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
self.img_size = img_size
self.patch_size = patch_size
self.grid = img_size // patch_size
self.num_patches = self.grid * self.grid
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x) # [B,E,H/ps,W/ps]
x = x.flatten(2).transpose(1, 2) # [B,N,E]
return x
class TransformerBlock(nn.Module):
"""
Pre-LN Transformer encoder block.
"""
def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.drop1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(embed_dim)
hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
xa = self.norm1(x)
a, _ = self.attn(xa, xa, xa)
x = x + self.drop1(a)
xm = self.norm2(x)
x = x + self.mlp(xm)
return x
class VisionPreprocessor(nn.Module):
"""
ViT-style preprocessor that outputs a single vector [B, E] via mean-pooling tokens.
This is a proper vision attention mechanism (tokens > 1).
"""
def __init__(self, img_size: int = 32, patch_size: int = 4, embed_dim: int = 384,
depth: int = 2, num_heads: int = 6, dropout: float = 0.1):
super().__init__()
self.patch = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch.num_patches, embed_dim))
self.pos_drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4.0, dropout=dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.patch(x)
t = t + self.pos_embed
t = self.pos_drop(t)
for blk in self.blocks:
t = blk(t)
t = self.norm(t)
z = t.mean(dim=1)
return z # [B, E]
# ----------------------------
# CIFAR-10 transforms
# ----------------------------
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
def make_loaders(batch_size: int, num_workers: int = 8, data_root: str = "./data"):
train_tf = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])
test_tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])
train_set = datasets.CIFAR10(data_root, train=True, transform=train_tf, download=True)
test_set = datasets.CIFAR10(data_root, train=False, transform=test_tf, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True, drop_last=True)
test_loader = DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=num_workers,
pin_memory=True, drop_last=False)
return train_loader, test_loader
# ----------------------------
# Iso-parameter MLP baseline (variable depth)
# ----------------------------
class IsoParamMLP(nn.Module):
"""
Depth-L MLP with optional norm/dropout. Hidden width solved to match ETHOS core params unless --mlp_width is provided.
"""
def __init__(self, d_in: int, num_classes: int, width: int, depth: int,
norm: str = "ln", dropout: float = 0.0):
super().__init__()
assert depth >= 1, "mlp_depth must be >= 1"
self.depth = depth
self.norm_type = norm
self.dropout_p = float(dropout)
layers = []
in_dim = d_in
for _ in range(depth):
layers.append(nn.Linear(in_dim, width, bias=True))
if norm == "ln":
layers.append(nn.LayerNorm(width))
elif norm == "bn":
layers.append(nn.BatchNorm1d(width))
elif norm == "none":
pass
else:
raise ValueError(f"Unknown norm: {norm}")
layers.append(nn.GELU())
if self.dropout_p > 0:
layers.append(nn.Dropout(self.dropout_p))
in_dim = width
self.hidden = nn.Sequential(*layers)
self.fc_out = nn.Linear(width, num_classes, bias=True)
for m in self.modules():
if isinstance(m, nn.Linear):
if m is self.fc_out:
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
else:
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(x.size(0), -1)
x = self.hidden(x)
return self.fc_out(x)
def mlp_param_count(d_in: int, num_classes: int, width: int, depth: int, norm: str) -> int:
linear = d_in * width + width # first
if depth >= 2:
linear += (depth - 1) * (width * width + width)
linear += width * num_classes + num_classes # out
norm_params_per_layer = 0
if norm in ("ln", "bn"):
norm_params_per_layer = 2 * width
norm_total = depth * norm_params_per_layer
return linear + norm_total
def solve_width_for_budget(d_in: int, num_classes: int, depth: int, norm: str, target_params: int) -> int:
lo, hi = 1, max(2, int(math.sqrt(max(1, target_params)))) * 4
best_w, best_diff = 1, float("inf")
while lo <= hi:
mid = (lo + hi) // 2
p = mlp_param_count(d_in, num_classes, mid, depth, norm)
diff = abs(p - target_params)
if p <= target_params and diff < best_diff:
best_w, best_diff = mid, diff
if p <= target_params:
lo = mid + 1
else:
hi = mid - 1
return max(1, best_w)
# ----------------------------
# Optim / Schedules
# ----------------------------
class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr=None, min_lr=1e-6, last_epoch=-1):
self.warmup_epochs = warmup_epochs
self.total_epochs = total_epochs
self.min_lr = min_lr
self.use_group_base_lr = (base_lr is None)
self.base_lr = base_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
epoch = self.last_epoch + 1
if epoch < self.warmup_epochs:
scale = epoch / max(1, self.warmup_epochs)
if self.use_group_base_lr:
return [scale * base_lr for base_lr in self.base_lrs]
else:
return [scale * self.base_lr for _ in self.base_lrs]
t = (epoch - self.warmup_epochs) / max(1, self.total_epochs - self.warmup_epochs)
cos = 0.5 * (1 + math.cos(math.pi * t))
if self.use_group_base_lr:
return [self.min_lr + (base_lr - self.min_lr) * cos for base_lr in self.base_lrs]
else:
lr = self.min_lr + (self.base_lr - self.min_lr) * cos
return [lr for _ in self.base_lrs]
# ----------------------------
# Metric EMA helper
# ----------------------------
class ScalarEMA:
"""
EMA for scalar metrics (e.g., accuracy in %).
If bias_correct=True, EMA_1 equals the first value.
"""
def __init__(self, decay: float = 0.9, bias_correct: bool = True):
self.decay = float(decay)
self.bias_correct = bool(bias_correct)
self.t = 0
self.m = 0.0
def update(self, x: float) -> float:
x = float(x)
self.t += 1
self.m = self.decay * self.m + (1.0 - self.decay) * x
if self.bias_correct:
denom = 1.0 - (self.decay ** self.t)
return self.m / max(1e-12, denom)
return self.m
def value(self) -> float | None:
if self.t == 0:
return None
if self.bias_correct:
denom = 1.0 - (self.decay ** self.t)
return self.m / max(1e-12, denom)
return self.m
# ----------------------------
# Train / Eval loops (shared)
# ----------------------------
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for xb, yb in loader:
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
logits = model(xb)
pred = logits.argmax(dim=1)
correct += (pred == yb).sum().item()
total += yb.numel()
return 100.0 * correct / max(1, total)
def train_one_epoch(model: nn.Module, loader: DataLoader, device: torch.device,
optimizer: torch.optim.Optimizer, criterion: nn.Module,
grad_clip: float):
model.train()
for xb, yb in loader:
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
logits = model(xb)
loss = criterion(logits, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
if grad_clip is not None and grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
# ----------------------------
# Args + parsing
# ----------------------------
@dataclass
class Args:
# Core training
epochs: int = 100
batch_size: int = 512
lr: float = 3e-4
lr_ethos: float | None = None
lr_mlp: float | None = None
lr_router: float | None = None
lr_hypernet: float | None = None
weight_decay: float = 5e-4
warmup_epochs: int = 5
min_lr: float = 1e-6
grad_clip: float = 1.0
seed: int = 42
workers: int = 8
data: str = "./data"
# ETHOS knobs
H: int = 1
A: int = 16
r: int = 64
S: int = 256
b: int = 64
beam: int = 128
K: int = 16
Hh: int = 256
M: int = 64
share_dict: bool = False
classes: int = 10
# Baseline controls
mlp_depth: int = 1
mlp_width: int = 0
mlp_norm: str = "ln"
mlp_dropout: float = 0.0
# Vision preprocessor controls
use_mha: bool = False
mha_heads: int = 6
mha_embed_dim: int = 384
mha_depth: int = 2
mha_patch: int = 4
mha_dropout: float = 0.1
# Metric EMA
metric_ema_decay: float = 0.9
metric_ema_bias_correct: bool = True
# Sweep / logging
config: str | None = None # path to YAML sweep file
sweep_name: str = "sweep"
log_root: str = "logs"
# Baseline cache controls
baseline_cache_dir: str | None = None # default resolves to <log_root>/cache
baseline_cache_ignore_seed: bool = False
baseline_cache_refresh: bool = False
# internal (not exposed to YAML by default)
_trial_id: str = field(default="", repr=False)
def parse_args() -> Args:
p = argparse.ArgumentParser()
# Core training
p.add_argument("--epochs", type=int, default=100)
p.add_argument("--batch_size", type=int, default=512)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--lr_ethos", type=float, default=None)
p.add_argument("--lr_mlp", type=float, default=None)
p.add_argument("--lr_router", type=float, default=None)
p.add_argument("--lr_hypernet", type=float, default=None)
p.add_argument("--weight_decay", type=float, default=5e-4)
p.add_argument("--warmup_epochs", type=int, default=100) # keep previous CLI default
p.add_argument("--min_lr", type=float, default=1e-6)
p.add_argument("--grad_clip", type=float, default=100.0)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--workers", type=int, default=8)
p.add_argument("--data", type=str, default="./data")
# ETHOS
p.add_argument("--H", type=int, default=1)
p.add_argument("--A", type=int, default=16)
p.add_argument("--r", type=int, default=64)
p.add_argument("--S", type=int, default=256)
p.add_argument("--b", type=int, default=64)
p.add_argument("--beam", type=int, default=128)
p.add_argument("--K", type=int, default=16)
p.add_argument("--Hh", type=int, default=256)
p.add_argument("--M", type=int, default=64)
p.add_argument("--share_dict", action="store_true")
p.add_argument("--classes", type=int, default=10)
# Baseline
p.add_argument("--mlp_depth", type=int, default=1)
p.add_argument("--mlp_width", type=int, default=0)
p.add_argument("--mlp_norm", type=str, default="ln", choices=["ln", "bn", "none"])
p.add_argument("--mlp_dropout", type=float, default=0.0)
# Vision preprocessor
p.add_argument("--use_mha", action="store_true")
p.add_argument("--mha_heads", type=int, default=6)
p.add_argument("--mha_embed_dim", type=int, default=384)
p.add_argument("--mha_depth", type=int, default=2)
p.add_argument("--mha_patch", type=int, default=4)
p.add_argument("--mha_dropout", type=float, default=0.1)
# Metric EMA
p.add_argument("--metric_ema_decay", type=float, default=0.9)
p.add_argument("--metric_ema_bias_correct", action="store_true")
# Sweep + logging
p.add_argument("--config", type=str, default=None, help="YAML file with lists for each key to grid-search")
p.add_argument("--sweep_name", type=str, default="sweep")
p.add_argument("--log_root", type=str, default="logs")
# Baseline cache controls
p.add_argument("--baseline_cache_dir", type=str, default=None)
p.add_argument("--baseline_cache_ignore_seed", action="store_true")
p.add_argument("--baseline_cache_refresh", action="store_true")
a = p.parse_args()
if "metric_ema_bias_correct" not in vars(a) or vars(a)["metric_ema_bias_correct"] is False:
setattr(a, "metric_ema_bias_correct", True)
return Args(**vars(a))
# ----------------------------
# Baseline cache helpers
# ----------------------------
def mlp_cache_key(args: Args, ethos_core_params: int, d_in_core: int):
lr_mlp_eff = args.lr_mlp if args.lr_mlp is not None else args.lr
sig = {
# arch
"d_in": d_in_core,
"classes": args.classes,
"use_mha": args.use_mha,
"mha_heads": args.mha_heads,
"mha_embed_dim": args.mha_embed_dim,
"mha_depth": args.mha_depth,
"mha_patch": args.mha_patch,
"mha_dropout": args.mha_dropout,
"mlp_depth": args.mlp_depth,
"mlp_norm": args.mlp_norm,
"mlp_dropout": args.mlp_dropout,
# width: literal or auto@budget
"mlp_width": args.mlp_width if args.mlp_width > 0 else f"auto@{ethos_core_params}",
# train hparams
"epochs": args.epochs,
"batch_size": args.batch_size,
"lr_mlp": lr_mlp_eff,
"weight_decay": args.weight_decay,
"warmup_epochs": args.warmup_epochs,
"min_lr": args.min_lr,
"grad_clip": args.grad_clip,
}
if not args.baseline_cache_ignore_seed:
sig["seed"] = args.seed
key = hashlib.sha1(json.dumps(sig, sort_keys=True).encode()).hexdigest()[:16]
return key, sig
def _cartesian_from_yaml(yaml_path: str) -> Tuple[List[str], List[Tuple[Any, ...]]]:
"""
Load YAML where each key maps to a list of candidate values.
Returns (keys, list_of_value_tuples)
"""
with open(yaml_path, "r") as f:
cfg = yaml.safe_load(f) or {}
keys = list(cfg.keys())
value_lists = []
for k in keys:
v = cfg[k]
if isinstance(v, list):
value_lists.append(v)
else:
value_lists.append([v])
combos = list(itertools.product(*value_lists))
return keys, combos
def _apply_overrides(base: Args, overrides: Dict[str, Any]) -> Args:
d = asdict(base)
for k, v in overrides.items():
if k not in d:
raise KeyError(f"Unknown config key in sweep: {k}")
d[k] = v
d["_trial_id"] = ""
return Args(**d)
def _append_summary_line(summary_path: Path, record: Dict[str, Any]):
summary_path.parent.mkdir(parents=True, exist_ok=True)
with open(summary_path, "a") as f:
f.write(json.dumps(record) + "\n")
# ----------------------------
# Trial runner (single)
# ----------------------------
def run_single_trial(args: Args, trial_log_path: Path) -> Dict[str, Any]:
"""
Runs one training trial with the provided Args.
Logs to both stdout and trial_log_path, returns a summary dict.
"""
trial_log_path.parent.mkdir(parents=True, exist_ok=True)
lf = open(trial_log_path, "w", buffering=1)
def log(msg: str):
msg = str(msg)
print(msg)
lf.write(msg + "\n")
lf.flush()
# Header
log("=" * 80)
log(f"[TRIAL] {args._trial_id} | {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
log(f"[CONFIG] {json.dumps({k: v for k, v in asdict(args).items() if not k.startswith('_')}, default=str)}")
log("=" * 80)
# Seed + device
set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
gpu_name = torch.cuda.get_device_name(0)
log(f"Using GPU: {gpu_name}")
else:
log("Using CPU")
# Data
train_loader, test_loader = make_loaders(args.batch_size, args.workers, data_root=args.data)
# Preprocessor / core dims
if args.use_mha:
D_in_core = int(args.mha_embed_dim)
vision_e = VisionPreprocessor(img_size=32, patch_size=args.mha_patch,
embed_dim=args.mha_embed_dim, depth=args.mha_depth,
num_heads=args.mha_heads, dropout=args.mha_dropout)
# For the baseline we may or may not instantiate, so keep a separate preproc when needed
pre_params_e = sum(p.numel() for p in vision_e.parameters())
log(f"[Vision Pre] Enabled (ViT-style)")
log(f" ETHOS-pre params: {pre_params_e:,} (patch={args.mha_patch}, depth={args.mha_depth}, heads={args.mha_heads}, dim={args.mha_embed_dim})")
else:
D_in_core = 3 * 32 * 32
pre_params_e = 0
vision_e = None
C = args.classes
# ETHOS core (exclude preproc from budget)
ethos_core = EthosGeneratedMLP(
d_in=D_in_core, n_classes=C,
H=args.H, A=args.A, r=args.r, S=args.S, K=args.K, b=args.b, beam_width=args.beam,
Hh=args.Hh, M=args.M, film_eps=0.1, film_tanh_gamma=True, per_head_dict=not args.share_dict, s_tile=512
)
# Compose full ETHOS model
if args.use_mha:
ethos = nn.Sequential(vision_e, ethos_core).to(device)
else:
ethos = nn.Sequential(nn.Flatten(), ethos_core).to(device)
# Iso MLP width matching ETHOS-core param budget (preproc excluded)
ethos_core_params = sum(p.numel() for p in ethos_core.parameters())
# ---------------- Cache setup for baseline ----------------
cache_root = Path(args.baseline_cache_dir) if args.baseline_cache_dir else (Path(args.log_root) / "cache")
cache_root.mkdir(parents=True, exist_ok=True)
cache_key, cache_sig = mlp_cache_key(args, ethos_core_params, D_in_core)
cache_path = cache_root / f"mlp_{cache_key}.summary.json"
baseline_cached = cache_path.exists() and (not args.baseline_cache_refresh)
best_m_cached = None
mlp_params_total_cached = None
if baseline_cached:
try:
with open(cache_path, "r") as f:
cached = json.load(f)
best_m_cached = float(cached["results"]["best_mlp"])
mlp_params_total_cached = int(cached["params"]["mlp_total"]) if "params" in cached and "mlp_total" in cached["params"] else None
log(f"[CACHE] Iso-MLP hit key={cache_key} | best={best_m_cached:.2f}% | params={mlp_params_total_cached if mlp_params_total_cached is not None else 'n/a'}")
except Exception as e:
log(f"[CACHE] Failed to read cache ({e}); retraining baseline.")
baseline_cached = False
# Determine learning rates
lr_ethos = args.lr_ethos if args.lr_ethos is not None else args.lr
lr_mlp = args.lr_mlp if args.lr_mlp is not None else args.lr
lr_router = args.lr_router if args.lr_router is not None else lr_ethos
lr_hypernet = args.lr_hypernet if args.lr_hypernet is not None else lr_ethos
# Optimizers / schedulers
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# Param groups for ETHOS
router_params, hypernet_params, other_params = [], [], []
for name, param in ethos.named_parameters():
if args.use_mha and name.startswith('0.'):
other_params.append(param) # vision preproc
elif 'router.query_w' in name:
router_params.append(param)
elif any(key in name for key in ['W1_blocks', 'W2', 'B_in', 'B_out']):
hypernet_params.append(param)
elif any(key in name for key in ['gamma_w', 'gamma_b', 'beta_w', 'beta_b']):
other_params.append(param)
else:
other_params.append(param)
ethos_param_groups = []
if router_params: ethos_param_groups.append({'params': router_params, 'lr': lr_router})
if hypernet_params: ethos_param_groups.append({'params': hypernet_params, 'lr': lr_hypernet})
if other_params: ethos_param_groups.append({'params': other_params, 'lr': lr_ethos})
opt_e = torch.optim.AdamW(ethos_param_groups, lr=lr_ethos, weight_decay=args.weight_decay)
sched_e = WarmupCosineLR(opt_e, warmup_epochs=args.warmup_epochs, total_epochs=args.epochs,
base_lr=None, min_lr=args.min_lr)
# Baseline build only if no cache
iso_mlp = None
opt_m = None
sched_m = None
vision_m = None
pre_params_m = 0
if not baseline_cached:
if args.use_mha:
vision_m = VisionPreprocessor(img_size=32, patch_size=args.mha_patch,
embed_dim=args.mha_embed_dim, depth=args.mha_depth,
num_heads=args.mha_heads, dropout=args.mha_dropout)
pre_params_m = sum(p.numel() for p in vision_m.parameters())
# Solve width (if needed) & build
if args.mlp_width > 0:
iso_width = int(args.mlp_width)
else:
iso_width = solve_width_for_budget(D_in_core, C, depth=args.mlp_depth,
norm=args.mlp_norm, target_params=ethos_core_params)
iso_mlp_core = IsoParamMLP(d_in=D_in_core, num_classes=C, width=iso_width, depth=args.mlp_depth,
norm=args.mlp_norm, dropout=args.mlp_dropout)
if args.use_mha:
iso_mlp = nn.Sequential(vision_m, iso_mlp_core).to(device)
else:
iso_mlp = nn.Sequential(nn.Flatten(), iso_mlp_core).to(device)
opt_m = torch.optim.AdamW(iso_mlp.parameters(), lr=lr_mlp, weight_decay=args.weight_decay)
sched_m = WarmupCosineLR(opt_m, warmup_epochs=args.warmup_epochs, total_epochs=args.epochs,
base_lr=lr_mlp, min_lr=args.min_lr)
ethos_params_total = sum(p.numel() for p in ethos.parameters())
mlp_params_total = (sum(p.numel() for p in iso_mlp.parameters()) if iso_mlp is not None else mlp_params_total_cached)
log(f"[Budget] ETHOS total params: {ethos_params_total:,} (core: {ethos_core_params:,}, preproc: {pre_params_e:,})")
log(f"[Model] ETHOS total: {ethos_params_total:,} | Iso-MLP total: {mlp_params_total if mlp_params_total is not None else 'n/a'}")
log(f"[LR] ETHOS lr: {lr_ethos:.6f} | MLP lr: {lr_mlp:.6f}")
log(f"[LR] Router lr: {lr_router:.6f} | Hypernet lr: {lr_hypernet:.6f}")
log(f"[Param Groups] Router: {sum(p.numel() for p in router_params):,} | Hypernet: {sum(p.numel() for p in hypernet_params):,} | Other: {sum(p.numel() for p in other_params):,}")
ema_acc_ethos = ScalarEMA(decay=args.metric_ema_decay, bias_correct=args.metric_ema_bias_correct)
ema_acc_mlp = ScalarEMA(decay=args.metric_ema_decay, bias_correct=args.metric_ema_bias_correct)
best_e = 0.0
best_m = (0.0 if not baseline_cached else float(best_m_cached))
for epoch in range(1, args.epochs + 1):
# ETHOS train
train_one_epoch(ethos, train_loader, device, opt_e, criterion, grad_clip=args.grad_clip)
sched_e.step()
acc_e = evaluate(ethos, test_loader, device)
acc_e_ema = ema_acc_ethos.update(acc_e)
best_e = max(best_e, acc_e)
if baseline_cached:
log(
f"epoch {epoch}/{args.epochs} | "
f"ETHOS acc: {acc_e:5.2f}% (EMA: {acc_e_ema:5.2f}%) | "
f"Iso-MLP: [cached] | "
f"Delta(best raw): {best_e - best_m:5.2f}"
)
else:
# Baseline train
train_one_epoch(iso_mlp, train_loader, device, opt_m, criterion, grad_clip=args.grad_clip)
sched_m.step()
acc_m = evaluate(iso_mlp, test_loader, device)
acc_m_ema = ema_acc_mlp.update(acc_m)
best_m = max(best_m, acc_m)
log(
f"epoch {epoch}/{args.epochs} | "
f"ETHOS acc: {acc_e:5.2f}% (EMA: {acc_e_ema:5.2f}%) | "
f"Iso-MLP acc: {acc_m:5.2f}% (EMA: {acc_m_ema:5.2f}%) | "
f"Delta(best raw): {best_e - best_m:5.2f}"
)
log(f"[FINAL] ETHOS: best {best_e:.2f}% | Iso-MLP: {'cached ' if baseline_cached else ''}best {best_m:.2f}%")
lf.close()
# write cache if trained baseline
if not baseline_cached:
try:
summary_for_cache = {
"key": cache_key,
"sig": cache_sig,
"params": {"mlp_total": int(mlp_params_total) if mlp_params_total is not None else None},
"results": {"best_mlp": float(best_m)},
"timestamp": datetime.now().isoformat(timespec="seconds"),
}
with open(cache_path, "w") as f:
json.dump(summary_for_cache, f)
except Exception as e:
print(f"[CACHE] Failed to write cache: {e}", file=sys.stderr)
# free mem between trials
try:
del ethos, ethos_core
if not baseline_cached:
del iso_mlp, iso_mlp_core
if args.use_mha:
del vision_e
if not baseline_cached and vision_m is not None:
del vision_m
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass
return {
"trial_id": args._trial_id,
"log_file": str(trial_log_path),
"config": {k: v for k, v in asdict(args).items() if not k.startswith('_')},
"params": {
"ethos_total": int(ethos_params_total),
"ethos_core": int(ethos_core_params),
"preproc": int(pre_params_e),
"mlp_total": (int(mlp_params_total) if mlp_params_total is not None else None),
},
"results": {
"best_ethos": float(best_e),
"best_mlp": float(best_m),
"delta": float(best_e - best_m),
},
"device": "cuda" if torch.cuda.is_available() else "cpu",
"timestamp": datetime.now().isoformat(timespec="seconds"),
"status": "ok",
"baseline_cache": {
"used": bool(baseline_cached),
"key": cache_key,
"path": str(cache_path),
}
}
# ----------------------------
# Main
# ----------------------------
def main():
args = parse_args()
log_root = Path(args.log_root)
log_root.mkdir(parents=True, exist_ok=True)
summary_path = log_root / "trials.log"
# Sweep mode: YAML provided
if args.config is not None:
keys, combos = _cartesian_from_yaml(args.config)
total = len(combos)
start_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
print(f"[SWEEP] {args.sweep_name} | {total} trial(s) from {args.config}")
print(f"[SWEEP] Logs in: {str(log_root.resolve())}")
for idx, values in enumerate(combos, start=1):
overrides = dict(zip(keys, values))
trial_args = _apply_overrides(args, overrides)
trial_args._trial_id = f"{args.sweep_name}_{start_stamp}_{idx:03d}"
trial_log_path = log_root / f"{trial_args._trial_id}.log"
try:
summary = run_single_trial(trial_args, trial_log_path)
except Exception as e:
summary = {
"trial_id": trial_args._trial_id,
"log_file": str(trial_log_path),
"config": {k: v for k, v in asdict(trial_args).items() if not k.startswith('_')},
"status": "error",
"error": repr(e),
"timestamp": datetime.now().isoformat(timespec="seconds"),
}
print(f"[ERROR] Trial {trial_args._trial_id} failed: {repr(e)}", file=sys.stderr)
_append_summary_line(summary_path, summary)
print(f"[SWEEP DONE] Wrote collated results to {summary_path}")
else:
# Single-run mode (no YAML)
start_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
args._trial_id = f"{args.sweep_name}_{start_stamp}_single"
trial_log_path = log_root / f"{args._trial_id}.log"
try:
summary = run_single_trial(args, trial_log_path)
except Exception as e:
summary = {
"trial_id": args._trial_id,
"log_file": str(trial_log_path),
"config": {k: v for k, v in asdict(args).items() if not k.startswith('_')},
"status": "error",
"error": repr(e),
"timestamp": datetime.now().isoformat(timespec="seconds"),
}
print(f"[ERROR] Trial failed: {repr(e)}", file=sys.stderr)
_append_summary_line(summary_path, summary)
print(f"[DONE] Wrote collated result to {summary_path}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment