trainer#


class TrainingContext(policy: TPolicy, envs: Environments, logger: BaseLogger)[source]#

Bases: object

class EpochTrainCallback[source]#

Bases: ToStringMixin, ABC

Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase of each epoch.

abstract callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
get_trainer_fn(context: TrainingContext) Callable[[int, int], None][source]#
class EpochTestCallback[source]#

Bases: ToStringMixin, ABC

Callback which is called at the beginning of the test phase of each epoch.

abstract callback(epoch: int, env_step: int | None, context: TrainingContext) None[source]#
get_trainer_fn(context: TrainingContext) Callable[[int, int | None], None][source]#
class EpochStopCallback[source]#

Bases: ToStringMixin, ABC

Callback which is called after the test phase of each epoch in order to determine whether training should stop early.

abstract should_stop(mean_rewards: float, context: TrainingContext) bool[source]#

Determines whether training should stop.

Parameters:
  • mean_rewards – the average undiscounted returns of the testing result

  • context – the training context

Returns:

True if the goal has been reached and training should stop, False otherwise

get_trainer_fn(context: TrainingContext) Callable[[float], bool][source]#
class TrainerCallbacks(epoch_train_callback: EpochTrainCallback | None = None, epoch_test_callback: EpochTestCallback | None = None, epoch_stop_callback: EpochStopCallback | None = None)[source]#

Bases: object

Container for callbacks used during training.

epoch_train_callback: EpochTrainCallback | None = None#
epoch_test_callback: EpochTestCallback | None = None#
epoch_stop_callback: EpochStopCallback | None = None#
class EpochTrainCallbackDQNSetEps(eps_test: float)[source]#

Bases: EpochTrainCallback

Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch.

callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
class EpochTrainCallbackDQNEpsLinearDecay(eps_train: float, eps_train_final: float, decay_steps: int = 1000000)[source]#

Bases: EpochTrainCallback

Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch, using a linear decay in the first decay_steps steps.

callback(epoch: int, env_step: int, context: TrainingContext) None[source]#
class EpochTestCallbackDQNSetEps(eps_test: float)[source]#

Bases: EpochTestCallback

Sets the epsilon value for DQN-based policies at the beginning of the test stage in each epoch.

callback(epoch: int, env_step: int | None, context: TrainingContext) None[source]#
class EpochStopCallbackRewardThreshold(threshold: float | None = None)[source]#

Bases: EpochStopCallback

Stops training once the mean rewards exceed the given reward threshold or the threshold that is specified in the gymnasium environment (i.e. env.spec.reward_threshold).

Parameters:

threshold – the reward threshold beyond which to stop training. If it is None, will use threshold specified by the environment, i.e. env.spec.reward_threshold.

should_stop(mean_rewards: float, context: TrainingContext) bool[source]#

Determines whether training should stop.

Parameters:
  • mean_rewards – the average undiscounted returns of the testing result

  • context – the training context

Returns:

True if the goal has been reached and training should stop, False otherwise