Skip to content

Instantly share code, notes, and snippets.

@muchanem
Created October 6, 2025 19:42
Show Gist options
  • Select an option

  • Save muchanem/1ba8a4563596f4557c5ae1b8062e8bcb to your computer and use it in GitHub Desktop.

Select an option

Save muchanem/1ba8a4563596f4557c5ae1b8062e8bcb to your computer and use it in GitHub Desktop.
StarSketchVAE
from dataclasses import dataclass
from typing import Tuple
import jax
import jax.numpy as jnp
from flax import nnx
from dataclasses import dataclass
from typing import Tuple
import jax
import jax.numpy as jnp
from flax import nnx
LOG_SIG_MIN = -5.0 # sigma >= e^-5 ≈ 0.0067
LOG_SIG_MAX = 2.0 # sigma <= e^2 ≈ 7.39
RHO_EPS = 1e-3
# =============================================================
# HyperLSTMCell
# =============================================================
class HyperLSTMCell(nnx.RNNCellBase):
def __init__(
self,
in_features: int,
hidden_features: int,
*,
hyper_num_units: int = 256,
hyper_embedding_size: int = 32,
forget_bias: float = 1.0,
gate_fn=jax.nn.sigmoid,
activation_fn=jnp.tanh,
rngs: nnx.Rngs,
):
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features
self.hyper_num_units = hyper_num_units
self.hyper_embedding_size = hyper_embedding_size
self.forget_bias = forget_bias
self.gate_fn = gate_fn
self.activation_fn = activation_fn
# Base projections (input has no bias; hidden has bias), like your LSTMCell
self.ii = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.if_ = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.ig = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.io = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.hi = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
self.hf = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
self.hg = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
self.ho = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
# Hyper LSTM that processes [x, h]
self.hyper_cell = nnx.LSTMCell(
in_features + hidden_features, hyper_num_units, rngs=rngs
)
# Hyper heads: zw -> alpha (+ beta). Separate heads per contribution
def head_pair():
zw = nnx.Linear(
hyper_num_units, hyper_embedding_size, use_bias=True, rngs=rngs
)
alpha = nnx.Linear(
hyper_embedding_size, hidden_features, use_bias=False, rngs=rngs
)
beta = nnx.Linear(
hyper_embedding_size, hidden_features, use_bias=False, rngs=rngs
)
return zw, alpha, beta
self.zw_ix, self.alpha_ix, self.beta_ix = head_pair()
self.zw_jx, self.alpha_jx, self.beta_jx = head_pair()
self.zw_fx, self.alpha_fx, self.beta_fx = head_pair()
self.zw_ox, self.alpha_ox, self.beta_ox = head_pair()
self.zw_ih, self.alpha_ih, self.beta_ih = head_pair()
self.zw_jh, self.alpha_jh, self.beta_jh = head_pair()
self.zw_fh, self.alpha_fh, self.beta_fh = head_pair()
self.zw_oh, self.alpha_oh, self.beta_oh = head_pair()
def initialize_carry(self, input_shape, rngs: nnx.Rngs):
batch_dims = input_shape[:-1]
total_units = self.hidden_features + self.hyper_num_units
mem_shape = batch_dims + (total_units,)
c = jnp.zeros(mem_shape, dtype=jnp.float32)
h = jnp.zeros(mem_shape, dtype=jnp.float32)
return (c, h)
def _hyper_affine(
self, which: str, hyper_h: jnp.ndarray, preact: jnp.ndarray, use_bias: bool
) -> jnp.ndarray:
zw = getattr(self, f"zw_{which}")(hyper_h)
alpha = getattr(self, f"alpha_{which}")(zw)
out = alpha * preact
if use_bias:
beta = getattr(self, f"beta_{which}")(zw)
out = out + beta
return out
def __call__(self, carry, inputs):
total_c, total_h = carry
c_main = total_c[..., : self.hidden_features]
h_main = total_h[..., : self.hidden_features]
c_hyp = total_c[..., self.hidden_features :]
h_hyp = total_h[..., self.hidden_features :]
# Hyper step
hyper_in = jnp.concatenate([inputs, h_main], axis=-1)
(new_c_hyp, new_h_hyp), _ = self.hyper_cell((c_hyp, h_hyp), hyper_in)
# Base preactivations
ix = self.ii(inputs)
ih = self.hi(h_main)
jx = self.ig(inputs)
jh = self.hg(h_main)
fx = self.if_(inputs)
fh = self.hf(h_main)
ox = self.io(inputs)
oh = self.ho(h_main)
# Modulate
ix = self._hyper_affine("ix", new_h_hyp, ix, use_bias=False)
jx = self._hyper_affine("jx", new_h_hyp, jx, use_bias=False)
fx = self._hyper_affine("fx", new_h_hyp, fx, use_bias=False)
ox = self._hyper_affine("ox", new_h_hyp, ox, use_bias=False)
ih = self._hyper_affine("ih", new_h_hyp, ih, use_bias=True)
jh = self._hyper_affine("jh", new_h_hyp, jh, use_bias=True)
fh = self._hyper_affine("fh", new_h_hyp, fh, use_bias=True)
oh = self._hyper_affine("oh", new_h_hyp, oh, use_bias=True)
# Gates and state update
i = self.gate_fn(ix + ih)
f = self.gate_fn(fx + fh + self.forget_bias)
g = self.activation_fn(jx + jh)
o = self.gate_fn(ox + oh)
new_c_main = f * c_main + i * g
new_h_main = o * self.activation_fn(new_c_main)
# Full state
new_total_c = jnp.concatenate([new_c_main, new_c_hyp], axis=-1)
new_total_h = jnp.concatenate([new_h_main, new_h_hyp], axis=-1)
new_carry = (new_total_c, new_total_h)
return new_carry, new_h_main
@property
def num_feature_axes(self) -> int:
"""Returns the number of feature axes of the RNN cell."""
return 1
class BidirectionalLSTMEncoder(nnx.Module):
def __init__(
self,
in_features: int,
hidden_size: int = 512,
latent_size: int = 128,
*,
rngs: nnx.Rngs,
):
super().__init__()
self.hidden_size = hidden_size
self.fwd = nnx.RNN(
nnx.LSTMCell(
in_features=in_features, hidden_features=hidden_size, rngs=rngs
)
)
self.bwd = nnx.RNN(
nnx.LSTMCell(
in_features=in_features, hidden_features=hidden_size, rngs=rngs
)
)
self.fc_mu = nnx.Linear(hidden_size * 2, latent_size, rngs=rngs)
self.fc_logsig = nnx.Linear(hidden_size * 2, latent_size, rngs=rngs)
def __call__(self, x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]:
h_fwd = self.fwd(x) # (B, T, H)
h_bwd = self.bwd(jnp.flip(x, axis=1)) # (B, T, H)
h_final = jnp.concatenate([h_fwd[:, -1], h_bwd[:, -1]], axis=-1) # (B, 2H)
mu = self.fc_mu(h_final)
log_sigma = self.fc_logsig(h_final)
return mu, log_sigma, h_final
class GMMDecoder(nnx.Module):
def __init__(
self,
in_features: int,
hidden_size: int = 2048,
n_mixtures: int = 20,
*,
rngs: nnx.Rngs,
):
super().__init__()
self.rnn = nnx.RNN(
HyperLSTMCell(
in_features=in_features, hidden_features=hidden_size, rngs=rngs
)
)
self.n_mixtures = n_mixtures
out_dim = n_mixtures * 6
self.fc_out = nnx.Linear(hidden_size, out_dim, rngs=rngs)
def __call__(self, inputs: jax.Array) -> Tuple[jax.Array, ...]:
h = self.rnn(inputs) # (B, T, H) # maybe needs to be h, _
y = self.fc_out(h) # (B, T, M*6)
B, T, _ = y.shape
M = self.n_mixtures
y = y.reshape(B, T, M, 6)
pi_logits = y[..., 0] # (B, T, M)
mu = y[..., 1:3] # (B, T, M, 2)
log_sig = jnp.clip(y[..., 3:5], LOG_SIG_MIN, LOG_SIG_MAX) # (B, T, M, 2)
rho = jnp.tanh(y[..., 5]) * (1.0 - RHO_EPS) # (B, T, M)
return pi_logits, mu, log_sig, rho
@dataclass
class ModelConfig:
seq_len: int = 100
data_dim: int = 2 # (Δx, Δt)
enc_hidden: int = 512
dec_hidden: int = 2048
latent_size: int = 128
n_mixtures: int = 20
num_classes: int = 0 # set >0 to enable classifier head
class StarSketchVAE(nnx.Module):
def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
super().__init__()
self.cfg = cfg
self.encoder = BidirectionalLSTMEncoder(
in_features=cfg.data_dim,
hidden_size=cfg.enc_hidden,
latent_size=cfg.latent_size,
rngs=rngs,
)
self.decoder = GMMDecoder(
in_features=cfg.data_dim + cfg.latent_size,
hidden_size=cfg.dec_hidden,
n_mixtures=cfg.n_mixtures,
rngs=rngs,
)
self.classifier = nnx.Linear(cfg.latent_size, cfg.num_classes, rngs=rngs)
def encode(self, x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]:
return self.encoder(x)
def decode(self, prev_xy: jax.Array, z: jax.Array) -> Tuple[jax.Array, ...]:
B, T, _ = prev_xy.shape
z_broadcast = jnp.repeat(z[:, None, :], T, axis=1) # (B, T, Z)
dec_in = jnp.concatenate([prev_xy, z_broadcast], axis=-1) # (B, T, Z+2)
return self.decoder(dec_in)
def bivariate_normal_logpdf(
xy: jax.Array, mu: jax.Array, log_sig: jax.Array, rho: jax.Array
) -> jax.Array:
sig = jnp.exp(log_sig)
x = xy[..., 0]
y = xy[..., 1]
mx = mu[..., 0]
my = mu[..., 1]
sx = sig[..., 0]
sy = sig[..., 1]
sx = jnp.maximum(sx, 1e-6)
sy = jnp.maximum(sy, 1e-6)
one_minus_r2 = jnp.maximum(1.0 - rho**2, 1e-6)
norm = (
-jnp.log(2.0 * jnp.pi) - jnp.log(sx) - jnp.log(sy) - 0.5 * jnp.log(one_minus_r2)
)
zx = (x - mx) / sx
zy = (y - my) / sy
exp_term = -0.5 * ((zx**2 + zy**2 - 2.0 * rho * zx * zy) / one_minus_r2)
return norm + exp_term
def gmm_log_likelihood(
xy: jax.Array,
pi_logits: jax.Array,
mu: jax.Array,
log_sig: jax.Array,
rho: jax.Array,
) -> jax.Array:
B, T, _ = xy.shape
M = pi_logits.shape[-1]
xy_exp = xy[:, :, None, :] # (B, T, 1, 2)
comp_logpdf = bivariate_normal_logpdf(xy_exp, mu, log_sig, rho) # (B, T, M)
log_pi = jax.nn.log_softmax(pi_logits, axis=-1) # (B, T, M)
# log-sum-exp over mixture components
ll = jax.scipy.special.logsumexp(log_pi + comp_logpdf, axis=-1) # (B, T)
return ll
def reconstruction_loss(
xy: jax.Array,
pi_logits: jax.Array,
mu: jax.Array,
log_sig: jax.Array,
rho: jax.Array,
w: jax.Array
) -> jax.Array:
ll = gmm_log_likelihood(xy, pi_logits, mu, log_sig, rho) # (B, T)
nll_t = -ll
nll = jnp.sum(w * nll_t) / (jnp.sum(w) + 1e-8)
return nll
def kl_loss(mu: jax.Array, log_sigma: jax.Array) -> jax.Array:
log_var = 2.0 * log_sigma
var = jnp.exp(log_var)
kl = 0.5 * (var + mu**2 - 1.0 - log_var)
return jnp.mean(jnp.sum(kl, axis=-1))
def model_loss(
model: StarSketchVAE,
batch_xy: jax.Array,
*,
class_ids: jax.Array | None = None,
w_kl: float = 1.0,
w_sup: float = 1.0,
rng_key: jax.Array,
) -> Tuple[jax.Array, dict]:
"""
batch_xy: (B, T=100, 2) full teacher-forced sequences of deltas
Returns total loss and metrics dict
"""
B, T, D = batch_xy.shape
assert D == 2, "Expecting (Δx, Δt)"
mu, log_sigma, enc_feats = model.encode(batch_xy)
eps = jax.random.normal(rng_key, mu.shape)
z = mu + jnp.exp(log_sigma) * eps
pad0 = jnp.zeros((B, 1, D), dtype=batch_xy.dtype)
prev_xy = jnp.concatenate([pad0, batch_xy[:, :-1, :]], axis=1) # (B, T, 2)
pi_logits, gmm_mu, gmm_logsig, gmm_rho = model.decode(prev_xy, z)
dt = batch_xy[..., 1]
# normalized weights
w = dt / (jnp.mean(dt) + 1e-8)
recon = reconstruction_loss(batch_xy, pi_logits, gmm_mu, gmm_logsig, gmm_rho,
w)
kl = kl_loss(mu, log_sigma)
total = recon + w_kl * kl
sup_loss = 0.0
cls_metrics = {}
logits = model.classifier(z) # (B, C)
onehot = jax.nn.one_hot(class_ids, logits.shape[-1])
ce = -jnp.sum(onehot * jax.nn.log_softmax(logits, axis=-1), axis=-1)
sup_loss = jnp.mean(ce)
probs = jax.nn.softmax(logits, axis=-1)
pred = jnp.argmax(probs, axis=-1)
acc = jnp.mean((pred == class_ids).astype(jnp.float32))
cls_metrics = {"sup_ce": sup_loss, "acc": acc}
# ===== Sequence-closure penalty: sum Δx_pred ≈ sum Δx_gt
pi = jax.nn.softmax(pi_logits, axis=-1)[..., None] # (B,T,M,1)
exp = jnp.sum(pi * gmm_mu, axis=-2) # (B,T,2)
x_gt_total = jnp.sum(batch_xy[..., 0], axis=1) # (B,)
x_pr_total = jnp.sum(exp[..., 0], axis=1) # (B,)
close_loss = jnp.mean((x_pr_total - x_gt_total) ** 2)
total = recon + w_kl*kl + w_sup * sup_loss + close_loss*5
metrics = {"loss": total, "recon": recon, "kl": kl, **cls_metrics}
return total, metrics
import jax
import jax.numpy as jnp
# ---------- Classifier direction utilities ----------
def _linear_params(linear) -> tuple[jnp.ndarray, jnp.ndarray | None]:
# Works for nnx.Linear; some variants use 'kernel' not 'weight'
if hasattr(linear, "weight"):
W = linear.weight
elif hasattr(linear, "kernel"):
W = linear.kernel
else:
raise AttributeError("Classifier Linear has no 'weight' or 'kernel'.")
b = getattr(linear, "bias", None)
return W, b
def get_classifier_direction(model: StarSketchVAE,
class_id: int,
*,
normalize: bool = True,
margin: bool = False) -> jnp.ndarray:
"""
Return a latent-space direction for class `class_id`.
Handles either W shape: (C, Z) or (Z, C).
If margin=True, uses a mean-difference direction for a stronger push.
"""
W, _ = _linear_params(model.classifier)
C = model.cfg.num_classes
Z = model.cfg.latent_size
assert 0 <= class_id < C, "class_id out of range"
# Figure out orientation
if W.shape == (C, Z): # rows = classes
w_c = W[class_id] # (Z,)
if margin:
mask = jnp.ones(C, dtype=bool).at[class_id].set(False)
w_rest_mean = jnp.mean(W[mask], axis=0) # (Z,)
direction = w_c - w_rest_mean
else:
direction = w_c
elif W.shape == (Z, C): # columns = classes
w_c = W[:, class_id] # (Z,)
if margin:
# mean over other columns
idx = jnp.arange(C) != class_id
w_rest_mean = jnp.mean(W[:, idx], axis=1) # (Z,)
direction = w_c - w_rest_mean
else:
direction = w_c
else:
raise ValueError(f"Unexpected classifier weight shape {W.shape}; "
f"expected (C,Z)=({C},{Z}) or (Z,C)=({Z},{C}).")
if normalize:
direction = direction / (jnp.linalg.norm(direction) + 1e-8)
return direction # (Z,)
def tweak_z_towards_class(z: jnp.ndarray,
model: StarSketchVAE,
class_id: int,
*,
alpha: float = 3.0,
normalize: bool = True,
margin: bool = False) -> jnp.ndarray:
d = get_classifier_direction(model, class_id, normalize=normalize, margin=margin)
assert d.shape == z.shape, f"direction {d.shape} must match z {z.shape}"
return z + alpha * d
# ---------- GMM sampling helpers ----------
def _sample_bivariate_from_params(mu: jnp.ndarray,
log_sig: jnp.ndarray,
rho: jnp.ndarray,
rng_key: jax.Array,
temperature: float = 1.0) -> jnp.ndarray:
sig = jnp.exp(log_sig) * jnp.sqrt(temperature)
rho = jnp.clip(rho, -0.999, 0.999)
u_key, v_key = jax.random.split(rng_key)
u = jax.random.normal(u_key)
v = jax.random.normal(v_key)
# correlated 2D normal
dx = sig[0] * u
dy = sig[1] * (rho * u + jnp.sqrt(1.0 - rho**2) * v)
return mu + jnp.array([dx, dy])
def _mixture_expectation(pi_logits_t: jnp.ndarray,
mu_t: jnp.ndarray,
temperature: float = 1.0) -> jnp.ndarray:
pi = jax.nn.softmax(pi_logits_t / temperature, axis=-1) # (M,)
return jnp.sum(pi[:, None] * mu_t, axis=0) # (2,)
def _mixture_sample(pi_logits_t: jnp.ndarray,
mu_t: jnp.ndarray,
log_sig_t: jnp.ndarray,
rho_t: jnp.ndarray,
rng_key: jax.Array,
temperature: float = 1.0) -> jnp.ndarray:
M = pi_logits_t.shape[0]
# Sample component
pi = jax.nn.softmax(pi_logits_t / temperature, axis=-1) # (M,)
comp_key, gauss_key = jax.random.split(rng_key)
k = jax.random.categorical(comp_key, jnp.log(pi), axis=-1) # int
return _sample_bivariate_from_params(mu_t[k], log_sig_t[k], rho_t[k], gauss_key, temperature)
# ---------- Autoregressive decoding ----------
def decode_from_z(model: StarSketchVAE,
z: jnp.ndarray,
*,
steps: int,
method: str = "expected", # "expected" or "sample"
temperature: float = 1.0,
rng_key: jax.Array | None = None) -> jnp.ndarray:
assert z.ndim == 1, "z must be shape (Z,)"
B = 1
D = 2
if method == "sample" and rng_key is None:
rng_key = jax.random.PRNGKey(0)
# running buffer of previous outputs
prev = jnp.zeros((steps, D), dtype=jnp.float32)
prev = prev.at[0].set(jnp.array([0.0, 0.0], dtype=jnp.float32)) # first prev is 0
out = []
zB = z[None, :] # (1,Z)
for t in range(steps):
prev_xy = prev[None, :, :] # (1, steps, 2)
pi_logits, gmm_mu, gmm_logsig, gmm_rho = model.decode(prev_xy, zB)
pi_t = pi_logits[0, t] # (M,)
mu_t = gmm_mu[0, t] # (M,2)
logs_t = gmm_logsig[0, t] # (M,2)
rho_t = gmm_rho[0, t] # (M,)
if method == "expected":
delta = _mixture_expectation(pi_t, mu_t, temperature=temperature)
else:
rng_key, step_key = jax.random.split(rng_key)
delta = _mixture_sample(pi_t, mu_t, logs_t, rho_t, step_key, temperature=temperature)
out.append(delta)
if t + 1 < steps:
prev = prev.at[t + 1].set(delta)
return jnp.stack(out, axis=0) # (steps, 2)
# ---------- End-to-end: sample z, push toward class, decode ----------
def class_conditioned_random_curve(model: StarSketchVAE,
class_id: int,
*,
steps: int = 100,
z_scale: float = 1.0,
alpha: float = 3.0,
normalize_dir: bool = True,
margin: bool = False,
method: str = "expected", # "expected" or "sample"
temperature: float = 0.9,
rng_key: jax.Array | None = None) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
if rng_key is None:
rng_key = jax.random.PRNGKey(42)
z_key, dec_key = jax.random.split(rng_key)
# sample z ~ N(0, I)
Z = model.cfg.latent_size
z = jax.random.normal(z_key, (Z,)) * z_scale # (Z,)
# push toward class direction
z_star = tweak_z_towards_class(z, model, class_id, alpha=alpha,
normalize=normalize_dir, margin=margin)
deltas = decode_from_z(model, z_star, steps=steps, method=method,
temperature=temperature, rng_key=dec_key)
dt = jnp.maximum(deltas[:, 1], 1e-4)
deltas = deltas.at[:, 1].set(dt)
# Convert to cumulative x(t)
x = jnp.cumsum(deltas[:, 0])
t = jnp.cumsum(deltas[:, 1])
return deltas, x, t
if __name__ == "__main__":
rngs = nnx.Rngs(0)
cfg = ModelConfig(
seq_len=100,
data_dim=2,
enc_hidden=512,
dec_hidden=2048,
latent_size=128,
n_mixtures=20,
num_classes=5,
)
model = StarSketchVAE(cfg, rngs=rngs)
B, T, D = 6, 100, 2
batch = jnp.ones((B, T, D))
# class_ids: all samples labeled (no unlabeled)
class_ids = jnp.array([0, 1, 2, 3, 4, 0])
key = jax.random.PRNGKey(42)
loss, metrics = model_loss(
model, batch, class_ids=class_ids, w_kl=1.0, w_sup=1.0, rng_key=key
)
print({k: float(v) for k, v in metrics.items()})
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment