base#


exception MalformedBufferError[source]#

Bases: RuntimeError

class ReplayBuffer(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, **kwargs: Any)[source]#

Bases: object

ReplayBuffer stores data generated from interaction between the policy and environment.

ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.

For the example usage of ReplayBuffer, please check out Section Buffer in Basic concepts in Tianshou.

Parameters:
  • size – the maximum size of replay buffer.

  • stack_num – the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking).

  • ignore_obs_next – whether to not store obs_next.

  • save_only_last_obs – only save the last obs/obs_next when it has a shape of (timestep, …) because of temporal stacking.

  • sample_avail – whether to sample only available indices when using the frame-stack sampling method.

property subbuffer_edges: ndarray#

Edges of contained buffers, mostly needed as part of the VectorReplayBuffer interface.

For the standard ReplayBuffer it is always [0, maxsize]. Transitions can be added to the buffer indefinitely, and one episode can “go over the edge”. Having the edges available is useful for fishing out whole episodes from the buffer and for input validation.

get_buffer_indices(start: int, stop: int) ndarray[source]#

Get the indices of the transitions in the buffer between start and stop.

The special thing about this is that stop may actually be smaller than start, since one often is interested in a sequence of transitions that goes over a subbuffer edge.

The main use case for this method is to retrieve an episode from the buffer, in which case start is the index of the first transition in the episode and stop is the index where done is True + 1. This can be done with the following code:

episode_indices = buffer.get_buffer_indices(episode_start_index, episode_done_index + 1)
episode = buffer[episode_indices]

Even when start is smaller than stop, it will be validated that they are in the same subbuffer.

Example:#

>>> list(buffer.subbuffer_edges) == [0, 5, 10]
>>> buffer.get_buffer_indices(start=2, stop=4)
[2, 3]
>>> buffer.get_buffer_indices(start=4, stop=2)
[4, 0, 1]
>>> buffer.get_buffer_indices(start=8, stop=7)
[8, 9, 5, 6]
>>> buffer.get_buffer_indices(start=1, stop=6)
ValueError: Start and stop indices must be within the same subbuffer.
>>> buffer.get_buffer_indices(start=8, stop=1)
ValueError: Start and stop indices must be within the same subbuffer.
param start:

The start index of the interval.

param stop:

The stop index of the interval.

return:

The indices of the transitions in the buffer between start and stop.

save_hdf5(path: str, compression: str | None = None) None[source]#

Save replay buffer to HDF5 file.

classmethod load_hdf5(path: str, device: str | None = None) Self[source]#

Load replay buffer from HDF5 file.

classmethod from_data(obs: Dataset, act: Dataset, rew: Dataset, terminated: Dataset, truncated: Dataset, done: Dataset, obs_next: Dataset) Self[source]#
reset(keep_statistics: bool = False) None[source]#

Clear all the data in replay buffer and episode statistics.

set_batch(batch: RolloutBatchProtocol) None[source]#

Manually choose the batch you want the ReplayBuffer to manage.

unfinished_index() ndarray[source]#

Return the index of unfinished episode.

prev(index: int | ndarray) ndarray[source]#

Return the index of previous transition.

The index won’t be modified if it is the beginning of an episode.

next(index: int | ndarray) ndarray[source]#

Return the index of next transition.

The index won’t be modified if it is the end of an episode.

update(buffer: ReplayBuffer) ndarray[source]#

Move the data from the given buffer to current buffer.

Return the updated indices. If update fails, return an empty array.

add(batch: RolloutBatchProtocol, buffer_ids: ndarray | list[int] | None = None) tuple[ndarray, ndarray, ndarray, ndarray][source]#

Add a batch of data into replay buffer.

Parameters:
  • batch – the input data batch. “obs”, “act”, “rew”, “terminated”, “truncated” are required keys.

  • buffer_ids – id’s of subbuffers, allowed here to be consistent with classes similar to VectorReplayBuffer. Since the ReplayBuffer has a single subbuffer, if this is not None, it must be a single element with value 0. In that case, the batch is expected to have the shape (1, len(data)). Failure to adhere to this will result in a ValueError.

Return (current_index, episode_return, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.

sample_indices(batch_size: int | None) ndarray[source]#

Get a random sample of index with size = batch_size.

Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.

Parameters:

batch_size – the number of indices to be sampled. If None, it will be set to the length of the buffer (i.e. return all available indices in a random order).

sample(batch_size: int | None) tuple[RolloutBatchProtocol, ndarray][source]#

Get a random sample from buffer with size = batch_size.

Return all the data in the buffer if batch_size is 0.

Returns:

Sample data and its corresponding index inside the buffer.

get(index: int | list[int] | ndarray, key: str, default_value: Any = None, stack_num: int | None = None) Batch | ndarray[source]#

Return the stacked result.

E.g., if you set key = "obs", stack_num = 4, index = t, it returns the stacked result as [obs[t-3], obs[t-2], obs[t-1], obs[t]].

Parameters:
  • index – the index for getting stacked data.

  • key – the key to get, should be one of the reserved_keys.

  • default_value – if the given key’s data is not found and default_value is set, return this default_value.

  • stack_num – Default to self.stack_num.

set_array_at_key(seq: ndarray, key: str, index: ndarray | slice | int | ellipsis | Sequence[slice | int | ellipsis] | None = None, default_value: float | None = None) None[source]#
hasnull() bool[source]#
isnull() RolloutBatchProtocol[source]#
dropnull() None[source]#