Skip to content

Instantly share code, notes, and snippets.

@RobertTLange
Created March 5, 2021 19:32
Show Gist options
  • Select an option

  • Save RobertTLange/292259f27e5d9754beb392099fb75012 to your computer and use it in GitHub Desktop.

Select an option

Save RobertTLange/292259f27e5d9754beb392099fb75012 to your computer and use it in GitHub Desktop.
Ornstein-Uhlenbeck Process in JAX
import jax
import jax.numpy as jnp
def ou_process(key, steps, dt, mu, tau, sigma):
""" Generate an Ornstein-Uhlenbeck process sample. """
ou_init = jnp.zeros((steps + 1, ))
noise = jax.random.normal(key, (steps,))
def ou_step(t, val):
dx = (-(val[t-1]-mu)/tau * dt
+ sigma*jnp.sqrt(2/tau)*
noise[t]*jnp.sqrt(dt))
val = jax.ops.index_update(val,
jax.ops.index[t],
val[t-1] + dx)
return val
return jax.lax.fori_loop(1, steps+1, ou_step, ou_init)[1:]
if __name__ == "__main__":
rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
keys = jax.random.split(key, 20)
steps, dt, mu, tau, sigma = 1000, 0.1, 0, 2, 1
# Single 'slow' OU sample
ou_process(key, steps, dt, mu, tau, sigma)
# Single 'fast' OU sample via jit
fast_ou = jax.jit(ou_process, static_argnums=1)
fast_ou(key, steps, dt, mu, tau, sigma)
# Batch OU sample via vmap
batch_ou = jax.vmap(fast_ou, in_axes=(0, None, None, None, None, None))
batch_ou(keys, steps, dt, mu, tau, sigma)
# Batch no jit
batch_slow_ou = jax.vmap(ou_process, in_axes=(0, None, None, None, None, None))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment