Created
October 6, 2025 19:42
-
-
Save muchanem/1ba8a4563596f4557c5ae1b8062e8bcb to your computer and use it in GitHub Desktop.
StarSketchVAE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import Tuple | |
| import jax | |
| import jax.numpy as jnp | |
| from flax import nnx | |
| from dataclasses import dataclass | |
| from typing import Tuple | |
| import jax | |
| import jax.numpy as jnp | |
| from flax import nnx | |
| LOG_SIG_MIN = -5.0 # sigma >= e^-5 ≈ 0.0067 | |
| LOG_SIG_MAX = 2.0 # sigma <= e^2 ≈ 7.39 | |
| RHO_EPS = 1e-3 | |
| # ============================================================= | |
| # HyperLSTMCell | |
| # ============================================================= | |
| class HyperLSTMCell(nnx.RNNCellBase): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_features: int, | |
| *, | |
| hyper_num_units: int = 256, | |
| hyper_embedding_size: int = 32, | |
| forget_bias: float = 1.0, | |
| gate_fn=jax.nn.sigmoid, | |
| activation_fn=jnp.tanh, | |
| rngs: nnx.Rngs, | |
| ): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.hidden_features = hidden_features | |
| self.hyper_num_units = hyper_num_units | |
| self.hyper_embedding_size = hyper_embedding_size | |
| self.forget_bias = forget_bias | |
| self.gate_fn = gate_fn | |
| self.activation_fn = activation_fn | |
| # Base projections (input has no bias; hidden has bias), like your LSTMCell | |
| self.ii = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs) | |
| self.if_ = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs) | |
| self.ig = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs) | |
| self.io = nnx.Linear(in_features, hidden_features, use_bias=False, rngs=rngs) | |
| self.hi = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs) | |
| self.hf = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs) | |
| self.hg = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs) | |
| self.ho = nnx.Linear(hidden_features, hidden_features, use_bias=True, rngs=rngs) | |
| # Hyper LSTM that processes [x, h] | |
| self.hyper_cell = nnx.LSTMCell( | |
| in_features + hidden_features, hyper_num_units, rngs=rngs | |
| ) | |
| # Hyper heads: zw -> alpha (+ beta). Separate heads per contribution | |
| def head_pair(): | |
| zw = nnx.Linear( | |
| hyper_num_units, hyper_embedding_size, use_bias=True, rngs=rngs | |
| ) | |
| alpha = nnx.Linear( | |
| hyper_embedding_size, hidden_features, use_bias=False, rngs=rngs | |
| ) | |
| beta = nnx.Linear( | |
| hyper_embedding_size, hidden_features, use_bias=False, rngs=rngs | |
| ) | |
| return zw, alpha, beta | |
| self.zw_ix, self.alpha_ix, self.beta_ix = head_pair() | |
| self.zw_jx, self.alpha_jx, self.beta_jx = head_pair() | |
| self.zw_fx, self.alpha_fx, self.beta_fx = head_pair() | |
| self.zw_ox, self.alpha_ox, self.beta_ox = head_pair() | |
| self.zw_ih, self.alpha_ih, self.beta_ih = head_pair() | |
| self.zw_jh, self.alpha_jh, self.beta_jh = head_pair() | |
| self.zw_fh, self.alpha_fh, self.beta_fh = head_pair() | |
| self.zw_oh, self.alpha_oh, self.beta_oh = head_pair() | |
| def initialize_carry(self, input_shape, rngs: nnx.Rngs): | |
| batch_dims = input_shape[:-1] | |
| total_units = self.hidden_features + self.hyper_num_units | |
| mem_shape = batch_dims + (total_units,) | |
| c = jnp.zeros(mem_shape, dtype=jnp.float32) | |
| h = jnp.zeros(mem_shape, dtype=jnp.float32) | |
| return (c, h) | |
| def _hyper_affine( | |
| self, which: str, hyper_h: jnp.ndarray, preact: jnp.ndarray, use_bias: bool | |
| ) -> jnp.ndarray: | |
| zw = getattr(self, f"zw_{which}")(hyper_h) | |
| alpha = getattr(self, f"alpha_{which}")(zw) | |
| out = alpha * preact | |
| if use_bias: | |
| beta = getattr(self, f"beta_{which}")(zw) | |
| out = out + beta | |
| return out | |
| def __call__(self, carry, inputs): | |
| total_c, total_h = carry | |
| c_main = total_c[..., : self.hidden_features] | |
| h_main = total_h[..., : self.hidden_features] | |
| c_hyp = total_c[..., self.hidden_features :] | |
| h_hyp = total_h[..., self.hidden_features :] | |
| # Hyper step | |
| hyper_in = jnp.concatenate([inputs, h_main], axis=-1) | |
| (new_c_hyp, new_h_hyp), _ = self.hyper_cell((c_hyp, h_hyp), hyper_in) | |
| # Base preactivations | |
| ix = self.ii(inputs) | |
| ih = self.hi(h_main) | |
| jx = self.ig(inputs) | |
| jh = self.hg(h_main) | |
| fx = self.if_(inputs) | |
| fh = self.hf(h_main) | |
| ox = self.io(inputs) | |
| oh = self.ho(h_main) | |
| # Modulate | |
| ix = self._hyper_affine("ix", new_h_hyp, ix, use_bias=False) | |
| jx = self._hyper_affine("jx", new_h_hyp, jx, use_bias=False) | |
| fx = self._hyper_affine("fx", new_h_hyp, fx, use_bias=False) | |
| ox = self._hyper_affine("ox", new_h_hyp, ox, use_bias=False) | |
| ih = self._hyper_affine("ih", new_h_hyp, ih, use_bias=True) | |
| jh = self._hyper_affine("jh", new_h_hyp, jh, use_bias=True) | |
| fh = self._hyper_affine("fh", new_h_hyp, fh, use_bias=True) | |
| oh = self._hyper_affine("oh", new_h_hyp, oh, use_bias=True) | |
| # Gates and state update | |
| i = self.gate_fn(ix + ih) | |
| f = self.gate_fn(fx + fh + self.forget_bias) | |
| g = self.activation_fn(jx + jh) | |
| o = self.gate_fn(ox + oh) | |
| new_c_main = f * c_main + i * g | |
| new_h_main = o * self.activation_fn(new_c_main) | |
| # Full state | |
| new_total_c = jnp.concatenate([new_c_main, new_c_hyp], axis=-1) | |
| new_total_h = jnp.concatenate([new_h_main, new_h_hyp], axis=-1) | |
| new_carry = (new_total_c, new_total_h) | |
| return new_carry, new_h_main | |
| @property | |
| def num_feature_axes(self) -> int: | |
| """Returns the number of feature axes of the RNN cell.""" | |
| return 1 | |
| class BidirectionalLSTMEncoder(nnx.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_size: int = 512, | |
| latent_size: int = 128, | |
| *, | |
| rngs: nnx.Rngs, | |
| ): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.fwd = nnx.RNN( | |
| nnx.LSTMCell( | |
| in_features=in_features, hidden_features=hidden_size, rngs=rngs | |
| ) | |
| ) | |
| self.bwd = nnx.RNN( | |
| nnx.LSTMCell( | |
| in_features=in_features, hidden_features=hidden_size, rngs=rngs | |
| ) | |
| ) | |
| self.fc_mu = nnx.Linear(hidden_size * 2, latent_size, rngs=rngs) | |
| self.fc_logsig = nnx.Linear(hidden_size * 2, latent_size, rngs=rngs) | |
| def __call__(self, x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]: | |
| h_fwd = self.fwd(x) # (B, T, H) | |
| h_bwd = self.bwd(jnp.flip(x, axis=1)) # (B, T, H) | |
| h_final = jnp.concatenate([h_fwd[:, -1], h_bwd[:, -1]], axis=-1) # (B, 2H) | |
| mu = self.fc_mu(h_final) | |
| log_sigma = self.fc_logsig(h_final) | |
| return mu, log_sigma, h_final | |
| class GMMDecoder(nnx.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_size: int = 2048, | |
| n_mixtures: int = 20, | |
| *, | |
| rngs: nnx.Rngs, | |
| ): | |
| super().__init__() | |
| self.rnn = nnx.RNN( | |
| HyperLSTMCell( | |
| in_features=in_features, hidden_features=hidden_size, rngs=rngs | |
| ) | |
| ) | |
| self.n_mixtures = n_mixtures | |
| out_dim = n_mixtures * 6 | |
| self.fc_out = nnx.Linear(hidden_size, out_dim, rngs=rngs) | |
| def __call__(self, inputs: jax.Array) -> Tuple[jax.Array, ...]: | |
| h = self.rnn(inputs) # (B, T, H) # maybe needs to be h, _ | |
| y = self.fc_out(h) # (B, T, M*6) | |
| B, T, _ = y.shape | |
| M = self.n_mixtures | |
| y = y.reshape(B, T, M, 6) | |
| pi_logits = y[..., 0] # (B, T, M) | |
| mu = y[..., 1:3] # (B, T, M, 2) | |
| log_sig = jnp.clip(y[..., 3:5], LOG_SIG_MIN, LOG_SIG_MAX) # (B, T, M, 2) | |
| rho = jnp.tanh(y[..., 5]) * (1.0 - RHO_EPS) # (B, T, M) | |
| return pi_logits, mu, log_sig, rho | |
| @dataclass | |
| class ModelConfig: | |
| seq_len: int = 100 | |
| data_dim: int = 2 # (Δx, Δt) | |
| enc_hidden: int = 512 | |
| dec_hidden: int = 2048 | |
| latent_size: int = 128 | |
| n_mixtures: int = 20 | |
| num_classes: int = 0 # set >0 to enable classifier head | |
| class StarSketchVAE(nnx.Module): | |
| def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.encoder = BidirectionalLSTMEncoder( | |
| in_features=cfg.data_dim, | |
| hidden_size=cfg.enc_hidden, | |
| latent_size=cfg.latent_size, | |
| rngs=rngs, | |
| ) | |
| self.decoder = GMMDecoder( | |
| in_features=cfg.data_dim + cfg.latent_size, | |
| hidden_size=cfg.dec_hidden, | |
| n_mixtures=cfg.n_mixtures, | |
| rngs=rngs, | |
| ) | |
| self.classifier = nnx.Linear(cfg.latent_size, cfg.num_classes, rngs=rngs) | |
| def encode(self, x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]: | |
| return self.encoder(x) | |
| def decode(self, prev_xy: jax.Array, z: jax.Array) -> Tuple[jax.Array, ...]: | |
| B, T, _ = prev_xy.shape | |
| z_broadcast = jnp.repeat(z[:, None, :], T, axis=1) # (B, T, Z) | |
| dec_in = jnp.concatenate([prev_xy, z_broadcast], axis=-1) # (B, T, Z+2) | |
| return self.decoder(dec_in) | |
| def bivariate_normal_logpdf( | |
| xy: jax.Array, mu: jax.Array, log_sig: jax.Array, rho: jax.Array | |
| ) -> jax.Array: | |
| sig = jnp.exp(log_sig) | |
| x = xy[..., 0] | |
| y = xy[..., 1] | |
| mx = mu[..., 0] | |
| my = mu[..., 1] | |
| sx = sig[..., 0] | |
| sy = sig[..., 1] | |
| sx = jnp.maximum(sx, 1e-6) | |
| sy = jnp.maximum(sy, 1e-6) | |
| one_minus_r2 = jnp.maximum(1.0 - rho**2, 1e-6) | |
| norm = ( | |
| -jnp.log(2.0 * jnp.pi) - jnp.log(sx) - jnp.log(sy) - 0.5 * jnp.log(one_minus_r2) | |
| ) | |
| zx = (x - mx) / sx | |
| zy = (y - my) / sy | |
| exp_term = -0.5 * ((zx**2 + zy**2 - 2.0 * rho * zx * zy) / one_minus_r2) | |
| return norm + exp_term | |
| def gmm_log_likelihood( | |
| xy: jax.Array, | |
| pi_logits: jax.Array, | |
| mu: jax.Array, | |
| log_sig: jax.Array, | |
| rho: jax.Array, | |
| ) -> jax.Array: | |
| B, T, _ = xy.shape | |
| M = pi_logits.shape[-1] | |
| xy_exp = xy[:, :, None, :] # (B, T, 1, 2) | |
| comp_logpdf = bivariate_normal_logpdf(xy_exp, mu, log_sig, rho) # (B, T, M) | |
| log_pi = jax.nn.log_softmax(pi_logits, axis=-1) # (B, T, M) | |
| # log-sum-exp over mixture components | |
| ll = jax.scipy.special.logsumexp(log_pi + comp_logpdf, axis=-1) # (B, T) | |
| return ll | |
| def reconstruction_loss( | |
| xy: jax.Array, | |
| pi_logits: jax.Array, | |
| mu: jax.Array, | |
| log_sig: jax.Array, | |
| rho: jax.Array, | |
| w: jax.Array | |
| ) -> jax.Array: | |
| ll = gmm_log_likelihood(xy, pi_logits, mu, log_sig, rho) # (B, T) | |
| nll_t = -ll | |
| nll = jnp.sum(w * nll_t) / (jnp.sum(w) + 1e-8) | |
| return nll | |
| def kl_loss(mu: jax.Array, log_sigma: jax.Array) -> jax.Array: | |
| log_var = 2.0 * log_sigma | |
| var = jnp.exp(log_var) | |
| kl = 0.5 * (var + mu**2 - 1.0 - log_var) | |
| return jnp.mean(jnp.sum(kl, axis=-1)) | |
| def model_loss( | |
| model: StarSketchVAE, | |
| batch_xy: jax.Array, | |
| *, | |
| class_ids: jax.Array | None = None, | |
| w_kl: float = 1.0, | |
| w_sup: float = 1.0, | |
| rng_key: jax.Array, | |
| ) -> Tuple[jax.Array, dict]: | |
| """ | |
| batch_xy: (B, T=100, 2) full teacher-forced sequences of deltas | |
| Returns total loss and metrics dict | |
| """ | |
| B, T, D = batch_xy.shape | |
| assert D == 2, "Expecting (Δx, Δt)" | |
| mu, log_sigma, enc_feats = model.encode(batch_xy) | |
| eps = jax.random.normal(rng_key, mu.shape) | |
| z = mu + jnp.exp(log_sigma) * eps | |
| pad0 = jnp.zeros((B, 1, D), dtype=batch_xy.dtype) | |
| prev_xy = jnp.concatenate([pad0, batch_xy[:, :-1, :]], axis=1) # (B, T, 2) | |
| pi_logits, gmm_mu, gmm_logsig, gmm_rho = model.decode(prev_xy, z) | |
| dt = batch_xy[..., 1] | |
| # normalized weights | |
| w = dt / (jnp.mean(dt) + 1e-8) | |
| recon = reconstruction_loss(batch_xy, pi_logits, gmm_mu, gmm_logsig, gmm_rho, | |
| w) | |
| kl = kl_loss(mu, log_sigma) | |
| total = recon + w_kl * kl | |
| sup_loss = 0.0 | |
| cls_metrics = {} | |
| logits = model.classifier(z) # (B, C) | |
| onehot = jax.nn.one_hot(class_ids, logits.shape[-1]) | |
| ce = -jnp.sum(onehot * jax.nn.log_softmax(logits, axis=-1), axis=-1) | |
| sup_loss = jnp.mean(ce) | |
| probs = jax.nn.softmax(logits, axis=-1) | |
| pred = jnp.argmax(probs, axis=-1) | |
| acc = jnp.mean((pred == class_ids).astype(jnp.float32)) | |
| cls_metrics = {"sup_ce": sup_loss, "acc": acc} | |
| # ===== Sequence-closure penalty: sum Δx_pred ≈ sum Δx_gt | |
| pi = jax.nn.softmax(pi_logits, axis=-1)[..., None] # (B,T,M,1) | |
| exp = jnp.sum(pi * gmm_mu, axis=-2) # (B,T,2) | |
| x_gt_total = jnp.sum(batch_xy[..., 0], axis=1) # (B,) | |
| x_pr_total = jnp.sum(exp[..., 0], axis=1) # (B,) | |
| close_loss = jnp.mean((x_pr_total - x_gt_total) ** 2) | |
| total = recon + w_kl*kl + w_sup * sup_loss + close_loss*5 | |
| metrics = {"loss": total, "recon": recon, "kl": kl, **cls_metrics} | |
| return total, metrics | |
| import jax | |
| import jax.numpy as jnp | |
| # ---------- Classifier direction utilities ---------- | |
| def _linear_params(linear) -> tuple[jnp.ndarray, jnp.ndarray | None]: | |
| # Works for nnx.Linear; some variants use 'kernel' not 'weight' | |
| if hasattr(linear, "weight"): | |
| W = linear.weight | |
| elif hasattr(linear, "kernel"): | |
| W = linear.kernel | |
| else: | |
| raise AttributeError("Classifier Linear has no 'weight' or 'kernel'.") | |
| b = getattr(linear, "bias", None) | |
| return W, b | |
| def get_classifier_direction(model: StarSketchVAE, | |
| class_id: int, | |
| *, | |
| normalize: bool = True, | |
| margin: bool = False) -> jnp.ndarray: | |
| """ | |
| Return a latent-space direction for class `class_id`. | |
| Handles either W shape: (C, Z) or (Z, C). | |
| If margin=True, uses a mean-difference direction for a stronger push. | |
| """ | |
| W, _ = _linear_params(model.classifier) | |
| C = model.cfg.num_classes | |
| Z = model.cfg.latent_size | |
| assert 0 <= class_id < C, "class_id out of range" | |
| # Figure out orientation | |
| if W.shape == (C, Z): # rows = classes | |
| w_c = W[class_id] # (Z,) | |
| if margin: | |
| mask = jnp.ones(C, dtype=bool).at[class_id].set(False) | |
| w_rest_mean = jnp.mean(W[mask], axis=0) # (Z,) | |
| direction = w_c - w_rest_mean | |
| else: | |
| direction = w_c | |
| elif W.shape == (Z, C): # columns = classes | |
| w_c = W[:, class_id] # (Z,) | |
| if margin: | |
| # mean over other columns | |
| idx = jnp.arange(C) != class_id | |
| w_rest_mean = jnp.mean(W[:, idx], axis=1) # (Z,) | |
| direction = w_c - w_rest_mean | |
| else: | |
| direction = w_c | |
| else: | |
| raise ValueError(f"Unexpected classifier weight shape {W.shape}; " | |
| f"expected (C,Z)=({C},{Z}) or (Z,C)=({Z},{C}).") | |
| if normalize: | |
| direction = direction / (jnp.linalg.norm(direction) + 1e-8) | |
| return direction # (Z,) | |
| def tweak_z_towards_class(z: jnp.ndarray, | |
| model: StarSketchVAE, | |
| class_id: int, | |
| *, | |
| alpha: float = 3.0, | |
| normalize: bool = True, | |
| margin: bool = False) -> jnp.ndarray: | |
| d = get_classifier_direction(model, class_id, normalize=normalize, margin=margin) | |
| assert d.shape == z.shape, f"direction {d.shape} must match z {z.shape}" | |
| return z + alpha * d | |
| # ---------- GMM sampling helpers ---------- | |
| def _sample_bivariate_from_params(mu: jnp.ndarray, | |
| log_sig: jnp.ndarray, | |
| rho: jnp.ndarray, | |
| rng_key: jax.Array, | |
| temperature: float = 1.0) -> jnp.ndarray: | |
| sig = jnp.exp(log_sig) * jnp.sqrt(temperature) | |
| rho = jnp.clip(rho, -0.999, 0.999) | |
| u_key, v_key = jax.random.split(rng_key) | |
| u = jax.random.normal(u_key) | |
| v = jax.random.normal(v_key) | |
| # correlated 2D normal | |
| dx = sig[0] * u | |
| dy = sig[1] * (rho * u + jnp.sqrt(1.0 - rho**2) * v) | |
| return mu + jnp.array([dx, dy]) | |
| def _mixture_expectation(pi_logits_t: jnp.ndarray, | |
| mu_t: jnp.ndarray, | |
| temperature: float = 1.0) -> jnp.ndarray: | |
| pi = jax.nn.softmax(pi_logits_t / temperature, axis=-1) # (M,) | |
| return jnp.sum(pi[:, None] * mu_t, axis=0) # (2,) | |
| def _mixture_sample(pi_logits_t: jnp.ndarray, | |
| mu_t: jnp.ndarray, | |
| log_sig_t: jnp.ndarray, | |
| rho_t: jnp.ndarray, | |
| rng_key: jax.Array, | |
| temperature: float = 1.0) -> jnp.ndarray: | |
| M = pi_logits_t.shape[0] | |
| # Sample component | |
| pi = jax.nn.softmax(pi_logits_t / temperature, axis=-1) # (M,) | |
| comp_key, gauss_key = jax.random.split(rng_key) | |
| k = jax.random.categorical(comp_key, jnp.log(pi), axis=-1) # int | |
| return _sample_bivariate_from_params(mu_t[k], log_sig_t[k], rho_t[k], gauss_key, temperature) | |
| # ---------- Autoregressive decoding ---------- | |
| def decode_from_z(model: StarSketchVAE, | |
| z: jnp.ndarray, | |
| *, | |
| steps: int, | |
| method: str = "expected", # "expected" or "sample" | |
| temperature: float = 1.0, | |
| rng_key: jax.Array | None = None) -> jnp.ndarray: | |
| assert z.ndim == 1, "z must be shape (Z,)" | |
| B = 1 | |
| D = 2 | |
| if method == "sample" and rng_key is None: | |
| rng_key = jax.random.PRNGKey(0) | |
| # running buffer of previous outputs | |
| prev = jnp.zeros((steps, D), dtype=jnp.float32) | |
| prev = prev.at[0].set(jnp.array([0.0, 0.0], dtype=jnp.float32)) # first prev is 0 | |
| out = [] | |
| zB = z[None, :] # (1,Z) | |
| for t in range(steps): | |
| prev_xy = prev[None, :, :] # (1, steps, 2) | |
| pi_logits, gmm_mu, gmm_logsig, gmm_rho = model.decode(prev_xy, zB) | |
| pi_t = pi_logits[0, t] # (M,) | |
| mu_t = gmm_mu[0, t] # (M,2) | |
| logs_t = gmm_logsig[0, t] # (M,2) | |
| rho_t = gmm_rho[0, t] # (M,) | |
| if method == "expected": | |
| delta = _mixture_expectation(pi_t, mu_t, temperature=temperature) | |
| else: | |
| rng_key, step_key = jax.random.split(rng_key) | |
| delta = _mixture_sample(pi_t, mu_t, logs_t, rho_t, step_key, temperature=temperature) | |
| out.append(delta) | |
| if t + 1 < steps: | |
| prev = prev.at[t + 1].set(delta) | |
| return jnp.stack(out, axis=0) # (steps, 2) | |
| # ---------- End-to-end: sample z, push toward class, decode ---------- | |
| def class_conditioned_random_curve(model: StarSketchVAE, | |
| class_id: int, | |
| *, | |
| steps: int = 100, | |
| z_scale: float = 1.0, | |
| alpha: float = 3.0, | |
| normalize_dir: bool = True, | |
| margin: bool = False, | |
| method: str = "expected", # "expected" or "sample" | |
| temperature: float = 0.9, | |
| rng_key: jax.Array | None = None) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: | |
| if rng_key is None: | |
| rng_key = jax.random.PRNGKey(42) | |
| z_key, dec_key = jax.random.split(rng_key) | |
| # sample z ~ N(0, I) | |
| Z = model.cfg.latent_size | |
| z = jax.random.normal(z_key, (Z,)) * z_scale # (Z,) | |
| # push toward class direction | |
| z_star = tweak_z_towards_class(z, model, class_id, alpha=alpha, | |
| normalize=normalize_dir, margin=margin) | |
| deltas = decode_from_z(model, z_star, steps=steps, method=method, | |
| temperature=temperature, rng_key=dec_key) | |
| dt = jnp.maximum(deltas[:, 1], 1e-4) | |
| deltas = deltas.at[:, 1].set(dt) | |
| # Convert to cumulative x(t) | |
| x = jnp.cumsum(deltas[:, 0]) | |
| t = jnp.cumsum(deltas[:, 1]) | |
| return deltas, x, t | |
| if __name__ == "__main__": | |
| rngs = nnx.Rngs(0) | |
| cfg = ModelConfig( | |
| seq_len=100, | |
| data_dim=2, | |
| enc_hidden=512, | |
| dec_hidden=2048, | |
| latent_size=128, | |
| n_mixtures=20, | |
| num_classes=5, | |
| ) | |
| model = StarSketchVAE(cfg, rngs=rngs) | |
| B, T, D = 6, 100, 2 | |
| batch = jnp.ones((B, T, D)) | |
| # class_ids: all samples labeled (no unlabeled) | |
| class_ids = jnp.array([0, 1, 2, 3, 4, 0]) | |
| key = jax.random.PRNGKey(42) | |
| loss, metrics = model_loss( | |
| model, batch, class_ids=class_ids, w_kl=1.0, w_sup=1.0, rng_key=key | |
| ) | |
| print({k: float(v) for k, v in metrics.items()}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment