Last active
November 5, 2025 17:16
-
-
Save WardBrian/dd54de815903b530c4e41432213fb798 to your computer and use it in GitHub Desktop.
functional-style jax models (take 2)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from jax import random, jit | |
| import jax.numpy as jnp | |
| from jax.scipy import stats | |
| from util import ravelize_function, make_log_density | |
| __all__ = ["log_density", "log_density_vec", "init_draw_zero"] | |
| def constrain_parameters(sigma_unc, alpha, beta): | |
| # NOTE: lower bound transform could (should) be a library function | |
| return {"alpha": alpha, "beta": beta, "sigma": jnp.exp(sigma_unc)}, sigma_unc | |
| def log_prior(alpha, beta, sigma): | |
| lp_alpha = jnp.sum(stats.norm.logpdf(alpha, loc=0.0, scale=1.0)) | |
| lp_beta = jnp.sum(stats.norm.logpdf(beta, loc=0.0, scale=1.0)) | |
| # "scale" is rate of exponential distribution (bad SciPy) | |
| lp_sigma = jnp.sum(stats.expon.logpdf(sigma, scale=1.0)) | |
| return lp_alpha + lp_beta + lp_sigma | |
| # data is closed over, so we need some fake data | |
| x = jnp.array([[1.0, 2.0, 3.0], [0.2, 0.1, 0.4]]).T # 3x2 | |
| y = jnp.array([2.1, 3.7, 6.5]) | |
| def log_likelihood(alpha, beta, sigma): | |
| mu = alpha + x @ beta | |
| return jnp.sum(stats.norm.logpdf(y, loc=mu, scale=sigma)) | |
| # These are the primary exports of this module: | |
| # a log density function | |
| log_density = make_log_density( | |
| log_prior, log_likelihood, constrain_parameters=constrain_parameters | |
| ) | |
| # an initial point on the unconstrained space. Here I pick a dumb one, zero. | |
| init_draw_zero = { | |
| "alpha": jnp.array(0.0), | |
| "beta": jnp.zeros(x.shape[1]), | |
| "sigma_unc": jnp.array(0.0), | |
| } | |
| # We can also provide a flattened version, automatically, | |
| # using the structure of this initial point as the template. | |
| log_density_vec = ravelize_function(log_density, init_draw_zero) | |
| # we might also want something like "generated quantities" | |
| @jit | |
| def generated_quantities(rng, x_new, **params): | |
| constrained, _ = constrain_parameters(**params) | |
| alpha, beta, sigma = constrained["alpha"], constrained["beta"], constrained["sigma"] | |
| mu_new = alpha + x_new @ beta | |
| y_new = mu_new + sigma * random.normal(rng, shape=x_new.shape) | |
| return {"alpha": alpha, "beta": beta, "sigma": sigma, "y_new": y_new} | |
| # and a flattened version | |
| @jit | |
| def generated_quantities_vec(rng, x_new, params_vec): | |
| gq = lambda param_dict: generated_quantities(rng, x_new, **param_dict) | |
| return ravelize_function(gq, init_draw_zero)(params_vec) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import functools | |
| import blackjax | |
| import jax | |
| import jax.numpy as jnp | |
| from util import init_random_uniform | |
| from linear_regression import ( | |
| log_density, | |
| generated_quantities, | |
| init_draw_zero, | |
| ) | |
| from linear_regression import log_density_vec, generated_quantities_vec | |
| def stan_sample(log_density, initial, steps=1_000, rng_key=None): | |
| # completely copied from https://blackjax-devs.github.io/blackjax/examples/quickstart.html | |
| def inference_loop(rng_key, kernel, initial_state, num_samples): | |
| @jax.jit | |
| def one_step(state, rng_key): | |
| state, _ = kernel(rng_key, state) | |
| return state, state | |
| keys = jax.random.split(rng_key, num_samples) | |
| _, states = jax.lax.scan(one_step, initial_state, keys) | |
| return states | |
| warmup = blackjax.window_adaptation(blackjax.nuts, log_density) | |
| rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3) | |
| (state, parameters), _ = warmup.run(warmup_key, initial, num_steps=steps) | |
| kernel = blackjax.nuts(log_density, **parameters).step | |
| states = inference_loop(sample_key, kernel, state, steps) | |
| return states | |
| if __name__ == "__main__": | |
| N = 1000 | |
| rng_key = jax.random.key(4567) | |
| init_key, sample_key, gq_key = jax.random.split(rng_key, 3) | |
| # for "generated quantities"-like behavior: | |
| rngs = jax.random.split(gq_key, N) | |
| x_new = jnp.array([0.1, 0.4]) | |
| # sample | |
| init_draw = init_random_uniform(init_draw_zero, init_key) | |
| states = stan_sample(log_density, init_draw, N, sample_key) | |
| # postprocess draws - constrains and does generated quantities | |
| draws = jax.vmap(generated_quantities, (0, None))(rngs, x_new, **states.position) | |
| print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws)) | |
| # ------------- "flat" version ------------- | |
| init_draw_vec = jax.random.uniform(init_key, shape=(4,)) | |
| states_vec = stan_sample(log_density_vec, init_draw_vec, N, sample_key) | |
| draws_vec = jax.vmap(generated_quantities_vec, (0, None, 0))( | |
| rngs, x_new, states_vec.position | |
| ) | |
| # note: because generated_quantities returns a pytree, we're no longer | |
| # in the flattened realm | |
| print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws_vec)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| def ravelize_function(f, pytree): | |
| """ | |
| Takes a function that accepts a PyTree and a PyTree, | |
| and produces a function that accepts a flat array. | |
| """ | |
| # note: ravel_pytree is only really safe when we | |
| # know all the dtypes are the same. See | |
| # https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html | |
| # This is usually true in stats models | |
| _, unravel = jax.flatten_util.ravel_pytree(pytree) | |
| return lambda x: f(unravel(x)) | |
| # version that assumes data is closed over in | |
| # the passed-in functions. | |
| # Could easily change to pass data later | |
| def make_log_density( | |
| log_prior, | |
| log_likelihood, | |
| constrain_parameters=lambda **x: (x, 0.0), | |
| ): | |
| """ | |
| Make a log_density function from a log_prior, log_likelihood, | |
| and (optionally) a function to constrain parameters. | |
| Parameters | |
| ---------- | |
| log_prior : function | |
| This function will be passed the parameters | |
| by name, and should return the log of the prior density. | |
| log_likelihood : function. | |
| This function will be passed the parameters | |
| by name, and should return the log of the likelihood. | |
| constrain_parameters : function | None | |
| This function will be passed the unconstrained parameters by | |
| name, and should return the constrained parameters as a dictionary | |
| and the log determinant of the Jacobian of the transformations. | |
| By default, this is the identity. | |
| pytree_for_unflattening : PyTree | None | |
| If provided, this should be a PyTree that specifies the | |
| structure of the dictionary that constrain_parameters accepts | |
| as input. The resulting log_density function will be defined | |
| to accept a flat array of parameters, and will un-ravel them | |
| using this PyTree before passing them to constrain_parameters. | |
| Returns | |
| ------- | |
| function | |
| A function that computes the log density of the model. | |
| """ | |
| @jax.jit | |
| def log_density(unc_params): | |
| params, log_det_jac = constrain_parameters(**unc_params) | |
| return log_det_jac + log_prior(**params) + log_likelihood(**params) | |
| return log_density | |
| # similar to a solution found at https://github.com/jax-ml/jax/discussions/9508#discussioncomment-2144076, | |
| # but uses ravel_pytree to avoid needing to split the key | |
| def init_random_uniform(target, rng_key, radius=2): | |
| """ | |
| Given a tree and a random key, return a tree with the same structure | |
| but with each leaf replaced by a random uniform value in the range [-radius, radius]. | |
| """ | |
| d, unravel = jax.flatten_util.ravel_pytree(target) | |
| uniforms = jax.random.uniform(rng_key, shape=d.shape, minval=-radius, maxval=radius) | |
| return unravel(uniforms) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment