trainer#
Source code: tianshou/trainer.py
This module contains Tianshou’s trainer classes, which orchestrate the training and call upon an RL algorithm’s specific network updating logic to perform the actual gradient updates.
Training is structured as follows (hierarchical glossary):
epoch: the outermost iteration level of the training loop. Each epoch consists of a number of training steps and one test step (see
TrainerParams.max_epochfor a detailed explanation).training step: a training step performs the steps necessary in order to apply a single update of the neural network components as defined by the underlying RL algorithm (
Algorithm). This involves the following sub-steps:for online learning algorithms:
collection step: collecting environment steps/transitions to be used for training.
(Potentially) a test step (see below) if the early stopping criterion is satisfied based on the data collected (see
OnlineTrainerParams.test_in_train).
update step: applying the actual gradient updates using the RL algorithm. The update is based on either:
data from only the preceding collection step (on-policy learning),
data from the collection step and previously collected data (off-policy learning), or
data from the user-provided replay buffer (offline learning).
For offline learning algorithms, a training step is thus equivalent to an update step.
test step: collects test episodes from dedicated test environments which are used to evaluate the performance of the policy. Optionally, the performance result can be used to determine whether training shall stop early (see
TrainerParams.stop_fn).
- class TrainerParams(*, max_epochs: int = 100, epoch_num_steps: int = 30000, test_collector: BaseCollector | None = None, test_step_num_episodes: int = 1, training_fn: collections.abc.Callable[[int, int], None] | None = None, test_fn: collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: collections.abc.Callable[[float], bool] | None = None, compute_score_fn: collections.abc.Callable[[CollectStats], float] | None = None, save_best_fn: collections.abc.Callable[['Algorithm'], None] | None = None, save_checkpoint_fn: collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, multi_agent_return_reduction: collections.abc.Callable[[numpy.ndarray], numpy.ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True)[source]#
Bases:
ToStringMixin- max_epochs: int = 100#
the (maximum) number of epochs to run training for. An epoch is the outermost iteration level and each epoch consists of a number of training steps and one test step, where each training step
[for the online case] collects environment steps/transitions (collection step), adding them to the (replay) buffer (see
collection_step_num_env_stepsandcollection_step_num_episodes)performs an update step via the RL algorithm being used, which can involve one or more actual gradient updates, depending on the algorithm
and the test step collects
num_episodes_per_testtest episodes in order to evaluate agent performance.Training may be stopped early if the stop criterion is met (see
stop_fn).For online training, the number of training steps in each epoch is indirectly determined by
epoch_num_steps: As many training steps will be performed as are required in order to reachepoch_num_stepstotal steps in the training environments. Specifically, if the number of transitions collected per step is c (seecollection_step_num_env_steps) andepoch_num_stepsis set to s, then the number of training steps per epoch is ceil(s / c). Therefore, if max_epochs = e, the total number of environment steps taken during training can be computed as e * ceil(s / c) * c.For offline training, the number of training steps per epoch is equal to
epoch_num_steps.
- epoch_num_steps: int = 30000#
For an online algorithm, this is the total number of environment steps to be collected per epoch, and, for an offline algorithm, it is the total number of training steps to take per epoch. See
max_epochsfor an explanation of epoch semantics.
- test_collector: BaseCollector | None = None#
the collector to use for test episode collection (test steps); if None, perform no test steps.
- test_step_num_episodes: int = 1#
the number of episodes to collect in each test step.
- training_fn: Callable[[int, int], None] | None = None#
a callback function which is called at the beginning of each training step. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.
- test_fn: Callable[[int, int | None], None] | None = None#
a callback function to be called at the beginning of each test step. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.
- stop_fn: Callable[[float], bool] | None = None#
a callback function with signature
f(score: float) -> bool, which is used to decide whether training shall be stopped early based on the score achieved in a test step. The score it receives is computed by thecompute_score_fncallback (which defaults to the mean reward if the function is not provided).Requires test steps to be activated and thus
test_collectorto be set.Note: The function is also used when
test_in_trainis activated (see docstring).
- compute_score_fn: Callable[[CollectStats], float] | None = None#
the callback function to use in order to compute the test batch performance score, which is used to determine what the best model is (score is maximized); if None, use the mean reward.
- save_best_fn: Callable[[Algorithm], None] | None = None#
the callback function to call in order to save the best model whenever a new best score (see
compute_score_fn) is achieved in a test step. It should have the signaturef(algorithm: Algorithm) -> None.
- save_checkpoint_fn: Callable[[int, int, int], str] | None = None#
the callback function with which to save checkpoint data after each training step, which can save whatever data is desired to a file and returns the path of the file. Signature:
f(epoch: int, env_step: int, gradient_step: int) -> str.
- resume_from_log: bool = False#
whether to load env_step/gradient_step and other metadata from the existing log, which is given in
logger.
- multi_agent_return_reduction: Callable[[ndarray], ndarray] | None = None#
a function with signature
f(returns: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), which is used in multi-agent RL. We need to return a single scalar for each episode’s return to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the return achieved by agent 1 or the average return over all agents.
- logger: BaseLogger | None = None#
the logger with which to log statistics during training/testing/updating. To not log anything, use None.
- Relevant step types for logger update intervals:
update_interval: update step
training_interval: env step
test_interval: env step
- verbose: bool = True#
whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging Python module).
- show_progress: bool = True#
whether to display a progress bars during training.
- class OnlineTrainerParams(*, max_epochs: int = 100, epoch_num_steps: int = 30000, test_collector: BaseCollector | None = None, test_step_num_episodes: int = 1, training_fn: collections.abc.Callable[[int, int], None] | None = None, test_fn: collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: collections.abc.Callable[[float], bool] | None = None, compute_score_fn: collections.abc.Callable[[CollectStats], float] | None = None, save_best_fn: collections.abc.Callable[['Algorithm'], None] | None = None, save_checkpoint_fn: collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, multi_agent_return_reduction: collections.abc.Callable[[numpy.ndarray], numpy.ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, training_collector: BaseCollector, collection_step_num_env_steps: int | None = 2048, collection_step_num_episodes: int | None = None, test_in_training: bool = False)[source]#
Bases:
TrainerParams- training_collector: BaseCollector#
the collector with which to gather new data for training in each training step
- collection_step_num_env_steps: int | None = 2048#
the number of environment steps/transitions to collect in each collection step before the network update within each training step.
This is mutually exclusive with
collection_step_num_episodes, and one of the two must be set.Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same (non-zero) number of transitions. Specifically, if this is set to n and m training environments are used, then the total number of transitions collected per collection step is ceil(n / m) * m =: c.
See
max_epochsfor information on the total number of environment steps being collected during training.
- collection_step_num_episodes: int | None = None#
the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected.
This is mutually exclusive with
collection_step_num_env_steps, and one of the two must be set.
- test_in_training: bool = False#
Whether to apply a test step within a training step depending on the early stopping criterion (given by
stop_fn) being satisfied based on the data collected within the training step. Specifically, after each collect step, we check whether the early stopping criterion (stop_fn) would be satisfied by data we collected (provided that at least one episode was indeed completed, such that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step (collectingtest_step_num_episodesepisodes in order to evaluate performance), and if the early stopping criterion is also satisfied based on the test data, we stop training early.
- class OnPolicyTrainerParams(*, max_epochs: int = 100, epoch_num_steps: int = 30000, test_collector: BaseCollector | None = None, test_step_num_episodes: int = 1, training_fn: collections.abc.Callable[[int, int], None] | None = None, test_fn: collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: collections.abc.Callable[[float], bool] | None = None, compute_score_fn: collections.abc.Callable[[CollectStats], float] | None = None, save_best_fn: collections.abc.Callable[['Algorithm'], None] | None = None, save_checkpoint_fn: collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, multi_agent_return_reduction: collections.abc.Callable[[numpy.ndarray], numpy.ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, training_collector: BaseCollector, collection_step_num_env_steps: int | None = 2048, collection_step_num_episodes: int | None = None, test_in_training: bool = False, batch_size: int | None = 64, update_step_num_repetitions: int = 1)[source]#
Bases:
OnlineTrainerParams- batch_size: int | None = 64#
Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, a form of regularization). Set
batch_size=Nonefor the full buffer that was collected within the training step to be used for the gradient update (no mini-batching).
- update_step_num_repetitions: int = 1#
controls, within one update step of an on-policy algorithm, the number of times the full collected data is applied for gradient updates, i.e. if the parameter is 5, then the collected data shall be used five times to update the policy within the same update step.
- class OffPolicyTrainerParams(*, max_epochs: int = 100, epoch_num_steps: int = 30000, test_collector: BaseCollector | None = None, test_step_num_episodes: int = 1, training_fn: collections.abc.Callable[[int, int], None] | None = None, test_fn: collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: collections.abc.Callable[[float], bool] | None = None, compute_score_fn: collections.abc.Callable[[CollectStats], float] | None = None, save_best_fn: collections.abc.Callable[['Algorithm'], None] | None = None, save_checkpoint_fn: collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, multi_agent_return_reduction: collections.abc.Callable[[numpy.ndarray], numpy.ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, training_collector: BaseCollector, collection_step_num_env_steps: int | None = 2048, collection_step_num_episodes: int | None = None, test_in_training: bool = False, batch_size: int = 64, update_step_num_gradient_steps_per_sample: float = 1.0)[source]#
Bases:
OnlineTrainerParams- batch_size: int = 64#
the the number of environment steps/transitions to sample from the buffer for a gradient update.
- update_step_num_gradient_steps_per_sample: float = 1.0#
the number of gradient steps to perform per sample collected (see
collection_step_num_env_steps). Specifically, if this is set to u and the number of samples collected in the preceding collection step is n, then round(u * n) gradient steps will be performed.
- class OfflineTrainerParams(*, max_epochs: int = 100, epoch_num_steps: int = 30000, test_collector: BaseCollector | None = None, test_step_num_episodes: int = 1, training_fn: collections.abc.Callable[[int, int], None] | None = None, test_fn: collections.abc.Callable[[int, int | None], None] | None = None, stop_fn: collections.abc.Callable[[float], bool] | None = None, compute_score_fn: collections.abc.Callable[[CollectStats], float] | None = None, save_best_fn: collections.abc.Callable[['Algorithm'], None] | None = None, save_checkpoint_fn: collections.abc.Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, multi_agent_return_reduction: collections.abc.Callable[[numpy.ndarray], numpy.ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, buffer: ReplayBuffer, batch_size: int = 64)[source]#
Bases:
TrainerParams- buffer: ReplayBuffer#
the replay buffer with environment steps to use as training data for offline learning. This buffer will be pre-processed using the RL algorithm’s pre-processing function (if any) before training.
- batch_size: int = 64#
the number of environment steps/transitions to sample from the buffer for a gradient update.
- class Trainer(algorithm: TAlgorithm, params: TTrainerParams)[source]#
Bases:
Generic[TAlgorithm,TTrainerParams],ABCBase class for trainers in Tianshou, which orchestrate the training process and call upon an RL algorithm’s specific network updating logic to perform the actual gradient updates.
The base class already implements the fundamental epoch logic and fully implements the test step logic, which is common to all trainers. The training step logic is left to be implemented by subclasses.
- reset(reset_collectors: bool = True, reset_collector_buffers: bool = False) None[source]#
Initializes the training process.
- Parameters:
reset_collectors – whether to reset the collectors prior to starting the training process. Specifically, this will reset the environments in the collectors (starting new episodes), and the statistics stored in the collector. Whether the contained buffers will be reset/cleared is determined by the reset_buffer parameter.
reset_collector_buffers – whether, for the case where the collectors are reset, to reset/clear the contained buffers as well. This has no effect if reset_collectors is False.
- execute_epoch() EpochStats[source]#
- run(reset_collectors: bool = True, reset_collector_buffers: bool = False) InfoStats[source]#
Runs the training process with the configuration given at construction.
- Parameters:
reset_collectors – whether to reset the collectors prior to starting the training process. Specifically, this will reset the environments in the collectors (starting new episodes), and the statistics stored in the collector. Whether the contained buffers will be reset/cleared is determined by the reset_buffer parameter.
reset_collector_buffers – whether, for the case where the collectors are reset, to reset/clear the contained buffers as well. This has no effect if reset_collectors is False.
- class OfflineTrainer(algorithm: OfflineAlgorithm, params: OfflineTrainerParams)[source]#
Bases:
Trainer[OfflineAlgorithm,OfflineTrainerParams]An offline trainer, which samples mini-batches from a given buffer and passes them to the algorithm’s update function.
- class OnlineTrainer(algorithm: TAlgorithm, params: TOnlineTrainerParams)[source]#
Bases:
Trainer[TAlgorithm,TOnlineTrainerParams],Generic[TAlgorithm,TOnlineTrainerParams],ABCAn online trainer, which collects data from the environment in each training step and uses the collected data to perform an update step, the nature of which is to be defined in subclasses.
- reset(reset_collectors: bool = True, reset_collector_buffers: bool = False) None[source]#
Initializes the training process.
- Parameters:
reset_collectors – whether to reset the collectors prior to starting the training process. Specifically, this will reset the environments in the collectors (starting new episodes), and the statistics stored in the collector. Whether the contained buffers will be reset/cleared is determined by the reset_buffer parameter.
reset_collector_buffers – whether, for the case where the collectors are reset, to reset/clear the contained buffers as well. This has no effect if reset_collectors is False.
- class OffPolicyTrainer(algorithm: TAlgorithm, params: TOnlineTrainerParams)[source]#
Bases:
OnlineTrainer[OffPolicyAlgorithm,OffPolicyTrainerParams]An off-policy trainer, which samples mini-batches from the buffer of collected data and passes them to algorithm’s update function.
The algorithm’s update method is expected to not perform additional mini-batching but just update model parameters from the received mini-batch.
- class OnPolicyTrainer(algorithm: TAlgorithm, params: TOnlineTrainerParams)[source]#
Bases:
OnlineTrainer[OnPolicyAlgorithm,OnPolicyTrainerParams]An on-policy trainer, which passes the entire buffer to the algorithm’s update methods and resets the buffer thereafter.
Note that it is expected that the update method of the algorithm will perform batching when using this trainer.