base#
Source code: tianshou/trainer/base.py
- class BaseTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#
Bases:
ABCAn iterator base class for trainers.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
- Parameters:
policy – an instance of the
BasePolicyclass.batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.
train_collector – the collector used for training.
test_collector – the collector used for testing. If it’s None, then no testing will be performed.
buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the
process_buffermethod, the replay buffer will be pre-processed before training.max_epoch – the maximum number of epochs for training. The training process might be finished before reaching
max_epochifstop_fnis set.step_per_epoch – the number of transitions collected per epoch.
repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms
episode_per_test – the number of episodes for one policy evaluation.
update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.
episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.
train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.
save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature
f(policy: BasePolicy) -> None.save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature
f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.
stop_fn – a function with signature
f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.reward_metric – a function with signature
f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.
verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).
show_progress – whether to display a progress bar when training.
test_in_train – whether to test in the training phase.
- reset(reset_collectors: bool = True, reset_buffer: bool = False) None[source]#
Initialize or reset the instance to yield a new iterator from zero.
- test_step() tuple[CollectStats, bool][source]#
Perform one testing step.
- training_step() tuple[CollectStatsBase, TrainingStats | None, bool][source]#
Perform one training iteration.
A training iteration includes collecting data (for online RL), determining whether to stop training, and performing a policy update if the training iteration should continue.
- Returns:
the iteration’s collect stats, training stats, and a flag indicating whether to stop training. If training is to be stopped, no gradient steps will be performed and the training stats will be None.
- abstract policy_update_fn(collect_stats: CollectStatsBase) TrainingStats[source]#
Policy update function for different trainer implementation.
- Parameters:
collect_stats – provides info about the most recent collection. In the offline case, this will contain stats of the whole dataset
- run(reset_collectors: bool = True, reset_buffer: bool = False) InfoStats[source]#
Consume iterator.
See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque).
- Parameters:
reset_collectors – whether to reset the collectors prior to starting the training process. Specifically, this will reset the environments in the collectors (starting new episodes), and the statistics stored in the collector. Whether the contained buffers will be reset/cleared is determined by the reset_buffer parameter.
reset_collector_buffers – whether, for the case where the collectors are reset, to reset/clear the contained buffers as well. This has no effect if reset_collectors is False.
- class OfflineTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#
Bases:
BaseTrainerOffline trainer, samples mini-batches from buffer and passes them to update.
Uses a buffer directly and usually does not have a collector. An iterator class for offline trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
The “step” in offline trainer means a gradient step.
Example usage:
trainer = OfflineTrainer(...) for epoch, epoch_stat, info in trainer: print("Epoch:", epoch) print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info)
epoch int: the epoch number
epoch_stat dict: a large collection of metrics of the current epoch
info dict: result returned from
gather_info()
You can even iterate on several trainers at the same time:
trainer1 = OfflineTrainer(...) trainer2 = OfflineTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...)
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
- Parameters:
policy – an instance of the
BasePolicyclass.batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.
train_collector – the collector used for training.
test_collector – the collector used for testing. If it’s None, then no testing will be performed.
buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the
process_buffermethod, the replay buffer will be pre-processed before training.max_epoch – the maximum number of epochs for training. The training process might be finished before reaching
max_epochifstop_fnis set.step_per_epoch – the number of transitions collected per epoch.
repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms
episode_per_test – the number of episodes for one policy evaluation.
update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.
episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.
train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.
save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature
f(policy: BasePolicy) -> None.save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature
f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.
stop_fn – a function with signature
f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.reward_metric – a function with signature
f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.
verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).
show_progress – whether to display a progress bar when training.
test_in_train – whether to test in the training phase.
- policy_update_fn(collect_stats: CollectStatsBase | None = None) TrainingStats[source]#
Perform one off-line policy update.
- class OffpolicyTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#
Bases:
BaseTrainerOffpolicy trainer, samples mini-batches from buffer and passes them to update.
Note that with this trainer, it is expected that the policy’s learn method does not perform additional mini-batching but just updates params from the received mini-batch. An iterator class for offpolicy trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
The “step” in offpolicy trainer means an environment step (a.k.a. transition).
Example usage:
trainer = OffpolicyTrainer(...) for epoch, epoch_stat, info in trainer: print("Epoch:", epoch) print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info)
epoch int: the epoch number
epoch_stat dict: a large collection of metrics of the current epoch
info dict: result returned from
gather_info()
You can even iterate on several trainers at the same time:
trainer1 = OffpolicyTrainer(...) trainer2 = OffpolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...)
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
- Parameters:
policy – an instance of the
BasePolicyclass.batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.
train_collector – the collector used for training.
test_collector – the collector used for testing. If it’s None, then no testing will be performed.
buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the
process_buffermethod, the replay buffer will be pre-processed before training.max_epoch – the maximum number of epochs for training. The training process might be finished before reaching
max_epochifstop_fnis set.step_per_epoch – the number of transitions collected per epoch.
repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms
episode_per_test – the number of episodes for one policy evaluation.
update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.
episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.
train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.
save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature
f(policy: BasePolicy) -> None.save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature
f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.
stop_fn – a function with signature
f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.reward_metric – a function with signature
f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.
verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).
show_progress – whether to display a progress bar when training.
test_in_train – whether to test in the training phase.
- policy_update_fn(collect_stats: CollectStatsBase) TrainingStats[source]#
Perform update_per_step * n_collected_steps gradient steps by sampling mini-batches from the buffer.
- Parameters:
collect_stats – the
TrainingStatsinstance returned by the last gradient step. Some values in it will be replaced by their moving averages.
- class OnpolicyTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#
Bases:
BaseTrainerAn iterator class for onpolicy trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
The “step” in onpolicy trainer means an environment step (a.k.a. transition).
Example usage:
trainer = OnpolicyTrainer(...) for epoch, epoch_stat, info in trainer: print("Epoch:", epoch) print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info)
epoch int: the epoch number
epoch_stat dict: a large collection of metrics of the current epoch
info dict: result returned from
gather_info()
You can even iterate on several trainers at the same time:
trainer1 = OnpolicyTrainer(...) trainer2 = OnpolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...)
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.
- Parameters:
policy – an instance of the
BasePolicyclass.batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.
train_collector – the collector used for training.
test_collector – the collector used for testing. If it’s None, then no testing will be performed.
buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the
process_buffermethod, the replay buffer will be pre-processed before training.max_epoch – the maximum number of epochs for training. The training process might be finished before reaching
max_epochifstop_fnis set.step_per_epoch – the number of transitions collected per epoch.
repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms
episode_per_test – the number of episodes for one policy evaluation.
update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.
episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.
train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature
f(num_epoch: int, step_idx: int) -> None.compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.
save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature
f(policy: BasePolicy) -> None.save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature
f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.
stop_fn – a function with signature
f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.reward_metric – a function with signature
f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.
verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).
show_progress – whether to display a progress bar when training.
test_in_train – whether to test in the training phase.
- policy_update_fn(result: CollectStatsBase | None = None) TrainingStats[source]#
Perform one on-policy update by passing the entire buffer to the policy’s update method.