Collector#
The Collector serves as the orchestration layer between the policy (agent) and the environment in Tianshou’s architecture. It manages the interaction loop, persists collected experiences to a replay buffer, and computes episode-level statistics. This module is fundamental to both training data collection and policy evaluation workflows.
Core Applications#
The Collector supports two primary use cases in reinforcement learning experiments:
Training: Collecting interaction data for policy optimization
Evaluation: Assessing policy performance without learning
Policy Evaluation#
Periodic policy evaluation is essential in deep reinforcement learning (DRL) experiments to monitor training progress and assess generalization. The Collector provides a standardized interface for this purpose.
Setup: A Collector requires two components:
An environment (or vectorized environment for parallelization)
A policy instance to evaluate
Show code cell content
import gymnasium as gym
import torch
from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import DiscreteActor
# Initialize single environment for configuration
env = gym.make("CartPole-v1")
# Create vectorized test environments (2 parallel environments)
test_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(2)])
# Configure neural network architecture
assert env.observation_space.shape is not None # for mypy
preprocess_net = Net(
state_shape=env.observation_space.shape,
hidden_sizes=[
16,
],
)
# Initialize discrete action actor network
assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy
actor = DiscreteActor(preprocess_net=preprocess_net, action_shape=env.action_space.n)
# Create policy with categorical action distribution
policy = ProbabilisticActorPolicy(
actor=actor,
dist_fn=torch.distributions.Categorical,
action_space=env.action_space,
action_scaling=False,
)
# Initialize collector for evaluation
test_collector = Collector[CollectStats](policy, test_envs)
Evaluating Untrained Policy Performance#
We now evaluate the randomly initialized policy across 9 episodes to establish a baseline performance metric:
# Collect 9 complete episodes with environment reset
collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)
collect_result.pprint_asdict()
CollectStats
----------------------------------------
{ 'collect_speed': 2250.5430326407936,
'collect_time': 0.06354022026062012,
'lens': array([10, 12, 14, 20, 9, 20, 13, 26, 19]),
'lens_stat': { 'max': 26.0,
'mean': 15.88888888888889,
'min': 9.0,
'std': 5.32174667325691},
'n_collected_episodes': 9,
'n_collected_steps': 143,
'pred_dist_std_array': array([[0.4880399 ],
[0.48841825],
[0.48649663],
[0.48687008],
[0.4848843 ],
[0.48525038],
[0.4832023 ],
[0.48355895],
[0.48156515],
[0.4853428 ],
[0.48016095],
[0.48363343],
[0.47871548],
[0.48185053],
[0.47650644],
[0.48036504],
[0.4736009 ],
[0.48193374],
[0.47040978],
[0.48008174],
[0.4880156 ],
[0.47717053],
[0.48936296],
[0.4739934 ],
[0.4880218 ],
[0.48857704],
[0.48933557],
[0.4869284 ],
[0.48803595],
[0.48857388],
[0.48930436],
[0.48692062],
[0.48805818],
[0.4851996 ],
[0.48931113],
[0.48346046],
[0.48849547],
[0.4853013 ],
[0.48935422],
[0.48703316],
[0.48855874],
[0.4852706 ],
[0.48786786],
[0.48345315],
[0.48864773],
[0.48181203],
[0.4873265 ],
[0.48011354],
[0.48846525],
[0.47722822],
[0.48670128],
[0.4738468 ],
[0.48796389],
[0.4882833 ],
[0.48614302],
[0.4866622 ],
[0.48743242],
[0.48497903],
[0.48893616],
[0.48681018],
[0.48841935],
[0.4884912 ],
[0.48673964],
[0.4868552 ],
[0.4850666 ],
[0.4885452 ],
[0.4833312 ],
[0.48689267],
[0.48509154],
[0.4885872 ],
[0.48333645],
[0.4895367 ],
[0.4816299 ],
[0.4885957 ],
[0.4793524 ],
[0.48689628],
[0.4763822 ],
[0.48520607],
[0.4881911 ],
[0.48345265],
[0.48881495],
[0.48191962],
[0.48822734],
[0.4804703 ],
[0.48669112],
[0.4789856 ],
[0.48508543],
[0.47746256],
[0.4834098 ],
[0.4747317 ],
[0.48176047],
[0.4715309 ],
[0.48356998],
[0.48812628],
[0.48535717],
[0.48650604],
[0.48362663],
[0.4882281 ],
[0.48543057],
[0.4866064 ],
[0.4871421 ],
[0.48492277],
[0.4887113 ],
[0.483177 ],
[0.48941064],
[0.4816125 ],
[0.4888766 ],
[0.48020184],
[0.48937303],
[0.4787505 ],
[0.48829105],
[0.47725698],
[0.48666742],
[0.47901878],
[0.4849916 ],
[0.47673035],
[0.4866427 ],
[0.47351408],
[0.4849519 ],
[0.48838204],
[0.48658875],
[0.48957685],
[0.48488182],
[0.4883188 ],
[0.4824834 ],
[0.4866633 ],
[0.48458612],
[0.48833394],
[0.48144317],
[0.48667425],
[0.48494807],
[0.48674086],
[0.48499957],
[0.4831906 ],
[0.48508227],
[0.4832453 ],
[0.48509973],
[0.4832335 ],
[0.48505294],
[0.4831555 ],
[0.48494118],
[0.48301002],
[0.48060977]], dtype=float32),
'pred_dist_std_array_stat': { 0: { 'max': 0.4895768463611603,
'mean': 0.4845207631587982,
'min': 0.4704097807407379,
'std': 0.004213924985378981}},
'returns': array([10., 12., 14., 20., 9., 20., 13., 26., 19.]),
'returns_stat': { 'max': 26.0,
'mean': 15.88888888888889,
'min': 9.0,
'std': 5.32174667325691}}
Baseline Comparison: Random Policy#
To contextualize the initialized policy’s performance, we establish a random action baseline:
# Evaluate random policy by sampling actions uniformly from action space
collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)
collect_result.pprint_asdict()
CollectStats
----------------------------------------
{ 'collect_speed': 7591.167553000646,
'collect_time': 0.029903173446655273,
'lens': array([14, 24, 39, 18, 14, 21, 70, 17, 10]),
'lens_stat': { 'max': 70.0,
'mean': 25.22222222222222,
'min': 10.0,
'std': 17.69355047453907},
'n_collected_episodes': 9,
'n_collected_steps': 227,
'pred_dist_std_array': None,
'pred_dist_std_array_stat': None,
'returns': array([14., 24., 39., 18., 14., 21., 70., 17., 10.]),
'returns_stat': { 'max': 70.0,
'mean': 25.22222222222222,
'min': 10.0,
'std': 17.69355047453907}}
Observation: The randomly initialized policy performs comparably to (or worse than) uniform random actions prior to training. This is expected behavior, as the network weights lack task-specific optimization.
Training Data Collection#
During the training phase, the Collector manages experience gathering and automatic storage in a replay buffer. This enables the experience replay mechanism fundamental to off-policy algorithms.
# Configuration for parallel training data collection
train_env_num = 4
buffer_size = 100
# Initialize vectorized training environments
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(train_env_num)])
# Create replay buffer compatible with vectorized environments
replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)
# Initialize training collector with buffer integration
training_collector = Collector[CollectStats](policy, train_envs, replayBuffer)
Step-Based Collection#
The Collector supports both step-based and episode-based collection modes. Here we demonstrate step-based collection, which is commonly used in training loops with fixed update frequencies.
Note: When using vectorized environments, the actual number of collected steps may exceed the requested amount to maintain synchronization across parallel environments.
# Reset collector and buffer to clean state
training_collector.reset()
replayBuffer.reset()
print(f"Replay buffer before collecting is empty, and has length={len(replayBuffer)} \n")
# Collect 50 environment steps
n_step = 50
collect_result = training_collector.collect(n_step=n_step)
print(
f"Replay buffer after collecting {n_step} steps has length={len(replayBuffer)}.\n"
f"The actual count may exceed n_step when it is not a multiple of train_env_num \n"
f"due to vectorization synchronization requirements.\n",
)
collect_result.pprint_asdict()
Replay buffer before collecting is empty, and has length=0
Replay buffer after collecting 50 steps has length=52.
The actual count may exceed n_step when it is not a multiple of train_env_num
due to vectorization synchronization requirements.
CollectStats
----------------------------------------
{ 'collect_speed': 4338.561158520817,
'collect_time': 0.011985540390014648,
'lens': array([10]),
'lens_stat': {'max': 10.0, 'mean': 10.0, 'min': 10.0, 'std': 0.0},
'n_collected_episodes': 1,
'n_collected_steps': 52,
'pred_dist_std_array': array([[0.48837554],
[0.48864424],
[0.48843145],
[0.48820987],
[0.48683757],
[0.4869983 ],
[0.48685384],
[0.48659655],
[0.48848015],
[0.4853471 ],
[0.48840535],
[0.4849213 ],
[0.4869011 ],
[0.48364106],
[0.48949882],
[0.48676533],
[0.48529765],
[0.4818775 ],
[0.4883997 ],
[0.48845357],
[0.4870329 ],
[0.48044658],
[0.4881135 ],
[0.48682684],
[0.4853813 ],
[0.47867474],
[0.488251 ],
[0.48513603],
[0.48369956],
[0.4805535 ],
[0.48945394],
[0.48694295],
[0.4855337 ],
[0.47792098],
[0.4882023 ],
[0.48523945],
[0.4837826 ],
[0.47479466],
[0.48944265],
[0.48702294],
[0.4856499 ],
[0.48841545],
[0.4881293 ],
[0.48865703],
[0.4873724 ],
[0.4868799 ],
[0.48945132],
[0.48914137],
[0.4889514 ],
[0.4852741 ],
[0.4881997 ],
[0.48857498]], dtype=float32),
'pred_dist_std_array_stat': { 0: { 'max': 0.4894988238811493,
'mean': 0.48627084493637085,
'min': 0.4747946560382843,
'std': 0.0031196388881653547}},
'returns': array([10.]),
'returns_stat': {'max': 10.0, 'mean': 10.0, 'min': 10.0, 'std': 0.0}}
/home/docs/checkouts/readthedocs.org/user_builds/tianshou/checkouts/stable/tianshou/data/collector.py:539: UserWarning: n_step=50 is not a multiple of (self.env_num=4), which may cause extra transitions being collected into the buffer.
warnings.warn(
Buffer Sampling Verification#
Verify that collected experiences are properly stored and can be sampled for training:
# Sample mini-batch of 10 transitions from buffer
replayBuffer.sample(10)
(Batch(
obs: array([[-1.4802839e-02, -3.5696667e-01, -1.6063595e-02, 4.9230999e-01],
[-1.1451793e-03, -1.6306867e-01, -3.4881335e-02, 2.2667289e-01],
[-6.1540153e-02, 3.2022689e-02, 4.8842553e-02, -6.5576687e-02],
[-5.5478465e-02, -1.1462041e+00, 9.7980194e-02, 1.7694761e+00],
[ 6.4153850e-02, 4.0744472e-01, 1.6117804e-02, -4.6681765e-01],
[ 3.9644178e-02, 2.1402776e-01, 4.5537446e-02, -2.1188019e-01],
[ 9.3507856e-02, 4.0738186e-01, -1.0022986e-02, -4.6533665e-01],
[ 8.8926248e-02, 1.6914638e-02, -9.1108372e-03, 1.2496795e-01],
[-7.0721351e-02, -4.0544215e-01, -9.5220339e-03, 5.3128541e-01],
[-6.1979044e-02, -2.1098430e-01, -1.8546028e-02, 2.5324982e-01]],
dtype=float32),
act: array([0, 0, 0, 1, 0, 1, 0, 1, 1, 1]),
rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
terminated: array([False, False, False, False, False, False, False, False, False,
False]),
truncated: array([False, False, False, False, False, False, False, False, False,
False]),
done: array([False, False, False, False, False, False, False, False, False,
False]),
obs_next: array([[-2.1942172e-02, -5.5185843e-01, -6.2173950e-03, 7.7988738e-01],
[-4.4065523e-03, -3.5767519e-01, -3.0347878e-02, 5.0815207e-01],
[-6.0899697e-02, -1.6376427e-01, 4.7531020e-02, 2.4210753e-01],
[-7.8402549e-02, -9.5231527e-01, 1.3336971e-01, 1.5087979e+00],
[ 7.2302744e-02, 2.1209879e-01, 6.7814514e-03, -1.6909839e-01],
[ 4.3924734e-02, 4.0847006e-01, 4.1299842e-02, -4.8985788e-01],
[ 1.0165550e-01, 2.1240297e-01, -1.9329719e-02, -1.7582971e-01],
[ 8.9264542e-02, 2.1216592e-01, -6.6114785e-03, -1.7057537e-01],
[-7.8830197e-02, -2.1018755e-01, 1.1036744e-03, 2.3561738e-01],
[-6.6198729e-02, -1.5602514e-02, -1.3481031e-02, -4.5224685e-02]],
dtype=float32),
info: Batch(
env_id: array([0, 0, 0, 1, 2, 2, 2, 2, 3, 3]),
),
policy: Batch(),
),
array([ 6, 3, 12, 31, 56, 53, 62, 60, 81, 78]))
Advanced Topics#
Asynchronous Collection#
The standard Collector implementation may collect more steps than requested when using vectorized environments. In the example above, requesting 50 steps resulted in 52 steps (the smallest multiple of 4 that is ≥50).
For scenarios requiring precise step counts, Tianshou provides the AsyncCollector, which enables exact step collection at the cost of additional implementation complexity. This is particularly relevant for:
Strict reproducibility requirements
Algorithms sensitive to exact batch sizes
Fine-grained control over data collection
Consult the AsyncCollector documentation for implementation details and usage patterns.