Skip to content

Instantly share code, notes, and snippets.

@shermansiu
Last active April 29, 2025 09:43
Show Gist options
  • Select an option

  • Save shermansiu/b492fddf4127f4214d57a647c0160b8f to your computer and use it in GitHub Desktop.

Select an option

Save shermansiu/b492fddf4127f4214d57a647c0160b8f to your computer and use it in GitHub Desktop.
A uniform buffer with support for envpool's async updates. Has a shared buffer for the replay buffer and online queue. Designed for Muesli.
IntScalar = chex.Array
@chex.dataclass(frozen=True)
class BoundaryPointer:
head: chex.Array
length: chex.Array
max_size: IntScalar
@classmethod
def init(cls, num_envs: IntScalar, max_size: IntScalar) -> BoundaryPointer:
return cls(
head=jnp.zeros(num_envs, jnp.int32),
length=jnp.zeros(num_envs, jnp.int32),
max_size=max_size,
)
@jax.jit
def advance(self, env_ids: chex.Array) -> BoundaryPointer:
new_head = self.head.at[env_ids].add(1) % self.max_size
new_length = self.length.at[env_ids].add(jnp.where(self.length[env_ids] == self.max_size, 0, 1))
return self.replace(head=new_head, length=new_length)
@property
@jax.jit
def tail(self):
return (self.head - self.length) % self.max_size
@jax.jit
def reset(self) -> BoundaryPointer:
return self.replace(
length=self.length.at[:].set(0),
)
@chex.dataclass(frozen=True)
class UniformBuffer:
"""A batched replay buffer, inspired by Hwhitetooth's Jax MuZero implementation and the dejax package.
See https://github.com/Hwhitetooth/jax_muzero/blob/main/algorithms/replay_buffers.py
and https://github.com/hr0nix/dejax
The buffer is designed so that we store a sequence of trajectories for each env.
Buffer updates can happen asymmetrically across the envs, so that it works with envpool's async mode.
This is a circular buffer with uniform experience sampling.
Sequences may contain several adjacent trajectories or just a subsequence of a trajectory.
Assumes each env trajectory stream is filled at a random rate that is i.i.d.
"""
data: chex.ArrayTree
online_queue_ind: BoundaryPointer
full_buffer_ind: BoundaryPointer
max_size: IntScalar
@classmethod
def init(cls, item_prototype: chex.ArrayTree, num_envs: int, max_size: int):
chex.assert_tree_has_only_ndarrays(item_prototype)
data = jax.tree_util.tree_map(
lambda t: jnp.tile(t[None, None, ...], (num_envs, max_size) + (1,) * t.ndim), item_prototype
)
return cls(
data=data,
online_queue_ind=BoundaryPointer.init(num_envs, max_size),
full_buffer_ind=BoundaryPointer.init(num_envs, max_size),
max_size=max_size,
)
@jax.jit
def reset_online_queue(self):
return self.replace(
online_queue_ind=self.online_queue_ind.reset(),
)
@chex.chexify
@jax.jit
def push_env_updates(self, update_batch: chex.ArrayTree, env_ids: chex.Array):
chex.assert_tree_has_only_ndarrays(update_batch)
new_data = jax.tree_util.tree_map(
lambda entry, t: entry.at[env_ids, self.full_buffer_ind.head[env_ids]].set(t), self.data, update_batch
)
return self.replace(
data=new_data,
online_queue_ind=self.online_queue_ind.advance(env_ids),
full_buffer_ind=self.full_buffer_ind.advance(env_ids),
)
@partial(jax.jit, static_argnums=(5, 6))
def _sample_sequence(
self,
boundary_pointer: BoundaryPointer,
arange_total_items: chex.Array,
arange_sequence_length: chex.Array,
rng: chex.PRNGKey,
batch_size: int,
sequence_length: int,
distribution_power: float = 1,
):
"""Sample a sequence of trajectories from the buffer.
Warning: the sequences are sampled with SRSWR, so they may overlap
or repeat. Dealing with SRSWOR is annoying when the online queue doesn't have
enough elements for a good SRSWOR. And then there's error handling and the fact
that SRSWOR (of subsequences!) cannot be done in parallel easily...
Using SRSWR is simpler, but it increases the variance of the estimator somewhat.
Args:
boundary_pointer: Information about where the stored info begins and ends.
rng: The PRNG key.
batch_size: The number of sequences to sample.
sequence_length: The max length of the sequences to sample.
distribution_power: Subsequences are sampled according to their length,
raised to the power of `distribution_power`.
arange_size: An array containing 0 to the size of the buffer-1
Returns:
seqs: The batch of requested sequences.
seqs_mask: The mask that indicates if sequences are shorter than sequence_length.
"""
# Get length of sequence if starting at an index
cum_lengths_per_row = jnp.cumsum(boundary_pointer.length)
def compute_remaining_sequence_length(carry, x):
staggered_lengths, length_cutoff = carry
corresponding_row = staggered_lengths[(staggered_lengths > x).argmax()]
return (staggered_lengths, length_cutoff), jnp.clip(corresponding_row - x, a_max=length_cutoff)
_, remaining_sequence_length = jax.lax.scan(
compute_remaining_sequence_length,
(cum_lengths_per_row, sequence_length),
arange_total_items,
)
flattened_index_logits = jnp.log(remaining_sequence_length) * distribution_power
# Sample from the non-empty indices in the buffer, with probability proportional
# to the length of the index
rng, index_selection_key = jax.random.split(rng)
flattened_indices = jax.random.categorical(index_selection_key, logits=flattened_index_logits, shape=(batch_size,))
# Figure out what indices in the buffer matrix that the flattened indices correspond to
env_indices = (cum_lengths_per_row.reshape(-1, 1) > flattened_indices).argmax(0)
env_start_index_to_flattened_index = jnp.concatenate([jnp.zeros(1).astype(jnp.int32), cum_lengths_per_row[:-1]], 0)
col_indices = flattened_indices + (boundary_pointer.tail - env_start_index_to_flattened_index)[env_indices]
col_indices = (arange_sequence_length.reshape(1, -1) + col_indices.reshape(-1, 1)) % self.max_size
sequences = jax.tree_util.tree_map(
lambda entry: entry[env_indices.reshape(-1, 1), col_indices],
self.data,
)
sequence_masks = arange_sequence_length.reshape(1, -1) < remaining_sequence_length[flattened_indices].reshape(-1, 1)
return sequences, sequence_masks
def _sample_sequence_jit_helper(
self,
boundary_pointer: BoundaryPointer,
rng: chex.PRNGKey,
batch_size: int,
sequence_length: int,
distribution_power: float = 1,
):
return self._sample_sequence(
boundary_pointer,
jnp.arange(boundary_pointer.length.sum()),
jnp.arange(sequence_length),
rng,
batch_size,
sequence_length,
distribution_power,
)
def sample_online_queue(
self,
rng: chex.PRNGKey,
batch_size: int,
sequence_length: int,
distribution_power: float = 1,
):
return self._sample_sequence_jit_helper(self.online_queue_ind, rng, batch_size, sequence_length, distribution_power)
def sample_replay_buffer(
self,
rng: chex.PRNGKey,
batch_size: int,
sequence_length: int,
distribution_power: float = 1,
):
return self._sample_sequence_jit_helper(self.full_buffer_ind, rng, batch_size, sequence_length, distribution_power)
def sample_rb_and_oq(
self,
rng: chex.PRNGKey,
rb_batch_size: int,
oq_batch_size: int,
sequence_length: int,
distribution_power: float = 1,
):
_, rb_rng, oq_rng = jax.random.split(rng, 3)
rb_sequence, rb_mask = self.sample_replay_buffer(
rb_rng, rb_batch_size, sequence_length, distribution_power=distribution_power
)
oq_sequence, oq_mask = self.sample_online_queue(
oq_rng, oq_batch_size, sequence_length, distribution_power=distribution_power
)
sequence = jax.tree_util.tree_map(
lambda rb_entry, oq_entry: jnp.vstack([rb_entry, oq_entry]), rb_sequence, oq_sequence
)
mask = jnp.vstack([rb_mask, oq_mask])
return sequence, mask
@jax.jit
def peek(self, env_ids: chex.Array):
"""Peek at the top of the buffer.
The online queue and the replay buffer have the same head.
"""
return jax.tree_util.tree_map(
lambda entry: entry[env_ids, (self.full_buffer_ind.head[env_ids] - 1) % self.max_size],
self.data,
)
@chex.dataclass
class Storage:
obs: chex.Array
action: chex.Array
logprob: chex.Array
reward: chex.Array
value: chex.Array
done: chex.Array
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment