Last active
April 29, 2025 09:43
-
-
Save shermansiu/b492fddf4127f4214d57a647c0160b8f to your computer and use it in GitHub Desktop.
A uniform buffer with support for envpool's async updates. Has a shared buffer for the replay buffer and online queue. Designed for Muesli.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| IntScalar = chex.Array | |
| @chex.dataclass(frozen=True) | |
| class BoundaryPointer: | |
| head: chex.Array | |
| length: chex.Array | |
| max_size: IntScalar | |
| @classmethod | |
| def init(cls, num_envs: IntScalar, max_size: IntScalar) -> BoundaryPointer: | |
| return cls( | |
| head=jnp.zeros(num_envs, jnp.int32), | |
| length=jnp.zeros(num_envs, jnp.int32), | |
| max_size=max_size, | |
| ) | |
| @jax.jit | |
| def advance(self, env_ids: chex.Array) -> BoundaryPointer: | |
| new_head = self.head.at[env_ids].add(1) % self.max_size | |
| new_length = self.length.at[env_ids].add(jnp.where(self.length[env_ids] == self.max_size, 0, 1)) | |
| return self.replace(head=new_head, length=new_length) | |
| @property | |
| @jax.jit | |
| def tail(self): | |
| return (self.head - self.length) % self.max_size | |
| @jax.jit | |
| def reset(self) -> BoundaryPointer: | |
| return self.replace( | |
| length=self.length.at[:].set(0), | |
| ) | |
| @chex.dataclass(frozen=True) | |
| class UniformBuffer: | |
| """A batched replay buffer, inspired by Hwhitetooth's Jax MuZero implementation and the dejax package. | |
| See https://github.com/Hwhitetooth/jax_muzero/blob/main/algorithms/replay_buffers.py | |
| and https://github.com/hr0nix/dejax | |
| The buffer is designed so that we store a sequence of trajectories for each env. | |
| Buffer updates can happen asymmetrically across the envs, so that it works with envpool's async mode. | |
| This is a circular buffer with uniform experience sampling. | |
| Sequences may contain several adjacent trajectories or just a subsequence of a trajectory. | |
| Assumes each env trajectory stream is filled at a random rate that is i.i.d. | |
| """ | |
| data: chex.ArrayTree | |
| online_queue_ind: BoundaryPointer | |
| full_buffer_ind: BoundaryPointer | |
| max_size: IntScalar | |
| @classmethod | |
| def init(cls, item_prototype: chex.ArrayTree, num_envs: int, max_size: int): | |
| chex.assert_tree_has_only_ndarrays(item_prototype) | |
| data = jax.tree_util.tree_map( | |
| lambda t: jnp.tile(t[None, None, ...], (num_envs, max_size) + (1,) * t.ndim), item_prototype | |
| ) | |
| return cls( | |
| data=data, | |
| online_queue_ind=BoundaryPointer.init(num_envs, max_size), | |
| full_buffer_ind=BoundaryPointer.init(num_envs, max_size), | |
| max_size=max_size, | |
| ) | |
| @jax.jit | |
| def reset_online_queue(self): | |
| return self.replace( | |
| online_queue_ind=self.online_queue_ind.reset(), | |
| ) | |
| @chex.chexify | |
| @jax.jit | |
| def push_env_updates(self, update_batch: chex.ArrayTree, env_ids: chex.Array): | |
| chex.assert_tree_has_only_ndarrays(update_batch) | |
| new_data = jax.tree_util.tree_map( | |
| lambda entry, t: entry.at[env_ids, self.full_buffer_ind.head[env_ids]].set(t), self.data, update_batch | |
| ) | |
| return self.replace( | |
| data=new_data, | |
| online_queue_ind=self.online_queue_ind.advance(env_ids), | |
| full_buffer_ind=self.full_buffer_ind.advance(env_ids), | |
| ) | |
| @partial(jax.jit, static_argnums=(5, 6)) | |
| def _sample_sequence( | |
| self, | |
| boundary_pointer: BoundaryPointer, | |
| arange_total_items: chex.Array, | |
| arange_sequence_length: chex.Array, | |
| rng: chex.PRNGKey, | |
| batch_size: int, | |
| sequence_length: int, | |
| distribution_power: float = 1, | |
| ): | |
| """Sample a sequence of trajectories from the buffer. | |
| Warning: the sequences are sampled with SRSWR, so they may overlap | |
| or repeat. Dealing with SRSWOR is annoying when the online queue doesn't have | |
| enough elements for a good SRSWOR. And then there's error handling and the fact | |
| that SRSWOR (of subsequences!) cannot be done in parallel easily... | |
| Using SRSWR is simpler, but it increases the variance of the estimator somewhat. | |
| Args: | |
| boundary_pointer: Information about where the stored info begins and ends. | |
| rng: The PRNG key. | |
| batch_size: The number of sequences to sample. | |
| sequence_length: The max length of the sequences to sample. | |
| distribution_power: Subsequences are sampled according to their length, | |
| raised to the power of `distribution_power`. | |
| arange_size: An array containing 0 to the size of the buffer-1 | |
| Returns: | |
| seqs: The batch of requested sequences. | |
| seqs_mask: The mask that indicates if sequences are shorter than sequence_length. | |
| """ | |
| # Get length of sequence if starting at an index | |
| cum_lengths_per_row = jnp.cumsum(boundary_pointer.length) | |
| def compute_remaining_sequence_length(carry, x): | |
| staggered_lengths, length_cutoff = carry | |
| corresponding_row = staggered_lengths[(staggered_lengths > x).argmax()] | |
| return (staggered_lengths, length_cutoff), jnp.clip(corresponding_row - x, a_max=length_cutoff) | |
| _, remaining_sequence_length = jax.lax.scan( | |
| compute_remaining_sequence_length, | |
| (cum_lengths_per_row, sequence_length), | |
| arange_total_items, | |
| ) | |
| flattened_index_logits = jnp.log(remaining_sequence_length) * distribution_power | |
| # Sample from the non-empty indices in the buffer, with probability proportional | |
| # to the length of the index | |
| rng, index_selection_key = jax.random.split(rng) | |
| flattened_indices = jax.random.categorical(index_selection_key, logits=flattened_index_logits, shape=(batch_size,)) | |
| # Figure out what indices in the buffer matrix that the flattened indices correspond to | |
| env_indices = (cum_lengths_per_row.reshape(-1, 1) > flattened_indices).argmax(0) | |
| env_start_index_to_flattened_index = jnp.concatenate([jnp.zeros(1).astype(jnp.int32), cum_lengths_per_row[:-1]], 0) | |
| col_indices = flattened_indices + (boundary_pointer.tail - env_start_index_to_flattened_index)[env_indices] | |
| col_indices = (arange_sequence_length.reshape(1, -1) + col_indices.reshape(-1, 1)) % self.max_size | |
| sequences = jax.tree_util.tree_map( | |
| lambda entry: entry[env_indices.reshape(-1, 1), col_indices], | |
| self.data, | |
| ) | |
| sequence_masks = arange_sequence_length.reshape(1, -1) < remaining_sequence_length[flattened_indices].reshape(-1, 1) | |
| return sequences, sequence_masks | |
| def _sample_sequence_jit_helper( | |
| self, | |
| boundary_pointer: BoundaryPointer, | |
| rng: chex.PRNGKey, | |
| batch_size: int, | |
| sequence_length: int, | |
| distribution_power: float = 1, | |
| ): | |
| return self._sample_sequence( | |
| boundary_pointer, | |
| jnp.arange(boundary_pointer.length.sum()), | |
| jnp.arange(sequence_length), | |
| rng, | |
| batch_size, | |
| sequence_length, | |
| distribution_power, | |
| ) | |
| def sample_online_queue( | |
| self, | |
| rng: chex.PRNGKey, | |
| batch_size: int, | |
| sequence_length: int, | |
| distribution_power: float = 1, | |
| ): | |
| return self._sample_sequence_jit_helper(self.online_queue_ind, rng, batch_size, sequence_length, distribution_power) | |
| def sample_replay_buffer( | |
| self, | |
| rng: chex.PRNGKey, | |
| batch_size: int, | |
| sequence_length: int, | |
| distribution_power: float = 1, | |
| ): | |
| return self._sample_sequence_jit_helper(self.full_buffer_ind, rng, batch_size, sequence_length, distribution_power) | |
| def sample_rb_and_oq( | |
| self, | |
| rng: chex.PRNGKey, | |
| rb_batch_size: int, | |
| oq_batch_size: int, | |
| sequence_length: int, | |
| distribution_power: float = 1, | |
| ): | |
| _, rb_rng, oq_rng = jax.random.split(rng, 3) | |
| rb_sequence, rb_mask = self.sample_replay_buffer( | |
| rb_rng, rb_batch_size, sequence_length, distribution_power=distribution_power | |
| ) | |
| oq_sequence, oq_mask = self.sample_online_queue( | |
| oq_rng, oq_batch_size, sequence_length, distribution_power=distribution_power | |
| ) | |
| sequence = jax.tree_util.tree_map( | |
| lambda rb_entry, oq_entry: jnp.vstack([rb_entry, oq_entry]), rb_sequence, oq_sequence | |
| ) | |
| mask = jnp.vstack([rb_mask, oq_mask]) | |
| return sequence, mask | |
| @jax.jit | |
| def peek(self, env_ids: chex.Array): | |
| """Peek at the top of the buffer. | |
| The online queue and the replay buffer have the same head. | |
| """ | |
| return jax.tree_util.tree_map( | |
| lambda entry: entry[env_ids, (self.full_buffer_ind.head[env_ids] - 1) % self.max_size], | |
| self.data, | |
| ) | |
| @chex.dataclass | |
| class Storage: | |
| obs: chex.Array | |
| action: chex.Array | |
| logprob: chex.Array | |
| reward: chex.Array | |
| value: chex.Array | |
| done: chex.Array |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment