Batch: Tianshou’s Core Data Structure#

The Batch class is Tianshou’s fundamental data structure for efficiently storing and manipulating heterogeneous data in reinforcement learning. This tutorial provides comprehensive guidance on understanding its conceptual foundations, operational behavior, and best practices.

import pickle
from typing import cast

import numpy as np
import torch
from torch.distributions import Categorical, Normal

from tianshou.data import Batch
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol

1. Introduction: Why Batch?#

The Challenge in Reinforcement Learning#

Reinforcement learning algorithms face a fundamental data management challenge:

  1. Diverse Data Requirements: Different RL algorithms need different data fields:

    • Basic algorithms: state, action, reward, done, next_state

    • Actor-Critic: additionally advantages, returns, values

    • Policy Gradient: additionally log_probs, old_log_probs

    • Off-policy: additionally priority_weights

  2. Heterogeneous Observation Spaces: Environments return diverse observation types:

    • Simple: vectors (np.array([1.0, 2.0, 3.0]))

    • Complex: images (np.array(shape=(84, 84, 3)))

    • Hybrid: dictionaries combining multiple modalities

    obs = {
        'camera': np.array(shape=(64, 64, 3)),
        'velocity': np.array([1.2, 0.5]),
        'inventory': np.array([5, 2, 0])
    }
    
  3. Data Flow Across Components: Data must flow seamlessly through:

    • Collectors (gathering experience from environments)

    • Replay Buffers (storing and sampling transitions)

    • Policies and Algorithms (learning and inference)

Why Not Alternatives?#

Plain Dictionaries#

Dictionaries lack essential features

data = {'obs': np.array([1, 2]), 'reward': np.array([1.0, 2.0])}

They would work in principle but has no shape/length semantics, no indexing, and no type safety.

TensorDict#

While TensorDict (used in pytorch-rl) is a powerful alternative:

  • Batch supports arbitrary objects, not just tensors (useful for object-dtype arrays, custom types)

  • Batch has better type checking via BatchProtocol (enables IDE autocompletion)

  • Batch preceded TensorDict and provides a stable foundation for Tianshou

  • TensorDict isn’t part of core PyTorch (external dependency)

What is Batch?#

Batch = Dictionary + Array hybrid with RL-specific features

Key capabilities:

  • Dict-like: Key-value storage with attribute access (batch.obs, batch.reward)

  • Array-like: Shape, indexing, slicing (batch[0], batch[:10], batch.shape)

  • Hierarchical: Nested structures for complex data

  • Type-safe: Protocol-based typing for IDE support

  • RL-aware: Special handling for distributions, missing values, heterogeneous aggregation

2. Core Concepts#

Hierarchical Named Tensors#

Batch stores hierarchical named tensors - collections of tensors whose identifiers form a structured hierarchy. Consider tensors [t1, t2, t3, t4] with names [name1, name2, name3, name4], where name1 and name2 are under namespace name0. The fully qualified name of t1 is name0.name1.

Tree Structure Visualization#

The structure can be visualized as a tree with:

  • Root: The Batch object itself

  • Internal nodes: Keys (names)

  • Leaf nodes: Values (scalars, arrays, tensors)

        graph TD
    root["Batch (root)"]
    root --> obs["obs"]
    root --> act["act"]
    root --> rew["rew"]
    obs --> camera["camera"]
    obs --> sensory["sensory"]
    camera --> cam_data["np.array(3,3)"]
    sensory --> sens_data["np.array(5,)"]
    act --> act_data["np.array(2,)"]
    rew --> rew_data["3.66"]
    
    style root fill:#e1f5ff
    style obs fill:#fff4e1
    style act fill:#fff4e1
    style rew fill:#fff4e1
    style camera fill:#ffe1f5
    style sensory fill:#ffe1f5
    style cam_data fill:#e8f5e1
    style sens_data fill:#e8f5e1
    style act_data fill:#e8f5e1
    style rew_data fill:#e8f5e1
    
# Example: hierarchical structure
data = {
    "action": np.array([1.0, 2.0, 3.0]),
    "reward": 3.66,
    "obs": {
        "camera": np.zeros((3, 3)),
        "sensory": np.ones(5),
    },
}

batch = Batch(data)
print(batch)
print("\nAccessing nested values:")
print(f"batch.obs.camera.shape = {batch.obs.camera.shape}")
print(f"batch.obs.sensory = {batch.obs.sensory}")
Batch(
    action: array([1., 2., 3.]),
    reward: array(3.66),
    obs: Batch(
             camera: array([[0., 0., 0.],
                            [0., 0., 0.],
                            [0., 0., 0.]]),
             sensory: array([1., 1., 1., 1., 1.]),
         ),
)

Accessing nested values:
batch.obs.camera.shape = (3, 3)
batch.obs.sensory = [1. 1. 1. 1. 1.]

Data Flow in RL Pipeline#

Batch facilitates data flow throughout the RL pipeline:

        graph LR
    A[Collector] -->|ActBatchProtocol| B[Environment]
    B[Environment + Action] -->|RolloutBatchProtocol| C[Replay Buffer]
    C -->|RolloutBatchProtocol| D[Policy]
    D -->|ActBatchProtocol| A
    D -->|BatchWithAdvantages| E[Algorithm/Trainer]
    E --> D
    
    style A fill:#e1f5ff
    style B fill:#fff4e1
    style C fill:#ffe1f5
    style D fill:#e8f5e1
    style E fill:#f5e1e1
    

Each arrow represents a specific BatchProtocol that defines what fields are expected at that stage.

3. Basic Operations#

3.1 Construction#

Batch objects can be constructed in several ways:

# From keyword arguments
batch1 = Batch(a=4, b=[5, 5], c="hello")
print("From kwargs:", batch1)

# From dictionary
batch2 = Batch({"a": 4, "b": [5, 5], "c": "hello"})
print("\nFrom dict:", batch2)

# From list of dictionaries (automatically stacked)
batch3 = Batch([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
print("\nFrom list of dicts:", batch3)

# Nested batch
batch4 = Batch(obs=Batch(x=1, y=2), act=5)
print("\nNested:", batch4)
From kwargs: Batch(
    a: array(4),
    b: array([5, 5]),
    c: 'hello',
)

From dict: Batch(
    a: array(4),
    b: array([5, 5]),
    c: 'hello',
)

From list of dicts: Batch(
    a: array([1, 3]),
    b: array([2, 4]),
)

Nested: Batch(
    obs: Batch(
             x: array(1),
             y: array(2),
         ),
    act: array(5),
)

3.2 Content Rules#

Understanding what Batch can store and how it converts data:

# Keys must be strings
batch = Batch()
batch.key1 = "value"
batch.key2 = np.array([1, 2, 3])
print("Keys:", list(batch.keys()))

# Automatic conversions
demo = Batch(
    scalar_int=5,  # → np.array(5)
    scalar_float=3.14,  # → np.array(3.14)
    list_nums=[1, 2, 3],  # → np.array([1, 2, 3])
    list_mixed=[1, "hello", None],  # → np.array([1, "hello", None], dtype=object)
    dict_val={"x": 1, "y": 2},  # → Batch(x=1, y=2)
)

print("\nAutomatic conversions:")
print(f"scalar_int type: {type(demo.scalar_int)}, value: {demo.scalar_int}")
print(f"list_nums type: {type(demo.list_nums)}, dtype: {demo.list_nums.dtype}")
print(f"list_mixed dtype: {demo.list_mixed.dtype}")
print(f"dict_val type: {type(demo.dict_val)}")
Keys: ['key1', 'key2']

Automatic conversions:
scalar_int type: <class 'numpy.ndarray'>, value: 5
list_nums type: <class 'numpy.ndarray'>, dtype: int64
list_mixed dtype: object
dict_val type: <class 'tianshou.data.batch.Batch'>

Important conversions:

  • Lists of numbers → NumPy arrays

  • Lists with mixed types → Object-dtype arrays

  • Dictionaries → Batch objects (recursively)

  • Scalars → NumPy scalars

3.3 Access Patterns#

Important: Understanding Iteration

batch = Batch(a=[1, 2, 3], b=[4, 5, 6])

# Attribute vs dictionary access (equivalent)
print("Attribute access:", batch.a)
print("Dict access:", batch["a"])

# Getting keys
print("\nKeys:", list(batch.keys()))

# Gotcha: Iteration is array like, not over keys
print("\nIteration behavior:")
print("for x in batch iterates over batch[0], batch[1], ..., NOT keys!")
for i, item in enumerate(batch):
    print(f"batch[{i}] = {item}")

# This is different from dict behavior!
regular_dict = {"a": [1, 2, 3], "b": [4, 5, 6]}
print("\nCompare with dict iteration (iterates over keys):")
for key in regular_dict:
    print(f"key = {key}")
Attribute access: [1 2 3]
Dict access: [1 2 3]

Keys: ['a', 'b']

Iteration behavior:
for x in batch iterates over batch[0], batch[1], ..., NOT keys!
batch[0] = Batch(
    a: 1,
    b: 4,
)
batch[1] = Batch(
    a: 2,
    b: 5,
)
batch[2] = Batch(
    a: 3,
    b: 6,
)

Compare with dict iteration (iterates over keys):
key = a
key = b

3.4 Indexing & Slicing#

Batch supports NumPy-like indexing and slicing:

batch = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])

print("Original batch shape:", batch.shape)
print("Original batch length:", len(batch))

# Single index
print("\nbatch[0]:")
print(batch[0])

# Slicing
print("\nbatch[:1]:")
print(batch[:1])

# Advanced indexing
print("\nbatch[[0, 1]]:")
print(batch[[0, 1]])

# Multi-dimensional indexing
print("\nbatch[:, 0] (first column of all arrays):")
print(batch[:, 0])
Original batch shape: [2, 2]
Original batch length: 2

batch[0]:
Batch(
    a: array([0., 2.]),
    b: array([ 5., -5.]),
)

batch[:1]:
Batch(
    a: array([[0., 2.]]),
    b: array([[ 5., -5.]]),
)

batch[[0, 1]]:
Batch(
    a: array([[0., 2.],
              [1., 3.]]),
    b: array([[ 5., -5.],
              [ 1., -2.]]),
)

batch[:, 0] (first column of all arrays):
Batch(
    a: array([0., 1.]),
    b: array([5., 1.]),
)
# Broadcasting and in-place operations
batch[:, 1] += 10
print("After batch[:, 1] += 10:")
print(batch)
After batch[:, 1] += 10:
Batch(
    a: array([[ 0., 12.],
              [ 1., 13.]]),
    b: array([[5., 5.],
              [1., 8.]]),
)

3.5 Stack, Concatenate, and Split#

Combining and splitting batches:

# Stack: adds a new dimension
batch1 = Batch(a=np.array([1, 2]), b=np.array([5, 6]))
batch2 = Batch(a=np.array([3, 4]), b=np.array([7, 8]))

stacked = Batch.stack([batch1, batch2])
print("Stacked:")
print(stacked)
print(f"Shape: {stacked.shape}")

# Concatenate: extends along existing dimension
concatenated = Batch.cat([batch1, batch2])
print("\nConcatenated:")
print(concatenated)
print(f"Shape: {concatenated.shape}")
Stacked:
Batch(
    a: array([[1, 2],
              [3, 4]]),
    b: array([[5, 6],
              [7, 8]]),
)
Shape: [2, 2]

Concatenated:
Batch(
    a: array([1, 2, 3, 4]),
    b: array([5, 6, 7, 8]),
)
Shape: [4]
# Split
batch = Batch(a=np.arange(10), b=np.arange(10, 20))
splits = list(batch.split(size=3, shuffle=False))
print(f"Split into {len(splits)} batches:")
for i, split in enumerate(splits):
    print(f"Split {i}: a={split.a}, length={len(split)}")
Split into 4 batches:
Split 0: a=[0 1 2], length=3
Split 1: a=[3 4 5], length=3
Split 2: a=[6 7 8], length=3
Split 3: a=[9], length=1

3.6 Data Type Conversion#

Converting between NumPy and PyTorch:

# Create batch with NumPy arrays
batch = Batch(a=np.zeros((3, 4)), b=np.ones(5))
print("Original (NumPy):")
print(f"batch.a type: {type(batch.a)}")

# Convert to PyTorch (in-place)
batch.to_torch_(dtype=torch.float32, device="cpu")
print("\nAfter to_torch_():")
print(f"batch.a type: {type(batch.a)}")
print(f"batch.a dtype: {batch.a.dtype}")

# Convert back to NumPy (in-place)
batch.to_numpy_()
print("\nAfter to_numpy_():")
print(f"batch.a type: {type(batch.a)}")

# Non-in-place versions return a new batch
batch_torch = batch.to_torch()
print("\nOriginal batch unchanged:", type(batch.a))
print("New batch:", type(batch_torch.a))
Original (NumPy):
batch.a type: <class 'numpy.ndarray'>

After to_torch_():
batch.a type: <class 'torch.Tensor'>
batch.a dtype: torch.float64

After to_numpy_():
batch.a type: <class 'numpy.ndarray'>

Original batch unchanged: <class 'numpy.ndarray'>
New batch: <class 'torch.Tensor'>

4. Type Safety with Protocols#

Why Protocols?#

Batch needs to be flexible (not fixed fields like dataclasses) but we still want type safety and IDE autocompletion. Protocols provide the best of both worlds:

  • Runtime flexibility: Add any fields dynamically

  • Static type checking: Type checkers (mypy, pyright) verify correct usage

  • IDE support: Autocompletion for expected fields

What is BatchProtocol?#

A Protocol defines an interface without implementation. Think of it as a contract: “any object with these fields is valid.”

# Creating a typed batch using cast
# This enables IDE autocompletion and type checking

# ActBatchProtocol: just needs 'act' field
act_batch = cast(ActBatchProtocol, Batch(act=np.array([1, 2, 3])))
print("ActBatchProtocol:", act_batch.act)

# ObsBatchProtocol: needs 'obs' and 'info' fields
obs_batch = cast(
    ObsBatchProtocol,
    Batch(obs=np.array([[1.0, 2.0], [3.0, 4.0]]), info=np.array([{}, {}], dtype=object)),
)
print("\nObsBatchProtocol:", obs_batch.obs)

# RolloutBatchProtocol: needs obs, obs_next, act, rew, terminated, truncated
rollout_batch = cast(
    RolloutBatchProtocol,
    Batch(
        obs=np.array([[1.0, 2.0], [3.0, 4.0]]),
        obs_next=np.array([[2.0, 3.0], [4.0, 5.0]]),
        act=np.array([0, 1]),
        rew=np.array([1.0, 2.0]),
        terminated=np.array([False, True]),
        truncated=np.array([False, False]),
        info=np.array([{}, {}], dtype=object),
    ),
)
print("\nRolloutBatchProtocol reward:", rollout_batch.rew)
ActBatchProtocol: [1 2 3]

ObsBatchProtocol: [[1. 2.]
 [3. 4.]]

RolloutBatchProtocol reward: [1. 2.]

Protocol Hierarchy#

Tianshou defines a hierarchy of protocols for different use cases:

        graph TD
    BP[BatchProtocol<br/>Base protocol] --> OBP[ObsBatchProtocol<br/>obs, info]
    BP --> ABP[ActBatchProtocol<br/>act]
    ABP --> ASBP[ActStateBatchProtocol<br/>act, state]
    OBP --> RBP[RolloutBatchProtocol<br/>+obs_next, act, rew,<br/>terminated, truncated]
    RBP --> BWRP[BatchWithReturnsProtocol<br/>+returns]
    BWRP --> BWAP[BatchWithAdvantagesProtocol<br/>+adv, v_s]
    ASBP --> MOBP[ModelOutputBatchProtocol<br/>+logits]
    MOBP --> DBP[DistBatchProtocol<br/>+dist]
    DBP --> DLPBP[DistLogProbBatchProtocol<br/>+log_prob]
    BWAP --> LOPBP[LogpOldProtocol<br/>+logp_old]
    
    style BP fill:#e1f5ff
    style OBP fill:#fff4e1
    style ABP fill:#fff4e1
    style RBP fill:#ffe1f5
    style BWRP fill:#e8f5e1
    style BWAP fill:#e8f5e1
    style DBP fill:#f5e1e1
    style LOPBP fill:#e1e1f5
    

Using Protocols in Functions#

Protocols enable type-safe function signatures:

def process_observations(batch: ObsBatchProtocol) -> np.ndarray:
    """Function that expects observations.

    IDE will autocomplete batch.obs and batch.info!
    Type checker will verify these fields exist.
    """
    # IDE knows batch.obs exists
    return batch.obs if isinstance(batch.obs, np.ndarray) else np.array(batch.obs)


def compute_advantage(batch: RolloutBatchProtocol) -> np.ndarray:
    """Function that expects rollout data.

    IDE will autocomplete batch.rew, batch.obs_next, etc.
    """
    # Simplified advantage computation
    return batch.rew  # IDE knows this exists


# Example usage
obs_data = Batch(obs=np.array([1, 2, 3]), info=np.array([{}], dtype=object))
result = process_observations(obs_data)
print("Processed obs:", result)
Processed obs: [1 2 3]

Key Protocol Types:

  • ActBatchProtocol: Just actions (for simple policies)

  • ObsBatchProtocol: Observations and info

  • RolloutBatchProtocol: Complete transitions (obs, act, rew, done, obs_next)

  • BatchWithReturnsProtocol: Rollouts + computed returns

  • BatchWithAdvantagesProtocol: Returns + advantages and values

  • DistBatchProtocol: Contains distribution objects

  • LogpOldProtocol: For importance sampling (PPO, etc.)

See tianshou/data/types.py for the complete list!

5. Distribution Slicing#

Why Special Handling?#

PyTorch Distribution objects need special slicing because they’re not simple arrays. When you slice batch[0:2], Tianshou needs to slice the underlying distribution parameters correctly.

Supported Distributions#

Tianshou supports slicing for:

  • Categorical: Discrete distributions

  • Normal: Continuous Gaussian distributions

  • Independent: Wraps other distributions

# Categorical distribution
probs = torch.tensor([[0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])
dist = Categorical(probs=probs)
batch = Batch(dist=dist, values=np.array([1, 2, 3]))

print("Original batch length:", len(batch))
print("Original dist probs shape:", batch.dist.probs.shape)

# Slicing automatically handles the distribution
sliced = batch[0:2]
print("\nSliced batch length:", len(sliced))
print("Sliced dist probs shape:", sliced.dist.probs.shape)
print("Sliced values:", sliced.values)
Original batch length: 3
Original dist probs shape: torch.Size([3, 2])

Sliced batch length: 2
Sliced dist probs shape: torch.Size([2, 2])
Sliced values: [1 2]
# Normal distribution
loc = torch.tensor([0.0, 1.0, 2.0])
scale = torch.tensor([1.0, 1.0, 1.0])
normal_dist = Normal(loc=loc, scale=scale)
batch_normal = Batch(dist=normal_dist, actions=np.array([0.5, 1.5, 2.5]))

print("Normal distribution batch:")
print(f"Original mean: {batch_normal.dist.mean}")

# Index a single element
single = batch_normal[1]
print(f"\nSingle element mean: {single.dist.mean}")
print(f"Single element action: {single.actions}")
Normal distribution batch:
Original mean: tensor([0., 1., 2.])

Single element mean: 1.0
Single element action: 1.5

Converting to At Least 2D#

Sometimes you need to ensure distributions have a batch dimension:

from tianshou.data.batch import dist_to_atleast_2d

# Scalar distribution (no batch dimension)
scalar_dist = Categorical(probs=torch.tensor([0.3, 0.7]))
print("Scalar dist batch_shape:", scalar_dist.batch_shape)

# Convert to have batch dimension
batched_dist = dist_to_atleast_2d(scalar_dist)
print("Batched dist batch_shape:", batched_dist.batch_shape)

# For entire batch
scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3)))
print("\nBefore to_at_least_2d:", scalar_batch.dist.batch_shape)

batch_2d = scalar_batch.to_at_least_2d()
print("After to_at_least_2d:", batch_2d.dist.batch_shape)
Scalar dist batch_shape: torch.Size([])
Batched dist batch_shape: torch.Size([1])

Before to_at_least_2d: torch.Size([])
After to_at_least_2d: torch.Size([1])

Use Cases#

Distribution slicing is used in:

  • Policy sampling: When policies output distributions, slicing batches preserves distribution structure

  • Replay buffer sampling: Distributions are stored and retrieved correctly

  • Advantage computation: Computing log probabilities on subsets of data

6. Advanced Topics#

6.1 Key Reservation#

Sometimes you know what keys you’ll need but don’t have values yet. Reserve keys using empty Batch() objects:

        graph TD
    root["Batch"]
    root --> a["key1: np.array([1,2,3])"]
    root --> b["key2: Batch() (reserved)"]
    root --> c["key3"]
    c --> c1["subkey1: Batch() (reserved)"]
    c --> c2["subkey2: np.array([4,5])"]
    
    style root fill:#e1f5ff
    style a fill:#e8f5e1
    style b fill:#ffcccc
    style c fill:#fff4e1
    style c1 fill:#ffcccc
    style c2 fill:#e8f5e1
    
# Reserving keys
batch = Batch(
    known_field=np.array([1, 2]),
    future_field=Batch(),  # Reserved for later
)
print("Batch with reserved key:")
print(batch)

# Later, assign actual data
batch.future_field = np.array([3, 4])
print("\nAfter assignment:")
print(batch)

# Nested reservation
batch2 = Batch(
    obs=Batch(
        camera=Batch(),  # Reserved
        lidar=np.zeros(10),
    )
)
print("\nNested reservation:")
print(batch2)
Batch with reserved key:
Batch(
    known_field: array([1, 2]),
    future_field: Batch(),
)

After assignment:
Batch(
    known_field: array([1, 2]),
    future_field: array([3, 4]),
)

Nested reservation:
Batch(
    obs: Batch(
             camera: Batch(),
             lidar: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
         ),
)

6.2 Length and Shape Semantics#

Understanding when len() works and what shape means:

# Normal case: all tensors same length
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5, 6]))
print("Normal batch:")
print(f"len(batch1) = {len(batch1)}")
print(f"batch1.shape = {batch1.shape}")

# Scalars have no length
batch2 = Batch(a=5, b=10)
print("\nScalar batch:")
print(f"batch2.shape = {batch2.shape}")
try:
    print(f"len(batch2) = {len(batch2)}")
except TypeError as e:
    print(f"len(batch2) raises TypeError: {e}")

# Mixed lengths: returns minimum
batch3 = Batch(a=[1, 2], b=[3, 4, 5])
print("\nMixed length batch:")
print(f"len(batch3) = {len(batch3)} (minimum of 2 and 3)")

# Reserved keys are ignored
batch4 = Batch(a=[1, 2, 3], reserved=Batch())
print("\nBatch with reserved key:")
print(f"len(batch4) = {len(batch4)} (reserved key ignored)")
Normal batch:
len(batch1) = 3
batch1.shape = [3]

Scalar batch:
batch2.shape = []
len(batch2) raises TypeError: Entry for a in Batch(
    a: array(5),
    b: array(10),
) is 5 has no len()

Mixed length batch:
len(batch3) = 2 (minimum of 2 and 3)

Batch with reserved key:
len(batch4) = 3 (reserved key ignored)

6.3 Empty Batches#

Understanding different meanings of “empty”:

# 1. No keys at all
empty1 = Batch()
print("No keys:")
print(f"len(empty1.get_keys()) = {len(list(empty1.get_keys()))}")
print(f"len(empty1) = {len(empty1)}")

# 2. Has keys but they're all reserved
empty2 = Batch(a=Batch(), b=Batch())
print("\nReserved keys only:")
print(f"len(empty2.get_keys()) = {len(list(empty2.get_keys()))}")
print(f"len(empty2) = {len(empty2)}")

# 3. Has data but length is 0
empty3 = Batch(a=np.array([]), b=np.array([]))
print("\nZero-length arrays:")
print(f"len(empty3.get_keys()) = {len(list(empty3.get_keys()))}")
print(f"len(empty3) = {len(empty3)}")
No keys:
len(empty1.get_keys()) = 0
len(empty1) = 0

Reserved keys only:
len(empty2.get_keys()) = 2
len(empty2) = 0

Zero-length arrays:
len(empty3.get_keys()) = 2
len(empty3) = 0

Checking emptiness:

  • len(batch.get_keys()) == 0: No keys (completely empty)

  • len(batch) == 0: No data elements (may have reserved keys)

The .empty() and .empty_() methods: These reset values to zeros/None, different from checking emptiness:

batch = Batch(a=[1, 2, 3], b=["x", "y", "z"])
print("Original:", batch)

# Empty specific index
batch[0] = Batch.empty(batch[0])
print("\nAfter emptying index 0:")
print(batch)
Original: Batch(
    a: array([1, 2, 3]),
    b: array(['x', 'y', 'z'], dtype=object),
)

After emptying index 0:
Batch(
    a: array([0, 2, 3]),
    b: array([None, 'y', 'z'], dtype=object),
)
/home/docs/checkouts/readthedocs.org/user_builds/tianshou/checkouts/stable/tianshou/data/batch.py:1139: UserWarning: You are calling Batch.empty on a NumPy scalar, which may cause undefined behaviors.
  warnings.warn(

6.4 Heterogeneous Aggregation#

Stacking/concatenating batches with different keys:

        graph LR
    A["Batch(a=[1,2], c=5)"] --> C["Batch.stack"]
    B["Batch(b=[3,4], c=6)"] --> C
    C --> D["Batch(a=[[1,2],[0,0]],<br/>b=[[0,0],[3,4]],<br/>c=[5,6])"]
    
    style A fill:#e1f5ff
    style B fill:#fff4e1
    style C fill:#ffe1f5
    style D fill:#e8f5e1
    
# Stack with different keys (missing keys padded with zeros)
batch_a = Batch(a=np.ones((2, 3)), shared=np.array([1, 2]))
batch_b = Batch(b=np.zeros((2, 4)), shared=np.array([3, 4]))

stacked = Batch.stack([batch_a, batch_b])
print("Stacked batch:")
print(f"a.shape = {stacked.a.shape} (padded with zeros for batch_b)")
print(f"b.shape = {stacked.b.shape} (padded with zeros for batch_a)")
print(f"shared.shape = {stacked.shared.shape} (in both batches)")
print(stacked)
Stacked batch:
a.shape = (2, 2, 3) (padded with zeros for batch_b)
b.shape = (2, 2, 4) (padded with zeros for batch_a)
shared.shape = (2, 2) (in both batches)
Batch(
    shared: array([[1, 2],
                   [3, 4]]),
    a: array([[[1., 1., 1.],
               [1., 1., 1.]],
       
              [[0., 0., 0.],
               [0., 0., 0.]]]),
    b: array([[[0., 0., 0., 0.],
               [0., 0., 0., 0.]],
       
              [[0., 0., 0., 0.],
               [0., 0., 0., 0.]]]),
)

6.5 Missing Values#

Handling None and NaN values:

# Batch with missing values
batch = Batch(a=[1, 2, None, 4], b=[5.0, np.nan, 7.0, 8.0], c=[[1, 2], [3, 4], [5, 6], [7, 8]])

# Check for nulls
print("Has null?", batch.hasnull())

# Get null mask
null_mask = batch.isnull()
print("\nNull mask:")
print(f"a: {null_mask.a}")
print(f"b: {null_mask.b}")

# Drop rows with any null
clean_batch = batch.dropnull()
print("\nAfter dropnull() (keeps rows 0 and 3):")
print(f"Length: {len(clean_batch)}")
print(f"a: {clean_batch.a}")
print(f"b: {clean_batch.b}")
Has null? True

Null mask:
a: [False False  True False]
b: [False  True False False]

After dropnull() (keeps rows 0 and 3):
Length: 2
a: [1 4]
b: [5. 8.]

6.6 Value Transformations#

Applying functions to all values recursively:

batch = Batch(a=np.array([1, 2, 3]), nested=Batch(b=np.array([4.0, 5.0]), c=np.array([6, 7, 8])))

# Apply transformation (returns new batch)
doubled = batch.apply_values_transform(lambda x: x * 2)
print("Original batch a:", batch.a)
print("Doubled batch a:", doubled.a)
print("Doubled nested.b:", doubled.nested.b)

# In-place transformation
batch.apply_values_transform(lambda x: x + 10, inplace=True)
print("\nAfter in-place +10:")
print("a:", batch.a)
print("nested.b:", batch.nested.b)
Original batch a: [1 2 3]
Doubled batch a: [2 4 6]
Doubled nested.b: [ 8. 10.]

After in-place +10:
a: [11 12 13]
nested.b: [14. 15.]

7. Surprising Behaviors & Gotchas#

Iteration Does NOT Iterate Over Keys!#

This is the most common source of confusion:

batch = Batch(a=[1, 2, 3], b=[4, 5, 6])

print("WRONG: This doesn't iterate over keys!")
for item in batch:
    print(f"item = {item}")  # Prints batch[0], batch[1], batch[2]

print("\nCORRECT: To iterate over keys:")
for key in batch.keys():
    print(f"key = {key}")

print("\nCORRECT: To iterate over key-value pairs:")
for key, value in batch.items():
    print(f"{key} = {value}")
WRONG: This doesn't iterate over keys!
item = Batch(
    a: 1,
    b: 4,
)
item = Batch(
    a: 2,
    b: 5,
)
item = Batch(
    a: 3,
    b: 6,
)

CORRECT: To iterate over keys:
key = a
key = b

CORRECT: To iterate over key-value pairs:
a = [1 2 3]
b = [4 5 6]

Automatic Type Conversions#

Be aware of these automatic conversions:

# Lists become arrays
batch = Batch(a=[1, 2, 3])
print("List → array:", type(batch.a), batch.a.dtype)

# Dicts become Batch
batch = Batch(a={"x": 1, "y": 2})
print("Dict → Batch:", type(batch.a))

# Scalars become numpy scalars
batch = Batch(a=5)
print("Scalar → np.ndarray:", type(batch.a), batch.a)

# Mixed types → object dtype
batch = Batch(a=[1, "hello", None])
print("Mixed → object:", batch.a.dtype, batch.a)
List → array: <class 'numpy.ndarray'> int64
Dict → Batch: <class 'tianshou.data.batch.Batch'>
Scalar → np.ndarray: <class 'numpy.ndarray'> 5
Mixed → object: object [1 'hello' None]

Length Edge Cases#

# 1. Scalars have no length
batch_scalar = Batch(a=5, b=10)
try:
    len(batch_scalar)
except TypeError as e:
    print(f"Scalar batch: {e}")

# 2. Empty nested batches ignored in len()
batch_empty_nested = Batch(a=[1, 2, 3], b=Batch())
print(f"\nWith empty nested: len = {len(batch_empty_nested)} (ignores b)")

# 3. Different lengths: returns minimum
batch_different = Batch(a=[1, 2], b=[1, 2, 3, 4])
print(f"Different lengths: len = {len(batch_different)} (minimum)")

# 4. None values don't affect length
batch_none = Batch(a=[1, 2, 3], b=None)
print(f"With None: len = {len(batch_none)} (None ignored)")
Scalar batch: Entry for a in Batch(
    a: array(5),
    b: array(10),
) is 5 has no len()

With empty nested: len = 3 (ignores b)
Different lengths: len = 2 (minimum)
With None: len = 3 (None ignored)

String Keys Only#

# Integer keys not allowed
try:
    batch = Batch({1: "value", 2: "other"})
except AssertionError as e:
    print("Integer keys not allowed:", e)

# String keys work
batch = Batch({"key1": "value", "key2": "other"})
print("\nString keys work:", list(batch.keys()))
Integer keys not allowed: keys should all be string, but got dict_keys([1, 2])

String keys work: ['key1', 'key2']

Cat vs Stack Behavior#

Recent changes have made concatenation stricter about structure:

# Stack pads missing keys with zeros
b1 = Batch(a=[1, 2])
b2 = Batch(b=[3, 4])
stacked = Batch.stack([b1, b2])
print("Stack (different keys):")
print(f"  a: {stacked.a}  (b2.a padded with 0)")
print(f"  b: {stacked.b}  (b1.b padded with 0)")

# Cat requires same structure now
b3 = Batch(a=[1, 2], b=[3, 4])
b4 = Batch(a=[5, 6], b=[7, 8])
concatenated = Batch.cat([b3, b4])
print("\nCat (same keys):")
print(f"  a: {concatenated.a}")
print(f"  b: {concatenated.b}")

# Cat with different structures raises error
try:
    Batch.cat([b1, b2])  # Different keys!
except ValueError:
    print("\nCat with different keys: ValueError raised")
Stack (different keys):
  a: [[1 2]
 [0 0]]  (b2.a padded with 0)
  b: [[0 0]
 [3 4]]  (b1.b padded with 0)

Cat (same keys):
  a: [1 2 5 6]
  b: [3 4 7 8]

Cat with different keys: ValueError raised

8. Best Practices#

When to Use Batch#

Good use cases:

  • Collecting environment data (transitions, episodes)

  • Storing replay buffer data

  • Passing data between components (collector → buffer → policy)

  • Handling heterogeneous observations (dict spaces)

Consider alternatives:

  • Simple scalar tracking (use regular variables)

  • Pure tensor operations (use PyTorch tensors directly)

  • Deeply nested arbitrary structures (use dataclasses)

Structuring Your Batches#

Use protocols for type safety:

# Good: Use protocols for clear interfaces
def train_step(batch: RolloutBatchProtocol) -> float:
    """IDE knows what fields exist."""
    loss = ((batch.rew - 0.5) ** 2).mean()  # Type-safe
    return float(loss)


# Create properly typed batch
train_batch = cast(
    RolloutBatchProtocol,
    Batch(
        obs=np.random.randn(10, 4),
        obs_next=np.random.randn(10, 4),
        act=np.random.randint(0, 2, 10),
        rew=np.random.randn(10),
        terminated=np.zeros(10, dtype=bool),
        truncated=np.zeros(10, dtype=bool),
        info=np.array([{}] * 10, dtype=object),
    ),
)

loss = train_step(train_batch)
print(f"Loss: {loss:.4f}")
Loss: 0.4739

Consistent key naming:

  • Follow Tianshou conventions: obs, act, rew, terminated, truncated

  • Use descriptive names: camera_obs not co

  • Avoid name collisions with Batch methods: don’t use keys, items, get, etc.

When to nest vs flatten:

# Good: Nest related data
batch_nested = Batch(
    obs=Batch(
        camera=np.zeros((32, 64, 64, 3)), lidar=np.zeros((32, 100)), position=np.zeros((32, 3))
    ),
    act=np.zeros(32),
)
print("Nested structure for related obs:")
print(f"  Access: batch.obs.camera.shape = {batch_nested.obs.camera.shape}")

# Less good: Flat structure loses semantic grouping
batch_flat = Batch(
    camera=np.zeros((32, 64, 64, 3)),
    lidar=np.zeros((32, 100)),
    position=np.zeros((32, 3)),
    act=np.zeros(32),
)
print("\nFlat structure (works but less clear):")
print(f"  Access: batch.camera.shape = {batch_flat.camera.shape}")
Nested structure for related obs:
  Access: batch.obs.camera.shape = (32, 64, 64, 3)

Flat structure (works but less clear):
  Access: batch.camera.shape = (32, 64, 64, 3)

Performance Tips#

Use in-place operations:

import time

batch = Batch(a=np.random.randn(1000, 100))

# Creates copy
start = time.time()
for _ in range(100):
    _ = batch.to_torch()
time_copy = time.time() - start

# In-place (faster)
start = time.time()
for _ in range(100):
    batch.to_torch_()
    batch.to_numpy_()
time_inplace = time.time() - start

print(f"Copy: {time_copy:.4f}s")
print(f"In-place: {time_inplace:.4f}s")
print(f"Speedup: {time_copy / time_inplace:.1f}x")
Copy: 0.0044s
In-place: 0.0011s
Speedup: 4.1x

Be mindful of copies:

arr = np.array([1, 2, 3])

# Default: creates reference (be careful!)
batch1 = Batch(a=arr)
batch1.a[0] = 999
print(f"Original array modified: {arr}")  # Changed!

# Explicit copy when needed
arr = np.array([1, 2, 3])
batch2 = Batch(a=arr, copy=True)
batch2.a[0] = 999
print(f"Original array preserved: {arr}")  # Unchanged
Original array modified: [999   2   3]
Original array preserved: [1 2 3]

Avoid unnecessary conversions:

# Inefficient: multiple conversions
batch = Batch(a=np.random.randn(100, 10))
batch.to_torch_()
batch.to_numpy_()  # Unnecessary if we just need NumPy

# Efficient: convert once, use many times
batch = Batch(a=np.random.randn(100, 10))
batch.to_torch_()  # Convert once
# ... do torch operations ...
# Keep as torch if that's what you need!

Common Patterns#

Pattern 1: Building batches incrementally

# Collect data from multiple steps
step_data = []
for i in range(5):
    step_data.append({"obs": np.random.randn(4), "act": i, "rew": np.random.randn()})

# Convert to batch (automatically stacks)
episode_batch = Batch(step_data)
print("Episode batch shape:", episode_batch.shape)
print("obs shape:", episode_batch.obs.shape)
Episode batch shape: [5]
obs shape: (5, 4)

Pattern 2: Slicing for mini-batches

# Large batch
large_batch = Batch(obs=np.random.randn(100, 4), act=np.random.randint(0, 2, 100))

# Split into mini-batches
batch_size = 32
for mini_batch in large_batch.split(batch_size, shuffle=True):
    print(f"Mini-batch size: {len(mini_batch)}")
    # Train on mini_batch...
    break  # Just show one iteration
Mini-batch size: 32

Pattern 3: Extending batches

# Start with some data
batch = Batch(obs=np.array([[1, 2], [3, 4]]), act=np.array([0, 1]))
print("Initial:", len(batch))

# Add more data
new_data = Batch(obs=np.array([[5, 6]]), act=np.array([1]))
batch.cat_(new_data)
print("After cat_:", len(batch))
print("obs:", batch.obs)
Initial: 2
After cat_: 3
obs: [[1 2]
 [3 4]
 [5 6]]

9. Summary#

Key Takeaways#

  1. Batch = Dict + Array: Combines key-value storage with array operations

  2. Hierarchical Structure: Perfect for complex RL data (nested observations, etc.)

  3. Type Safety via Protocols: Use BatchProtocol subclasses for IDE support and type checking

  4. Special RL Features: Distribution slicing, heterogeneous aggregation, missing value handling

  5. Remember: Iteration is over indices, NOT keys!

Quick Reference#

Operation

Code

Notes

Create

Batch(a=1, b=[2, 3])

Auto-converts types

Access

batch.a or batch["a"]

Equivalent

Index

batch[0], batch[:10]

Returns sliced Batch

Iterate indices

for item in batch:

Yields batch[0], batch[1], …

Iterate keys

for k in batch.keys():

Like dict

Stack

Batch.stack([b1, b2])

Adds dimension

Concatenate

Batch.cat([b1, b2])

Extends dimension

Split

batch.split(size=10)

Returns iterator

To PyTorch

batch.to_torch_()

In-place

To NumPy

batch.to_numpy_()

In-place

Transform

batch.apply_values_transform(fn)

Recursive

Next Steps#

  • Collector Deep Dive: See how Batch flows through data collection

  • Buffer Deep Dive: Understand how Batch is stored and sampled

  • Policy Guide: Learn how policies work with BatchProtocol

  • API Reference: Full details at Batch API documentation

Questions?#

Appendix: Serialization & Advanced Topics#

Pickle Support#

# Batch objects are picklable
original = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))

# Serialize and deserialize
serialized = pickle.dumps(original)
restored = pickle.loads(serialized)

print("Original obs.a:", original.obs.a)
print("Restored obs.a:", restored.obs.a)
print("Equal:", original == restored)
Original obs.a: 0.0
Restored obs.a: 0.0
Equal: True

Advanced Indexing#

# Multi-dimensional data
batch = Batch(a=np.random.randn(5, 3, 2))
print("Original shape:", batch.a.shape)

# Various indexing operations
print("batch[0].a.shape:", batch[0].a.shape)
print("batch[:, 0].a.shape:", batch[:, 0].a.shape)
print("batch[[0, 2, 4]].a.shape:", batch[[0, 2, 4]].a.shape)
Original shape:
 (5, 3, 2)
batch[0].a.shape: (3, 2)
batch[:, 0].a.shape: (5, 2)
batch[[0, 2, 4]].a.shape: (3, 3, 2)