algorithm#
Source code: tianshou/highlevel/algorithm.py
- class AlgorithmFactory(training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
ABC,ToStringMixin,Generic[TTrainingConfig]Factory for the creation of an
Algorithminstance, its policy, trainer as well as collectors.- set_collector_factory(collector_factory: CollectorFactory) None[source]#
- create_train_test_collectors(algorithm: Algorithm, envs: Environments, reset_collectors: bool = True) tuple[BaseCollector, BaseCollector][source]#
Creates the collectors for training and test environments.
- Parameters:
algorithm – the algorithm
envs – the environments wrapper
reset_collectors – Whether to reset the collectors before returning them. Setting to True means that the envs will be reset as well.
- Returns:
a tuple of (training_collector, test_collector)
- set_policy_wrapper_factory(policy_wrapper_factory: AlgorithmWrapperFactory | None) None[source]#
- set_trainer_callbacks(callbacks: TrainerCallbacks) None[source]#
- create_algorithm(envs: Environments, device: str | device) Algorithm[source]#
- abstract create_trainer(world: World, policy_persistence: PolicyPersistence) Trainer[source]#
- class OnPolicyAlgorithmFactory(training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
AlgorithmFactory[OnPolicyTrainingConfig],ABC- create_trainer(world: World, policy_persistence: PolicyPersistence) OnPolicyTrainer[source]#
- class OffPolicyAlgorithmFactory(training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
AlgorithmFactory[OffPolicyTrainingConfig],ABC- create_trainer(world: World, policy_persistence: PolicyPersistence) OffPolicyTrainer[source]#
- class ReinforceAlgorithmFactory(params: ReinforceParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
OnPolicyAlgorithmFactory
- class ActorCriticOnPolicyAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#
Bases:
OnPolicyAlgorithmFactory,Generic[TActorCriticParams,TAlgorithm]
- class A2CAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#
- class PPOAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#
- class NPGAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#
- class TRPOAlgorithmFactory(params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory)[source]#
Bases:
ActorCriticOnPolicyAlgorithmFactory[TRPOParams,TRPO]
- class DiscreteCriticOnlyOffPolicyAlgorithmFactory(params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
OffPolicyAlgorithmFactory,Generic[TDiscreteCriticOnlyParams,TAlgorithm]
- class DQNAlgorithmFactory(params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
DiscreteCriticOnlyOffPolicyAlgorithmFactory[DQNParams,DQN]
- class IQNAlgorithmFactory(params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
DiscreteCriticOnlyOffPolicyAlgorithmFactory[IQNParams,IQN]
- class DDPGAlgorithmFactory(params: DDPGParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
OffPolicyAlgorithmFactory
- class REDQAlgorithmFactory(params: REDQParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
OffPolicyAlgorithmFactory
- class ActorDualCriticsOffPolicyAlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
OffPolicyAlgorithmFactory,Generic[TActorDualCriticsParams,TAlgorithm,TPolicy]
- class SACAlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
ActorDualCriticsOffPolicyAlgorithmFactory[SACParams,SAC,SACPolicy]
- class DiscreteSACAlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams,DiscreteSAC,DiscreteSACPolicy]
- class TD3AlgorithmFactory(params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory)[source]#
Bases:
ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params,TD3,ContinuousDeterministicPolicy]