Skip to content

Instantly share code, notes, and snippets.

@WardBrian
Last active November 5, 2025 17:16
Show Gist options
  • Select an option

  • Save WardBrian/dd54de815903b530c4e41432213fb798 to your computer and use it in GitHub Desktop.

Select an option

Save WardBrian/dd54de815903b530c4e41432213fb798 to your computer and use it in GitHub Desktop.
functional-style jax models (take 2)
from jax import random, jit
import jax.numpy as jnp
from jax.scipy import stats
from util import ravelize_function, make_log_density
__all__ = ["log_density", "log_density_vec", "init_draw_zero"]
def constrain_parameters(sigma_unc, alpha, beta):
# NOTE: lower bound transform could (should) be a library function
return {"alpha": alpha, "beta": beta, "sigma": jnp.exp(sigma_unc)}, sigma_unc
def log_prior(alpha, beta, sigma):
lp_alpha = jnp.sum(stats.norm.logpdf(alpha, loc=0.0, scale=1.0))
lp_beta = jnp.sum(stats.norm.logpdf(beta, loc=0.0, scale=1.0))
# "scale" is rate of exponential distribution (bad SciPy)
lp_sigma = jnp.sum(stats.expon.logpdf(sigma, scale=1.0))
return lp_alpha + lp_beta + lp_sigma
# data is closed over, so we need some fake data
x = jnp.array([[1.0, 2.0, 3.0], [0.2, 0.1, 0.4]]).T # 3x2
y = jnp.array([2.1, 3.7, 6.5])
def log_likelihood(alpha, beta, sigma):
mu = alpha + x @ beta
return jnp.sum(stats.norm.logpdf(y, loc=mu, scale=sigma))
# These are the primary exports of this module:
# a log density function
log_density = make_log_density(
log_prior, log_likelihood, constrain_parameters=constrain_parameters
)
# an initial point on the unconstrained space. Here I pick a dumb one, zero.
init_draw_zero = {
"alpha": jnp.array(0.0),
"beta": jnp.zeros(x.shape[1]),
"sigma_unc": jnp.array(0.0),
}
# We can also provide a flattened version, automatically,
# using the structure of this initial point as the template.
log_density_vec = ravelize_function(log_density, init_draw_zero)
# we might also want something like "generated quantities"
@jit
def generated_quantities(rng, x_new, **params):
constrained, _ = constrain_parameters(**params)
alpha, beta, sigma = constrained["alpha"], constrained["beta"], constrained["sigma"]
mu_new = alpha + x_new @ beta
y_new = mu_new + sigma * random.normal(rng, shape=x_new.shape)
return {"alpha": alpha, "beta": beta, "sigma": sigma, "y_new": y_new}
# and a flattened version
@jit
def generated_quantities_vec(rng, x_new, params_vec):
gq = lambda param_dict: generated_quantities(rng, x_new, **param_dict)
return ravelize_function(gq, init_draw_zero)(params_vec)
import functools
import blackjax
import jax
import jax.numpy as jnp
from util import init_random_uniform
from linear_regression import (
log_density,
generated_quantities,
init_draw_zero,
)
from linear_regression import log_density_vec, generated_quantities_vec
def stan_sample(log_density, initial, steps=1_000, rng_key=None):
# completely copied from https://blackjax-devs.github.io/blackjax/examples/quickstart.html
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
warmup = blackjax.window_adaptation(blackjax.nuts, log_density)
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
(state, parameters), _ = warmup.run(warmup_key, initial, num_steps=steps)
kernel = blackjax.nuts(log_density, **parameters).step
states = inference_loop(sample_key, kernel, state, steps)
return states
if __name__ == "__main__":
N = 1000
rng_key = jax.random.key(4567)
init_key, sample_key, gq_key = jax.random.split(rng_key, 3)
# for "generated quantities"-like behavior:
rngs = jax.random.split(gq_key, N)
x_new = jnp.array([0.1, 0.4])
# sample
init_draw = init_random_uniform(init_draw_zero, init_key)
states = stan_sample(log_density, init_draw, N, sample_key)
# postprocess draws - constrains and does generated quantities
draws = jax.vmap(generated_quantities, (0, None))(rngs, x_new, **states.position)
print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws))
# ------------- "flat" version -------------
init_draw_vec = jax.random.uniform(init_key, shape=(4,))
states_vec = stan_sample(log_density_vec, init_draw_vec, N, sample_key)
draws_vec = jax.vmap(generated_quantities_vec, (0, None, 0))(
rngs, x_new, states_vec.position
)
# note: because generated_quantities returns a pytree, we're no longer
# in the flattened realm
print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws_vec))
import jax
def ravelize_function(f, pytree):
"""
Takes a function that accepts a PyTree and a PyTree,
and produces a function that accepts a flat array.
"""
# note: ravel_pytree is only really safe when we
# know all the dtypes are the same. See
# https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html
# This is usually true in stats models
_, unravel = jax.flatten_util.ravel_pytree(pytree)
return lambda x: f(unravel(x))
# version that assumes data is closed over in
# the passed-in functions.
# Could easily change to pass data later
def make_log_density(
log_prior,
log_likelihood,
constrain_parameters=lambda **x: (x, 0.0),
):
"""
Make a log_density function from a log_prior, log_likelihood,
and (optionally) a function to constrain parameters.
Parameters
----------
log_prior : function
This function will be passed the parameters
by name, and should return the log of the prior density.
log_likelihood : function.
This function will be passed the parameters
by name, and should return the log of the likelihood.
constrain_parameters : function | None
This function will be passed the unconstrained parameters by
name, and should return the constrained parameters as a dictionary
and the log determinant of the Jacobian of the transformations.
By default, this is the identity.
pytree_for_unflattening : PyTree | None
If provided, this should be a PyTree that specifies the
structure of the dictionary that constrain_parameters accepts
as input. The resulting log_density function will be defined
to accept a flat array of parameters, and will un-ravel them
using this PyTree before passing them to constrain_parameters.
Returns
-------
function
A function that computes the log density of the model.
"""
@jax.jit
def log_density(unc_params):
params, log_det_jac = constrain_parameters(**unc_params)
return log_det_jac + log_prior(**params) + log_likelihood(**params)
return log_density
# similar to a solution found at https://github.com/jax-ml/jax/discussions/9508#discussioncomment-2144076,
# but uses ravel_pytree to avoid needing to split the key
def init_random_uniform(target, rng_key, radius=2):
"""
Given a tree and a random key, return a tree with the same structure
but with each leaf replaced by a random uniform value in the range [-radius, radius].
"""
d, unravel = jax.flatten_util.ravel_pytree(target)
uniforms = jax.random.uniform(rng_key, shape=d.shape, minval=-radius, maxval=radius)
return unravel(uniforms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment