agent#
Source code: tianshou/highlevel/agent.py
- class AgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#
Bases:
ABC
,ToStringMixin
Factory for the creation of an agent’s policy, its trainer as well as collectors.
- create_train_test_collector(policy: BasePolicy, envs: Environments, reset_collectors: bool = True) tuple[BaseCollector, BaseCollector] [source]#
- Parameters:
policy –
envs –
reset_collectors – Whether to reset the collectors before returning them. Setting to True means that the envs will be reset as well.
- Returns:
- set_policy_wrapper_factory(policy_wrapper_factory: PolicyWrapperFactory | None) None [source]#
- set_trainer_callbacks(callbacks: TrainerCallbacks) None [source]#
- create_policy(envs: Environments, device: str | device) BasePolicy [source]#
- abstract create_trainer(world: World, policy_persistence: PolicyPersistence) BaseTrainer [source]#
- class OnPolicyAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#
Bases:
AgentFactory
,ABC
- create_trainer(world: World, policy_persistence: PolicyPersistence) OnpolicyTrainer [source]#
- class OffPolicyAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#
Bases:
AgentFactory
,ABC
- create_trainer(world: World, policy_persistence: PolicyPersistence) OffpolicyTrainer [source]#
- class RandomActionAgentFactory(sampling_config: SamplingConfig, optim_factory: OptimizerFactory)[source]#
Bases:
OnPolicyAgentFactory
- class PGAgentFactory(params: PGParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactory)[source]#
Bases:
OnPolicyAgentFactory
- class ActorCriticAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
Bases:
Generic
[TActorCriticParams
,TPolicy
],OnPolicyAgentFactory
,ABC
- create_actor_critic_module_opt(envs: Environments, device: str | device, lr: float) ActorCriticOpt [source]#
- class A2CAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
Bases:
ActorCriticAgentFactory
[A2CParams
,A2CPolicy
]
- class PPOAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
Bases:
ActorCriticAgentFactory
[PPOParams
,PPOPolicy
]
- class NPGAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
Bases:
ActorCriticAgentFactory
[NPGParams
,NPGPolicy
]
- class TRPOAgentFactory(params: TActorCriticParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactory)[source]#
Bases:
ActorCriticAgentFactory
[TRPOParams
,TRPOPolicy
]
- class DiscreteCriticOnlyAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#
Bases:
OffPolicyAgentFactory
,Generic
[TDiscreteCriticOnlyParams
,TPolicy
]
- class DQNAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#
- class IQNAgentFactory(params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactory)[source]#
- class DDPGAgentFactory(params: DDPGParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
Bases:
OffPolicyAgentFactory
- class REDQAgentFactory(params: REDQParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactory)[source]#
Bases:
OffPolicyAgentFactory
- class ActorDualCriticsAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
Bases:
OffPolicyAgentFactory
,Generic
[TActorDualCriticsParams
,TPolicy
],ABC
- class SACAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
- class DiscreteSACAgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#
Bases:
ActorDualCriticsAgentFactory
[DiscreteSACParams
,DiscreteSACPolicy
]
- class TD3AgentFactory(params: TActorDualCriticsParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactory)[source]#