Dual APIs#
Tianshou provides two distinct APIs to serve different use cases and user preferences:
high-level API: a declarative, configuration-based interface designed for ease of use
procedural API: a flexible, imperative interface providing maximum control
Both APIs access the same underlying algorithm implementations, allowing you to choose the level of abstraction that best fits your needs without sacrificing functionality.
Overview#
High-Level API#
The high-level API is built around the builder pattern and declarative semantics. Instead of writing procedural code that sequentially constructs and connects components, you declare what you want through configuration objects and let Tianshou handle how to build and execute the experiment.
Key characteristics:
centered around
ExperimentBuilderclasses (e.g.,DQNExperimentBuilder,PPOExperimentBuilder, etc.)uses configuration dataclasses and factories for all relevant parameters
automatically handles component creation and “wiring”
provides sensible defaults that adapt to the nature of your environment
includes built-in persistence, logging, and experiment management
full type hints (but object structure is not flat; a proper IDE is required for seamless user experience)
Procedural API#
The procedural API provides explicit control over every component in the RL pipeline. You manually create environments, networks, policies, algorithms, collectors, and trainers, then wire them together.
Key characteristics:
direct instantiation of all components
explicit control over the training loop
lower-level access to internal mechanisms
minimal abstraction (closer to the implementation)
ideal for algorithm development and research
When to Use Which API#
Use the high-level API when …
you’re applying existing algorithms to new problems
you want to get started quickly with minimal boilerplate
you need experiment management with persistence, logging, and reproducibility
you prefer declarative code that focuses on configuration
you’re building applications rather than developing new algorithms
Use the procedural API when:
you’re developing new algorithms or modifying existing ones
you need fine-grained control over the training process
you want to understand the internal workings of Tianshou
you’re implementing custom components not supported by the high-level API
you prefer imperative programming where each step is explicit
you need maximum flexibility for experimental research
Comparison by Example#
Let’s compare both APIs by implementing the same DQN learning task on the CartPole environment.
High-Level API Example#
from tianshou.highlevel.config import OffPolicyTrainingConfig
from tianshou.highlevel.env import EnvFactoryRegistered, VectorEnvType
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
from tianshou.highlevel.params.algorithm_params import DQNParams
from tianshou.highlevel.trainer import EpochStopCallbackRewardThreshold
# Build the experiment through configuration
experiment = (
DQNExperimentBuilder(
# Environment configuration
EnvFactoryRegistered(
task="CartPole-v1",
venv_type=VectorEnvType.DUMMY,
training_seed=0,
test_seed=10,
),
# Experiment settings
ExperimentConfig(
persistence_enabled=False,
watch=True,
watch_render=1 / 35,
watch_num_episodes=100,
),
# Training configuration
OffPolicyTrainingConfig(
max_epochs=10,
epoch_num_steps=10000,
batch_size=64,
num_training_envs=10,
num_test_envs=100,
buffer_size=20000,
collection_step_num_env_steps=10,
update_step_num_gradient_steps_per_sample=1 / 10,
),
)
# Algorithm-specific parameters
.with_dqn_params(
DQNParams(
lr=1e-3,
gamma=0.9,
n_step_return_horizon=3,
target_update_freq=320,
eps_training=0.3,
eps_inference=0.0,
),
)
# Network architecture
.with_model_factory_default(hidden_sizes=(64, 64))
# Stop condition
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
.build()
)
# Run the experiment
experiment.run()
What’s happening here:
We create an
ExperimentBuilderwith three main configuration objectsWe chain builder methods to specify algorithm parameters, model architecture, and callbacks
We call
.build()to construct the experimentWe call
.run()to execute the entire training pipeline
The high-level API handles …
creating and configuring environments
building the neural network
instantiating the policy and algorithm
setting up collectors and replay buffer
managing the training loop
watching the trained agent
Procedural API Example#
import gymnasium as gym
import tianshou as ts
from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy
from tianshou.algorithm.optim import AdamOptimizerFactory
from tianshou.data import CollectStats
from tianshou.trainer import OffPolicyTrainerParams
from tianshou.utils.net.common import Net
from tianshou.utils.space_info import SpaceInfo
from torch.utils.tensorboard import SummaryWriter
# Define hyperparameters
task = "CartPole-v1"
lr, epoch, batch_size = 1e-3, 10, 64
num_training_envs, num_test_envs = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
epoch_num_steps, collection_step_num_env_steps = 10000, 10
# Set up logging
logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn"))
# Create environments
training_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
# Build the network
env = gym.make(task, render_mode="human")
space_info = SpaceInfo.from_env(env)
state_shape = space_info.observation_info.obs_shape
action_shape = space_info.action_info.action_shape
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
# Create policy and algorithm
policy = DiscreteQLearningPolicy(
model=net,
action_space=env.action_space,
eps_training=eps_train,
eps_inference=eps_test,
)
algorithm = ts.algorithm.DQN(
policy=policy,
optim=AdamOptimizerFactory(lr=lr),
gamma=gamma,
n_step_return_horizon=n_step,
target_update_freq=target_freq,
)
# Set up collectors
training_collector = ts.data.Collector[CollectStats](
algorithm,
training_envs,
ts.data.VectorReplayBuffer(buffer_size, num_training_envs),
exploration_noise=True,
)
test_collector = ts.data.Collector[CollectStats](
algorithm,
test_envs,
exploration_noise=True,
)
# Define stop condition
def stop_fn(mean_rewards: float) -> bool:
if env.spec and env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
return False
# Train the algorithm
result = algorithm.run_training(
OffPolicyTrainerParams(
training_collector=training_collector,
test_collector=test_collector,
max_epochs=epoch,
epoch_num_steps=epoch_num_steps,
collection_step_num_env_steps=collection_step_num_env_steps,
test_step_num_episodes=num_test_envs,
batch_size=batch_size,
update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps,
stop_fn=stop_fn,
logger=logger,
test_in_training=True,
)
)
print(f"Finished training in {result.timing.total_time} seconds")
# Watch the trained agent
collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True)
collector.collect(n_episode=100, render=1 / 35)
What’s happening here:
We explicitly define all hyperparameters as variables
We manually create the logger
We construct training and test environments
We build the neural network by extracting space information from the environment
We create the policy and algorithm objects
We set up collectors with a replay buffer
We define callback functions
We call
algorithm.run_training()with explicit parametersWe manually set up and run the evaluation collector
The procedural API requires …
explicit creation of every component
manual extraction of environment properties
direct specification of all connections
Key Concepts in the High-Level API#
ExperimentBuilder#
The ExperimentBuilder is the core abstraction.
Each algorithm has its own builder (e.g., DQNExperimentBuilder, PPOExperimentBuilder, SACExperimentBuilder).
Some methods you will find in experiment builders:
.with_<algorithm>_params()- Set algorithm-specific parameters.with_model_factory(),.with_model_factory_default()- Configure network architecture.with_critic_factory()- Configure critic network (for actor-critic methods).with_epoch_train_callback()- Add function to be called at the beginning of the training step in each epoch.with_epoch_test_callback()- Add function to be called at the beginning of the test step in each epoch.with_epoch_stop_callback()- Define stopping conditions.with_algorithm_wrapper_factory()- Add algorithm wrappers (e.g., ICM)
Configuration Objects#
Three main configuration objects are required when constructing an experiment builder:
Environment Configuration (
EnvFactorysubclasses)Defines how to create and configure environments
Existing factories:
EnvFactoryRegistered- For the creation of environments registered in GymnasiumAtariEnvFactory- For Atari environments with preprocessing
Custom factories for your own environments can be created by subclassing
EnvFactory
Experiment Configuration (
ExperimentConfig): General settings for the experiment, particularly related tologging
randomization
persistence
watching the trained agent’s performance after training
Training Configuration (
OffPolicyTrainingConfig,OnPolicyTrainingConfig): Defines all parameters related to the training process
Parameter Classes#
Algorithm parameters are defined in dataclasses specific to each algorithm (e.g., DQNParams, PPOParams).
The parameters are extensively documented.
Note
Make sure to use a modern IDE to take advantage of auto-completion and inline documentation!
Factories#
The high-level API uses factories extensively:
Model Factories: Create neural networks (e.g.,
IntermediateModuleFactoryAtariDQN)Environment Factories: Create and configure environments
Optimizer Factories: Create optimizers with specific configurations
Extensibility#
The high-level API is designed to be extensible. You can create custom factories (e.g. for your own models or your own environments) by subclassing the appropriate base classes and then use them in the experiment builder.
If we have created a torch module in CustomNetwork, which we want to use within our policy,
we simply need to define a factory for it in order to apply it in the high-level API:
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory, IntermediateModule
class CustomNetFactory(IntermediateModuleFactory):
def __init__(self, hidden_sizes: tuple[int, ...] = (128, 128)):
self.hidden_sizes = hidden_sizes
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
obs_shape = envs.get_observation_shape()
action_shape = envs.get_action_shape()
# Your custom network creation logic
net = CustomNetwork(
obs_shape=obs_shape,
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
).to(device)
return IntermediateModule(net, net.output_dim)
experiment = (
DQNExperimentBuilder(...)
.with_model_factory(CustomNetFactory(hidden_sizes=(256, 256)))
.build()
)
Key Concepts in the Procedural API#
Core Components#
You manually create and connect …
environments: e.g. using
gym.make()and vectorization (DummyVectorEnv,SubprocVectorEnv)networks: using
Netor other PyTorch modulespolicies: using algorithm-specific policy classes (e.g.,
DiscreteQLearningPolicy)collectors: using
Collectorto gather experiencebuffers: using
VectorReplayBufferorReplayBuffertrainers: using the respective trainer class and corresponding parameter class (e.g.,
OffPolicyTrainerandOffPolicyTrainerParams)
Training Loop#
The training is executed via algorithm.run_training(), which takes a trainer parameter object.
You can alternatively implement custom training loops (or even your own trainer class) for maximum flexibility.
Additional Resources#
high-Level API examples: See
examples/directory (scripts ending in_hl.py)procedural API examples: See
examples/directory (scripts without suffix)