Buffer: Experience Replay in Tianshou#

The replay buffer is a fundamental component in reinforcement learning, particularly for off-policy algorithms. Tianshou’s buffer implementation extends beyond simple data storage to provide sophisticated trajectory tracking, efficient sampling, and seamless integration with the RL training pipeline.

This tutorial provides comprehensive coverage of Tianshou’s buffer system, from basic concepts to advanced features and integration patterns.

import pickle
import tempfile

import numpy as np

from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer, VectorReplayBuffer

1. Introduction: Why Buffers in Reinforcement Learning?#

The Role of Experience Replay#

Experience replay is a critical technique in modern reinforcement learning that addresses three fundamental challenges:

  1. Breaking Temporal Correlation: Sequential experiences from an agent are highly correlated. Training directly on these sequences can lead to unstable learning. By storing experiences and sampling randomly, we break these correlations.

  2. Sample Efficiency: In RL, collecting data through environment interaction is often expensive. Experience replay allows us to reuse each experience multiple times for training, dramatically improving sample efficiency.

  3. Mini-batch Training: Modern deep learning requires mini-batch gradient descent. Buffers enable efficient batching of experiences for neural network training.

Why Not Alternatives?#

Plain Python Lists

  • No efficient random sampling

  • No automatic circular queue behavior

  • No trajectory boundary tracking

  • Poor memory management for large datasets

Simple Batch Storage

  • No automatic overwriting when full

  • No episode metadata (returns, lengths)

  • No methods for boundary navigation (prev/next)

  • No specialized sampling strategies

Buffer = Batch + Trajectory Management + Sampling#

Tianshou’s buffers build on the Batch class to provide:

  • Circular queue storage: Automatic overwriting of oldest data

  • Trajectory tracking: Episode boundaries, returns, and lengths

  • Efficient sampling: Random access with various strategies

  • Integration utilities: Seamless connection to Collector and Policy

Use Cases#

  • Off-policy algorithms: DQN, SAC, TD3, DDPG require experience replay

  • On-policy with replay: Some PPO implementations reuse buffer data

  • Offline RL: Loading and using pre-collected datasets

  • Multi-environment training: VectorReplayBuffer for parallel collection

2. Buffer Types and Hierarchy#

Tianshou provides several buffer implementations, each designed for specific use cases. Understanding this hierarchy is crucial for choosing the right buffer.

Buffer Hierarchy#

        graph TD
    RB[ReplayBuffer<br/>Single environment<br/>Circular queue] --> RBM[ReplayBufferManager<br/>Manages multiple buffers<br/>Contiguous memory]
    RBM --> VRB[VectorReplayBuffer<br/>Parallel environments<br/>Maintains temporal order]
    
    RB --> PRB[PrioritizedReplayBuffer<br/>TD-error based sampling<br/>Importance weights]
    PRB --> PVRB[PrioritizedVectorReplayBuffer<br/>Prioritized + Parallel]
    
    RB --> CRB[CachedReplayBuffer<br/>Primary + auxiliary caches<br/>Imitation learning]
    
    RB --> HERB[HERReplayBuffer<br/>Hindsight Experience Replay<br/>Goal-conditioned RL]
    HERB --> HVRB[HERVectorReplayBuffer<br/>HER + Parallel]
    
    style RB fill:#e1f5ff
    style RBM fill:#fff4e1
    style VRB fill:#ffe1f5
    style PRB fill:#e8f5e1
    style CRB fill:#f5e1e1
    style HERB fill:#e1e1f5
    

When to Use Which Buffer#

ReplayBuffer: Single environment scenarios

  • Simple setup and testing

  • Debugging algorithms

  • Low-parallelism training

VectorReplayBuffer: Multiple parallel environments (most common)

  • Standard production use case

  • Efficient parallel data collection

  • Maintains per-environment episode boundaries

PrioritizedReplayBuffer: DQN variants with prioritization

  • Rainbow DQN

  • Algorithms requiring importance sampling

  • When some transitions are more valuable than others

CachedReplayBuffer: Separate primary and auxiliary caches

  • Imitation learning (expert + agent data)

  • GAIL and similar algorithms

  • When you need different sampling strategies for different data sources

HERReplayBuffer: Goal-conditioned reinforcement learning

  • Sparse reward environments

  • Robotics tasks with explicit goals

  • Relabeling failed experiences with achieved goals

3. Basic Operations#

3.1 Construction and Configuration#

The ReplayBuffer constructor accepts several important parameters that control its behavior:

# Create a buffer with all configuration options
buf = ReplayBuffer(
    size=20,  # Maximum capacity (transitions)
    stack_num=1,  # Frame stacking for RNNs (default: 1, no stacking)
    ignore_obs_next=False,  # Save memory by not storing obs_next
    save_only_last_obs=False,  # For temporal stacking (Atari-style)
    sample_avail=False,  # Sample only valid indices for frame stacking
    random_seed=42,  # Reproducible sampling
)

print(f"Buffer created: {buf}")
print(f"Max size: {buf.maxsize}")
print(f"Current length: {len(buf)}")
Buffer created: ReplayBuffer()
Max size: 20
Current length: 0

Parameter Explanations:

  • size: Maximum number of transitions the buffer can hold. When full, oldest data is overwritten.

  • stack_num: Number of consecutive frames to stack. Used for RNN inputs or frame-based policies (Atari).

  • ignore_obs_next: If True, obs_next is not stored, saving memory. The buffer reconstructs it from the next obs when needed.

  • save_only_last_obs: For temporal stacking. Only saves the last observation in a stack.

  • sample_avail: When True with stack_num > 1, only samples indices where a complete stack is available.

  • random_seed: Seeds the random number generator for reproducible sampling.

3.2 Reserved Keys and the Done Flag System#

ReplayBuffer uses nine reserved keys that integrate with Gymnasium conventions. Understanding the done flag system is critical.

# The nine reserved keys
print("Reserved keys:")
print(ReplayBuffer._reserved_keys)
print("\nKeys required for add():")
print(ReplayBuffer._required_keys_for_add)
Reserved keys:
('obs', 'act', 'rew', 'terminated', 'truncated', 'done', 'obs_next', 'info', 'policy')

Keys required for add():
{'act', 'truncated', 'rew', 'done', 'obs', 'terminated'}

Important: Understanding done, terminated, and truncated

Gymnasium (the successor to OpenAI Gym) introduced a crucial distinction:

  • terminated: Episode ended naturally (agent reached goal or failed)

    • Examples: CartPole fell over, agent reached goal state

    • Should be used for bootstrapping calculations

  • truncated: Episode was cut off artificially (time limit, external interruption)

    • Examples: Maximum episode length reached, environment reset externally

    • Should NOT be used for bootstrapping (the episode could have continued)

  • done: Computed automatically as terminated OR truncated

    • Used internally for episode boundary tracking

    • You should NEVER manually set this field

Best Practice: Always use the info dictionary for custom metadata rather than adding top-level keys:

# GOOD: Custom metadata in info dictionary
good_batch = Batch(
    obs=np.array([1.0, 2.0]),
    act=0,
    rew=1.0,
    terminated=False,
    truncated=False,
    obs_next=np.array([1.5, 2.5]),
    info={"custom_metric": 0.95, "step_count": 10},  # Custom data here
)

# BAD: Don't add custom top-level keys (may conflict with future buffer features)
# bad_batch = Batch(..., custom_metric=0.95)  # Don't do this!

print("Good batch structure:")
print(good_batch)
Good batch structure:
Batch(
    obs: array([1., 2.]),
    act: array(0),
    rew: array(1.),
    terminated: array(False),
    truncated: array(False),
    obs_next: array([1.5, 2.5]),
    info: Batch(
              custom_metric: array(0.95),
              step_count: array(10),
          ),
)

3.3 Circular Queue Storage#

The buffer implements a circular queue: when it reaches maximum capacity, new data overwrites the oldest entries.

# Create a small buffer to demonstrate circular behavior
demo_buf = ReplayBuffer(size=5)

print("Adding 3 transitions:")
for i in range(3):
    demo_buf.add(
        Batch(
            obs=i,
            act=i,
            rew=float(i),
            terminated=False,
            truncated=False,
            obs_next=i + 1,
            info={},
        )
    )
print(f"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}")
print(f"Observations: {demo_buf.obs[: len(demo_buf)]}")

print("\nAdding 5 more transitions (total 8, exceeds capacity 5):")
for i in range(3, 8):
    demo_buf.add(
        Batch(
            obs=i,
            act=i,
            rew=float(i),
            terminated=False,
            truncated=False,
            obs_next=i + 1,
            info={},
        )
    )
print(f"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}")
print(f"Observations: {demo_buf.obs[: len(demo_buf)]}")
print("\nNotice: First 3 transitions (0,1,2) were overwritten by (3,4,5)")
print("Buffer now contains: [3, 4, 5, 6, 7]")
Adding 3 transitions:
Length: 3, Max: 5
Observations: [0 1 2]

Adding 5 more transitions (total 8, exceeds capacity 5):
Length: 5, Max: 5
Observations: [5 6 7 3 4]

Notice: First 3 transitions (0,1,2) were overwritten by (3,4,5)
Buffer now contains: [3, 4, 5, 6, 7]

3.4 Batch-Compatible Operations#

Since ReplayBuffer extends Batch functionality, it supports standard indexing and slicing:

# Indexing and slicing
print("Last transition:")
print(demo_buf[-1])

print("\nLast 3 transitions:")
print(demo_buf[-3:])

print("\nSpecific indices [0, 2, 4]:")
print(demo_buf[np.array([0, 2, 4])])
Last transition:
Batch(
    obs: array(4),
    act: array(4),
    rew: array(4.),
    terminated: array(False),
    truncated: array(False),
    done: array(False),
    obs_next: array(5),
    info: Batch(),
    policy: Batch(),
)

Last 3 transitions:
Batch(
    obs: array([7, 3, 4]),
    act: array([7, 3, 4]),
    rew: array([7., 3., 4.]),
    terminated: array([False, False, False]),
    truncated: array([False, False, False]),
    done: array([False, False, False]),
    obs_next: array([8, 4, 5]),
    info: Batch(),
    policy: Batch(),
)

Specific indices [0, 2, 4]:
Batch(
    obs: array([5, 7, 4]),
    act: array([5, 7, 4]),
    rew: array([5., 7., 4.]),
    terminated: array([False, False, False]),
    truncated: array([False, False, False]),
    done: array([False, False, False]),
    obs_next: array([6, 8, 5]),
    info: Batch(),
    policy: Batch(),
)

4. Trajectory Management#

A key distinguishing feature of ReplayBuffer is its automatic tracking of episode boundaries and metadata.

4.1 Episode Tracking and Metadata#

The add() method returns four values that provide episode information:

# Create a fresh buffer for trajectory demonstration
traj_buf = ReplayBuffer(size=20)

print("Episode 1: 4 steps, terminates naturally")
for i in range(4):
    idx, ep_rew, ep_len, ep_start = traj_buf.add(
        Batch(
            obs=i,
            act=i,
            rew=float(i + 1),  # Rewards: 1, 2, 3, 4
            terminated=i == 3,  # Last step terminates
            truncated=False,
            obs_next=i + 1,
            info={},
        )
    )
    print(f"  Step {i}: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len}, ep_start={ep_start}")

print("\nNotice: Episode return (10.0) and length (4) only appear at the end!")
Episode 1: 4 steps, terminates naturally
  Step 0: idx=[0], ep_rew=[0.], ep_len=[0], ep_start=[0]
  Step 1: idx=[1], ep_rew=[0.], ep_len=[0], ep_start=[0]
  Step 2: idx=[2], ep_rew=[0.], ep_len=[0], ep_start=[0]
  Step 3: idx=[3], ep_rew=[10.], ep_len=[4], ep_start=[0]

Notice: Episode return (10.0) and length (4) only appear at the end!

Return Values Explained:

  1. idx: Index where the transition was inserted (np.ndarray of shape (1,))

  2. ep_rew: Episode return, only non-zero when done=True (np.ndarray of shape (1,))

  3. ep_len: Episode length, only non-zero when done=True (np.ndarray of shape (1,))

  4. ep_start: Index where the episode started (np.ndarray of shape (1,))

This automatic computation eliminates manual episode tracking during data collection.

# Continue with Episode 2: 5 steps
print("Episode 2: 5 steps, truncated (time limit)")
for i in range(4, 9):
    idx, ep_rew, ep_len, ep_start = traj_buf.add(
        Batch(
            obs=i,
            act=i,
            rew=float(i + 1),
            terminated=False,
            truncated=i == 8,  # Last step truncated
            obs_next=i + 1,
            info={},
        )
    )
    if i == 8:
        print(
            f"  Final step: idx={idx}, ep_rew={ep_rew[0]:.1f}, ep_len={ep_len[0]}, ep_start={ep_start}"
        )

# Episode 3: Ongoing (not finished)
print("\nEpisode 3: 3 steps, ongoing (not done)")
for i in range(9, 12):
    idx, ep_rew, ep_len, ep_start = traj_buf.add(
        Batch(
            obs=i,
            act=i,
            rew=float(i + 1),
            terminated=False,
            truncated=False,  # Episode continues
            obs_next=i + 1,
            info={},
        )
    )
    if i == 11:
        print(
            f"  Latest step: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len} (zeros because not done)"
        )

print(f"\nBuffer state: {len(traj_buf)} transitions across 2 complete + 1 ongoing episode")
Episode 2: 5 steps, truncated (time limit)
  Final step: idx=[8], ep_rew=35.0, ep_len=5, ep_start=[4]

Episode 3: 3 steps, ongoing (not done)
  Latest step: idx=[11], ep_rew=[0.], ep_len=[0] (zeros because not done)

Buffer state: 12 transitions across 2 complete + 1 ongoing episode

4.2 Boundary Navigation: prev() and next()#

The buffer provides methods to navigate within episodes while respecting episode boundaries:

# Examine the buffer structure
print("Buffer contents:")
print(f"Indices:    {np.arange(len(traj_buf))}")
print(f"Obs:        {traj_buf.obs[: len(traj_buf)]}")
print(f"Terminated: {traj_buf.terminated[: len(traj_buf)]}")
print(f"Truncated:  {traj_buf.truncated[: len(traj_buf)]}")
print(f"Done:       {traj_buf.done[: len(traj_buf)]}")
print("\nEpisode boundaries: indices 3 (terminated) and 8 (truncated)")
Buffer contents:
Indices:    [ 0  1  2  3  4  5  6  7  8  9 10 11]
Obs:        [ 0  1  2  3  4  5  6  7  8  9 10 11]
Terminated: [False False False  True False False False False False False False False]
Truncated:  [False False False False False False False False  True False False False]
Done:       [False False False  True False False False False  True False False False]

Episode boundaries: indices 3 (terminated) and 8 (truncated)
# prev() returns the previous index within the same episode
# It STOPS at episode boundaries
test_indices = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
prev_indices = traj_buf.prev(test_indices)

print("prev() behavior:")
print(f"Index:     {test_indices}")
print(f"Prev:      {prev_indices}")
print("\nObservations:")
print("- Index 0 stays at 0 (start of episode 1)")
print("- Index 4 stays at 4 (start of episode 2, can't go back to episode 1)")
print("- Index 9 stays at 9 (start of episode 3, can't go back to episode 2)")
prev() behavior:
Index:     [ 0  1  2  3  4  5  6  7  8  9 10 11]
Prev:      [ 0  0  1  2  4  4  5  6  7  9  9 10]

Observations:
- Index 0 stays at 0 (start of episode 1)
- Index 4 stays at 4 (start of episode 2, can't go back to episode 1)
- Index 9 stays at 9 (start of episode 3, can't go back to episode 2)
# next() returns the next index within the same episode
# It STOPS at episode boundaries
next_indices = traj_buf.next(test_indices)

print("next() behavior:")
print(f"Index:     {test_indices}")
print(f"Next:      {next_indices}")
print("\nObservations:")
print("- Index 3 stays at 3 (end of episode 1, terminated)")
print("- Index 8 stays at 8 (end of episode 2, truncated)")
print("- Indices 9-11 advance normally (episode 3 ongoing)")
next() behavior:
Index:     [ 0  1  2  3  4  5  6  7  8  9 10 11]
Next:      [ 1  2  3  3  5  6  7  8  8 10 11 11]

Observations:
- Index 3 stays at 3 (end of episode 1, terminated)
- Index 8 stays at 8 (end of episode 2, truncated)
- Indices 9-11 advance normally (episode 3 ongoing)

Use Cases for prev() and next():

These methods are essential for computing algorithmic quantities:

  • N-step returns: Use prev() to look back N steps within an episode

  • GAE (Generalized Advantage Estimation): Navigate backwards through episodes

  • Episode extraction: Find episode start/end indices

  • Temporal difference targets: Ensure you don’t bootstrap across episode boundaries

4.3 Identifying Unfinished Episodes#

The unfinished_index() method returns indices of ongoing episodes:

unfinished = traj_buf.unfinished_index()
print(f"Unfinished episode indices: {unfinished}")
print(f"Latest step of ongoing episode: obs={traj_buf.obs[unfinished[0]]}")

# After finishing episode 3
traj_buf.add(
    Batch(
        obs=12,
        act=12,
        rew=13.0,
        terminated=True,
        truncated=False,
        obs_next=13,
        info={},
    )
)

unfinished_after = traj_buf.unfinished_index()
print("\nAfter finishing episode 3:")
print(f"Unfinished episodes: {unfinished_after} (empty array)")
Unfinished episode indices: [11]
Latest step of ongoing episode: obs=11

After finishing episode 3:
Unfinished episodes: [] (empty array)

5. Sampling Strategies#

Efficient sampling is critical for RL training. The buffer provides several sampling methods and strategies.

5.1 Basic Sampling#

# Create a buffer with some data
sample_buf = ReplayBuffer(size=100)
for i in range(50):
    sample_buf.add(
        Batch(
            obs=i,
            act=i % 4,
            rew=np.random.random(),
            terminated=(i + 1) % 10 == 0,
            truncated=False,
            obs_next=i + 1,
            info={},
        )
    )

# Sample with batch_size
batch, indices = sample_buf.sample(batch_size=8)
print(f"Sampled batch size: {len(batch)}")
print(f"Sampled indices: {indices}")
print(f"Sampled observations: {batch.obs}")

# batch_size=None: return all data in random order
all_data, all_indices = sample_buf.sample(batch_size=None)
print(f"\nSample all (batch_size=None): {len(all_data)} transitions")

# batch_size=0: return all data in buffer order
ordered_data, ordered_indices = sample_buf.sample(batch_size=0)
print(f"Get all in order (batch_size=0): {len(ordered_data)} transitions")
print(f"Indices in order: {ordered_indices[:10]}...")  # Show first 10
Sampled batch size: 8
Sampled indices: [38 28 14 42  7 20 38 18]
Sampled observations: [38 28 14 42  7 20 38 18]

Sample all (batch_size=None): 50 transitions
Get all in order (batch_size=0): 50 transitions
Indices in order: [0 1 2 3 4 5 6 7 8 9]...

Sampling Behavior Summary:

  • batch_size > 0: Random sample of specified size

  • batch_size = None: All data in random order

  • batch_size = 0: All data in insertion order

  • batch_size < 0: Empty array (edge case handling)

5.2 Frame Stacking#

The stack_num parameter enables automatic frame stacking, useful for RNN inputs or Atari-style environments where temporal context matters:

# Create buffer with frame stacking
stack_buf = ReplayBuffer(size=20, stack_num=4)

# Add observations: 0, 1, 2, ..., 9
for i in range(10):
    stack_buf.add(
        Batch(
            obs=np.array([i]),  # Single frame
            act=0,
            rew=1.0,
            terminated=i == 9,
            truncated=False,
            obs_next=np.array([i + 1]),
            info={},
        )
    )

# Get stacked frames for index 6
# Should return [3, 4, 5, 6] (4 consecutive frames ending at 6)
stacked = stack_buf.get(index=6, key="obs")
print("Frame stacking demo:")
print("Requested index: 6")
print(f"Stacked frames shape: {stacked.shape}")
print(f"Stacked frames: {stacked.flatten()}")
print("\nExplanation: stack_num=4, so index 6 returns [obs[3], obs[4], obs[5], obs[6]]")
Frame stacking demo:
Requested index: 6
Stacked frames shape: (4, 1)
Stacked frames: [3 4 5 6]

Explanation: stack_num=4, so index 6 returns [obs[3], obs[4], obs[5], obs[6]]
# Demonstrate episode boundary handling with frame stacking
boundary_buf = ReplayBuffer(size=20, stack_num=4)

# Episode 1: indices 0-4
for i in range(5):
    boundary_buf.add(
        Batch(
            obs=np.array([i]),
            act=0,
            rew=1.0,
            terminated=i == 4,
            truncated=False,
            obs_next=np.array([i + 1]),
            info={},
        )
    )

# Episode 2: indices 5-9
for i in range(5, 10):
    boundary_buf.add(
        Batch(
            obs=np.array([i]),
            act=0,
            rew=1.0,
            terminated=i == 9,
            truncated=False,
            obs_next=np.array([i + 1]),
            info={},
        )
    )

# Try to get stacked frames at episode boundary
boundary_stack = boundary_buf.get(index=6, key="obs")  # Early in episode 2
print("\nFrame stacking at episode boundary:")
print(f"Index 6 stacked frames: {boundary_stack.flatten()}")
print("Notice: Frames don't cross episode boundary (5,5,5,6 not 3,4,5,6)")
print("The buffer uses prev() internally, which respects episode boundaries")
Frame stacking at episode boundary:
Index 6 stacked frames: [5 5 5 6]
Notice: Frames don't cross episode boundary (5,5,5,6 not 3,4,5,6)
The buffer uses prev() internally, which respects episode boundaries

Frame Stacking Use Cases:

  • RNN/LSTM inputs: Provide temporal context to recurrent networks

  • Atari games: Stack 4 frames to capture motion (as in DQN paper)

  • Velocity estimation: Multiple frames allow computing derivatives

  • Partially observable environments: Build up state estimates

Important Notes:

  • Frame stacking respects episode boundaries (won’t stack across episodes)

  • Set sample_avail=True to only sample indices where full stacks are available

  • save_only_last_obs=True saves memory in Atari-style setups

6. VectorReplayBuffer: Parallel Environment Support#

VectorReplayBuffer is essential for modern RL training with parallel environments. It maintains separate subbuffers for each environment while providing a unified interface.

6.1 Motivation and Architecture#

When training with multiple parallel environments (e.g., 8 environments running simultaneously), we need:

  • Per-environment episode tracking: Each environment has its own episode boundaries

  • Temporal ordering: Preserve the sequence of events within each environment

  • Unified sampling: Sample uniformly across all environments for training

        graph LR
    E1[Env 1] --> B1[Subbuffer 1<br/>2500 capacity]
    E2[Env 2] --> B2[Subbuffer 2<br/>2500 capacity]
    E3[Env 3] --> B3[Subbuffer 3<br/>2500 capacity]
    E4[Env 4] --> B4[Subbuffer 4<br/>2500 capacity]
    
    B1 --> VRB[VectorReplayBuffer<br/>Total: 10000<br/>Unified Sampling]
    B2 --> VRB
    B3 --> VRB
    B4 --> VRB
    
    VRB --> Policy[Policy Training]
    
    style E1 fill:#e1f5ff
    style E2 fill:#e1f5ff
    style E3 fill:#e1f5ff
    style E4 fill:#e1f5ff
    style B1 fill:#fff4e1
    style B2 fill:#fff4e1
    style B3 fill:#fff4e1
    style B4 fill:#fff4e1
    style VRB fill:#ffe1f5
    style Policy fill:#e8f5e1
    
# Create VectorReplayBuffer for 4 parallel environments
vec_buf = VectorReplayBuffer(
    total_size=100,  # Total capacity across all subbuffers
    buffer_num=4,  # Number of parallel environments
)

print("VectorReplayBuffer created:")
print(f"Total size: {vec_buf.maxsize}")
print(f"Number of subbuffers: {vec_buf.buffer_num}")
print(f"Size per subbuffer: {vec_buf.maxsize // vec_buf.buffer_num}")
print(f"Subbuffer edges: {vec_buf.subbuffer_edges}")
print("\nSubbuffer edges define the boundary indices: [0, 25, 50, 75, 100]")
print("Subbuffer 0: indices 0-24, Subbuffer 1: indices 25-49, etc.")
VectorReplayBuffer created:
Total size: 100
Number of subbuffers: 4
Size per subbuffer: 25
Subbuffer edges: [  0  25  50  75 100]

Subbuffer edges define the boundary indices: [0, 25, 50, 75, 100]
Subbuffer 0: indices 0-24, Subbuffer 1: indices 25-49, etc.

6.2 The buffer_ids Parameter#

This is one of the most confusing aspects for new users. The buffer_ids parameter specifies which subbuffer each transition belongs to.

# Simulate data from 4 parallel environments
# Each environment produces one transition
parallel_batch = Batch(
    obs=np.array([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2], [3.1, 3.2]]),  # 4 observations
    act=np.array([0, 1, 0, 1]),  # 4 actions
    rew=np.array([1.0, 2.0, 3.0, 4.0]),  # 4 rewards
    terminated=np.array([False, False, False, False]),
    truncated=np.array([False, False, False, False]),
    obs_next=np.array([[0.2, 0.3], [1.2, 1.3], [2.2, 2.3], [3.2, 3.3]]),
    info=np.array([{}, {}, {}, {}], dtype=object),
)

print("Parallel batch shape:", parallel_batch.obs.shape)
print("This represents 4 transitions, one from each environment")

# Add with buffer_ids specifying which subbuffer each transition goes to
indices, ep_rews, ep_lens, ep_starts = vec_buf.add(
    parallel_batch,
    buffer_ids=[0, 1, 2, 3],  # Transition 0→Subbuf 0, 1→Subbuf 1, etc.
)

print(f"\nAdded to indices: {indices}")
print("Notice: Indices are in different subbuffers:")
print(f"  Index {indices[0]} in subbuffer 0 (range 0-24)")
print(f"  Index {indices[1]} in subbuffer 1 (range 25-49)")
print(f"  Index {indices[2]} in subbuffer 2 (range 50-74)")
print(f"  Index {indices[3]} in subbuffer 3 (range 75-99)")
Parallel batch shape: (4, 2)
This represents 4 transitions, one from each environment

Added to indices: [ 0 25 50 75]
Notice: Indices are in different subbuffers:
  Index 0 in subbuffer 0 (range 0-24)
  Index 25 in subbuffer 1 (range 25-49)
  Index 50 in subbuffer 2 (range 50-74)
  Index 75 in subbuffer 3 (range 75-99)
# Add more data to demonstrate buffer_ids
# Environments don't always produce data in order 0,1,2,3
# For example, if only environments 1 and 3 are ready:
partial_batch = Batch(
    obs=np.array([[1.2, 1.3], [3.2, 3.3]]),  # Only 2 observations
    act=np.array([0, 1]),
    rew=np.array([2.5, 4.5]),
    terminated=np.array([False, False]),
    truncated=np.array([False, False]),
    obs_next=np.array([[1.3, 1.4], [3.3, 3.4]]),
    info=np.array([{}, {}], dtype=object),
)

# Only environments 1 and 3 produced data
indices2, _, _, _ = vec_buf.add(
    partial_batch,
    buffer_ids=[1, 3],  # Only these two subbuffers receive data
)

print("Added partial batch (only envs 1 and 3):")
print(f"Indices: {indices2}")
print(f"Subbuffer 1 received data at index {indices2[0]}")
print(f"Subbuffer 3 received data at index {indices2[1]}")
Added partial batch (only envs 1 and 3):
Indices: [26 76]
Subbuffer 1 received data at index 26
Subbuffer 3 received data at index 76

Important: buffer_ids Requirements:

For VectorReplayBuffer:

  • buffer_ids length must match batch size

  • Values must be in range [0, buffer_num)

  • Can be partial (not all environments at once)

For regular ReplayBuffer:

  • If buffer_ids is not None, it must be [0]

  • Batch must have shape (1, data_length)

  • This is for API compatibility with VectorReplayBuffer

6.3 Subbuffer Edges and Episode Handling#

Subbuffer edges prevent episodes from spanning across subbuffers, ensuring data from different environments doesn’t get mixed:

# The subbuffer_edges property defines boundaries
print(f"Subbuffer edges: {vec_buf.subbuffer_edges}")
print("\nThis creates 4 subbuffers:")
for i in range(vec_buf.buffer_num):
    start = vec_buf.subbuffer_edges[i]
    end = vec_buf.subbuffer_edges[i + 1]
    print(f"Subbuffer {i}: indices [{start}, {end})")

# Episodes cannot cross these boundaries
# prev() and next() respect subbuffer edges just like episode boundaries
test_idx = np.array([24, 25, 49, 50])  # At subbuffer edges
prev_result = vec_buf.prev(test_idx)
next_result = vec_buf.next(test_idx)

print("\nBoundary navigation test:")
print(f"Indices:  {test_idx}")
print(f"prev():   {prev_result}")
print(f"next():   {next_result}")
print("\nNotice: prev/next don't cross subbuffer boundaries")
Subbuffer edges: [  0  25  50  75 100]

This creates 4 subbuffers:
Subbuffer 0: indices [0, 25)
Subbuffer 1: indices [25, 50)
Subbuffer 2: indices [50, 75)
Subbuffer 3: indices [75, 100)

Boundary navigation test:
Indices:  [24 25 49 50]
prev():   [ 0 25 25 50]
next():   [ 0 26 26 50]

Notice: prev/next don't cross subbuffer boundaries

6.4 Sampling from VectorReplayBuffer#

Sampling is uniform across all subbuffers (proportional to their current fill level):

# Add more data to have enough for sampling
for _step in range(10):
    batch = Batch(
        obs=np.random.randn(4, 2),
        act=np.random.randint(0, 2, size=4),
        rew=np.random.random(4),
        terminated=np.zeros(4, dtype=bool),
        truncated=np.zeros(4, dtype=bool),
        obs_next=np.random.randn(4, 2),
        info=np.array([{}] * 4, dtype=object),
    )
    vec_buf.add(batch, buffer_ids=[0, 1, 2, 3])

# Sample batch
sampled, indices = vec_buf.sample(batch_size=16)
print(f"Sampled {len(sampled)} transitions")
print(f"Sample indices (from different subbuffers): {indices}")
print("\nNotice indices span across all subbuffer ranges")
Sampled 16 transitions
Sample indices (from different subbuffers): [ 6  3 10  7  4  6  9 31 56 53 60 57 81 78 85 82]

Notice indices span across all subbuffer ranges

7. Specialized Buffer Variants#

7.1 PrioritizedReplayBuffer#

Implements prioritized experience replay where transitions are sampled based on their TD-error magnitudes:

# Create prioritized buffer
prio_buf = PrioritizedReplayBuffer(
    size=100,
    alpha=0.6,  # Prioritization exponent (0=uniform, 1=fully prioritized)
    beta=0.4,  # Importance sampling correction (annealed to 1)
)

# Add some transitions
for i in range(20):
    prio_buf.add(
        Batch(
            obs=np.array([i]),
            act=i % 4,
            rew=np.random.random(),
            terminated=False,
            truncated=False,
            obs_next=np.array([i + 1]),
            info={},
        )
    )

# Sample returns batch and indices
# Importance weights are INSIDE the batch as batch.weight
batch, indices = prio_buf.sample(batch_size=8)
print(f"Sampled batch size: {len(batch)}")
print(f"Indices: {indices}")
print(f"Importance weights (batch.weight): {batch.weight}")
print("\nWeights are stored in batch.weight and compensate for biased sampling")
Sampled batch size: 8
Indices: [ 3 12 14 12 19  9  8  7]
Importance weights (batch.weight): [1. 1. 1. 1. 1. 1. 1. 1.]

Weights are stored in batch.weight and compensate for biased sampling
# After computing TD-errors from the sampled batch, update priorities
# In practice, these would be actual TD-errors: |Q(s,a) - (r + γ*max Q(s',a'))|
fake_td_errors = np.random.random(len(indices)) * 10  # Simulated TD-errors

# Update priorities (higher TD-error = higher priority)
prio_buf.update_weight(indices, fake_td_errors)

print("Updated priorities based on TD-errors")
print("Transitions with higher TD-errors will be sampled more frequently")

# Demonstrate beta annealing
prio_buf.set_beta(0.6)  # Increase beta over training
print(f"\nAnnealed beta to: {prio_buf.options['beta']}")
print("Beta typically starts at 0.4 and anneals to 1.0 over training")
Updated priorities based on TD-errors
Transitions with higher TD-errors will be sampled more frequently

Annealed beta to: 0.4
Beta typically starts at 0.4 and anneals to 1.0 over training

PrioritizedReplayBuffer Use Cases:

  • Rainbow DQN and variants

  • Any algorithm where some transitions are more “surprising” and valuable

  • Environments with rare but important events

Key Parameters:

  • alpha: Controls how much prioritization affects sampling (0=uniform, 1=fully proportional to priority)

  • beta: Importance sampling correction to remain unbiased (anneal from ~0.4 to 1.0)

7.2 Other Specialized Buffers#

CachedReplayBuffer: Maintains a primary buffer plus auxiliary caches

  • Use case: Imitation learning where you want separate expert and agent buffers

  • Example: GAIL (Generative Adversarial Imitation Learning)

  • Allows different sampling ratios from different sources

HERReplayBuffer: Hindsight Experience Replay for goal-conditioned tasks

  • Use case: Sparse reward robotics tasks

  • Relabels failed episodes with achieved goals as if they were intended

  • Dramatically improves learning in goal-reaching tasks

  • See the HER documentation for detailed examples

For detailed usage of these specialized buffers, refer to the Tianshou API documentation and algorithm-specific tutorials.

8. Serialization and Persistence#

Buffers support multiple serialization formats for saving and loading data.

8.1 Pickle Serialization#

The simplest method, preserving all buffer state including trajectory metadata:

# Create and populate a buffer
save_buf = ReplayBuffer(size=50)
for i in range(30):
    save_buf.add(
        Batch(
            obs=np.array([i, i + 1]),
            act=i % 4,
            rew=float(i),
            terminated=(i + 1) % 10 == 0,
            truncated=False,
            obs_next=np.array([i + 1, i + 2]),
            info={"step": i},
        )
    )

print(f"Original buffer: {len(save_buf)} transitions")

# Serialize with pickle
pickled_data = pickle.dumps(save_buf)
print(f"Serialized size: {len(pickled_data)} bytes")

# Deserialize
loaded_buf = pickle.loads(pickled_data)
print(f"Loaded buffer: {len(loaded_buf)} transitions")
print(f"Data preserved: obs[0] = {loaded_buf.obs[0]}")
print(f"Metadata preserved: info[0] = {loaded_buf.info[0]}")
Original buffer: 30 transitions
Serialized size: 7073 bytes
Loaded buffer: 30 transitions
Data preserved: obs[0] = [0 1]
Metadata preserved: info[0] = Batch(
    step: 0,
)

8.2 HDF5 Serialization#

HDF5 is recommended for large datasets and cross-platform compatibility:

# Save to HDF5
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as tmp:
    hdf5_path = tmp.name

save_buf.save_hdf5(hdf5_path, compression="gzip")
print(f"Saved to HDF5: {hdf5_path}")

# Load from HDF5
loaded_hdf5_buf = ReplayBuffer.load_hdf5(hdf5_path)
print(f"Loaded from HDF5: {len(loaded_hdf5_buf)} transitions")
print(f"Data matches: {np.array_equal(save_buf.obs, loaded_hdf5_buf.obs)}")

# Clean up
import os

os.unlink(hdf5_path)
Saved to HDF5: /tmp/tmpt3ia27p1.hdf5
Loaded from HDF5: 30 transitions
Data matches: True

When to Use HDF5:

  • Large datasets (> 1GB)

  • Offline RL with pre-collected data

  • Sharing data across platforms

  • Need for compression

  • Integration with external tools (many scientific tools read HDF5)

When to Use Pickle:

  • Quick saves during development

  • Small buffers

  • Python-only workflow

  • Simpler serialization needs

8.3 Loading from Raw Data with from_data()#

For offline RL, you can create a buffer from raw arrays:

# Simulate pre-collected offline dataset
import h5py

# Create temporary HDF5 file with raw data
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as tmp:
    offline_path = tmp.name

with h5py.File(offline_path, "w") as f:
    # Create datasets
    n = 100
    f.create_dataset("obs", data=np.random.randn(n, 4))
    f.create_dataset("act", data=np.random.randint(0, 2, n))
    f.create_dataset("rew", data=np.random.randn(n))
    f.create_dataset("terminated", data=np.random.random(n) < 0.1)
    f.create_dataset("truncated", data=np.zeros(n, dtype=bool))
    f.create_dataset("done", data=np.random.random(n) < 0.1)
    f.create_dataset("obs_next", data=np.random.randn(n, 4))

# Load into buffer
with h5py.File(offline_path, "r") as f:
    offline_buf = ReplayBuffer.from_data(
        obs=f["obs"],
        act=f["act"],
        rew=f["rew"],
        terminated=f["terminated"],
        truncated=f["truncated"],
        done=f["done"],
        obs_next=f["obs_next"],
    )

print(f"Loaded offline dataset: {len(offline_buf)} transitions")
print(f"Observation shape: {offline_buf.obs.shape}")

# Clean up
os.unlink(offline_path)
Loaded offline dataset: 100 transitions
Observation shape: (100, 4)

This is the standard approach for offline RL where you have pre-collected datasets from other sources.

9. Integration with the RL Pipeline#

Understanding how buffers integrate with other Tianshou components is essential for effective usage.

9.1 Data Flow in RL Training#

        graph LR
    ENV[Vectorized<br/>Environments] -->|observations| COL[Collector]
    POL[Policy] -->|actions| COL
    COL -->|transitions| BUF[Buffer]
    BUF -->|sampled batches| POL
    POL -->|forward pass| ALG[Algorithm]
    ALG -->|loss & gradients| POL
    
    style ENV fill:#e1f5ff
    style COL fill:#fff4e1
    style BUF fill:#ffe1f5
    style POL fill:#e8f5e1
    style ALG fill:#f5e1e1
    

9.2 Typical Training Loop Pattern#

Here’s how buffers are typically used in a training loop:

# Pseudocode for typical RL training loop
# (This is illustrative; actual implementation would use Trainer)


def training_loop_pseudocode():
    """
    Illustrative training loop showing buffer integration.

    In practice, use Tianshou's Trainer class which handles this.
    """
    # Setup (illustration only)
    # env = make_vectorized_env(num_envs=8)
    # policy = make_policy()
    # buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)
    # collector = Collector(policy, env, buffer)

    # Training loop
    # for epoch in range(num_epochs):
    #     # 1. Collect data from environments
    #     collect_result = collector.collect(n_step=1000)
    #     # Collector automatically adds transitions to buffer with correct buffer_ids
    #
    #     # 2. Train on multiple batches
    #     for _ in range(update_per_collect):
    #         # Sample batch from buffer
    #         batch, indices = buffer.sample(batch_size=256)
    #
    #         # Compute loss and update policy
    #         loss = policy.learn(batch)
    #
    #         # For prioritized buffers, update priorities
    #         # if isinstance(buffer, PrioritizedReplayBuffer):
    #         #     buffer.update_weight(indices, td_errors)

    print("This pseudocode illustrates the buffer's role:")
    print("1. Collector fills buffer from environment interaction")
    print("2. Buffer provides random samples for training")
    print("3. Policy learns from sampled batches")
    print("\nIn practice, use Tianshou's Trainer for this workflow")


training_loop_pseudocode()
This pseudocode illustrates the buffer's role:
1. Collector fills buffer from environment interaction
2. Buffer provides random samples for training
3. Policy learns from sampled batches

In practice, use Tianshou's Trainer for this workflow

9.3 Collector Integration#

The Collector class handles the complexity of:

  • Calling policy to get actions

  • Stepping environments

  • Adding transitions to buffer with correct buffer_ids

  • Tracking episode statistics

When you create a Collector, you pass it a buffer, and it automatically:

  • Uses VectorReplayBuffer for vectorized environments

  • Sets buffer_ids based on which environments are ready

  • Handles episode resets and boundary tracking

See the Collector tutorial for detailed examples of this integration.

10. Advanced Topics and Edge Cases#

10.1 Buffer Overflow and Episode Boundaries#

What happens when the buffer fills up mid-episode?

# Small buffer to demonstrate overflow
overflow_buf = ReplayBuffer(size=8)

# Add a long episode (12 steps, buffer size is only 8)
print("Adding 12-step episode to buffer with size 8:")
for i in range(12):
    idx, ep_rew, ep_len, ep_start = overflow_buf.add(
        Batch(
            obs=i,
            act=0,
            rew=1.0,
            terminated=i == 11,
            truncated=False,
            obs_next=i + 1,
            info={},
        )
    )
    if i in [7, 11]:
        print(f"  Step {i}: idx={idx}, buffer_len={len(overflow_buf)}")

print("\nFinal buffer contents (most recent 8 steps):")
print(f"Observations: {overflow_buf.obs[: len(overflow_buf)]}")
print(f"Episode return: {ep_rew[0]} (sum of all 12 steps, tracked correctly!)")
print("\nNote: Buffer overwrote old data but episode statistics are still correct")
Adding 12-step episode to buffer with size 8:
  Step 7: idx=[7], buffer_len=8
  Step 11: idx=[3], buffer_len=8

Final buffer contents (most recent 8 steps):
Observations: [ 8  9 10 11  4  5  6  7]
Episode return: 12.0 (sum of all 12 steps, tracked correctly!)

Note: Buffer overwrote old data but episode statistics are still correct

Important: Episode returns and lengths are tracked internally and remain correct even when the episode spans buffer overflows. The buffer maintains _ep_return, _ep_len, and _ep_start_idx to track ongoing episodes.

10.2 Episode Spanning Subbuffer Edges#

In VectorReplayBuffer, episodes can wrap around within their subbuffer:

# Create small VectorReplayBuffer to demonstrate edge crossing
edge_buf = VectorReplayBuffer(total_size=20, buffer_num=2)  # 10 per subbuffer

print(f"Subbuffer edges: {edge_buf.subbuffer_edges}")
print("Subbuffer 0: indices 0-9, Subbuffer 1: indices 10-19\n")

# Fill subbuffer 0 with 12 steps (wraps around since capacity is 10)
for i in range(12):
    batch = Batch(
        obs=np.array([[i]]),
        act=np.array([0]),
        rew=np.array([1.0]),
        terminated=np.array([i == 11]),
        truncated=np.array([False]),
        obs_next=np.array([[i + 1]]),
        info=np.array([{}], dtype=object),
    )
    idx, _, _, _ = edge_buf.add(batch, buffer_ids=[0])
    if i >= 10:
        print(f"Step {i} added at index {idx[0]} (wrapped around in subbuffer 0)")

# get_buffer_indices handles this correctly
episode_indices = edge_buf.get_buffer_indices(start=8, stop=2)  # Crosses edge
print(f"\nEpisode spanning edge (from 8 to 1): {episode_indices}")
print("Correctly retrieves [8, 9, 0, 1] within subbuffer 0")
Subbuffer edges: [ 0 10 20]
Subbuffer 0: indices 0-9, Subbuffer 1: indices 10-19

Step 10 added at index 0 (wrapped around in subbuffer 0)
Step 11 added at index 1 (wrapped around in subbuffer 0)

Episode spanning edge (from 8 to 1): [8 9 0 1]
Correctly retrieves [8, 9, 0, 1] within subbuffer 0

10.3 ignore_obs_next Memory Optimization#

For memory-constrained scenarios, you can avoid storing obs_next:

# Buffer that doesn't store obs_next
memory_buf = ReplayBuffer(size=10, ignore_obs_next=True)

# Add transitions (obs_next is ignored)
for i in range(5):
    memory_buf.add(
        Batch(
            obs=np.array([i, i + 1]),
            act=i,
            rew=1.0,
            terminated=False,
            truncated=False,
            obs_next=np.array([i + 1, i + 2]),  # Provided but not stored
            info={},
        )
    )

# When sampling, obs_next is reconstructed from next obs
sample, _ = memory_buf.sample(batch_size=1)
print(f"Sampled obs: {sample.obs}")
print(f"Sampled obs_next: {sample.obs_next}")
print("\nobs_next was reconstructed, not stored directly")
print("This saves memory at the cost of slightly more complex retrieval")
Sampled obs: [[3 4]]
Sampled obs_next: [[4 5]]

obs_next was reconstructed, not stored directly
This saves memory at the cost of slightly more complex retrieval

This is particularly useful for Atari environments with large observation spaces (84x84x4 frames).

11. Surprising Behaviors and Gotchas#

11.1 Most Common Mistake: buffer_ids Confusion#

The buffer_ids parameter is the most common source of errors:

# COMMON ERROR 1: Forgetting buffer_ids with VectorReplayBuffer
vec_demo = VectorReplayBuffer(total_size=100, buffer_num=4)

parallel_data = Batch(
    obs=np.random.randn(4, 2),
    act=np.array([0, 1, 0, 1]),
    rew=np.array([1.0, 2.0, 3.0, 4.0]),
    terminated=np.array([False, False, False, False]),
    truncated=np.array([False, False, False, False]),
    obs_next=np.random.randn(4, 2),
    info=np.array([{}, {}, {}, {}], dtype=object),
)

# WRONG: Omitting buffer_ids (defaults to [0,1,2,3] which is OK here)
# But if you have partial data, this will fail
vec_demo.add(parallel_data)  # Works by default

# CORRECT: Always explicit
vec_demo.add(parallel_data, buffer_ids=[0, 1, 2, 3])
print("Always specify buffer_ids explicitly for clarity")
Always specify buffer_ids explicitly for clarity
# COMMON ERROR 2: Shape mismatch with buffer_ids
try:
    # Trying to add 2 transitions but specifying 4 buffer_ids
    wrong_batch = Batch(
        obs=np.random.randn(2, 2),  # Only 2 transitions!
        act=np.array([0, 1]),
        rew=np.array([1.0, 2.0]),
        terminated=np.array([False, False]),
        truncated=np.array([False, False]),
        obs_next=np.random.randn(2, 2),
        info=np.array([{}, {}], dtype=object),
    )
    vec_demo.add(wrong_batch, buffer_ids=[0, 1, 2, 3])  # MISMATCH!
except (IndexError, ValueError) as e:
    print(f"Error caught: {type(e).__name__}")
    print("Lesson: buffer_ids length must match batch size")
Error caught: IndexError
Lesson: buffer_ids length must match batch size

11.2 Done Flag Confusion#

Never manually set the done flag:

# WRONG: Manually setting done
wrong_batch = Batch(
    obs=1,
    act=0,
    rew=1.0,
    terminated=True,
    truncated=False,
    # done=True,  # DON'T DO THIS! It will be overwritten anyway
    obs_next=2,
    info={},
)

# CORRECT: Only set terminated and truncated
# done is automatically computed as (terminated OR truncated)
correct_batch = Batch(
    obs=1,
    act=0,
    rew=1.0,
    terminated=True,  # Episode ended naturally
    truncated=False,  # Not cut off
    obs_next=2,
    info={},
)

demo = ReplayBuffer(size=10)
demo.add(correct_batch)
print(f"Terminated: {demo.terminated[0]}")
print(f"Truncated: {demo.truncated[0]}")
print(f"Done (auto-computed): {demo.done[0]}")
Terminated: True
Truncated: False
Done (auto-computed): True

11.3 Sampling from Empty or Near-Empty Buffers#

# Edge case: Sampling more than available
small_buf = ReplayBuffer(size=100)
for i in range(5):  # Only 5 transitions
    small_buf.add(
        Batch(obs=i, act=0, rew=1.0, terminated=False, truncated=False, obs_next=i + 1, info={})
    )

# Request 20 but only 5 available - samples with replacement
batch, indices = small_buf.sample(batch_size=20)
print(f"Requested 20, buffer has {len(small_buf)}, got {len(batch)}")
print(f"Indices: {indices}")
print("Notice: Some indices repeat (sampling with replacement)")

# Defensive pattern: Check buffer size
if len(small_buf) >= 128:
    batch, _ = small_buf.sample(128)
else:
    print(f"Buffer has {len(small_buf)} < 128, waiting for more data")
Requested 20, buffer has 5, got 20
Indices: [3 4 2 4 4 1 2 2 2 4 3 2 4 1 3 1 3 4 0 3]
Notice: Some indices repeat (sampling with replacement)
Buffer has 5 < 128, waiting for more data

11.4 Frame Stacking Valid Indices#

With stack_num > 1, not all indices are valid for sampling:

# With frame stacking, early indices can't form complete stacks
stack_demo = ReplayBuffer(size=20, stack_num=4, sample_avail=True)

for i in range(10):
    stack_demo.add(
        Batch(
            obs=np.array([i]),
            act=0,
            rew=1.0,
            terminated=i == 9,
            truncated=False,
            obs_next=np.array([i + 1]),
            info={},
        )
    )

# With sample_avail=True, only valid indices are sampled
sampled, indices = stack_demo.sample(batch_size=5)
print(f"Sampled indices with stack_num=4, sample_avail=True: {indices}")
print("All indices >= 3 (can form complete 4-frame stacks)")

# Without sample_avail, any index can be sampled (may have incomplete stacks)
stack_demo2 = ReplayBuffer(size=20, stack_num=4, sample_avail=False)
for i in range(10):
    stack_demo2.add(
        Batch(
            obs=np.array([i]),
            act=0,
            rew=1.0,
            terminated=False,
            truncated=False,
            obs_next=np.array([i + 1]),
            info={},
        )
    )

sampled2, indices2 = stack_demo2.sample(batch_size=5)
print(f"\nSampled indices with sample_avail=False: {indices2}")
print("May include indices < 3 (incomplete stacks repeated from boundary)")
Sampled indices with stack_num=4, sample_avail=True: [9 6 7 9 5]
All indices >= 3 (can form complete 4-frame stacks)

Sampled indices with sample_avail=False: [6 3 7 4 6]
May include indices < 3 (incomplete stacks repeated from boundary)

12. Best Practices#

12.1 Choosing the Right Buffer#

Decision Tree:

  1. Are you using parallel environments?

    • Yes → Use VectorReplayBuffer

    • No → Continue to 2

  2. Do you need prioritized experience replay?

    • Yes → Use PrioritizedReplayBuffer or PrioritizedVectorReplayBuffer

    • No → Continue to 3

  3. Is it goal-conditioned RL with sparse rewards?

    • Yes → Use HERReplayBuffer or HERVectorReplayBuffer

    • No → Continue to 4

  4. Do you need separate expert and agent buffers?

    • Yes → Use CachedReplayBuffer

    • No → Use ReplayBuffer (single env) or VectorReplayBuffer (standard choice)

Most Common Setup: VectorReplayBuffer for production training

12.2 Buffer Sizing Guidelines#

Rule of Thumb by Domain:

  • Atari games: 1,000,000 transitions (1e6)

  • Continuous control (MuJoCo): 100,000-1,000,000 (1e5-1e6)

  • Robotics: 100,000-500,000 (1e5-5e5)

  • Simple environments (CartPole): 10,000-50,000 (1e4-5e4)

Factors to Consider:

  • Available RAM (each transition ~observation_size * 2 + metadata)

  • Training time vs sample efficiency tradeoff

  • Algorithm requirements (some need larger buffers)

Memory Estimation:

# For environments with observation shape (84, 84, 4) (Atari):
# Each transition: 2 * 84 * 84 * 4 bytes (obs + obs_next) + ~100 bytes overhead
# = ~56KB per transition
# 1M transitions = ~56GB (use ignore_obs_next to halve this!)

12.3 Configuration Best Practices#

When to use stack_num > 1:

  • RNN/LSTM policies need temporal context

  • Frame-based policies (Atari with 4-frame stacking)

  • Velocity estimation from positions

When to use ignore_obs_next=True:

  • Memory-constrained environments

  • Atari (large observation spaces)

  • When obs_next can be reconstructed from next obs

When to use save_only_last_obs=True:

  • Atari with temporal stacking in environment wrapper

  • When observations already contain frame history

When to use sample_avail=True:

  • Always use with stack_num > 1 for correctness

  • Ensures samples have complete frame stacks

  • Small performance cost but worth it for data quality

12.4 Integration Patterns#

Pattern 1: Standard Off-Policy Setup

# env = make_vectorized_env(num_envs=8)
# buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)
# policy = SACPolicy(...)
# collector = Collector(policy, env, buffer)
# 
# # Collect and train
# collector.collect(n_step=1000)
# for _ in range(10):
#     batch, indices = buffer.sample(256)
#     policy.learn(batch)

Pattern 2: Pre-fill Buffer Before Training

# # Collect random exploration data
# collector.collect(n_step=10000)  # Fill buffer
# 
# # Then start training
# while not converged:
#     collector.collect(n_step=100)
#     for _ in range(10):
#         batch = buffer.sample(256)
#         policy.learn(batch)

Pattern 3: Offline RL

# # Load pre-collected dataset
# buffer = ReplayBuffer.load_hdf5("expert_data.hdf5")
# 
# # Train without further collection
# for epoch in range(num_epochs):
#     for _ in range(updates_per_epoch):
#         batch = buffer.sample(256)
#         policy.learn(batch)

12.5 Performance Tips#

Tip 1: Pre-allocate buffer size appropriately

  • Don’t make buffer too large (wastes memory)

  • Don’t make it too small (loses important old experiences)

  • Start with domain defaults and adjust based on performance

Tip 2: Use HDF5 for large offline datasets

  • Compression saves disk space

  • Faster loading than pickle for large files

  • Better for sharing across systems

Tip 3: Batch sampling efficiently

  • Sample once and use multiple times if possible

  • Don’t sample more than you need

  • For multi-GPU training, sample once and split

Tip 4: Monitor buffer usage

# print(f"Buffer usage: {len(buffer)}/{buffer.maxsize}")
# if len(buffer) < batch_size:
#     print("Warning: Sampling with replacement!")

Tip 5: Consider ignore_obs_next for large observation spaces

  • Can halve memory usage

  • Small computational overhead on sampling

  • Especially valuable for image-based RL

13. Quick Reference#

Method Summary#

Method

Purpose

Returns

Notes

add(batch, buffer_ids)

Add transition(s)

(idx, ep_rew, ep_len, ep_start)

ep_rew/ep_len only non-zero when done=True

sample(size)

Random sample

(batch, indices)

size=None for all (random), 0 for all (ordered)

prev(idx)

Previous in episode

indices

Stops at episode boundaries

next(idx)

Next in episode

indices

Stops at episode boundaries

get(idx, key, stack_num)

Get with stacking

data

Returns stacked frames if stack_num > 1

get_buffer_indices(start, stop)

Episode range

indices

Handles edge-crossing episodes

unfinished_index()

Ongoing episodes

indices

Returns last step of unfinished episodes

save_hdf5(path)

Save to HDF5

-

Recommended for large datasets

load_hdf5(path)

Load from HDF5

buffer

Class method

from_data(...)

Create from arrays

buffer

For offline RL datasets

reset()

Clear buffer

-

Optionally keep episode statistics

sample_indices(size)

Get indices only

indices

For custom sampling logic

Common Patterns Cheatsheet#

Single Environment:

buffer = ReplayBuffer(size=10000)
buffer.add(Batch(obs=..., act=..., rew=..., terminated=..., truncated=..., obs_next=..., info={}))
batch, indices = buffer.sample(batch_size=256)

Parallel Environments:

buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)
buffer.add(parallel_batch, buffer_ids=[0,1,2,3,4,5,6,7])
batch, indices = buffer.sample(batch_size=256)

Frame Stacking:

buffer = ReplayBuffer(size=100000, stack_num=4, sample_avail=True)
stacked_obs = buffer.get(index=50, key="obs")  # Returns 4 stacked frames

Prioritized Replay:

buffer = PrioritizedReplayBuffer(size=100000, alpha=0.6, beta=0.4)
batch, indices = buffer.sample(batch_size=256)
weights = batch.weight  # Importance weights are inside the batch
# ... compute TD errors ...
buffer.update_weight(indices, td_errors)

Offline RL:

buffer = ReplayBuffer.load_hdf5("dataset.hdf5")
# Or:
with h5py.File("dataset.hdf5", "r") as f:
    buffer = ReplayBuffer.from_data(obs=f["obs"], act=f["act"], ...)

Episode Retrieval:

# Find episode boundaries, then:
episode_indices = buffer.get_buffer_indices(start=ep_start_idx, stop=ep_end_idx+1)
episode = buffer[episode_indices]

Summary and Next Steps#

This tutorial covered Tianshou’s buffer system comprehensively:

  1. Buffer fundamentals: Why buffers are essential for RL

  2. Buffer hierarchy: Understanding different buffer types

  3. Basic operations: Construction, configuration, and data management

  4. Trajectory management: Episode tracking and boundary navigation

  5. Sampling strategies: Basic sampling and frame stacking

  6. VectorReplayBuffer: Critical for parallel environments

  7. Specialized buffers: Prioritized, cached, and HER variants

  8. Serialization: Pickle and HDF5 persistence

  9. Integration: How buffers fit in the RL pipeline

  10. Advanced topics: Edge cases and overflow handling

  11. Gotchas: Common mistakes and how to avoid them

  12. Best practices: Configuration, sizing, and performance

  13. Quick reference: Method summary and common patterns

Next Steps#

  • Collector Deep Dive: Learn how Collector fills buffers from environments

  • Policy Tutorial: Understand how policies sample from buffers for training

  • Algorithm Examples: See buffer usage in specific algorithms (DQN, SAC, PPO)

  • API Reference: Full details at Buffer API documentation

Further Resources#