class CachedReplayBuffer(main_buffer: ReplayBuffer, cached_buffer_num: int, max_episode_length: int)[source]#

CachedReplayBuffer contains a given main buffer and n cached buffers, cached_buffer_num * ReplayBuffer(size=max_episode_length).

The memory layout is: | main_buffer | cached_buffers[0] | cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1] |.

The data is first stored in cached buffers. When an episode is terminated, the data will move to the main buffer and the corresponding cached buffer will be reset.

  • main_buffer – the main buffer whose .update() function behaves normally.

  • cached_buffer_num – number of ReplayBuffer needs to be created for cached buffer.

  • max_episode_length – the maximum length of one episode, used in each cached buffer’s maxsize.

See also

Please refer to ReplayBuffer for other APIs’ usage.

add(batch: RolloutBatchProtocol, buffer_ids: ndarray | list[int] | None = None) tuple[ndarray, ndarray, ndarray, ndarray][source]#

Add a batch of data into CachedReplayBuffer.

Each of the data’s length (first dimension) must equal to the length of buffer_ids. By default the buffer_ids is [0, 1, …, cached_buffer_num - 1].

Return (current_index, episode_reward, episode_length, episode_start_index) with each of the shape (len(buffer_ids), …), where (current_index[i], episode_reward[i], episode_length[i], episode_start_index[i]) refers to the cached_buffer_ids[i]th cached buffer’s corresponding episode result.