prio#


class PrioritizedReplayBuffer(size: int, alpha: float, beta: float, weight_norm: bool = True, **kwargs: Any)[source]#

Bases: ReplayBuffer

Implementation of Prioritized Experience Replay. arXiv:1511.05952.

Parameters:
  • alpha – the prioritization exponent.

  • beta – the importance sample soft coefficient.

  • weight_norm – whether to normalize returned weights with the maximum weight value within the batch. Default to True.

See also

Please refer to ReplayBuffer for other APIs’ usage.

init_weight(index: int | ndarray) None[source]#
update(buffer: ReplayBuffer) ndarray[source]#

Move the data from the given buffer to current buffer.

Return the updated indices. If update fails, return an empty array.

add(batch: RolloutBatchProtocol, buffer_ids: ndarray | list[int] | None = None) tuple[ndarray, ndarray, ndarray, ndarray][source]#

Add a batch of data into replay buffer.

Parameters:
  • batch – the input data batch. “obs”, “act”, “rew”, “terminated”, “truncated” are required keys.

  • buffer_ids – id’s of subbuffers, allowed here to be consistent with classes similar to VectorReplayBuffer. Since the ReplayBuffer has a single subbuffer, if this is not None, it must be a single element with value 0. In that case, the batch is expected to have the shape (1, len(data)). Failure to adhere to this will result in a ValueError.

Return (current_index, episode_return, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.

sample_indices(batch_size: int | None) ndarray[source]#

Get a random sample of index with size = batch_size.

Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.

Parameters:

batch_size – the number of indices to be sampled. If None, it will be set to the length of the buffer (i.e. return all available indices in a random order).

get_weight(index: int | ndarray) float | ndarray[source]#

Get the importance sampling weight.

The “weight” in the returned Batch is the weight on loss function to debias the sampling process (some transition tuples are sampled more often so their losses are weighted less).

update_weight(index: ndarray, new_weight: ndarray | Tensor) None[source]#

Update priority weight by index in this buffer.

Parameters:
  • index (np.ndarray) – index you want to update weight.

  • new_weight (np.ndarray) – new priority weight you want to update.

sample(batch_size: int | None) tuple[PrioBatchProtocol, ndarray][source]#

Get a random sample from buffer with size = batch_size.

Return all the data in the buffer if batch_size is 0.

Returns:

Sample data and its corresponding index inside the buffer.

set_beta(beta: float) None[source]#