base#


class BaseTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#

Bases: ABC

An iterator base class for trainers.

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

Parameters:
  • policy – an instance of the BasePolicy class.

  • batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.

  • train_collector – the collector used for training.

  • test_collector – the collector used for testing. If it’s None, then no testing will be performed.

  • buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the process_buffer method, the replay buffer will be pre-processed before training.

  • max_epoch – the maximum number of epochs for training. The training process might be finished before reaching max_epoch if stop_fn is set.

  • step_per_epoch – the number of transitions collected per epoch.

  • repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms

  • episode_per_test – the number of episodes for one policy evaluation.

  • update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).

  • step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.

  • episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.

  • train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.

  • save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature f(policy: BasePolicy) -> None.

  • save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.

  • resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.

  • stop_fn – a function with signature f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.

  • reward_metric – a function with signature f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.

  • logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.

  • verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).

  • show_progress – whether to display a progress bar when training.

  • test_in_train – whether to test in the training phase.

static gen_doc(learning_type: str) str[source]#

Document string for subclass trainer.

reset(reset_collectors: bool = True, reset_buffer: bool = False) None[source]#

Initialize or reset the instance to yield a new iterator from zero.

test_step() tuple[CollectStats, bool][source]#

Perform one testing step.

training_step() tuple[CollectStatsBase, TrainingStats | None, bool][source]#

Perform one training iteration.

A training iteration includes collecting data (for online RL), determining whether to stop training, and performing a policy update if the training iteration should continue.

Returns:

the iteration’s collect stats, training stats, and a flag indicating whether to stop training. If training is to be stopped, no gradient steps will be performed and the training stats will be None.

abstract policy_update_fn(collect_stats: CollectStatsBase) TrainingStats[source]#

Policy update function for different trainer implementation.

Parameters:

collect_stats – provides info about the most recent collection. In the offline case, this will contain stats of the whole dataset

run(reset_prior_to_run: bool = True, reset_buffer: bool = False) InfoStats[source]#

Consume iterator.

See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque).

Parameters:
  • reset_prior_to_run – whether to reset collectors prior to run

  • reset_buffer – only has effect if reset_prior_to_run is True. Then it will also reset the buffer. This is usually not necessary, use with caution.

class OfflineTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#

Bases: BaseTrainer

Offline trainer, samples mini-batches from buffer and passes them to update.

Uses a buffer directly and usually does not have a collector. An iterator class for offline trainer procedure.

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

The “step” in offline trainer means a gradient step.

Example usage:

trainer = OfflineTrainer(...)
for epoch, epoch_stat, info in trainer:
    print("Epoch:", epoch)
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
  • epoch int: the epoch number

  • epoch_stat dict: a large collection of metrics of the current epoch

  • info dict: result returned from gather_info()

You can even iterate on several trainers at the same time:

trainer1 = OfflineTrainer(...)
trainer2 = OfflineTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

Parameters:
  • policy – an instance of the BasePolicy class.

  • batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.

  • train_collector – the collector used for training.

  • test_collector – the collector used for testing. If it’s None, then no testing will be performed.

  • buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the process_buffer method, the replay buffer will be pre-processed before training.

  • max_epoch – the maximum number of epochs for training. The training process might be finished before reaching max_epoch if stop_fn is set.

  • step_per_epoch – the number of transitions collected per epoch.

  • repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms

  • episode_per_test – the number of episodes for one policy evaluation.

  • update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).

  • step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.

  • episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.

  • train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.

  • save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature f(policy: BasePolicy) -> None.

  • save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.

  • resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.

  • stop_fn – a function with signature f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.

  • reward_metric – a function with signature f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.

  • logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.

  • verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).

  • show_progress – whether to display a progress bar when training.

  • test_in_train – whether to test in the training phase.

policy_update_fn(collect_stats: CollectStatsBase | None = None) TrainingStats[source]#

Perform one off-line policy update.

class OffpolicyTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#

Bases: BaseTrainer

Offpolicy trainer, samples mini-batches from buffer and passes them to update.

Note that with this trainer, it is expected that the policy’s learn method does not perform additional mini-batching but just updates params from the received mini-batch. An iterator class for offpolicy trainer procedure.

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

The “step” in offpolicy trainer means an environment step (a.k.a. transition).

Example usage:

trainer = OffpolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
    print("Epoch:", epoch)
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
  • epoch int: the epoch number

  • epoch_stat dict: a large collection of metrics of the current epoch

  • info dict: result returned from gather_info()

You can even iterate on several trainers at the same time:

trainer1 = OffpolicyTrainer(...)
trainer2 = OffpolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

Parameters:
  • policy – an instance of the BasePolicy class.

  • batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.

  • train_collector – the collector used for training.

  • test_collector – the collector used for testing. If it’s None, then no testing will be performed.

  • buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the process_buffer method, the replay buffer will be pre-processed before training.

  • max_epoch – the maximum number of epochs for training. The training process might be finished before reaching max_epoch if stop_fn is set.

  • step_per_epoch – the number of transitions collected per epoch.

  • repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms

  • episode_per_test – the number of episodes for one policy evaluation.

  • update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).

  • step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.

  • episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.

  • train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.

  • save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature f(policy: BasePolicy) -> None.

  • save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.

  • resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.

  • stop_fn – a function with signature f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.

  • reward_metric – a function with signature f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.

  • logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.

  • verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).

  • show_progress – whether to display a progress bar when training.

  • test_in_train – whether to test in the training phase.

policy_update_fn(collect_stats: CollectStatsBase) TrainingStats[source]#

Perform update_per_step * n_collected_steps gradient steps by sampling mini-batches from the buffer.

Parameters:

collect_stats – the TrainingStats instance returned by the last gradient step. Some values in it will be replaced by their moving averages.

class OnpolicyTrainer(policy: BasePolicy, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, episode_per_test: int | None = None, update_per_step: float = 1.0, step_per_collect: int | None = None, episode_per_collect: int | None = None, train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[ndarray], ndarray] | None = None, logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True)[source]#

Bases: BaseTrainer

An iterator class for onpolicy trainer procedure.

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

The “step” in onpolicy trainer means an environment step (a.k.a. transition).

Example usage:

trainer = OnpolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
    print("Epoch:", epoch)
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
  • epoch int: the epoch number

  • epoch_stat dict: a large collection of metrics of the current epoch

  • info dict: result returned from gather_info()

You can even iterate on several trainers at the same time:

trainer1 = OnpolicyTrainer(...)
trainer2 = OnpolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

Parameters:
  • policy – an instance of the BasePolicy class.

  • batch_size – the batch size of sample data, which is going to feed in the policy network. If None, will use the whole buffer in each gradient step.

  • train_collector – the collector used for training.

  • test_collector – the collector used for testing. If it’s None, then no testing will be performed.

  • buffer – the replay buffer used for off-policy algorithms or for pre-training. If a policy overrides the process_buffer method, the replay buffer will be pre-processed before training.

  • max_epoch – the maximum number of epochs for training. The training process might be finished before reaching max_epoch if stop_fn is set.

  • step_per_epoch – the number of transitions collected per epoch.

  • repeat_per_collect – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. Only used in on-policy algorithms

  • episode_per_test – the number of episodes for one policy evaluation.

  • update_per_step – only used in off-policy algorithms. How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).

  • step_per_collect – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.

  • episode_per_collect – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.

  • train_fn – a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • test_fn – a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature f(num_epoch: int, step_idx: int) -> None.

  • compute_score_fn – Calculate the test batch performance score to determine whether it is the best model, the mean reward will be used as score if not provided.

  • save_best_fn – a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature f(policy: BasePolicy) -> None.

  • save_checkpoint_fn – a function to save training process and return the saved checkpoint path, with the signature f(epoch: int, env_step: int, gradient_step: int) -> str; you can save whatever you want.

  • resume_from_log – resume env_step/gradient_step and other metadata from existing tensorboard log.

  • stop_fn – a function with signature f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.

  • reward_metric – a function with signature f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,), used in multi-agent RL. We need to return a single scalar for each episode’s result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.

  • logger – A logger that logs statistics during training/testing/updating. To not log anything, keep the default logger.

  • verbose – whether to print status information to stdout. If set to False, status information will still be logged (provided that logging is enabled via the logging module).

  • show_progress – whether to display a progress bar when training.

  • test_in_train – whether to test in the training phase.

policy_update_fn(result: CollectStatsBase | None = None) TrainingStats[source]#

Perform one on-policy update by passing the entire buffer to the policy’s update method.