Skip to content

Instantly share code, notes, and snippets.

@muchanem
Created October 6, 2025 19:42
Show Gist options
  • Select an option

  • Save muchanem/1ba8a4563596f4557c5ae1b8062e8bcb to your computer and use it in GitHub Desktop.

Select an option

Save muchanem/1ba8a4563596f4557c5ae1b8062e8bcb to your computer and use it in GitHub Desktop.
StarSketchVAE
from dataclasses import dataclass
from typing import Tuple
import jax
import jax.numpy as jnp
from flax import nnx
from dataclasses import dataclass
from typing import Tuple
import jax
import jax.numpy as jnp
from flax import nnx
LOG_SIG_MIN = -5.0 # sigma >= e^-5 ≈ 0.0067
LOG_SIG_MAX = 2.0 # sigma <= e^2 ≈ 7.39
RHO_EPS = 1e-3
# =============================================================
# HyperLSTMCell
# =============================================================
class HyperLSTMCell(nnx.RNNCellBase):
def __init__(
self,
in_features: int,
hidden_features: int,
*,
hyper_num_units: int = 256,
hyper_embedding_size: int = 32,
forget_bias: float = 1.0,
gate_fn=jax.nn.sigmoid,
activation_fn=jnp.tanh,
rngs: nnx.Rngs,
):
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features
self.hyper_num_units = hyper_num_units
self.hyper_embedding_size = hyper_embedding_size
self.forget_bias = forget_bias
self.gate_fn = gate_fn
self.activation_fn = activation_fn
# Base projections (input has no bias; hidden has bias), like your LSTMCell
self.ii = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.if_ = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.ig = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.io = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs)
self.hi = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
self.hf = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
self.hg = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
self.ho = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs)
# Hyper LSTM that processes [x, h]
self.hyper_cell = nnx.LSTMCell(
in_features + hidden_features, hyper_num_units, rngs=rngs
)
# Hyper heads: zw -> alpha (+ beta). Separate heads per contribution
def head_pair():
zw = nnx.Linear(
hyper_num_units, hyper_embedding_size, use_bias=True, rngs=rngs
)
alpha = nnx.Linear(
hyper_embedding_size, hidden_features, use_bias=False, rngs=rngs
)
beta = nnx.Linear(
hyper_embedding_size, hidden_features, use_bias=False, rngs=rngs
)
return zw, alpha, beta
self.zw_ix, self.alpha_ix, self.beta_ix = head_pair()
self.zw_jx, self.alpha_jx, self.beta_jx = head_pair()
self.zw_fx, self.alpha_fx, self.beta_fx = head_pair()
self.zw_ox, self.alpha_ox, self.beta_ox = head_pair()
self.zw_ih, self.alpha_ih, self.beta_ih = head_pair()
self.zw_jh, self.alpha_jh, self.beta_jh = head_pair()
self.zw_fh, self.alpha_fh, self.beta_fh = head_pair()
self.zw_oh, self.alpha_oh, self.beta_oh = head_pair()
def initialize_carry(self, input_shape, rngs: nnx.Rngs):
batch_dims = input_shape[:-1]
total_units = self.hidden_features + self.hyper_num_units
mem_shape = batch_dims + (total_units,)
c = jnp.zeros(mem_shape, dtype=jnp.float32)
h = jnp.zeros(mem_shape, dtype=jnp.float32)
return (c, h)
def _hyper_affine(
self, which: str, hyper_h: jnp.ndarray, preact: jnp.ndarray, use_bias: bool
) -> jnp.ndarray:
zw = getattr(self, f"zw_{which}")(hyper_h)
alpha = getattr(self, f"alpha_{which}")(zw)
out = alpha * preact
if use_bias:
beta = getattr(self, f"beta_{which}")(zw)
out = out + beta
return out
def __call__(self, carry, inputs):
total_c, total_h = carry
c_main = total_c[..., : self.hidden_features]
h_main = total_h[..., : self.hidden_features]
c_hyp = total_c[..., self.hidden_features :]
h_hyp = total_h[..., self.hidden_features :]
# Hyper step
hyper_in = jnp.concatenate([inputs, h_main], axis=-1)
(new_c_hyp, new_h_hyp), _ = self.hyper_cell((c_hyp, h_hyp), hyper_in)
# Base preactivations
ix = self.ii(inputs)
ih = self.hi(h_main)
jx = self.ig(inputs)
jh = self.hg(h_main)
fx = self.if_(inputs)
fh = self.hf(h_main)
ox = self.io(inputs)
oh = self.ho(h_main)
# Modulate
ix = self._hyper_affine("ix", new_h_hyp, ix, use_bias=False)
jx = self._hyper_affine("jx", new_h_hyp, jx, use_bias=False)
fx = self._hyper_affine("fx", new_h_hyp, fx, use_bias=False)
ox = self._hyper_affine("ox", new_h_hyp, ox, use_bias=False)
ih = self._hyper_affine("ih", new_h_hyp, ih, use_bias=True)
jh = self._hyper_affine("jh", new_h_hyp, jh, use_bias=True)
fh = self._hyper_affine("fh", new_h_hyp, fh, use_bias=True)
oh = self._hyper_affine("oh", new_h_hyp, oh, use_bias=True)
# Gates and state update
i = self.gate_fn(ix + ih)
f = self.gate_fn(fx + fh + self.forget_bias)
g = self.activation_fn(jx + jh)
o = self.gate_fn(ox + oh)
new_c_main = f * c_main + i * g
new_h_main = o * self.activation_fn(new_c_main)
# Full state
new_total_c = jnp.concatenate([new_c_main, new_c_hyp], axis=-1)
new_total_h = jnp.concatenate([new_h_main, new_h_hyp], axis=-1)
new_carry = (new_total_c, new_total_h)
return new_carry, new_h_main
@property
def num_feature_axes(self) -> int:
"""Returns the number of feature axes of the RNN cell."""
return 1
class BidirectionalLSTMEncoder(nnx.Module):
def __init__(
self,
in_features: int,
hidden_size: int = 512,
latent_size: int = 128,
*,
rngs: nnx.Rngs,
):
super().__init__()
self.hidden_size = hidden_size
self.fwd = nnx.RNN(
nnx.LSTMCell(
in_features=in_features, hidden_features=hidden_size, rngs=rngs
)
)
self.bwd = nnx.RNN(
nnx.LSTMCell(
in_features=in_features, hidden_features=hidden_size, rngs=rngs
)
)
self.fc_mu = nnx.Linear(hidden_size * 2, latent_size, rngs=rngs)
self.fc_logsig = nnx.Linear(hidden_size * 2, latent_size, rngs=rngs)
def __call__(self, x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]:
h_fwd = self.fwd(x) # (B, T, H)
h_bwd = self.bwd(jnp.flip(x, axis=1)) # (B, T, H)
h_final = jnp.concatenate([h_fwd[:, -1], h_bwd[:, -1]], axis=-1) # (B, 2H)
mu = self.fc_mu(h_final)
log_sigma = self.fc_logsig(h_final)
return mu, log_sigma, h_final
class GMMDecoder(nnx.Module):
def __init__(
self,
in_features: int,
hidden_size: int = 2048,
n_mixtures: int = 20,
*,
rngs: nnx.Rngs,
):
super().__init__()
self.rnn = nnx.RNN(
HyperLSTMCell(
in_features=in_features, hidden_features=hidden_size, rngs=rngs
)
)
self.n_mixtures = n_mixtures
out_dim = n_mixtures * 6
self.fc_out = nnx.Linear(hidden_size, out_dim, rngs=rngs)
def __call__(self, inputs: jax.Array) -> Tuple[jax.Array, ...]:
h = self.rnn(inputs) # (B, T, H) # maybe needs to be h, _
y = self.fc_out(h) # (B, T, M*6)
B, T, _ = y.shape
M = self.n_mixtures
y = y.reshape(B, T, M, 6)
pi_logits = y[..., 0] # (B, T, M)
mu = y[..., 1:3] # (B, T, M, 2)
log_sig = jnp.clip(y[..., 3:5], LOG_SIG_MIN, LOG_SIG_MAX) # (B, T, M, 2)
rho = jnp.tanh(y[..., 5]) * (1.0 - RHO_EPS) # (B, T, M)
return pi_logits, mu, log_sig, rho
@dataclass
class ModelConfig:
seq_len: int = 100
data_dim: int = 2 # (Δx, Δt)
enc_hidden: int = 512
dec_hidden: int = 2048
latent_size: int = 128
n_mixtures: int = 20
num_classes: int = 0 # set >0 to enable classifier head
class StarSketchVAE(nnx.Module):
def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
super().__init__()
self.cfg = cfg
self.encoder = BidirectionalLSTMEncoder(
in_features=cfg.data_dim,
hidden_size=cfg.enc_hidden,
latent_size=cfg.latent_size,
rngs=rngs,
)
self.decoder = GMMDecoder(
in_features=cfg.data_dim + cfg.latent_size,
hidden_size=cfg.dec_hidden,
n_mixtures=cfg.n_mixtures,
rngs=rngs,
)
self.classifier = nnx.Linear(cfg.latent_size, cfg.num_classes, rngs=rngs)
def encode(self, x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]:
return self.encoder(x)
def decode(self, prev_xy: jax.Array, z: jax.Array) -> Tuple[jax.Array, ...]:
B, T, _ = prev_xy.shape
z_broadcast = jnp.repeat(z[:, None, :], T, axis=1) # (B, T, Z)
dec_in = jnp.concatenate([prev_xy, z_broadcast], axis=-1) # (B, T, Z+2)
return self.decoder(dec_in)
def bivariate_normal_logpdf(
xy: jax.Array, mu: jax.Array, log_sig: jax.Array, rho: jax.Array
) -> jax.Array:
sig = jnp.exp(log_sig)
x = xy[..., 0]
y = xy[..., 1]
mx = mu[..., 0]
my = mu[..., 1]
sx = sig[..., 0]
sy = sig[..., 1]
sx = jnp.maximum(sx, 1e-6)
sy = jnp.maximum(sy, 1e-6)
one_minus_r2 = jnp.maximum(1.0 - rho**2, 1e-6)
norm = (
-jnp.log(2.0 * jnp.pi) - jnp.log(sx) - jnp.log(sy) - 0.5 * jnp.log(one_minus_r2)
)
zx = (x - mx) / sx
zy = (y - my) / sy
exp_term = -0.5 * ((zx**2 + zy**2 - 2.0 * rho * zx * zy) / one_minus_r2)
return norm + exp_term
def gmm_log_likelihood(
xy: jax.Array,
pi_logits: jax.Array,
mu: jax.Array,
log_sig: jax.Array,
rho: jax.Array,
) -> jax.Array:
B, T, _ = xy.shape
M = pi_logits.shape[-1]
xy_exp = xy[:, :, None, :] # (B, T, 1, 2)
comp_logpdf = bivariate_normal_logpdf(xy_exp, mu, log_sig, rho) # (B, T, M)
log_pi = jax.nn.log_softmax(pi_logits, axis=-1) # (B, T, M)
# log-sum-exp over mixture components
ll = jax.scipy.special.logsumexp(log_pi + comp_logpdf, axis=-1) # (B, T)
return ll
def reconstruction_loss(
xy: jax.Array,
pi_logits: jax.Array,
mu: jax.Array,
log_sig: jax.Array,
rho: jax.Array,
w: jax.Array
) -> jax.Array:
ll = gmm_log_likelihood(xy, pi_logits, mu, log_sig, rho) # (B, T)
nll_t = -ll
nll = jnp.sum(w * nll_t) / (jnp.sum(w) + 1e-8)
return nll
def kl_loss(mu: jax.Array, log_sigma: jax.Array) -> jax.Array:
log_var = 2.0 * log_sigma
var = jnp.exp(log_var)
kl = 0.5 * (var + mu**2 - 1.0 - log_var)
return jnp.mean(jnp.sum(kl, axis=-1))
def model_loss(
model: StarSketchVAE,
batch_xy: jax.Array,
*,
class_ids: jax.Array | None = None,
w_kl: float = 1.0,
w_sup: float = 1.0,
rng_key: jax.Array,
) -> Tuple[jax.Array, dict]:
"""
batch_xy: (B, T=100, 2) full teacher-forced sequences of deltas
Returns total loss and metrics dict
"""
B, T, D = batch_xy.shape
assert D == 2, "Expecting (Δx, Δt)"
mu, log_sigma, enc_feats = model.encode(batch_xy)
eps = jax.random.normal(rng_key, mu.shape)
z = mu + jnp.exp(log_sigma) * eps
pad0 = jnp.zeros((B, 1, D), dtype=batch_xy.dtype)
prev_xy = jnp.concatenate([pad0, batch_xy[:, :-1, :]], axis=1) # (B, T, 2)
pi_logits, gmm_mu, gmm_logsig, gmm_rho = model.decode(prev_xy, z)
dt = batch_xy[..., 1]
# normalized weights
w = dt / (jnp.mean(dt) + 1e-8)
recon = reconstruction_loss(batch_xy, pi_logits, gmm_mu, gmm_logsig, gmm_rho,
w)
kl = kl_loss(mu, log_sigma)
total = recon + w_kl * kl
sup_loss = 0.0
cls_metrics = {}
logits = model.classifier(z) # (B, C)
onehot = jax.nn.one_hot(class_ids, logits.shape[-1])
ce = -jnp.sum(onehot * jax.nn.log_softmax(logits, axis=-1), axis=-1)
sup_loss = jnp.mean(ce)
probs = jax.nn.softmax(logits, axis=-1)
pred = jnp.argmax(probs, axis=-1)
acc = jnp.mean((pred == class_ids).astype(jnp.float32))
cls_metrics = {"sup_ce": sup_loss, "acc": acc}
# ===== Sequence-closure penalty: sum Δx_pred ≈ sum Δx_gt
pi = jax.nn.softmax(pi_logits, axis=-1)[..., None] # (B,T,M,1)
exp = jnp.sum(pi * gmm_mu, axis=-2) # (B,T,2)
x_gt_total = jnp.sum(batch_xy[..., 0], axis=1) # (B,)
x_pr_total = jnp.sum(exp[..., 0], axis=1) # (B,)
close_loss = jnp.mean((x_pr_total - x_gt_total) ** 2)
total = recon + w_kl*kl + w_sup * sup_loss + close_loss*5
metrics = {"loss": total, "recon": recon, "kl": kl, **cls_metrics}
return total, metrics
import jax
import jax.numpy as jnp
# ---------- Classifier direction utilities ----------
def _linear_params(linear) -> tuple[jnp.ndarray, jnp.ndarray | None]:
# Works for nnx.Linear; some variants use 'kernel' not 'weight'
if hasattr(linear, "weight"):
W = linear.weight
elif hasattr(linear, "kernel"):
W = linear.kernel
else:
raise AttributeError("Classifier Linear has no 'weight' or 'kernel'.")
b = getattr(linear, "bias", None)
return W, b
def get_classifier_direction(model: StarSketchVAE,
class_id: int,
*,
normalize: bool = True,
margin: bool = False) -> jnp.ndarray:
"""
Return a latent-space direction for class `class_id`.
Handles either W shape: (C, Z) or (Z, C).
If margin=True, uses a mean-difference direction for a stronger push.
"""
W, _ = _linear_params(model.classifier)
C = model.cfg.num_classes
Z = model.cfg.latent_size
assert 0 <= class_id < C, "class_id out of range"
# Figure out orientation
if W.shape == (C, Z): # rows = classes
w_c = W[class_id] # (Z,)
if margin:
mask = jnp.ones(C, dtype=bool).at[class_id].set(False)
w_rest_mean = jnp.mean(W[mask], axis=0) # (Z,)
direction = w_c - w_rest_mean
else:
direction = w_c
elif W.shape == (Z, C): # columns = classes
w_c = W[:, class_id] # (Z,)
if margin:
# mean over other columns
idx = jnp.arange(C) != class_id
w_rest_mean = jnp.mean(W[:, idx], axis=1) # (Z,)
direction = w_c - w_rest_mean
else:
direction = w_c
else:
raise ValueError(f"Unexpected classifier weight shape {W.shape}; "
f"expected (C,Z)=({C},{Z}) or (Z,C)=({Z},{C}).")
if normalize:
direction = direction / (jnp.linalg.norm(direction) + 1e-8)
return direction # (Z,)
def tweak_z_towards_class(z: jnp.ndarray,
model: StarSketchVAE,
class_id: int,
*,
alpha: float = 3.0,
normalize: bool = True,
margin: bool = False) -> jnp.ndarray:
d = get_classifier_direction(model, class_id, normalize=normalize, margin=margin)
assert d.shape == z.shape, f"direction {d.shape} must match z {z.shape}"
return z + alpha * d
# ---------- GMM sampling helpers ----------
def _sample_bivariate_from_params(mu: jnp.ndarray,
log_sig: jnp.ndarray,
rho: jnp.ndarray,
rng_key: jax.Array,
temperature: float = 1.0) -> jnp.ndarray:
sig = jnp.exp(log_sig) * jnp.sqrt(temperature)
rho = jnp.clip(rho, -0.999, 0.999)
u_key, v_key = jax.random.split(rng_key)
u = jax.random.normal(u_key)
v = jax.random.normal(v_key)
# correlated 2D normal
dx = sig[0] * u
dy = sig[1] * (rho * u + jnp.sqrt(1.0 - rho**2) * v)
return mu + jnp.array([dx, dy])
def _mixture_expectation(pi_logits_t: jnp.ndarray,
mu_t: jnp.ndarray,
temperature: float = 1.0) -> jnp.ndarray:
pi = jax.nn.softmax(pi_logits_t / temperature, axis=-1) # (M,)
return jnp.sum(pi[:, None] * mu_t, axis=0) # (2,)
def _mixture_sample(pi_logits_t: jnp.ndarray,
mu_t: jnp.ndarray,
log_sig_t: jnp.ndarray,
rho_t: jnp.ndarray,
rng_key: jax.Array,
temperature: float = 1.0) -> jnp.ndarray:
M = pi_logits_t.shape[0]
# Sample component
pi = jax.nn.softmax(pi_logits_t / temperature, axis=-1) # (M,)
comp_key, gauss_key = jax.random.split(rng_key)
k = jax.random.categorical(comp_key, jnp.log(pi), axis=-1) # int
return _sample_bivariate_from_params(mu_t[k], log_sig_t[k], rho_t[k], gauss_key, temperature)
# ---------- Autoregressive decoding ----------
def decode_from_z(model: StarSketchVAE,
z: jnp.ndarray,
*,
steps: int,
method: str = "expected", # "expected" or "sample"
temperature: float = 1.0,
rng_key: jax.Array | None = None) -> jnp.ndarray:
assert z.ndim == 1, "z must be shape (Z,)"
B = 1
D = 2
if method == "sample" and rng_key is None:
rng_key = jax.random.PRNGKey(0)
# running buffer of previous outputs
prev = jnp.zeros((steps, D), dtype=jnp.float32)
prev = prev.at[0].set(jnp.array([0.0, 0.0], dtype=jnp.float32)) # first prev is 0
out = []
zB = z[None, :] # (1,Z)
for t in range(steps):
prev_xy = prev[None, :, :] # (1, steps, 2)
pi_logits, gmm_mu, gmm_logsig, gmm_rho = model.decode(prev_xy, zB)
pi_t = pi_logits[0, t] # (M,)
mu_t = gmm_mu[0, t] # (M,2)
logs_t = gmm_logsig[0, t] # (M,2)
rho_t = gmm_rho[0, t] # (M,)
if method == "expected":
delta = _mixture_expectation(pi_t, mu_t, temperature=temperature)
else:
rng_key, step_key = jax.random.split(rng_key)
delta = _mixture_sample(pi_t, mu_t, logs_t, rho_t, step_key, temperature=temperature)
out.append(delta)
if t + 1 < steps:
prev = prev.at[t + 1].set(delta)
return jnp.stack(out, axis=0) # (steps, 2)
# ---------- End-to-end: sample z, push toward class, decode ----------
def class_conditioned_random_curve(model: StarSketchVAE,
class_id: int,
*,
steps: int = 100,
z_scale: float = 1.0,
alpha: float = 3.0,
normalize_dir: bool = True,
margin: bool = False,
method: str = "expected", # "expected" or "sample"
temperature: float = 0.9,
rng_key: jax.Array | None = None) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
if rng_key is None:
rng_key = jax.random.PRNGKey(42)
z_key, dec_key = jax.random.split(rng_key)
# sample z ~ N(0, I)
Z = model.cfg.latent_size
z = jax.random.normal(z_key, (Z,)) * z_scale # (Z,)
# push toward class direction
z_star = tweak_z_towards_class(z, model, class_id, alpha=alpha,
normalize=normalize_dir, margin=margin)
deltas = decode_from_z(model, z_star, steps=steps, method=method,
temperature=temperature, rng_key=dec_key)
dt = jnp.maximum(deltas[:, 1], 1e-4)
deltas = deltas.at[:, 1].set(dt)
# Convert to cumulative x(t)
x = jnp.cumsum(deltas[:, 0])
t = jnp.cumsum(deltas[:, 1])
return deltas, x, t
if __name__ == "__main__":
rngs = nnx.Rngs(0)
cfg = ModelConfig(
seq_len=100,
data_dim=2,
enc_hidden=512,
dec_hidden=2048,
latent_size=128,
n_mixtures=20,
num_classes=5,
)
model = StarSketchVAE(cfg, rngs=rngs)
B, T, D = 6, 100, 2
batch = jnp.ones((B, T, D))
# class_ids: all samples labeled (no unlabeled)
class_ids = jnp.array([0, 1, 2, 3, 4, 0])
key = jax.random.PRNGKey(42)
loss, metrics = model_loss(
model, batch, class_ids=class_ids, w_kl=1.0, w_sup=1.0, rng_key=key
)
print({k: float(v) for k, v in metrics.items()})
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "864eae34-0cc9-484f-ae98-7615e8e7c2bd",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b6f06154-72d9-49cb-87c3-761395b08b3a",
"metadata": {},
"outputs": [],
"source": [
"from star_sketch_rnn import *"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "183f7b53-0541-461d-9e43-6fa430f3a0e2",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"from flax import nnx"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1c983116-8feb-4e3c-a267-d1328be8bc33",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_parquet(\"fit_stars.parquet\").sample(frac=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "18f3b22c-a56e-4ebc-b6c4-c9e65d2930dc",
"metadata": {},
"outputs": [],
"source": [
"def to_diff(arr):\n",
" return np.diff(arr, axis=1, prepend=np.zeros((arr.shape[0], 1), dtype=arr.dtype)) + arr[:, :1]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a5916a08-ca99-47c5-9599-a6d35fca1bd5",
"metadata": {},
"outputs": [],
"source": [
"mag = to_diff(np.stack([x for x in df[\"y_1P\"].to_numpy()])) "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "978ac2f1-fa50-42ff-b7d4-694e8883fe73",
"metadata": {},
"outputs": [],
"source": [
"filt = ~(mag >= 0.01).any(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4d91d68e-f24e-443a-8e21-5c57237f50f6",
"metadata": {},
"outputs": [],
"source": [
"time = to_diff(np.stack([x for x in df[\"t_days_1P\"].to_numpy()]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f5c3d6cd-41fc-4890-ac51-6cba7cbbdd1c",
"metadata": {},
"outputs": [],
"source": [
"types = df[\"type\"].unique().tolist()\n",
"classes = df[\"type\"].apply(lambda x: types.index(x)).to_numpy()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "374dbe63-ff75-4afb-a9d1-5213498e311b",
"metadata": {},
"outputs": [],
"source": [
"train = (mag[filt][:65_000], time[filt][:65_000], classes[filt][:65_000])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e6bd720b-59c4-40ec-8096-79d913885c2a",
"metadata": {},
"outputs": [],
"source": [
"test = (mag[filt][65_000:], time[filt][65_000:], classes[filt][65_000:])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "49b46bd4-8eda-4ff2-98d5-9b6b0e4f22ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(types)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "18661700-d536-490f-98db-006c8c42794b",
"metadata": {},
"outputs": [],
"source": [
"rngs = nnx.Rngs(0)\n",
"cfg = ModelConfig(\n",
" seq_len=100,\n",
" data_dim=2,\n",
" enc_hidden=512,\n",
" dec_hidden=2048,\n",
" latent_size=128,\n",
" n_mixtures=20,\n",
" num_classes=12,\n",
")\n",
"model = StarSketchVAE(cfg, rngs=rngs)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5ac5fc7e-ef52-4418-bb31-9d65b250dd02",
"metadata": {},
"outputs": [],
"source": [
"lr = 3e-4\n",
"tx = optax.chain(\n",
" optax.clip_by_global_norm(1.0),\n",
" optax.adam(lr),\n",
")\n",
"\n",
"opt = nnx.Optimizer(model, tx)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "fe26ed0d-93d2-4229-8a47-8b61d70cfe03",
"metadata": {},
"outputs": [],
"source": [
"@nnx.jit\n",
"def train_step(model: nnx.Module,\n",
" opt: nnx.Optimizer,\n",
" batch_xy: jax.Array, # (B, T=100, 2) -> (Δx, Δt)\n",
" class_ids: jax.Array, # (B,)\n",
" kl_weight: float,\n",
" sup_weight: float,\n",
" key: jax.Array):\n",
" \"\"\"Returns (updated_model, updated_opt, metrics).\"\"\"\n",
"\n",
" def loss_fn(m):\n",
" loss, metrics = model_loss(\n",
" m,\n",
" batch_xy,\n",
" class_ids=class_ids,\n",
" w_kl=kl_weight,\n",
" w_sup=sup_weight,\n",
" rng_key=key,\n",
" )\n",
" return loss, metrics\n",
"\n",
" (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)\n",
" opt.update(grads)\n",
" #opt.apply_updates(model)\n",
" return model, opt, metrics"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "469358a6-9c2b-4f52-8cc7-9f8970f1479a",
"metadata": {},
"outputs": [],
"source": [
"@nnx.jit\n",
"def eval_step(model: nnx.Module,\n",
" batch_xy: jax.Array,\n",
" class_ids: jax.Array,\n",
" key: jax.Array):\n",
" loss, metrics = model_loss(\n",
" model,\n",
" batch_xy,\n",
" class_ids=class_ids,\n",
" w_kl=1.0,\n",
" w_sup=1.0,\n",
" rng_key=key,\n",
" )\n",
" return metrics"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "9d58ed35-58f6-4744-b108-8cd3202a80f9",
"metadata": {},
"outputs": [],
"source": [
"def iterate_minibatches(X_mag, X_time, y_cls, batch_size, shuffle=True):\n",
" N = X_mag.shape[0]\n",
" idx = np.arange(N)\n",
" if shuffle:\n",
" np.random.shuffle(idx)\n",
" for start in range(0, N, batch_size):\n",
" end = min(start + batch_size, N)\n",
" sel = idx[start:end]\n",
" batch = np.stack((X_mag[sel], X_time[sel]), axis=-1) # (B,T,2) in ORIGINAL units\n",
" #batch = batch - batch.mean(axis=1, keepdims=True)\n",
" yield batch, y_cls[sel]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "c949700f-7e08-4a98-8f8e-bf85f514b878",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 0 | train_loss=46.5021 recon=1.9357 kl=0.1865 sup_ce=2.9643 acc=0.141 | eval_loss=37.5979 eval_acc=0.117\n",
"step 100 | train_loss=1.2059 recon=-1.5797 kl=1.2156 sup_ce=2.1404 acc=0.312 | eval_loss=3.5798 eval_acc=0.234\n",
"step 200 | train_loss=-0.2314 recon=-2.2060 kl=0.6651 sup_ce=2.1743 acc=0.289 | eval_loss=1.2422 eval_acc=0.289\n",
"step 300 | train_loss=-1.2403 recon=-3.0053 kl=0.5497 sup_ce=2.0626 acc=0.312 | eval_loss=0.3983 eval_acc=0.328\n",
"step 400 | train_loss=-1.6501 recon=-3.4225 kl=0.5933 sup_ce=2.0357 acc=0.359 | eval_loss=-0.0955 eval_acc=0.352\n",
"step 500 | train_loss=-2.2685 recon=-3.8309 kl=0.4807 sup_ce=1.8648 acc=0.336 | eval_loss=-0.5499 eval_acc=0.320\n",
"epoch 1 | train_loss=-2.2131 recon=-3.9231 kl=0.4868 sup_ce=1.9915 acc=0.356 | eval_loss=-0.7181 eval_acc=0.336\n",
"step 0 | train_loss=-1.7936 recon=-3.7092 kl=0.5414 sup_ce=1.9363 acc=0.383 | eval_loss=-0.6880 eval_acc=0.336\n",
"step 100 | train_loss=-2.3393 recon=-3.9952 kl=0.3958 sup_ce=1.8324 acc=0.461 | eval_loss=-0.8823 eval_acc=0.344\n",
"step 200 | train_loss=-2.2735 recon=-4.0245 kl=0.4805 sup_ce=1.7829 acc=0.383 | eval_loss=-1.3115 eval_acc=0.406\n",
"step 300 | train_loss=-2.5875 recon=-4.1634 kl=0.4761 sup_ce=1.6604 acc=0.422 | eval_loss=-1.3304 eval_acc=0.336\n",
"step 400 | train_loss=-2.4696 recon=-4.0756 kl=0.3896 sup_ce=1.8260 acc=0.383 | eval_loss=-1.5577 eval_acc=0.359\n",
"step 500 | train_loss=-2.7015 recon=-4.2522 kl=0.3790 sup_ce=1.7719 acc=0.367 | eval_loss=-1.7647 eval_acc=0.438\n",
"epoch 2 | train_loss=-2.9328 recon=-4.4264 kl=0.3949 sup_ce=1.6880 acc=0.394 | eval_loss=-1.6857 eval_acc=0.367\n",
"step 0 | train_loss=-2.5810 recon=-4.3915 kl=0.3828 sup_ce=1.8710 acc=0.398 | eval_loss=-1.6001 eval_acc=0.391\n",
"step 100 | train_loss=-2.5038 recon=-4.1318 kl=0.3346 sup_ce=1.6240 acc=0.422 | eval_loss=-1.7121 eval_acc=0.344\n",
"step 200 | train_loss=-2.9047 recon=-4.5835 kl=0.3615 sup_ce=1.7179 acc=0.430 | eval_loss=-1.9314 eval_acc=0.375\n",
"step 300 | train_loss=-2.8661 recon=-4.4328 kl=0.3773 sup_ce=1.5870 acc=0.406 | eval_loss=-2.1180 eval_acc=0.328\n",
"step 400 | train_loss=-3.0018 recon=-4.6010 kl=0.2515 sup_ce=1.7208 acc=0.438 | eval_loss=-2.2898 eval_acc=0.328\n",
"step 500 | train_loss=-3.1853 recon=-4.7219 kl=0.2917 sup_ce=1.6104 acc=0.430 | eval_loss=-2.5992 eval_acc=0.492\n",
"epoch 3 | train_loss=-3.1608 recon=-4.7712 kl=0.3150 sup_ce=1.6604 acc=0.452 | eval_loss=-2.5221 eval_acc=0.406\n",
"step 0 | train_loss=-2.9159 recon=-4.7732 kl=0.3274 sup_ce=1.7308 acc=0.438 | eval_loss=-2.5389 eval_acc=0.445\n",
"step 100 | train_loss=-2.9882 recon=-4.7597 kl=0.2422 sup_ce=1.7119 acc=0.398 | eval_loss=-2.6269 eval_acc=0.430\n",
"step 200 | train_loss=-2.9619 recon=-4.6602 kl=0.2416 sup_ce=1.6127 acc=0.469 | eval_loss=-2.6653 eval_acc=0.359\n",
"step 300 | train_loss=-3.1689 recon=-4.7695 kl=0.2065 sup_ce=1.5602 acc=0.477 | eval_loss=-2.8632 eval_acc=0.492\n",
"step 400 | train_loss=-3.1341 recon=-4.8682 kl=0.2233 sup_ce=1.6918 acc=0.414 | eval_loss=-3.0476 eval_acc=0.477\n",
"step 500 | train_loss=-3.5936 recon=-5.1131 kl=0.1969 sup_ce=1.4729 acc=0.500 | eval_loss=-3.1354 eval_acc=0.414\n",
"epoch 4 | train_loss=-3.2947 recon=-4.9280 kl=0.1800 sup_ce=1.6238 acc=0.471 | eval_loss=-2.8099 eval_acc=0.430\n",
"step 0 | train_loss=-2.7201 recon=-4.4851 kl=0.2572 sup_ce=1.5305 acc=0.469 | eval_loss=-3.0156 eval_acc=0.430\n",
"step 100 | train_loss=-3.2445 recon=-4.9394 kl=0.1635 sup_ce=1.5447 acc=0.469 | eval_loss=-3.0728 eval_acc=0.430\n",
"step 200 | train_loss=-3.3980 recon=-5.0827 kl=0.1805 sup_ce=1.4927 acc=0.484 | eval_loss=-3.2222 eval_acc=0.383\n",
"step 300 | train_loss=-3.4169 recon=-5.0997 kl=0.1672 sup_ce=1.4804 acc=0.453 | eval_loss=-3.1448 eval_acc=0.391\n",
"step 400 | train_loss=-3.8233 recon=-5.3426 kl=0.1276 sup_ce=1.4011 acc=0.469 | eval_loss=-3.3062 eval_acc=0.414\n",
"step 500 | train_loss=-3.3796 recon=-4.9774 kl=0.1332 sup_ce=1.4867 acc=0.484 | eval_loss=-3.3846 eval_acc=0.461\n",
"epoch 5 | train_loss=-3.2788 recon=-4.8331 kl=0.1072 sup_ce=1.4674 acc=0.529 | eval_loss=-3.3734 eval_acc=0.461\n",
"step 0 | train_loss=-3.2406 recon=-4.8606 kl=0.1120 sup_ce=1.5181 acc=0.406 | eval_loss=-3.3314 eval_acc=0.414\n",
"step 100 | train_loss=-3.8496 recon=-5.3313 kl=0.1010 sup_ce=1.3500 acc=0.516 | eval_loss=-3.4509 eval_acc=0.414\n",
"step 200 | train_loss=-3.5082 recon=-5.2396 kl=0.0783 sup_ce=1.6130 acc=0.414 | eval_loss=-3.6469 eval_acc=0.445\n",
"step 300 | train_loss=-3.6952 recon=-5.2373 kl=0.0657 sup_ce=1.4466 acc=0.492 | eval_loss=-3.5869 eval_acc=0.430\n",
"step 400 | train_loss=-3.8266 recon=-5.3385 kl=0.0632 sup_ce=1.4361 acc=0.484 | eval_loss=-3.7103 eval_acc=0.445\n",
"step 500 | train_loss=-3.8763 recon=-5.4211 kl=0.0506 sup_ce=1.4706 acc=0.453 | eval_loss=-3.7715 eval_acc=0.461\n",
"epoch 6 | train_loss=-3.7528 recon=-5.3057 kl=0.0573 sup_ce=1.4602 acc=0.433 | eval_loss=-3.6639 eval_acc=0.398\n",
"step 0 | train_loss=-4.0338 recon=-5.5182 kl=0.0532 sup_ce=1.3976 acc=0.508 | eval_loss=-3.7643 eval_acc=0.469\n",
"step 100 | train_loss=-4.1112 recon=-5.5335 kl=0.0397 sup_ce=1.3796 acc=0.477 | eval_loss=-3.7793 eval_acc=0.414\n",
"step 200 | train_loss=-3.8078 recon=-5.2511 kl=0.0304 sup_ce=1.4111 acc=0.500 | eval_loss=-3.9057 eval_acc=0.453\n",
"step 300 | train_loss=-3.9710 recon=-5.3875 kl=0.0232 sup_ce=1.3870 acc=0.508 | eval_loss=-3.9297 eval_acc=0.461\n",
"step 400 | train_loss=-3.8938 recon=-5.4326 kl=0.0229 sup_ce=1.4653 acc=0.484 | eval_loss=-3.8819 eval_acc=0.422\n",
"step 500 | train_loss=-4.1468 recon=-5.5036 kl=0.0211 sup_ce=1.3285 acc=0.547 | eval_loss=-4.0318 eval_acc=0.453\n",
"epoch 7 | train_loss=-4.2213 recon=-5.6256 kl=0.0188 sup_ce=1.3808 acc=0.433 | eval_loss=-4.1490 eval_acc=0.539\n",
"step 0 | train_loss=-4.2351 recon=-5.6195 kl=0.0178 sup_ce=1.3652 acc=0.484 | eval_loss=-3.9614 eval_acc=0.430\n",
"step 100 | train_loss=-4.0468 recon=-5.4817 kl=0.0172 sup_ce=1.3736 acc=0.523 | eval_loss=-4.0817 eval_acc=0.438\n",
"step 200 | train_loss=-4.4394 recon=-5.7968 kl=0.0157 sup_ce=1.3208 acc=0.578 | eval_loss=-4.1147 eval_acc=0.414\n",
"step 300 | train_loss=-4.3627 recon=-5.7794 kl=0.0176 sup_ce=1.3668 acc=0.461 | eval_loss=-4.2116 eval_acc=0.500\n",
"step 400 | train_loss=-4.0753 recon=-5.4199 kl=0.0134 sup_ce=1.3217 acc=0.508 | eval_loss=-4.1605 eval_acc=0.438\n",
"step 500 | train_loss=-4.3648 recon=-5.7506 kl=0.0159 sup_ce=1.3569 acc=0.453 | eval_loss=-4.2464 eval_acc=0.508\n",
"epoch 8 | train_loss=-4.4228 recon=-5.8212 kl=0.0135 sup_ce=1.3747 acc=0.510 | eval_loss=-4.1181 eval_acc=0.500\n",
"step 0 | train_loss=-4.2178 recon=-5.4982 kl=0.0133 sup_ce=1.2396 acc=0.594 | eval_loss=-4.1389 eval_acc=0.469\n",
"step 100 | train_loss=-4.4617 recon=-5.8206 kl=0.0100 sup_ce=1.3203 acc=0.547 | eval_loss=-4.2829 eval_acc=0.430\n",
"step 200 | train_loss=-4.2466 recon=-5.5921 kl=0.0101 sup_ce=1.3079 acc=0.516 | eval_loss=-4.3115 eval_acc=0.484\n",
"step 300 | train_loss=-4.4373 recon=-5.7913 kl=0.0083 sup_ce=1.3420 acc=0.516 | eval_loss=-4.2483 eval_acc=0.469\n",
"step 400 | train_loss=-4.2794 recon=-5.6517 kl=0.0092 sup_ce=1.3232 acc=0.562 | eval_loss=-4.1492 eval_acc=0.492\n",
"step 500 | train_loss=-4.5504 recon=-5.8806 kl=0.0054 sup_ce=1.2962 acc=0.586 | eval_loss=-4.4070 eval_acc=0.461\n",
"epoch 9 | train_loss=-4.5222 recon=-5.8240 kl=0.0076 sup_ce=1.2841 acc=0.519 | eval_loss=-4.4066 eval_acc=0.461\n",
"step 0 | train_loss=-4.3504 recon=-5.7417 kl=0.0077 sup_ce=1.3783 acc=0.531 | eval_loss=-4.3692 eval_acc=0.469\n",
"step 100 | train_loss=-4.3907 recon=-5.7698 kl=0.0068 sup_ce=1.3325 acc=0.539 | eval_loss=-4.3803 eval_acc=0.477\n",
"step 200 | train_loss=-4.5646 recon=-5.9375 kl=0.0048 sup_ce=1.3590 acc=0.492 | eval_loss=-4.4658 eval_acc=0.469\n",
"step 300 | train_loss=-4.5880 recon=-5.9255 kl=0.0047 sup_ce=1.2945 acc=0.508 | eval_loss=-4.3337 eval_acc=0.461\n",
"step 400 | train_loss=-4.7077 recon=-5.9371 kl=0.0052 sup_ce=1.2071 acc=0.547 | eval_loss=-4.5713 eval_acc=0.461\n",
"step 500 | train_loss=-4.7069 recon=-5.9930 kl=0.0041 sup_ce=1.2604 acc=0.562 | eval_loss=-4.4489 eval_acc=0.469\n",
"epoch 10 | train_loss=-4.5853 recon=-5.8341 kl=0.0040 sup_ce=1.2254 acc=0.490 | eval_loss=-4.2734 eval_acc=0.445\n"
]
}
],
"source": [
"epochs = 10\n",
"batch_size = 128\n",
"kl_weight = 1.0\n",
"sup_weight = 1.0\n",
"\n",
"rng = jax.random.key(0)\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" N = train[0].shape[0]/batch_size\n",
" # ---- Train\n",
" for i, dat in enumerate(iterate_minibatches(\n",
" train[0], train[1], train[2], batch_size, shuffle=True\n",
" )):\n",
" batch_xy_np, class_ids_np = dat\n",
" \n",
" mult_ce = min(1.0, 0.5+1.0*(epoch/epochs))\n",
" mult_kl = min(1.0, 0.3+1.0*(epoch/epochs))\n",
" rng, sub = jax.random.split(rng)\n",
" batch_xy = jnp.asarray(batch_xy_np)\n",
" class_ids = jnp.asarray(class_ids_np)\n",
" model, opt, metrics = train_step(\n",
" model, opt, batch_xy, class_ids, mult_kl*kl_weight, mult_ce*sup_weight, sub\n",
" )\n",
" if i % 100 == 0:\n",
" eval_batch = np.stack((test[0][:batch_size], test[1][:batch_size]), axis=-1)\n",
" eval_classes = test[2][:batch_size]\n",
" rng, sub = jax.random.split(rng)\n",
" eval_metrics = eval_step(\n",
" model, jnp.asarray(eval_batch), jnp.asarray(eval_classes), sub\n",
" )\n",
" print(f\"step {i} | \"\n",
" f\"train_loss={float(metrics['loss']):.4f} \"\n",
" f\"recon={float(metrics['recon']):.4f} \"\n",
" f\"kl={float(metrics['kl']):.4f} \"\n",
" f\"sup_ce={float(metrics.get('sup_ce', 0.0)):.4f} \"\n",
" f\"acc={float(metrics.get('acc', 0.0)):.3f} | \"\n",
" f\"eval_loss={float(eval_metrics['loss']):.4f} \"\n",
" f\"eval_acc={float(eval_metrics.get('acc', 0.0)):.3f}\")\n",
"\n",
" eval_batch = np.stack((test[0][:batch_size], test[1][:batch_size]), axis=-1)\n",
"\n",
" eval_classes = test[2][:batch_size]\n",
" rng, sub = jax.random.split(rng)\n",
" eval_metrics = eval_step(\n",
" model, jnp.asarray(eval_batch), jnp.asarray(eval_classes), sub\n",
" )\n",
"\n",
" print(f\"epoch {epoch} | \"\n",
" f\"train_loss={float(metrics['loss']):.4f} \"\n",
" f\"recon={float(metrics['recon']):.4f} \"\n",
" f\"kl={float(metrics['kl']):.4f} \"\n",
" f\"sup_ce={float(metrics.get('sup_ce', 0.0)):.4f} \"\n",
" f\"acc={float(metrics.get('acc', 0.0)):.3f} | \"\n",
" f\"eval_loss={float(eval_metrics['loss']):.4f} \"\n",
" f\"eval_acc={float(eval_metrics.get('acc', 0.0)):.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "1b3d08c6-13fe-402e-b98c-e9364b9d9d43",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def deltas_to_curve(deltas: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:\n",
" \"\"\"Convert (Δx, Δt) of shape (T, 2) to cumulative (x(t), t).\n",
" Returns (x, t) each shape (T,).\n",
" \"\"\"\n",
" dx = deltas[:, 0]\n",
" dt = deltas[:, 1]\n",
" t = jnp.cumsum(dt)\n",
" x = jnp.cumsum(dx)\n",
" return x, t\n",
"\n",
"\n",
"def reconstruct_deltas(model: StarSketchVAE,\n",
" seq: jnp.ndarray,\n",
" *,\n",
" method: str = \"expected\",\n",
" temperature: float = 0.8,\n",
" rng_key: jax.Array | None = None) -> jnp.ndarray:\n",
" \"\"\"Teacher-forced reconstruction of a single sequence (T,2).\n",
" \"\"\"\n",
" assert seq.ndim == 2 and seq.shape[-1] == 2\n",
" B = 1\n",
" T = seq.shape[0]\n",
" batch = seq[None, ...]\n",
" mu, log_sigma, enc_feats = model.encode(batch)\n",
" z = mu # use mean for deterministic reconstruction\n",
"\n",
" pad0 = jnp.zeros((B, 1, 2), dtype=batch.dtype)\n",
" prev_xy = jnp.concatenate([pad0, batch[:, :-1, :]], axis=1)\n",
" pi_logits, gmm_mu, gmm_logsig, gmm_rho = model.decode(prev_xy, z)\n",
"\n",
" if method == \"expected\":\n",
" # Expected value of mixture: sum_k softmax(pi) * mu_k\n",
" pi = jax.nn.softmax(pi_logits, axis=-1) # (1, T, M)\n",
" ex = jnp.sum(pi[..., None] * gmm_mu, axis=-2) # (1, T, 2)\n",
" return ex[0]\n",
" elif method == \"sample\":\n",
" if rng_key is None:\n",
" rng_key = jax.random.PRNGKey(0)\n",
" return sample_sequence(model, z, steps=T, temperature=temperature, rng_key=rng_key)[0]\n",
"\n",
"def plot_input_and_recon(seq: jnp.ndarray,\n",
" recon: jnp.ndarray,\n",
" *,\n",
" title_input: str = \"Input lightcurve\",\n",
" title_recon: str = \"Reconstruction\") -> None:\n",
" \"\"\"Plot input and reconstructed lightcurves as two separate figures.\n",
" \"\"\"\n",
" x_in, t_in = deltas_to_curve(seq)\n",
" plt.figure()\n",
" plt.plot(t_in, x_in)\n",
" plt.xlabel(\"t\"); plt.ylabel(\"x\"); plt.title(title_input)\n",
"\n",
" x_rec, t_rec = deltas_to_curve(recon)\n",
" plt.figure()\n",
" plt.plot(t_in, x_rec)\n",
" plt.xlabel(\"t\"); plt.ylabel(\"x\"); plt.title(title_recon)\n",
"\n",
"\n",
"def compare_lightcurves(seq: jnp.ndarray,\n",
" recon: jnp.ndarray,\n",
" *,\n",
" title: str = \"Input vs Reconstruction\") -> None:\n",
" x_in, t_in = deltas_to_curve(seq)\n",
" x_rec, t_rec = deltas_to_curve(recon)\n",
" plt.figure()\n",
" plt.plot(t_in, x_in, label=\"input\")\n",
" plt.plot(t_in, x_rec, label=\"recon\")\n",
" plt.xlabel(\"t\"); plt.ylabel(\"x\"); plt.title(title)\n",
" plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "55233b2f-32e3-4f88-b9b9-cc1ef0640e24",
"metadata": {},
"outputs": [],
"source": [
"test_xy = np.stack((test[0], test[1]), axis=-1)\n",
"test_classes = test[2]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "5b158cd5-687a-4ffc-857d-9fbbe68ff6f4",
"metadata": {},
"outputs": [],
"source": [
"seq = test_xy[500]"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "f5454941-29cd-47dd-9fc0-1280e793388f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4288, 100, 2)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_xy.shape"
]
},
{
"cell_type": "markdown",
"id": "2b29f2b4-ee68-4939-92a1-9de519cd70c3",
"metadata": {},
"source": [
"### We learn the rough shape/inflection points of lightcurves--but have drift issues (maybe switching to non-importance weighted sampling of the input data would fix or different regularizers)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "17bb8471-4640-41e0-b71c-21913ddc5a2e",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"recon = reconstruct_deltas(model, seq, method=\"expected\")\n",
"compare_lightcurves(test_xy[500], recon)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "15e13353-e166-4fa2-b806-a0b5437034be",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"recon = reconstruct_deltas(model, seq, method=\"expected\")\n",
"compare_lightcurves(test_xy[1490], recon)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "9564e43a-5d09-4d1a-a2bc-f0dc2ea7e665",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"recon = reconstruct_deltas(model, seq, method=\"expected\")\n",
"compare_lightcurves(test_xy[2015], recon)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "6d680d49-8f0e-4ee2-9cdd-c88eb60869c9",
"metadata": {},
"outputs": [],
"source": [
"from star_sketch_rnn import *\n",
"def plot_class_conditioned_curve(model: StarSketchVAE,\n",
" class_id: int,\n",
" *,\n",
" steps: int = 100,\n",
" method: str = \"expected\",\n",
" temperature: float = 0.9,\n",
" alpha: float = 3.0,\n",
" normalize_dir: bool = True,\n",
" margin: bool = False,\n",
" rng_key: jax.Array | None = None,\n",
" title: str | None = None):\n",
" import matplotlib.pyplot as plt\n",
" deltas, x, t = class_conditioned_random_curve(\n",
" model, class_id, steps=steps, method=method, temperature=temperature,\n",
" alpha=alpha, normalize_dir=normalize_dir, margin=margin, rng_key=rng_key\n",
" )\n",
" if title is None:\n",
" title = f\"class {class_id} | method={method} temp={temperature} alpha={alpha}\"\n",
" plt.figure()\n",
" plt.plot(t, x)\n",
" plt.xlabel(\"t\"); plt.ylabel(\"x\"); plt.title(title)\n",
" return deltas, x, t"
]
},
{
"cell_type": "markdown",
"id": "d3adaaa3-b962-4686-9e83-0171c57e304a",
"metadata": {},
"source": [
"### classification head gives easy class based sampling with clearly different shapes for different star types (albeit not very realistic light curves"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "b663fa10-1600-4a01-b54d-492e5b3abdc4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(0, 'GCAS'),\n",
" (1, 'ROT'),\n",
" (2, 'EW'),\n",
" (3, 'DSCT'),\n",
" (4, 'EA'),\n",
" (5, 'EB'),\n",
" (6, 'UV'),\n",
" (7, 'YSO'),\n",
" (8, 'RRab'),\n",
" (9, 'HADS'),\n",
" (10, 'Cepheids'),\n",
" (11, 'RRcd')]"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(enumerate(types))"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "2956ff13-a4b1-4816-820e-71102cbdb209",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"type\n",
"ROT 36716\n",
"GCAS 23635\n",
"DSCT 4885\n",
"EA 3937\n",
"EW 1301\n",
"UV 477\n",
"EB 403\n",
"RRab 332\n",
"YSO 277\n",
"HADS 184\n",
"Cepheids 184\n",
"RRcd 139\n",
"Name: count, dtype: Int64"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[\"type\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "7516a66d-0c7c-4d3a-8179-10220e9699c5",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_ = plot_class_conditioned_curve(model, class_id=1, steps=100,\n",
" method=\"sample\", alpha=2.0, temperature=0.8)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "f4984477-a2b2-4457-b035-090d7c25c98a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_ = plot_class_conditioned_curve(model, class_id=0, steps=100,\n",
" method=\"sample\", alpha=2.0, temperature=0.8)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment