base#
Source code: tianshou/data/buffer/base.py
- class ReplayBuffer(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, **kwargs: Any)[source]#
ReplayBuffer
stores data generated from interaction between the policy and environment.ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
For the example usage of ReplayBuffer, please check out Section Buffer in Basic concepts in Tianshou.
- Parameters:
size – the maximum size of replay buffer.
stack_num – the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking).
ignore_obs_next – whether to not store obs_next. Default to False.
save_only_last_obs – only save the last obs/obs_next when it has a shape of (timestep, …) because of temporal stacking. Default to False.
sample_avail – the parameter indicating sampling only available index when using frame-stack sampling method. Default to False.
- add(batch: RolloutBatchProtocol, buffer_ids: ndarray | list[int] | None = None) tuple[ndarray, ndarray, ndarray, ndarray] [source]#
Add a batch of data into replay buffer.
- Parameters:
batch – the input data batch. “obs”, “act”, “rew”, “terminated”, “truncated” are required keys.
buffer_ids – to make consistent with other buffer’s add function; if it is not None, we assume the input batch’s first dimension is always 1.
Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0.
- classmethod from_data(obs: Dataset, act: Dataset, rew: Dataset, terminated: Dataset, truncated: Dataset, done: Dataset, obs_next: Dataset) Self [source]#
- get(index: int | list[int] | ndarray, key: str, default_value: Any = None, stack_num: int | None = None) Batch | ndarray [source]#
Return the stacked result.
E.g., if you set
key = "obs", stack_num = 4, index = t
, it returns the stacked result as[obs[t-3], obs[t-2], obs[t-1], obs[t]]
.- Parameters:
index – the index for getting stacked data.
key (str) – the key to get, should be one of the reserved_keys.
default_value – if the given key’s data is not found and default_value is set, return this default_value.
stack_num – Default to self.stack_num.
- classmethod load_hdf5(path: str, device: str | None = None) Self [source]#
Load replay buffer from HDF5 file.
- next(index: int | ndarray) ndarray [source]#
Return the index of next transition.
The index won’t be modified if it is the end of an episode.
- prev(index: int | ndarray) ndarray [source]#
Return the index of previous transition.
The index won’t be modified if it is the beginning of an episode.
- reset(keep_statistics: bool = False) None [source]#
Clear all the data in replay buffer and episode statistics.
- sample(batch_size: int | None) tuple[RolloutBatchProtocol, ndarray] [source]#
Get a random sample from buffer with size = batch_size.
Return all the data in the buffer if batch_size is 0.
- Returns:
Sample data and its corresponding index inside the buffer.
- sample_indices(batch_size: int | None) ndarray [source]#
Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled.
- Parameters:
batch_size – the number of indices to be sampled. If None, it will be set to the length of the buffer (i.e. return all available indices in a random order).
- save_hdf5(path: str, compression: str | None = None) None [source]#
Save replay buffer to HDF5 file.
- set_batch(batch: RolloutBatchProtocol) None [source]#
Manually choose the batch you want the ReplayBuffer to manage.
- update(buffer: ReplayBuffer) ndarray [source]#
Move the data from the given buffer to current buffer.
Return the updated indices. If update fails, return an empty array.