algorithm#


class AlgorithmFactory(training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory)[source]#

Bases: ABC, ToStringMixin, Generic[TTrainingConfig]

Factory for the creation of an Algorithm instance, its policy, trainer as well as collectors.

set_collector_factory(collector_factory: CollectorFactory) None[source]#
create_train_test_collectors(algorithm: Algorithm, envs: Environments, reset_collectors: bool = True) tuple[BaseCollector, BaseCollector][source]#

Creates the collectors for training and test environments.

Parameters:
  • algorithm – the algorithm

  • envs – the environments wrapper

  • reset_collectors – Whether to reset the collectors before returning them. Setting to True means that the envs will be reset as well.

Returns:

a tuple of (training_collector, test_collector)

set_policy_wrapper_factory(policy_wrapper_factory: AlgorithmWrapperFactory | None) None[source]#
set_trainer_callbacks(callbacks: TrainerCallbacks) None[source]#
create_algorithm(envs: Environments, device: str | device) Algorithm[source]#
abstract create_trainer(world: World, policy_persistence: PolicyPersistence) Trainer[source]#
class OnPolicyAlgorithmFactory(training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory)[source]#

Bases: AlgorithmFactory[OnPolicyTrainingConfig], ABC

create_trainer(world: World, policy_persistence: PolicyPersistence) OnPolicyTrainer[source]#
class OffPolicyAlgorithmFactory(training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory)[source]#

Bases: AlgorithmFactory[OffPolicyTrainingConfig], ABC

create_trainer(world: World, policy_persistence: PolicyPersistence) OffPolicyTrainer[source]#
class ReinforceAlgorithmFactory(params: ReinforceParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: OnPolicyAlgorithmFactory

class ActorCriticOnPolicyAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#

Bases: OnPolicyAlgorithmFactory, Generic[TActorCriticParams, TAlgorithm]

class A2CAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#

Bases: ActorCriticOnPolicyAlgorithmFactory[A2CParams, A2C]

class PPOAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#

Bases: ActorCriticOnPolicyAlgorithmFactory[PPOParams, PPO]

class NPGAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#

Bases: ActorCriticOnPolicyAlgorithmFactory[NPGParams, NPG]

class TRPOAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#

Bases: ActorCriticOnPolicyAlgorithmFactory[TRPOParams, TRPO]

class DiscreteCriticOnlyOffPolicyAlgorithmFactory(params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: OffPolicyAlgorithmFactory, Generic[TDiscreteCriticOnlyParams, TAlgorithm]

class DQNAlgorithmFactory(params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: DiscreteCriticOnlyOffPolicyAlgorithmFactory[DQNParams, DQN]

class IQNAlgorithmFactory(params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: DiscreteCriticOnlyOffPolicyAlgorithmFactory[IQNParams, IQN]

class DDPGAlgorithmFactory(params: DDPGParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: OffPolicyAlgorithmFactory

class REDQAlgorithmFactory(params: REDQParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: OffPolicyAlgorithmFactory

class ActorDualCriticsOffPolicyAlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: OffPolicyAlgorithmFactory, Generic[TActorDualCriticsParams, TAlgorithm, TPolicy]

class SACAlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, SACPolicy]

class DiscreteSACAlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, DiscreteSACPolicy]

class TD3AlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#

Bases: ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, ContinuousDeterministicPolicy]