agent#


class AgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#

Bases: ABC, ToStringMixin

Factory for the creation of an agent’s policy, its trainer as well as collectors.

create_train_test_collector(policy: BasePolicy, envs: Environments, reset_collectors: bool = True) tuple[BaseCollector, BaseCollector][source]#
Parameters:
  • policy

  • envs

  • reset_collectors – Whether to reset the collectors before returning them. Setting to True means that the envs will be reset as well.

Returns:

set_policy_wrapper_factory(policy_wrapper_factory: PolicyWrapperFactory | None) None[source]#
set_trainer_callbacks(callbacks: TrainerCallbacks) None[source]#
create_policy(envs: Environments, device: str | device) BasePolicy[source]#
abstract create_trainer(world: World, policy_persistence: PolicyPersistence) BaseTrainer[source]#
class OnPolicyAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#

Bases: AgentFactory, ABC

create_trainer(world: World, policy_persistence: PolicyPersistence) OnpolicyTrainer[source]#
class OffPolicyAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#

Bases: AgentFactory, ABC

create_trainer(world: World, policy_persistence: PolicyPersistence) OffpolicyTrainer[source]#
class RandomActionAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#

Bases: OnPolicyAgentFactory

class PGAgentFactory(params: PGParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactory)[source]#

Bases: OnPolicyAgentFactory

class ActorCriticAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#

Bases: Generic[TActorCriticParams, TPolicy], OnPolicyAgentFactory, ABC

create_actor_critic_module_opt(envs: Environments, device: str | device, lr: float) ActorCriticOpt[source]#
class A2CAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#

Bases: ActorCriticAgentFactory[A2CParams, A2CPolicy]

class PPOAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#

Bases: ActorCriticAgentFactory[PPOParams, PPOPolicy]

class NPGAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#

Bases: ActorCriticAgentFactory[NPGParams, NPGPolicy]

class TRPOAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#

Bases: ActorCriticAgentFactory[TRPOParams, TRPOPolicy]

class DiscreteCriticOnlyAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#

Bases: OffPolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy]

class DQNAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#

Bases: DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]

class IQNAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#

Bases: DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]

class DDPGAgentFactory(params: DDPGParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#

Bases: OffPolicyAgentFactory

class REDQAgentFactory(params: REDQParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactory)[source]#

Bases: OffPolicyAgentFactory

class ActorDualCriticsAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#

Bases: OffPolicyAgentFactory, Generic[TActorDualCriticsParams, TPolicy], ABC

class SACAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#

Bases: ActorDualCriticsAgentFactory[SACParams, SACPolicy]

class DiscreteSACAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#

Bases: ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]

class TD3AgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#

Bases: ActorDualCriticsAgentFactory[TD3Params, TD3Policy]