tianshou.trainer¶
-
tianshou.trainer.gather_info(start_time, train_c, test_c, best_reward)[source]¶ A simple wrapper of gathering information from collectors.
- Returns
A dictionary with the following keys:
train_stepthe total collected step of training collector;train_episodethe total collected episode of training collector;train_time/collectorthe time for collecting frames in the training collector;train_time/modelthe time for training models;train_speedthe speed of training (frames per second);test_stepthe total collected step of test collector;test_episodethe total collected episode of test collector;test_timethe time for testing;test_speedthe speed of testing (frames per second);best_rewardthe best reward over the test results;durationthe total elapsed time.
-
tianshou.trainer.test_episode(policy, collector, test_fn, epoch, n_episode)[source]¶ A simple wrapper of testing policy in collector.
-
tianshou.trainer.onpolicy_trainer(policy, train_collector, test_collector, max_epoch, step_per_epoch, collect_per_step, repeat_per_collect, episode_per_test, batch_size, train_fn=None, test_fn=None, stop_fn=None, writer=None, log_interval=1, verbose=True, task='', **kwargs)[source]¶ A wrapper for on-policy trainer procedure.
- Parameters
policy – an instance of the
BasePolicyclass.train_collector (
Collector) – the collector used for training.test_collector (
Collector) – the collector used for testing.max_epoch (int) – the maximum of epochs for training. The training process might be finished before reaching the
max_epoch.step_per_epoch (int) – the number of step for updating policy network in one epoch.
collect_per_step (int) – the number of frames the collector would collect before the network update. In other words, collect some frames and do one policy network update.
repeat_per_collect (int) – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice.
episode_per_test (int or list of ints) – the number of episodes for one policy evaluation.
batch_size (int) – the batch size of sample data, which is going to feed in the policy network.
train_fn (function) – a function receives the current number of epoch index and performs some operations at the beginning of training in this epoch.
test_fn (function) – a function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch.
stop_fn (function) – a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
writer (torch.utils.tensorboard.SummaryWriter) – a TensorBoard SummaryWriter.
log_interval (int) – the log interval of the writer.
verbose (bool) – whether to print the information.
- Returns
See
gather_info().
-
tianshou.trainer.offpolicy_trainer(policy, train_collector, test_collector, max_epoch, step_per_epoch, collect_per_step, episode_per_test, batch_size, train_fn=None, test_fn=None, stop_fn=None, writer=None, log_interval=1, verbose=True, task='', **kwargs)[source]¶ A wrapper for off-policy trainer procedure.
- Parameters
policy – an instance of the
BasePolicyclass.train_collector (
Collector) – the collector used for training.test_collector (
Collector) – the collector used for testing.max_epoch (int) – the maximum of epochs for training. The training process might be finished before reaching the
max_epoch.step_per_epoch (int) – the number of step for updating policy network in one epoch.
collect_per_step (int) – the number of frames the collector would collect before the network update. In other words, collect some frames and do one policy network update.
episode_per_test – the number of episodes for one policy evaluation.
batch_size (int) – the batch size of sample data, which is going to feed in the policy network.
train_fn (function) – a function receives the current number of epoch index and performs some operations at the beginning of training in this epoch.
test_fn (function) – a function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch.
stop_fn (function) – a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
writer (torch.utils.tensorboard.SummaryWriter) – a TensorBoard SummaryWriter.
log_interval (int) – the log interval of the writer.
verbose (bool) – whether to print the information.
- Returns
See
gather_info().