Buffer: Experience Replay in Tianshou#
The replay buffer is a fundamental component in reinforcement learning, particularly for off-policy algorithms. Tianshou’s buffer implementation extends beyond simple data storage to provide sophisticated trajectory tracking, efficient sampling, and seamless integration with the RL training pipeline.
This tutorial provides comprehensive coverage of Tianshou’s buffer system, from basic concepts to advanced features and integration patterns.
import pickle
import tempfile
import numpy as np
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer, VectorReplayBuffer
1. Introduction: Why Buffers in Reinforcement Learning?#
The Role of Experience Replay#
Experience replay is a critical technique in modern reinforcement learning that addresses three fundamental challenges:
Breaking Temporal Correlation: Sequential experiences from an agent are highly correlated. Training directly on these sequences can lead to unstable learning. By storing experiences and sampling randomly, we break these correlations.
Sample Efficiency: In RL, collecting data through environment interaction is often expensive. Experience replay allows us to reuse each experience multiple times for training, dramatically improving sample efficiency.
Mini-batch Training: Modern deep learning requires mini-batch gradient descent. Buffers enable efficient batching of experiences for neural network training.
Why Not Alternatives?#
Plain Python Lists
No efficient random sampling
No automatic circular queue behavior
No trajectory boundary tracking
Poor memory management for large datasets
Simple Batch Storage
No automatic overwriting when full
No episode metadata (returns, lengths)
No methods for boundary navigation (prev/next)
No specialized sampling strategies
Buffer = Batch + Trajectory Management + Sampling#
Tianshou’s buffers build on the Batch class to provide:
Circular queue storage: Automatic overwriting of oldest data
Trajectory tracking: Episode boundaries, returns, and lengths
Efficient sampling: Random access with various strategies
Integration utilities: Seamless connection to Collector and Policy
Use Cases#
Off-policy algorithms: DQN, SAC, TD3, DDPG require experience replay
On-policy with replay: Some PPO implementations reuse buffer data
Offline RL: Loading and using pre-collected datasets
Multi-environment training: VectorReplayBuffer for parallel collection
2. Buffer Types and Hierarchy#
Tianshou provides several buffer implementations, each designed for specific use cases. Understanding this hierarchy is crucial for choosing the right buffer.
Buffer Hierarchy#
graph TD
RB[ReplayBuffer<br/>Single environment<br/>Circular queue] --> RBM[ReplayBufferManager<br/>Manages multiple buffers<br/>Contiguous memory]
RBM --> VRB[VectorReplayBuffer<br/>Parallel environments<br/>Maintains temporal order]
RB --> PRB[PrioritizedReplayBuffer<br/>TD-error based sampling<br/>Importance weights]
PRB --> PVRB[PrioritizedVectorReplayBuffer<br/>Prioritized + Parallel]
RB --> CRB[CachedReplayBuffer<br/>Primary + auxiliary caches<br/>Imitation learning]
RB --> HERB[HERReplayBuffer<br/>Hindsight Experience Replay<br/>Goal-conditioned RL]
HERB --> HVRB[HERVectorReplayBuffer<br/>HER + Parallel]
style RB fill:#e1f5ff
style RBM fill:#fff4e1
style VRB fill:#ffe1f5
style PRB fill:#e8f5e1
style CRB fill:#f5e1e1
style HERB fill:#e1e1f5
When to Use Which Buffer#
ReplayBuffer: Single environment scenarios
Simple setup and testing
Debugging algorithms
Low-parallelism training
VectorReplayBuffer: Multiple parallel environments (most common)
Standard production use case
Efficient parallel data collection
Maintains per-environment episode boundaries
PrioritizedReplayBuffer: DQN variants with prioritization
Rainbow DQN
Algorithms requiring importance sampling
When some transitions are more valuable than others
CachedReplayBuffer: Separate primary and auxiliary caches
Imitation learning (expert + agent data)
GAIL and similar algorithms
When you need different sampling strategies for different data sources
HERReplayBuffer: Goal-conditioned reinforcement learning
Sparse reward environments
Robotics tasks with explicit goals
Relabeling failed experiences with achieved goals
3. Basic Operations#
3.1 Construction and Configuration#
The ReplayBuffer constructor accepts several important parameters that control its behavior:
# Create a buffer with all configuration options
buf = ReplayBuffer(
size=20, # Maximum capacity (transitions)
stack_num=1, # Frame stacking for RNNs (default: 1, no stacking)
ignore_obs_next=False, # Save memory by not storing obs_next
save_only_last_obs=False, # For temporal stacking (Atari-style)
sample_avail=False, # Sample only valid indices for frame stacking
random_seed=42, # Reproducible sampling
)
print(f"Buffer created: {buf}")
print(f"Max size: {buf.maxsize}")
print(f"Current length: {len(buf)}")
Buffer created: ReplayBuffer()
Max size: 20
Current length: 0
Parameter Explanations:
size: Maximum number of transitions the buffer can hold. When full, oldest data is overwritten.stack_num: Number of consecutive frames to stack. Used for RNN inputs or frame-based policies (Atari).ignore_obs_next: If True, obs_next is not stored, saving memory. The buffer reconstructs it from the next obs when needed.save_only_last_obs: For temporal stacking. Only saves the last observation in a stack.sample_avail: When True with stack_num > 1, only samples indices where a complete stack is available.random_seed: Seeds the random number generator for reproducible sampling.
3.2 Reserved Keys and the Done Flag System#
ReplayBuffer uses nine reserved keys that integrate with Gymnasium conventions. Understanding the done flag system is critical.
# The nine reserved keys
print("Reserved keys:")
print(ReplayBuffer._reserved_keys)
print("\nKeys required for add():")
print(ReplayBuffer._required_keys_for_add)
Reserved keys:
('obs', 'act', 'rew', 'terminated', 'truncated', 'done', 'obs_next', 'info', 'policy')
Keys required for add():
{'act', 'truncated', 'rew', 'done', 'obs', 'terminated'}
Important: Understanding done, terminated, and truncated
Gymnasium (the successor to OpenAI Gym) introduced a crucial distinction:
terminated: Episode ended naturally (agent reached goal or failed)Examples: CartPole fell over, agent reached goal state
Should be used for bootstrapping calculations
truncated: Episode was cut off artificially (time limit, external interruption)Examples: Maximum episode length reached, environment reset externally
Should NOT be used for bootstrapping (the episode could have continued)
done: Computed automatically asterminated OR truncatedUsed internally for episode boundary tracking
You should NEVER manually set this field
Best Practice: Always use the info dictionary for custom metadata rather than adding top-level keys:
# GOOD: Custom metadata in info dictionary
good_batch = Batch(
obs=np.array([1.0, 2.0]),
act=0,
rew=1.0,
terminated=False,
truncated=False,
obs_next=np.array([1.5, 2.5]),
info={"custom_metric": 0.95, "step_count": 10}, # Custom data here
)
# BAD: Don't add custom top-level keys (may conflict with future buffer features)
# bad_batch = Batch(..., custom_metric=0.95) # Don't do this!
print("Good batch structure:")
print(good_batch)
Good batch structure:
Batch(
obs: array([1., 2.]),
act: array(0),
rew: array(1.),
terminated: array(False),
truncated: array(False),
obs_next: array([1.5, 2.5]),
info: Batch(
custom_metric: array(0.95),
step_count: array(10),
),
)
3.3 Circular Queue Storage#
The buffer implements a circular queue: when it reaches maximum capacity, new data overwrites the oldest entries.
# Create a small buffer to demonstrate circular behavior
demo_buf = ReplayBuffer(size=5)
print("Adding 3 transitions:")
for i in range(3):
demo_buf.add(
Batch(
obs=i,
act=i,
rew=float(i),
terminated=False,
truncated=False,
obs_next=i + 1,
info={},
)
)
print(f"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}")
print(f"Observations: {demo_buf.obs[: len(demo_buf)]}")
print("\nAdding 5 more transitions (total 8, exceeds capacity 5):")
for i in range(3, 8):
demo_buf.add(
Batch(
obs=i,
act=i,
rew=float(i),
terminated=False,
truncated=False,
obs_next=i + 1,
info={},
)
)
print(f"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}")
print(f"Observations: {demo_buf.obs[: len(demo_buf)]}")
print("\nNotice: First 3 transitions (0,1,2) were overwritten by (3,4,5)")
print("Buffer now contains: [3, 4, 5, 6, 7]")
Adding 3 transitions:
Length: 3, Max: 5
Observations: [0 1 2]
Adding 5 more transitions (total 8, exceeds capacity 5):
Length: 5, Max: 5
Observations: [5 6 7 3 4]
Notice: First 3 transitions (0,1,2) were overwritten by (3,4,5)
Buffer now contains: [3, 4, 5, 6, 7]
3.4 Batch-Compatible Operations#
Since ReplayBuffer extends Batch functionality, it supports standard indexing and slicing:
# Indexing and slicing
print("Last transition:")
print(demo_buf[-1])
print("\nLast 3 transitions:")
print(demo_buf[-3:])
print("\nSpecific indices [0, 2, 4]:")
print(demo_buf[np.array([0, 2, 4])])
Last transition:
Batch(
obs: array(4),
act: array(4),
rew: array(4.),
terminated: array(False),
truncated: array(False),
done: array(False),
obs_next: array(5),
info: Batch(),
policy: Batch(),
)
Last 3 transitions:
Batch(
obs: array([7, 3, 4]),
act: array([7, 3, 4]),
rew: array([7., 3., 4.]),
terminated: array([False, False, False]),
truncated: array([False, False, False]),
done: array([False, False, False]),
obs_next: array([8, 4, 5]),
info: Batch(),
policy: Batch(),
)
Specific indices [0, 2, 4]:
Batch(
obs: array([5, 7, 4]),
act: array([5, 7, 4]),
rew: array([5., 7., 4.]),
terminated: array([False, False, False]),
truncated: array([False, False, False]),
done: array([False, False, False]),
obs_next: array([6, 8, 5]),
info: Batch(),
policy: Batch(),
)
4. Trajectory Management#
A key distinguishing feature of ReplayBuffer is its automatic tracking of episode boundaries and metadata.
4.1 Episode Tracking and Metadata#
The add() method returns four values that provide episode information:
# Create a fresh buffer for trajectory demonstration
traj_buf = ReplayBuffer(size=20)
print("Episode 1: 4 steps, terminates naturally")
for i in range(4):
idx, ep_rew, ep_len, ep_start = traj_buf.add(
Batch(
obs=i,
act=i,
rew=float(i + 1), # Rewards: 1, 2, 3, 4
terminated=i == 3, # Last step terminates
truncated=False,
obs_next=i + 1,
info={},
)
)
print(f" Step {i}: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len}, ep_start={ep_start}")
print("\nNotice: Episode return (10.0) and length (4) only appear at the end!")
Episode 1: 4 steps, terminates naturally
Step 0: idx=[0], ep_rew=[0.], ep_len=[0], ep_start=[0]
Step 1: idx=[1], ep_rew=[0.], ep_len=[0], ep_start=[0]
Step 2: idx=[2], ep_rew=[0.], ep_len=[0], ep_start=[0]
Step 3: idx=[3], ep_rew=[10.], ep_len=[4], ep_start=[0]
Notice: Episode return (10.0) and length (4) only appear at the end!
Return Values Explained:
idx: Index where the transition was inserted (np.ndarray of shape (1,))ep_rew: Episode return, only non-zero whendone=True(np.ndarray of shape (1,))ep_len: Episode length, only non-zero whendone=True(np.ndarray of shape (1,))ep_start: Index where the episode started (np.ndarray of shape (1,))
This automatic computation eliminates manual episode tracking during data collection.
# Continue with Episode 2: 5 steps
print("Episode 2: 5 steps, truncated (time limit)")
for i in range(4, 9):
idx, ep_rew, ep_len, ep_start = traj_buf.add(
Batch(
obs=i,
act=i,
rew=float(i + 1),
terminated=False,
truncated=i == 8, # Last step truncated
obs_next=i + 1,
info={},
)
)
if i == 8:
print(
f" Final step: idx={idx}, ep_rew={ep_rew[0]:.1f}, ep_len={ep_len[0]}, ep_start={ep_start}"
)
# Episode 3: Ongoing (not finished)
print("\nEpisode 3: 3 steps, ongoing (not done)")
for i in range(9, 12):
idx, ep_rew, ep_len, ep_start = traj_buf.add(
Batch(
obs=i,
act=i,
rew=float(i + 1),
terminated=False,
truncated=False, # Episode continues
obs_next=i + 1,
info={},
)
)
if i == 11:
print(
f" Latest step: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len} (zeros because not done)"
)
print(f"\nBuffer state: {len(traj_buf)} transitions across 2 complete + 1 ongoing episode")
Episode 2: 5 steps, truncated (time limit)
Final step: idx=[8], ep_rew=35.0, ep_len=5, ep_start=[4]
Episode 3: 3 steps, ongoing (not done)
Latest step: idx=[11], ep_rew=[0.], ep_len=[0] (zeros because not done)
Buffer state: 12 transitions across 2 complete + 1 ongoing episode
4.3 Identifying Unfinished Episodes#
The unfinished_index() method returns indices of ongoing episodes:
unfinished = traj_buf.unfinished_index()
print(f"Unfinished episode indices: {unfinished}")
print(f"Latest step of ongoing episode: obs={traj_buf.obs[unfinished[0]]}")
# After finishing episode 3
traj_buf.add(
Batch(
obs=12,
act=12,
rew=13.0,
terminated=True,
truncated=False,
obs_next=13,
info={},
)
)
unfinished_after = traj_buf.unfinished_index()
print("\nAfter finishing episode 3:")
print(f"Unfinished episodes: {unfinished_after} (empty array)")
Unfinished episode indices: [11]
Latest step of ongoing episode: obs=11
After finishing episode 3:
Unfinished episodes: [] (empty array)
5. Sampling Strategies#
Efficient sampling is critical for RL training. The buffer provides several sampling methods and strategies.
5.1 Basic Sampling#
# Create a buffer with some data
sample_buf = ReplayBuffer(size=100)
for i in range(50):
sample_buf.add(
Batch(
obs=i,
act=i % 4,
rew=np.random.random(),
terminated=(i + 1) % 10 == 0,
truncated=False,
obs_next=i + 1,
info={},
)
)
# Sample with batch_size
batch, indices = sample_buf.sample(batch_size=8)
print(f"Sampled batch size: {len(batch)}")
print(f"Sampled indices: {indices}")
print(f"Sampled observations: {batch.obs}")
# batch_size=None: return all data in random order
all_data, all_indices = sample_buf.sample(batch_size=None)
print(f"\nSample all (batch_size=None): {len(all_data)} transitions")
# batch_size=0: return all data in buffer order
ordered_data, ordered_indices = sample_buf.sample(batch_size=0)
print(f"Get all in order (batch_size=0): {len(ordered_data)} transitions")
print(f"Indices in order: {ordered_indices[:10]}...") # Show first 10
Sampled batch size: 8
Sampled indices: [38 28 14 42 7 20 38 18]
Sampled observations: [38 28 14 42 7 20 38 18]
Sample all (batch_size=None): 50 transitions
Get all in order (batch_size=0): 50 transitions
Indices in order: [0 1 2 3 4 5 6 7 8 9]...
Sampling Behavior Summary:
batch_size > 0: Random sample of specified sizebatch_size = None: All data in random orderbatch_size = 0: All data in insertion orderbatch_size < 0: Empty array (edge case handling)
5.2 Frame Stacking#
The stack_num parameter enables automatic frame stacking, useful for RNN inputs or Atari-style environments where temporal context matters:
# Create buffer with frame stacking
stack_buf = ReplayBuffer(size=20, stack_num=4)
# Add observations: 0, 1, 2, ..., 9
for i in range(10):
stack_buf.add(
Batch(
obs=np.array([i]), # Single frame
act=0,
rew=1.0,
terminated=i == 9,
truncated=False,
obs_next=np.array([i + 1]),
info={},
)
)
# Get stacked frames for index 6
# Should return [3, 4, 5, 6] (4 consecutive frames ending at 6)
stacked = stack_buf.get(index=6, key="obs")
print("Frame stacking demo:")
print("Requested index: 6")
print(f"Stacked frames shape: {stacked.shape}")
print(f"Stacked frames: {stacked.flatten()}")
print("\nExplanation: stack_num=4, so index 6 returns [obs[3], obs[4], obs[5], obs[6]]")
Frame stacking demo:
Requested index: 6
Stacked frames shape: (4, 1)
Stacked frames: [3 4 5 6]
Explanation: stack_num=4, so index 6 returns [obs[3], obs[4], obs[5], obs[6]]
# Demonstrate episode boundary handling with frame stacking
boundary_buf = ReplayBuffer(size=20, stack_num=4)
# Episode 1: indices 0-4
for i in range(5):
boundary_buf.add(
Batch(
obs=np.array([i]),
act=0,
rew=1.0,
terminated=i == 4,
truncated=False,
obs_next=np.array([i + 1]),
info={},
)
)
# Episode 2: indices 5-9
for i in range(5, 10):
boundary_buf.add(
Batch(
obs=np.array([i]),
act=0,
rew=1.0,
terminated=i == 9,
truncated=False,
obs_next=np.array([i + 1]),
info={},
)
)
# Try to get stacked frames at episode boundary
boundary_stack = boundary_buf.get(index=6, key="obs") # Early in episode 2
print("\nFrame stacking at episode boundary:")
print(f"Index 6 stacked frames: {boundary_stack.flatten()}")
print("Notice: Frames don't cross episode boundary (5,5,5,6 not 3,4,5,6)")
print("The buffer uses prev() internally, which respects episode boundaries")
Frame stacking at episode boundary:
Index 6 stacked frames: [5 5 5 6]
Notice: Frames don't cross episode boundary (5,5,5,6 not 3,4,5,6)
The buffer uses prev() internally, which respects episode boundaries
Frame Stacking Use Cases:
RNN/LSTM inputs: Provide temporal context to recurrent networks
Atari games: Stack 4 frames to capture motion (as in DQN paper)
Velocity estimation: Multiple frames allow computing derivatives
Partially observable environments: Build up state estimates
Important Notes:
Frame stacking respects episode boundaries (won’t stack across episodes)
Set
sample_avail=Trueto only sample indices where full stacks are availablesave_only_last_obs=Truesaves memory in Atari-style setups
6. VectorReplayBuffer: Parallel Environment Support#
VectorReplayBuffer is essential for modern RL training with parallel environments. It maintains separate subbuffers for each environment while providing a unified interface.
6.1 Motivation and Architecture#
When training with multiple parallel environments (e.g., 8 environments running simultaneously), we need:
Per-environment episode tracking: Each environment has its own episode boundaries
Temporal ordering: Preserve the sequence of events within each environment
Unified sampling: Sample uniformly across all environments for training
graph LR
E1[Env 1] --> B1[Subbuffer 1<br/>2500 capacity]
E2[Env 2] --> B2[Subbuffer 2<br/>2500 capacity]
E3[Env 3] --> B3[Subbuffer 3<br/>2500 capacity]
E4[Env 4] --> B4[Subbuffer 4<br/>2500 capacity]
B1 --> VRB[VectorReplayBuffer<br/>Total: 10000<br/>Unified Sampling]
B2 --> VRB
B3 --> VRB
B4 --> VRB
VRB --> Policy[Policy Training]
style E1 fill:#e1f5ff
style E2 fill:#e1f5ff
style E3 fill:#e1f5ff
style E4 fill:#e1f5ff
style B1 fill:#fff4e1
style B2 fill:#fff4e1
style B3 fill:#fff4e1
style B4 fill:#fff4e1
style VRB fill:#ffe1f5
style Policy fill:#e8f5e1
# Create VectorReplayBuffer for 4 parallel environments
vec_buf = VectorReplayBuffer(
total_size=100, # Total capacity across all subbuffers
buffer_num=4, # Number of parallel environments
)
print("VectorReplayBuffer created:")
print(f"Total size: {vec_buf.maxsize}")
print(f"Number of subbuffers: {vec_buf.buffer_num}")
print(f"Size per subbuffer: {vec_buf.maxsize // vec_buf.buffer_num}")
print(f"Subbuffer edges: {vec_buf.subbuffer_edges}")
print("\nSubbuffer edges define the boundary indices: [0, 25, 50, 75, 100]")
print("Subbuffer 0: indices 0-24, Subbuffer 1: indices 25-49, etc.")
VectorReplayBuffer created:
Total size: 100
Number of subbuffers: 4
Size per subbuffer: 25
Subbuffer edges: [ 0 25 50 75 100]
Subbuffer edges define the boundary indices: [0, 25, 50, 75, 100]
Subbuffer 0: indices 0-24, Subbuffer 1: indices 25-49, etc.
6.2 The buffer_ids Parameter#
This is one of the most confusing aspects for new users. The buffer_ids parameter specifies which subbuffer each transition belongs to.
# Simulate data from 4 parallel environments
# Each environment produces one transition
parallel_batch = Batch(
obs=np.array([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2], [3.1, 3.2]]), # 4 observations
act=np.array([0, 1, 0, 1]), # 4 actions
rew=np.array([1.0, 2.0, 3.0, 4.0]), # 4 rewards
terminated=np.array([False, False, False, False]),
truncated=np.array([False, False, False, False]),
obs_next=np.array([[0.2, 0.3], [1.2, 1.3], [2.2, 2.3], [3.2, 3.3]]),
info=np.array([{}, {}, {}, {}], dtype=object),
)
print("Parallel batch shape:", parallel_batch.obs.shape)
print("This represents 4 transitions, one from each environment")
# Add with buffer_ids specifying which subbuffer each transition goes to
indices, ep_rews, ep_lens, ep_starts = vec_buf.add(
parallel_batch,
buffer_ids=[0, 1, 2, 3], # Transition 0→Subbuf 0, 1→Subbuf 1, etc.
)
print(f"\nAdded to indices: {indices}")
print("Notice: Indices are in different subbuffers:")
print(f" Index {indices[0]} in subbuffer 0 (range 0-24)")
print(f" Index {indices[1]} in subbuffer 1 (range 25-49)")
print(f" Index {indices[2]} in subbuffer 2 (range 50-74)")
print(f" Index {indices[3]} in subbuffer 3 (range 75-99)")
Parallel batch shape: (4, 2)
This represents 4 transitions, one from each environment
Added to indices: [ 0 25 50 75]
Notice: Indices are in different subbuffers:
Index 0 in subbuffer 0 (range 0-24)
Index 25 in subbuffer 1 (range 25-49)
Index 50 in subbuffer 2 (range 50-74)
Index 75 in subbuffer 3 (range 75-99)
# Add more data to demonstrate buffer_ids
# Environments don't always produce data in order 0,1,2,3
# For example, if only environments 1 and 3 are ready:
partial_batch = Batch(
obs=np.array([[1.2, 1.3], [3.2, 3.3]]), # Only 2 observations
act=np.array([0, 1]),
rew=np.array([2.5, 4.5]),
terminated=np.array([False, False]),
truncated=np.array([False, False]),
obs_next=np.array([[1.3, 1.4], [3.3, 3.4]]),
info=np.array([{}, {}], dtype=object),
)
# Only environments 1 and 3 produced data
indices2, _, _, _ = vec_buf.add(
partial_batch,
buffer_ids=[1, 3], # Only these two subbuffers receive data
)
print("Added partial batch (only envs 1 and 3):")
print(f"Indices: {indices2}")
print(f"Subbuffer 1 received data at index {indices2[0]}")
print(f"Subbuffer 3 received data at index {indices2[1]}")
Added partial batch (only envs 1 and 3):
Indices: [26 76]
Subbuffer 1 received data at index 26
Subbuffer 3 received data at index 76
Important: buffer_ids Requirements:
For VectorReplayBuffer:
buffer_idslength must match batch sizeValues must be in range [0, buffer_num)
Can be partial (not all environments at once)
For regular ReplayBuffer:
If
buffer_idsis not None, it must be [0]Batch must have shape (1, data_length)
This is for API compatibility with VectorReplayBuffer
6.3 Subbuffer Edges and Episode Handling#
Subbuffer edges prevent episodes from spanning across subbuffers, ensuring data from different environments doesn’t get mixed:
# The subbuffer_edges property defines boundaries
print(f"Subbuffer edges: {vec_buf.subbuffer_edges}")
print("\nThis creates 4 subbuffers:")
for i in range(vec_buf.buffer_num):
start = vec_buf.subbuffer_edges[i]
end = vec_buf.subbuffer_edges[i + 1]
print(f"Subbuffer {i}: indices [{start}, {end})")
# Episodes cannot cross these boundaries
# prev() and next() respect subbuffer edges just like episode boundaries
test_idx = np.array([24, 25, 49, 50]) # At subbuffer edges
prev_result = vec_buf.prev(test_idx)
next_result = vec_buf.next(test_idx)
print("\nBoundary navigation test:")
print(f"Indices: {test_idx}")
print(f"prev(): {prev_result}")
print(f"next(): {next_result}")
print("\nNotice: prev/next don't cross subbuffer boundaries")
Subbuffer edges: [ 0 25 50 75 100]
This creates 4 subbuffers:
Subbuffer 0: indices [0, 25)
Subbuffer 1: indices [25, 50)
Subbuffer 2: indices [50, 75)
Subbuffer 3: indices [75, 100)
Boundary navigation test:
Indices: [24 25 49 50]
prev(): [ 0 25 25 50]
next(): [ 0 26 26 50]
Notice: prev/next don't cross subbuffer boundaries
6.4 Sampling from VectorReplayBuffer#
Sampling is uniform across all subbuffers (proportional to their current fill level):
# Add more data to have enough for sampling
for _step in range(10):
batch = Batch(
obs=np.random.randn(4, 2),
act=np.random.randint(0, 2, size=4),
rew=np.random.random(4),
terminated=np.zeros(4, dtype=bool),
truncated=np.zeros(4, dtype=bool),
obs_next=np.random.randn(4, 2),
info=np.array([{}] * 4, dtype=object),
)
vec_buf.add(batch, buffer_ids=[0, 1, 2, 3])
# Sample batch
sampled, indices = vec_buf.sample(batch_size=16)
print(f"Sampled {len(sampled)} transitions")
print(f"Sample indices (from different subbuffers): {indices}")
print("\nNotice indices span across all subbuffer ranges")
Sampled 16 transitions
Sample indices (from different subbuffers): [ 6 3 10 7 4 6 9 31 56 53 60 57 81 78 85 82]
Notice indices span across all subbuffer ranges
7. Specialized Buffer Variants#
7.1 PrioritizedReplayBuffer#
Implements prioritized experience replay where transitions are sampled based on their TD-error magnitudes:
# Create prioritized buffer
prio_buf = PrioritizedReplayBuffer(
size=100,
alpha=0.6, # Prioritization exponent (0=uniform, 1=fully prioritized)
beta=0.4, # Importance sampling correction (annealed to 1)
)
# Add some transitions
for i in range(20):
prio_buf.add(
Batch(
obs=np.array([i]),
act=i % 4,
rew=np.random.random(),
terminated=False,
truncated=False,
obs_next=np.array([i + 1]),
info={},
)
)
# Sample returns batch and indices
# Importance weights are INSIDE the batch as batch.weight
batch, indices = prio_buf.sample(batch_size=8)
print(f"Sampled batch size: {len(batch)}")
print(f"Indices: {indices}")
print(f"Importance weights (batch.weight): {batch.weight}")
print("\nWeights are stored in batch.weight and compensate for biased sampling")
Sampled batch size: 8
Indices: [ 3 12 14 12 19 9 8 7]
Importance weights (batch.weight): [1. 1. 1. 1. 1. 1. 1. 1.]
Weights are stored in batch.weight and compensate for biased sampling
# After computing TD-errors from the sampled batch, update priorities
# In practice, these would be actual TD-errors: |Q(s,a) - (r + γ*max Q(s',a'))|
fake_td_errors = np.random.random(len(indices)) * 10 # Simulated TD-errors
# Update priorities (higher TD-error = higher priority)
prio_buf.update_weight(indices, fake_td_errors)
print("Updated priorities based on TD-errors")
print("Transitions with higher TD-errors will be sampled more frequently")
# Demonstrate beta annealing
prio_buf.set_beta(0.6) # Increase beta over training
print(f"\nAnnealed beta to: {prio_buf.options['beta']}")
print("Beta typically starts at 0.4 and anneals to 1.0 over training")
Updated priorities based on TD-errors
Transitions with higher TD-errors will be sampled more frequently
Annealed beta to: 0.4
Beta typically starts at 0.4 and anneals to 1.0 over training
PrioritizedReplayBuffer Use Cases:
Rainbow DQN and variants
Any algorithm where some transitions are more “surprising” and valuable
Environments with rare but important events
Key Parameters:
alpha: Controls how much prioritization affects sampling (0=uniform, 1=fully proportional to priority)beta: Importance sampling correction to remain unbiased (anneal from ~0.4 to 1.0)
7.2 Other Specialized Buffers#
CachedReplayBuffer: Maintains a primary buffer plus auxiliary caches
Use case: Imitation learning where you want separate expert and agent buffers
Example: GAIL (Generative Adversarial Imitation Learning)
Allows different sampling ratios from different sources
HERReplayBuffer: Hindsight Experience Replay for goal-conditioned tasks
Use case: Sparse reward robotics tasks
Relabels failed episodes with achieved goals as if they were intended
Dramatically improves learning in goal-reaching tasks
See the HER documentation for detailed examples
For detailed usage of these specialized buffers, refer to the Tianshou API documentation and algorithm-specific tutorials.
8. Serialization and Persistence#
Buffers support multiple serialization formats for saving and loading data.
8.1 Pickle Serialization#
The simplest method, preserving all buffer state including trajectory metadata:
# Create and populate a buffer
save_buf = ReplayBuffer(size=50)
for i in range(30):
save_buf.add(
Batch(
obs=np.array([i, i + 1]),
act=i % 4,
rew=float(i),
terminated=(i + 1) % 10 == 0,
truncated=False,
obs_next=np.array([i + 1, i + 2]),
info={"step": i},
)
)
print(f"Original buffer: {len(save_buf)} transitions")
# Serialize with pickle
pickled_data = pickle.dumps(save_buf)
print(f"Serialized size: {len(pickled_data)} bytes")
# Deserialize
loaded_buf = pickle.loads(pickled_data)
print(f"Loaded buffer: {len(loaded_buf)} transitions")
print(f"Data preserved: obs[0] = {loaded_buf.obs[0]}")
print(f"Metadata preserved: info[0] = {loaded_buf.info[0]}")
Original buffer: 30 transitions
Serialized size: 7073 bytes
Loaded buffer: 30 transitions
Data preserved: obs[0] = [0 1]
Metadata preserved: info[0] = Batch(
step: 0,
)
8.2 HDF5 Serialization#
HDF5 is recommended for large datasets and cross-platform compatibility:
# Save to HDF5
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as tmp:
hdf5_path = tmp.name
save_buf.save_hdf5(hdf5_path, compression="gzip")
print(f"Saved to HDF5: {hdf5_path}")
# Load from HDF5
loaded_hdf5_buf = ReplayBuffer.load_hdf5(hdf5_path)
print(f"Loaded from HDF5: {len(loaded_hdf5_buf)} transitions")
print(f"Data matches: {np.array_equal(save_buf.obs, loaded_hdf5_buf.obs)}")
# Clean up
import os
os.unlink(hdf5_path)
Saved to HDF5: /tmp/tmpt3ia27p1.hdf5
Loaded from HDF5: 30 transitions
Data matches: True
When to Use HDF5:
Large datasets (> 1GB)
Offline RL with pre-collected data
Sharing data across platforms
Need for compression
Integration with external tools (many scientific tools read HDF5)
When to Use Pickle:
Quick saves during development
Small buffers
Python-only workflow
Simpler serialization needs
8.3 Loading from Raw Data with from_data()#
For offline RL, you can create a buffer from raw arrays:
# Simulate pre-collected offline dataset
import h5py
# Create temporary HDF5 file with raw data
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as tmp:
offline_path = tmp.name
with h5py.File(offline_path, "w") as f:
# Create datasets
n = 100
f.create_dataset("obs", data=np.random.randn(n, 4))
f.create_dataset("act", data=np.random.randint(0, 2, n))
f.create_dataset("rew", data=np.random.randn(n))
f.create_dataset("terminated", data=np.random.random(n) < 0.1)
f.create_dataset("truncated", data=np.zeros(n, dtype=bool))
f.create_dataset("done", data=np.random.random(n) < 0.1)
f.create_dataset("obs_next", data=np.random.randn(n, 4))
# Load into buffer
with h5py.File(offline_path, "r") as f:
offline_buf = ReplayBuffer.from_data(
obs=f["obs"],
act=f["act"],
rew=f["rew"],
terminated=f["terminated"],
truncated=f["truncated"],
done=f["done"],
obs_next=f["obs_next"],
)
print(f"Loaded offline dataset: {len(offline_buf)} transitions")
print(f"Observation shape: {offline_buf.obs.shape}")
# Clean up
os.unlink(offline_path)
Loaded offline dataset: 100 transitions
Observation shape: (100, 4)
This is the standard approach for offline RL where you have pre-collected datasets from other sources.
9. Integration with the RL Pipeline#
Understanding how buffers integrate with other Tianshou components is essential for effective usage.
9.1 Data Flow in RL Training#
graph LR
ENV[Vectorized<br/>Environments] -->|observations| COL[Collector]
POL[Policy] -->|actions| COL
COL -->|transitions| BUF[Buffer]
BUF -->|sampled batches| POL
POL -->|forward pass| ALG[Algorithm]
ALG -->|loss & gradients| POL
style ENV fill:#e1f5ff
style COL fill:#fff4e1
style BUF fill:#ffe1f5
style POL fill:#e8f5e1
style ALG fill:#f5e1e1
9.2 Typical Training Loop Pattern#
Here’s how buffers are typically used in a training loop:
# Pseudocode for typical RL training loop
# (This is illustrative; actual implementation would use Trainer)
def training_loop_pseudocode():
"""
Illustrative training loop showing buffer integration.
In practice, use Tianshou's Trainer class which handles this.
"""
# Setup (illustration only)
# env = make_vectorized_env(num_envs=8)
# policy = make_policy()
# buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)
# collector = Collector(policy, env, buffer)
# Training loop
# for epoch in range(num_epochs):
# # 1. Collect data from environments
# collect_result = collector.collect(n_step=1000)
# # Collector automatically adds transitions to buffer with correct buffer_ids
#
# # 2. Train on multiple batches
# for _ in range(update_per_collect):
# # Sample batch from buffer
# batch, indices = buffer.sample(batch_size=256)
#
# # Compute loss and update policy
# loss = policy.learn(batch)
#
# # For prioritized buffers, update priorities
# # if isinstance(buffer, PrioritizedReplayBuffer):
# # buffer.update_weight(indices, td_errors)
print("This pseudocode illustrates the buffer's role:")
print("1. Collector fills buffer from environment interaction")
print("2. Buffer provides random samples for training")
print("3. Policy learns from sampled batches")
print("\nIn practice, use Tianshou's Trainer for this workflow")
training_loop_pseudocode()
This pseudocode illustrates the buffer's role:
1. Collector fills buffer from environment interaction
2. Buffer provides random samples for training
3. Policy learns from sampled batches
In practice, use Tianshou's Trainer for this workflow
9.3 Collector Integration#
The Collector class handles the complexity of:
Calling policy to get actions
Stepping environments
Adding transitions to buffer with correct buffer_ids
Tracking episode statistics
When you create a Collector, you pass it a buffer, and it automatically:
Uses VectorReplayBuffer for vectorized environments
Sets buffer_ids based on which environments are ready
Handles episode resets and boundary tracking
See the Collector tutorial for detailed examples of this integration.
10. Advanced Topics and Edge Cases#
10.1 Buffer Overflow and Episode Boundaries#
What happens when the buffer fills up mid-episode?
# Small buffer to demonstrate overflow
overflow_buf = ReplayBuffer(size=8)
# Add a long episode (12 steps, buffer size is only 8)
print("Adding 12-step episode to buffer with size 8:")
for i in range(12):
idx, ep_rew, ep_len, ep_start = overflow_buf.add(
Batch(
obs=i,
act=0,
rew=1.0,
terminated=i == 11,
truncated=False,
obs_next=i + 1,
info={},
)
)
if i in [7, 11]:
print(f" Step {i}: idx={idx}, buffer_len={len(overflow_buf)}")
print("\nFinal buffer contents (most recent 8 steps):")
print(f"Observations: {overflow_buf.obs[: len(overflow_buf)]}")
print(f"Episode return: {ep_rew[0]} (sum of all 12 steps, tracked correctly!)")
print("\nNote: Buffer overwrote old data but episode statistics are still correct")
Adding 12-step episode to buffer with size 8:
Step 7: idx=[7], buffer_len=8
Step 11: idx=[3], buffer_len=8
Final buffer contents (most recent 8 steps):
Observations: [ 8 9 10 11 4 5 6 7]
Episode return: 12.0 (sum of all 12 steps, tracked correctly!)
Note: Buffer overwrote old data but episode statistics are still correct
Important: Episode returns and lengths are tracked internally and remain correct even when the episode spans buffer overflows. The buffer maintains _ep_return, _ep_len, and _ep_start_idx to track ongoing episodes.
10.2 Episode Spanning Subbuffer Edges#
In VectorReplayBuffer, episodes can wrap around within their subbuffer:
# Create small VectorReplayBuffer to demonstrate edge crossing
edge_buf = VectorReplayBuffer(total_size=20, buffer_num=2) # 10 per subbuffer
print(f"Subbuffer edges: {edge_buf.subbuffer_edges}")
print("Subbuffer 0: indices 0-9, Subbuffer 1: indices 10-19\n")
# Fill subbuffer 0 with 12 steps (wraps around since capacity is 10)
for i in range(12):
batch = Batch(
obs=np.array([[i]]),
act=np.array([0]),
rew=np.array([1.0]),
terminated=np.array([i == 11]),
truncated=np.array([False]),
obs_next=np.array([[i + 1]]),
info=np.array([{}], dtype=object),
)
idx, _, _, _ = edge_buf.add(batch, buffer_ids=[0])
if i >= 10:
print(f"Step {i} added at index {idx[0]} (wrapped around in subbuffer 0)")
# get_buffer_indices handles this correctly
episode_indices = edge_buf.get_buffer_indices(start=8, stop=2) # Crosses edge
print(f"\nEpisode spanning edge (from 8 to 1): {episode_indices}")
print("Correctly retrieves [8, 9, 0, 1] within subbuffer 0")
Subbuffer edges: [ 0 10 20]
Subbuffer 0: indices 0-9, Subbuffer 1: indices 10-19
Step 10 added at index 0 (wrapped around in subbuffer 0)
Step 11 added at index 1 (wrapped around in subbuffer 0)
Episode spanning edge (from 8 to 1): [8 9 0 1]
Correctly retrieves [8, 9, 0, 1] within subbuffer 0
10.3 ignore_obs_next Memory Optimization#
For memory-constrained scenarios, you can avoid storing obs_next:
# Buffer that doesn't store obs_next
memory_buf = ReplayBuffer(size=10, ignore_obs_next=True)
# Add transitions (obs_next is ignored)
for i in range(5):
memory_buf.add(
Batch(
obs=np.array([i, i + 1]),
act=i,
rew=1.0,
terminated=False,
truncated=False,
obs_next=np.array([i + 1, i + 2]), # Provided but not stored
info={},
)
)
# When sampling, obs_next is reconstructed from next obs
sample, _ = memory_buf.sample(batch_size=1)
print(f"Sampled obs: {sample.obs}")
print(f"Sampled obs_next: {sample.obs_next}")
print("\nobs_next was reconstructed, not stored directly")
print("This saves memory at the cost of slightly more complex retrieval")
Sampled obs: [[3 4]]
Sampled obs_next: [[4 5]]
obs_next was reconstructed, not stored directly
This saves memory at the cost of slightly more complex retrieval
This is particularly useful for Atari environments with large observation spaces (84x84x4 frames).
11. Surprising Behaviors and Gotchas#
11.1 Most Common Mistake: buffer_ids Confusion#
The buffer_ids parameter is the most common source of errors:
# COMMON ERROR 1: Forgetting buffer_ids with VectorReplayBuffer
vec_demo = VectorReplayBuffer(total_size=100, buffer_num=4)
parallel_data = Batch(
obs=np.random.randn(4, 2),
act=np.array([0, 1, 0, 1]),
rew=np.array([1.0, 2.0, 3.0, 4.0]),
terminated=np.array([False, False, False, False]),
truncated=np.array([False, False, False, False]),
obs_next=np.random.randn(4, 2),
info=np.array([{}, {}, {}, {}], dtype=object),
)
# WRONG: Omitting buffer_ids (defaults to [0,1,2,3] which is OK here)
# But if you have partial data, this will fail
vec_demo.add(parallel_data) # Works by default
# CORRECT: Always explicit
vec_demo.add(parallel_data, buffer_ids=[0, 1, 2, 3])
print("Always specify buffer_ids explicitly for clarity")
Always specify buffer_ids explicitly for clarity
# COMMON ERROR 2: Shape mismatch with buffer_ids
try:
# Trying to add 2 transitions but specifying 4 buffer_ids
wrong_batch = Batch(
obs=np.random.randn(2, 2), # Only 2 transitions!
act=np.array([0, 1]),
rew=np.array([1.0, 2.0]),
terminated=np.array([False, False]),
truncated=np.array([False, False]),
obs_next=np.random.randn(2, 2),
info=np.array([{}, {}], dtype=object),
)
vec_demo.add(wrong_batch, buffer_ids=[0, 1, 2, 3]) # MISMATCH!
except (IndexError, ValueError) as e:
print(f"Error caught: {type(e).__name__}")
print("Lesson: buffer_ids length must match batch size")
Error caught: IndexError
Lesson: buffer_ids length must match batch size
11.2 Done Flag Confusion#
Never manually set the done flag:
# WRONG: Manually setting done
wrong_batch = Batch(
obs=1,
act=0,
rew=1.0,
terminated=True,
truncated=False,
# done=True, # DON'T DO THIS! It will be overwritten anyway
obs_next=2,
info={},
)
# CORRECT: Only set terminated and truncated
# done is automatically computed as (terminated OR truncated)
correct_batch = Batch(
obs=1,
act=0,
rew=1.0,
terminated=True, # Episode ended naturally
truncated=False, # Not cut off
obs_next=2,
info={},
)
demo = ReplayBuffer(size=10)
demo.add(correct_batch)
print(f"Terminated: {demo.terminated[0]}")
print(f"Truncated: {demo.truncated[0]}")
print(f"Done (auto-computed): {demo.done[0]}")
Terminated: True
Truncated: False
Done (auto-computed): True
11.3 Sampling from Empty or Near-Empty Buffers#
# Edge case: Sampling more than available
small_buf = ReplayBuffer(size=100)
for i in range(5): # Only 5 transitions
small_buf.add(
Batch(obs=i, act=0, rew=1.0, terminated=False, truncated=False, obs_next=i + 1, info={})
)
# Request 20 but only 5 available - samples with replacement
batch, indices = small_buf.sample(batch_size=20)
print(f"Requested 20, buffer has {len(small_buf)}, got {len(batch)}")
print(f"Indices: {indices}")
print("Notice: Some indices repeat (sampling with replacement)")
# Defensive pattern: Check buffer size
if len(small_buf) >= 128:
batch, _ = small_buf.sample(128)
else:
print(f"Buffer has {len(small_buf)} < 128, waiting for more data")
Requested 20, buffer has 5, got 20
Indices: [3 4 2 4 4 1 2 2 2 4 3 2 4 1 3 1 3 4 0 3]
Notice: Some indices repeat (sampling with replacement)
Buffer has 5 < 128, waiting for more data
11.4 Frame Stacking Valid Indices#
With stack_num > 1, not all indices are valid for sampling:
# With frame stacking, early indices can't form complete stacks
stack_demo = ReplayBuffer(size=20, stack_num=4, sample_avail=True)
for i in range(10):
stack_demo.add(
Batch(
obs=np.array([i]),
act=0,
rew=1.0,
terminated=i == 9,
truncated=False,
obs_next=np.array([i + 1]),
info={},
)
)
# With sample_avail=True, only valid indices are sampled
sampled, indices = stack_demo.sample(batch_size=5)
print(f"Sampled indices with stack_num=4, sample_avail=True: {indices}")
print("All indices >= 3 (can form complete 4-frame stacks)")
# Without sample_avail, any index can be sampled (may have incomplete stacks)
stack_demo2 = ReplayBuffer(size=20, stack_num=4, sample_avail=False)
for i in range(10):
stack_demo2.add(
Batch(
obs=np.array([i]),
act=0,
rew=1.0,
terminated=False,
truncated=False,
obs_next=np.array([i + 1]),
info={},
)
)
sampled2, indices2 = stack_demo2.sample(batch_size=5)
print(f"\nSampled indices with sample_avail=False: {indices2}")
print("May include indices < 3 (incomplete stacks repeated from boundary)")
Sampled indices with stack_num=4, sample_avail=True: [9 6 7 9 5]
All indices >= 3 (can form complete 4-frame stacks)
Sampled indices with sample_avail=False: [6 3 7 4 6]
May include indices < 3 (incomplete stacks repeated from boundary)
12. Best Practices#
12.1 Choosing the Right Buffer#
Decision Tree:
Are you using parallel environments?
Yes → Use
VectorReplayBufferNo → Continue to 2
Do you need prioritized experience replay?
Yes → Use
PrioritizedReplayBufferorPrioritizedVectorReplayBufferNo → Continue to 3
Is it goal-conditioned RL with sparse rewards?
Yes → Use
HERReplayBufferorHERVectorReplayBufferNo → Continue to 4
Do you need separate expert and agent buffers?
Yes → Use
CachedReplayBufferNo → Use
ReplayBuffer(single env) orVectorReplayBuffer(standard choice)
Most Common Setup: VectorReplayBuffer for production training
12.2 Buffer Sizing Guidelines#
Rule of Thumb by Domain:
Atari games: 1,000,000 transitions (1e6)
Continuous control (MuJoCo): 100,000-1,000,000 (1e5-1e6)
Robotics: 100,000-500,000 (1e5-5e5)
Simple environments (CartPole): 10,000-50,000 (1e4-5e4)
Factors to Consider:
Available RAM (each transition ~observation_size * 2 + metadata)
Training time vs sample efficiency tradeoff
Algorithm requirements (some need larger buffers)
Memory Estimation:
# For environments with observation shape (84, 84, 4) (Atari):
# Each transition: 2 * 84 * 84 * 4 bytes (obs + obs_next) + ~100 bytes overhead
# = ~56KB per transition
# 1M transitions = ~56GB (use ignore_obs_next to halve this!)
12.3 Configuration Best Practices#
When to use stack_num > 1:
RNN/LSTM policies need temporal context
Frame-based policies (Atari with 4-frame stacking)
Velocity estimation from positions
When to use ignore_obs_next=True:
Memory-constrained environments
Atari (large observation spaces)
When obs_next can be reconstructed from next obs
When to use save_only_last_obs=True:
Atari with temporal stacking in environment wrapper
When observations already contain frame history
When to use sample_avail=True:
Always use with stack_num > 1 for correctness
Ensures samples have complete frame stacks
Small performance cost but worth it for data quality
12.4 Integration Patterns#
Pattern 1: Standard Off-Policy Setup
# env = make_vectorized_env(num_envs=8)
# buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)
# policy = SACPolicy(...)
# collector = Collector(policy, env, buffer)
#
# # Collect and train
# collector.collect(n_step=1000)
# for _ in range(10):
# batch, indices = buffer.sample(256)
# policy.learn(batch)
Pattern 2: Pre-fill Buffer Before Training
# # Collect random exploration data
# collector.collect(n_step=10000) # Fill buffer
#
# # Then start training
# while not converged:
# collector.collect(n_step=100)
# for _ in range(10):
# batch = buffer.sample(256)
# policy.learn(batch)
Pattern 3: Offline RL
# # Load pre-collected dataset
# buffer = ReplayBuffer.load_hdf5("expert_data.hdf5")
#
# # Train without further collection
# for epoch in range(num_epochs):
# for _ in range(updates_per_epoch):
# batch = buffer.sample(256)
# policy.learn(batch)
12.5 Performance Tips#
Tip 1: Pre-allocate buffer size appropriately
Don’t make buffer too large (wastes memory)
Don’t make it too small (loses important old experiences)
Start with domain defaults and adjust based on performance
Tip 2: Use HDF5 for large offline datasets
Compression saves disk space
Faster loading than pickle for large files
Better for sharing across systems
Tip 3: Batch sampling efficiently
Sample once and use multiple times if possible
Don’t sample more than you need
For multi-GPU training, sample once and split
Tip 4: Monitor buffer usage
# print(f"Buffer usage: {len(buffer)}/{buffer.maxsize}")
# if len(buffer) < batch_size:
# print("Warning: Sampling with replacement!")
Tip 5: Consider ignore_obs_next for large observation spaces
Can halve memory usage
Small computational overhead on sampling
Especially valuable for image-based RL
13. Quick Reference#
Method Summary#
Method |
Purpose |
Returns |
Notes |
|---|---|---|---|
|
Add transition(s) |
|
ep_rew/ep_len only non-zero when done=True |
|
Random sample |
|
size=None for all (random), 0 for all (ordered) |
|
Previous in episode |
|
Stops at episode boundaries |
|
Next in episode |
|
Stops at episode boundaries |
|
Get with stacking |
|
Returns stacked frames if stack_num > 1 |
|
Episode range |
|
Handles edge-crossing episodes |
|
Ongoing episodes |
|
Returns last step of unfinished episodes |
|
Save to HDF5 |
- |
Recommended for large datasets |
|
Load from HDF5 |
|
Class method |
|
Create from arrays |
|
For offline RL datasets |
|
Clear buffer |
- |
Optionally keep episode statistics |
|
Get indices only |
|
For custom sampling logic |
Common Patterns Cheatsheet#
Single Environment:
buffer = ReplayBuffer(size=10000)
buffer.add(Batch(obs=..., act=..., rew=..., terminated=..., truncated=..., obs_next=..., info={}))
batch, indices = buffer.sample(batch_size=256)
Parallel Environments:
buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)
buffer.add(parallel_batch, buffer_ids=[0,1,2,3,4,5,6,7])
batch, indices = buffer.sample(batch_size=256)
Frame Stacking:
buffer = ReplayBuffer(size=100000, stack_num=4, sample_avail=True)
stacked_obs = buffer.get(index=50, key="obs") # Returns 4 stacked frames
Prioritized Replay:
buffer = PrioritizedReplayBuffer(size=100000, alpha=0.6, beta=0.4)
batch, indices = buffer.sample(batch_size=256)
weights = batch.weight # Importance weights are inside the batch
# ... compute TD errors ...
buffer.update_weight(indices, td_errors)
Offline RL:
buffer = ReplayBuffer.load_hdf5("dataset.hdf5")
# Or:
with h5py.File("dataset.hdf5", "r") as f:
buffer = ReplayBuffer.from_data(obs=f["obs"], act=f["act"], ...)
Episode Retrieval:
# Find episode boundaries, then:
episode_indices = buffer.get_buffer_indices(start=ep_start_idx, stop=ep_end_idx+1)
episode = buffer[episode_indices]
Summary and Next Steps#
This tutorial covered Tianshou’s buffer system comprehensively:
Buffer fundamentals: Why buffers are essential for RL
Buffer hierarchy: Understanding different buffer types
Basic operations: Construction, configuration, and data management
Trajectory management: Episode tracking and boundary navigation
Sampling strategies: Basic sampling and frame stacking
VectorReplayBuffer: Critical for parallel environments
Specialized buffers: Prioritized, cached, and HER variants
Serialization: Pickle and HDF5 persistence
Integration: How buffers fit in the RL pipeline
Advanced topics: Edge cases and overflow handling
Gotchas: Common mistakes and how to avoid them
Best practices: Configuration, sizing, and performance
Quick reference: Method summary and common patterns
Next Steps#
Collector Deep Dive: Learn how Collector fills buffers from environments
Policy Tutorial: Understand how policies sample from buffers for training
Algorithm Examples: See buffer usage in specific algorithms (DQN, SAC, PPO)
API Reference: Full details at Buffer API documentation
Further Resources#
Tianshou GitHub for source code and examples
Gymnasium Documentation for environment conventions
Research papers on experience replay and prioritized sampling