trainer#
Source code: tianshou/highlevel/trainer.py
- class TrainingContext(policy: TPolicy, envs: Environments, logger: BaseLogger)[source]#
Bases:
object
- class EpochTrainCallback[source]#
Bases:
ToStringMixin
,ABC
Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase of each epoch.
- abstract callback(epoch: int, env_step: int, context: TrainingContext) None [source]#
- get_trainer_fn(context: TrainingContext) Callable[[int, int], None] [source]#
- class EpochTestCallback[source]#
Bases:
ToStringMixin
,ABC
Callback which is called at the beginning of the test phase of each epoch.
- abstract callback(epoch: int, env_step: int | None, context: TrainingContext) None [source]#
- get_trainer_fn(context: TrainingContext) Callable[[int, int | None], None] [source]#
- class EpochStopCallback[source]#
Bases:
ToStringMixin
,ABC
Callback which is called after the test phase of each epoch in order to determine whether training should stop early.
- abstract should_stop(mean_rewards: float, context: TrainingContext) bool [source]#
Determines whether training should stop.
- Parameters:
mean_rewards – the average undiscounted returns of the testing result
context – the training context
- Returns:
True if the goal has been reached and training should stop, False otherwise
- get_trainer_fn(context: TrainingContext) Callable[[float], bool] [source]#
- class TrainerCallbacks(epoch_train_callback: EpochTrainCallback | None = None, epoch_test_callback: EpochTestCallback | None = None, epoch_stop_callback: EpochStopCallback | None = None)[source]#
Bases:
object
Container for callbacks used during training.
- epoch_train_callback: EpochTrainCallback | None = None#
- epoch_test_callback: EpochTestCallback | None = None#
- epoch_stop_callback: EpochStopCallback | None = None#
- class EpochTrainCallbackDQNSetEps(eps_test: float)[source]#
Bases:
EpochTrainCallback
Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch.
- callback(epoch: int, env_step: int, context: TrainingContext) None [source]#
- class EpochTrainCallbackDQNEpsLinearDecay(eps_train: float, eps_train_final: float, decay_steps: int = 1000000)[source]#
Bases:
EpochTrainCallback
Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch, using a linear decay in the first decay_steps steps.
- callback(epoch: int, env_step: int, context: TrainingContext) None [source]#
- class EpochTestCallbackDQNSetEps(eps_test: float)[source]#
Bases:
EpochTestCallback
Sets the epsilon value for DQN-based policies at the beginning of the test stage in each epoch.
- callback(epoch: int, env_step: int | None, context: TrainingContext) None [source]#
- class EpochStopCallbackRewardThreshold(threshold: float | None = None)[source]#
Bases:
EpochStopCallback
Stops training once the mean rewards exceed the given reward threshold or the threshold that is specified in the gymnasium environment (i.e. env.spec.reward_threshold).
- Parameters:
threshold – the reward threshold beyond which to stop training. If it is None, will use threshold specified by the environment, i.e. env.spec.reward_threshold.
- should_stop(mean_rewards: float, context: TrainingContext) bool [source]#
Determines whether training should stop.
- Parameters:
mean_rewards – the average undiscounted returns of the testing result
context – the training context
- Returns:
True if the goal has been reached and training should stop, False otherwise