Source code for tianshou.highlevel.algorithm

import logging
import typing
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast

import gymnasium
import torch
from sensai.util.string import ToStringMixin

from tianshou.algorithm import (
    A2C,
    DDPG,
    DQN,
    IQN,
    NPG,
    PPO,
    REDQ,
    SAC,
    TD3,
    TRPO,
    Algorithm,
    DiscreteSAC,
    Reinforce,
)
from tianshou.algorithm.algorithm_base import (
    OffPolicyAlgorithm,
    OnPolicyAlgorithm,
    Policy,
)
from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy
from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy
from tianshou.algorithm.modelfree.iqn import IQNPolicy
from tianshou.algorithm.modelfree.redq import REDQPolicy
from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy
from tianshou.algorithm.modelfree.sac import SACPolicy
from tianshou.data import ReplayBuffer, VectorReplayBuffer
from tianshou.data.collector import BaseCollector
from tianshou.highlevel.config import (
    OffPolicyTrainingConfig,
    OnPolicyTrainingConfig,
    TrainingConfig,
)
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import (
    ActorFactory,
)
from tianshou.highlevel.module.core import (
    ModuleFactory,
    TDevice,
)
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
from tianshou.highlevel.params.algorithm_params import (
    A2CParams,
    DDPGParams,
    DiscreteSACParams,
    DQNParams,
    IQNParams,
    NPGParams,
    Params,
    ParamsMixinActorAndDualCritics,
    ParamsMixinSingleModel,
    ParamTransformerData,
    PPOParams,
    REDQParams,
    ReinforceParams,
    SACParams,
    TD3Params,
    TRPOParams,
)
from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory
from tianshou.highlevel.params.collector import (
    CollectorFactory,
    CollectorFactoryDefault,
)
from tianshou.highlevel.params.optim import OptimizerFactoryFactory
from tianshou.highlevel.persistence import PolicyPersistence
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
from tianshou.highlevel.world import World
from tianshou.trainer import (
    OffPolicyTrainer,
    OffPolicyTrainerParams,
    OnPolicyTrainer,
    OnPolicyTrainerParams,
    Trainer,
)
from tianshou.utils.net.discrete import DiscreteActor

CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar(
    "TActorCriticParams",
    bound=Params | ParamsMixinSingleModel,
)
TActorDualCriticsParams = TypeVar(
    "TActorDualCriticsParams",
    bound=Params | ParamsMixinActorAndDualCritics,
)
TDiscreteCriticOnlyParams = TypeVar(
    "TDiscreteCriticOnlyParams",
    bound=Params | ParamsMixinSingleModel,
)
TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm)
TPolicy = TypeVar("TPolicy", bound=Policy)
TTrainingConfig = TypeVar("TTrainingConfig", bound=TrainingConfig)
log = logging.getLogger(__name__)


[docs] class AlgorithmFactory(ABC, ToStringMixin, Generic[TTrainingConfig]): """Factory for the creation of an :class:`Algorithm` instance, its policy, trainer as well as collectors.""" def __init__(self, training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory): self.training_config = training_config self.optim_factory = optim_factory self.algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self.collector_factory: CollectorFactory = CollectorFactoryDefault()
[docs] def set_collector_factory(self, collector_factory: CollectorFactory) -> None: self.collector_factory = collector_factory
[docs] def create_train_test_collectors( self, algorithm: Algorithm, envs: Environments, reset_collectors: bool = True, ) -> tuple[BaseCollector, BaseCollector]: """ Creates the collectors for training and test environments. :param algorithm: the algorithm :param envs: the environments wrapper :param reset_collectors: Whether to reset the collectors before returning them. Setting to True means that the envs will be reset as well. :return: a tuple of (training_collector, test_collector) """ buffer_size = self.training_config.buffer_size training_envs = envs.training_envs buffer: ReplayBuffer if len(training_envs) > 1: buffer = VectorReplayBuffer( buffer_size, len(training_envs), stack_num=self.training_config.replay_buffer_stack_num, save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, ) else: buffer = ReplayBuffer( buffer_size, stack_num=self.training_config.replay_buffer_stack_num, save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, ) training_collector = self.collector_factory.create_collector( algorithm, training_envs, buffer, exploration_noise=True, ) test_collector = self.collector_factory.create_collector(algorithm, envs.test_envs) if reset_collectors: training_collector.reset() test_collector.reset() return training_collector, test_collector
[docs] def set_policy_wrapper_factory( self, policy_wrapper_factory: AlgorithmWrapperFactory | None, ) -> None: self.algorithm_wrapper_factory = policy_wrapper_factory
[docs] def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: self.trainer_callbacks = callbacks
@staticmethod def _create_policy_from_args( constructor: type[TPolicy], params_dict: dict, policy_params: list[str], **kwargs: Any, ) -> TPolicy: params = {p: params_dict.pop(p) for p in policy_params} return constructor(**params, **kwargs) @abstractmethod def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: pass
[docs] def create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: algorithm = self._create_algorithm(envs, device) if self.algorithm_wrapper_factory is not None: algorithm = self.algorithm_wrapper_factory.create_wrapped_algorithm( algorithm, envs, self.optim_factory, device, ) return algorithm
[docs] @abstractmethod def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> Trainer: pass
[docs] class OnPolicyAlgorithmFactory(AlgorithmFactory[OnPolicyTrainingConfig], ABC):
[docs] def create_trainer( self, world: World, policy_persistence: PolicyPersistence, ) -> OnPolicyTrainer: training_config = self.training_config callbacks = self.trainer_callbacks context = TrainingContext(world.algorithm, world.envs, world.logger) train_fn = ( callbacks.epoch_train_callback.get_trainer_fn(context) if callbacks.epoch_train_callback else None ) test_fn = ( callbacks.epoch_test_callback.get_trainer_fn(context) if callbacks.epoch_test_callback else None ) stop_fn = ( callbacks.epoch_stop_callback.get_trainer_fn(context) if callbacks.epoch_stop_callback else None ) algorithm = cast(OnPolicyAlgorithm, world.algorithm) assert world.training_collector is not None return algorithm.create_trainer( OnPolicyTrainerParams( training_collector=world.training_collector, test_collector=world.test_collector, max_epochs=training_config.max_epochs, epoch_num_steps=training_config.epoch_num_steps, update_step_num_repetitions=training_config.update_step_num_repetitions, test_step_num_episodes=training_config.test_step_num_episodes, batch_size=training_config.batch_size, collection_step_num_env_steps=training_config.collection_step_num_env_steps, save_best_fn=policy_persistence.get_save_best_fn(world), save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), logger=world.logger, test_in_training=training_config.test_in_training, training_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, verbose=False, ) )
[docs] class OffPolicyAlgorithmFactory(AlgorithmFactory[OffPolicyTrainingConfig], ABC):
[docs] def create_trainer( self, world: World, policy_persistence: PolicyPersistence, ) -> OffPolicyTrainer: training_config = self.training_config callbacks = self.trainer_callbacks context = TrainingContext(world.algorithm, world.envs, world.logger) train_fn = ( callbacks.epoch_train_callback.get_trainer_fn(context) if callbacks.epoch_train_callback else None ) test_fn = ( callbacks.epoch_test_callback.get_trainer_fn(context) if callbacks.epoch_test_callback else None ) stop_fn = ( callbacks.epoch_stop_callback.get_trainer_fn(context) if callbacks.epoch_stop_callback else None ) algorithm = cast(OffPolicyAlgorithm, world.algorithm) assert world.training_collector is not None return algorithm.create_trainer( OffPolicyTrainerParams( training_collector=world.training_collector, test_collector=world.test_collector, max_epochs=training_config.max_epochs, epoch_num_steps=training_config.epoch_num_steps, collection_step_num_env_steps=training_config.collection_step_num_env_steps, test_step_num_episodes=training_config.test_step_num_episodes, batch_size=training_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), logger=world.logger, update_step_num_gradient_steps_per_sample=training_config.update_step_num_gradient_steps_per_sample, test_in_training=training_config.test_in_training, training_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, verbose=False, ) )
[docs] class ReinforceAlgorithmFactory(OnPolicyAlgorithmFactory): def __init__( self, params: ReinforceParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.params = params self.actor_factory = actor_factory self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: actor = self.actor_factory.create_module(envs, device) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) dist_fn = self.actor_factory.create_dist_fn(envs) assert dist_fn is not None policy = self._create_policy_from_args( ProbabilisticActorPolicy, kwargs, ["action_scaling", "action_bound_method", "deterministic_eval"], actor=actor, dist_fn=dist_fn, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) return Reinforce( policy=policy, **kwargs, )
[docs] class ActorCriticOnPolicyAlgorithmFactory( OnPolicyAlgorithmFactory, Generic[TActorCriticParams, TAlgorithm], ): def __init__( self, params: TActorCriticParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory=optimizer_factory) self.params = params self.actor_factory = actor_factory self.critic_factory = critic_factory self.optim_factory = optimizer_factory self.critic_use_action = False @abstractmethod def _get_algorithm_class(self) -> type[TAlgorithm]: pass @typing.no_type_check def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) kwargs["actor"] = actor kwargs["critic"] = critic kwargs["action_space"] = envs.get_action_space() kwargs["observation_space"] = envs.get_observation_space() kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) return kwargs def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: params = self._create_kwargs(envs, device) policy = self._create_policy_from_args( ProbabilisticActorPolicy, params, [ "actor", "dist_fn", "action_space", "deterministic_eval", "observation_space", "action_scaling", "action_bound_method", ], ) algorithm_class = self._get_algorithm_class() return algorithm_class(policy=policy, **params)
[docs] class A2CAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[A2CParams, A2C]): def _get_algorithm_class(self) -> type[A2C]: return A2C
[docs] class PPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[PPOParams, PPO]): def _get_algorithm_class(self) -> type[PPO]: return PPO
[docs] class NPGAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[NPGParams, NPG]): def _get_algorithm_class(self) -> type[NPG]: return NPG
[docs] class TRPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[TRPOParams, TRPO]): def _get_algorithm_class(self) -> type[TRPO]: return TRPO
[docs] class DiscreteCriticOnlyOffPolicyAlgorithmFactory( OffPolicyAlgorithmFactory, Generic[TDiscreteCriticOnlyParams, TAlgorithm], ): def __init__( self, params: TDiscreteCriticOnlyParams, training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.params = params self.model_factory = model_factory self.optim_factory = optim_factory @abstractmethod def _get_algorithm_class(self) -> type[TAlgorithm]: pass @abstractmethod def _create_policy( self, model: torch.nn.Module, params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> Policy: pass @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: model = self.model_factory.create_module(envs, device) params_dict = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) envs.get_type().assert_discrete(self) action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) policy = self._create_policy(model, params_dict, action_space, envs.get_observation_space()) algorithm_class = self._get_algorithm_class() return algorithm_class( policy=policy, **params_dict, )
[docs] class DQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[DQNParams, DQN]): def _create_policy( self, model: torch.nn.Module, params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> Policy: return self._create_policy_from_args( constructor=DiscreteQLearningPolicy, params_dict=params, policy_params=["eps_training", "eps_inference"], model=model, action_space=action_space, observation_space=observation_space, ) def _get_algorithm_class(self) -> type[DQN]: return DQN
[docs] class IQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[IQNParams, IQN]): def _create_policy( self, model: torch.nn.Module, params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> Policy: return self._create_policy_from_args( IQNPolicy, params, [ "sample_size", "online_sample_size", "target_sample_size", "eps_training", "eps_inference", ], model=model, action_space=action_space, observation_space=observation_space, ) def _get_algorithm_class(self) -> type[IQN]: return IQN
[docs] class DDPGAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: DDPGParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.critic_factory = critic_factory self.actor_factory = actor_factory self.params = params self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module( envs, device, True, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) policy = self._create_policy_from_args( ContinuousDeterministicPolicy, kwargs, ["exploration_noise", "action_scaling", "action_bound_method"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) return DDPG( policy=policy, critic=critic, **kwargs, )
[docs] class REDQAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: REDQParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.critic_ensemble_factory = critic_ensemble_factory self.actor_factory = actor_factory self.params = params self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: envs.get_type().assert_continuous(self) actor = self.actor_factory.create_module( envs, device, ) critic_ensemble = self.critic_ensemble_factory.create_module( envs, device, self.params.ensemble_size, True, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) policy = self._create_policy_from_args( REDQPolicy, kwargs, [ "exploration_noise", "deterministic_eval", "action_scaling", "action_bound_method", ], actor=actor, action_space=action_space, observation_space=envs.get_observation_space(), ) return REDQ( policy=policy, critic=critic_ensemble, **kwargs, )
[docs] class ActorDualCriticsOffPolicyAlgorithmFactory( OffPolicyAlgorithmFactory, Generic[TActorDualCriticsParams, TAlgorithm, TPolicy], ): def __init__( self, params: TActorDualCriticsParams, training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory, ): super().__init__(training_config, optim_factory) self.params = params self.actor_factory = actor_factory self.critic1_factory = critic1_factory self.critic2_factory = critic2_factory self.optim_factory = optim_factory @abstractmethod def _get_algorithm_class(self) -> type[TAlgorithm]: pass def _get_discrete_last_size_use_action_shape(self) -> bool: return True @staticmethod def _get_critic_use_action(envs: Environments) -> bool: return envs.get_type().is_continuous() @abstractmethod def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> TPolicy: pass @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: actor = self.actor_factory.create_module(envs, device) use_action_shape = self._get_discrete_last_size_use_action_shape() critic_use_action = self._get_critic_use_action(envs) critic1 = self.critic1_factory.create_module( envs, device, use_action=critic_use_action, discrete_last_size_use_action_shape=use_action_shape, ) critic2 = self.critic2_factory.create_module( envs, device, use_action=critic_use_action, discrete_last_size_use_action_shape=use_action_shape, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, optim_factory_default=self.optim_factory, ), ) policy = self._create_policy(actor, envs, kwargs) algorithm_class = self._get_algorithm_class() return algorithm_class( policy=policy, critic=critic1, critic2=critic2, **kwargs, )
[docs] class SACAlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, SACPolicy]): def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> SACPolicy: return self._create_policy_from_args( SACPolicy, params, ["exploration_noise", "deterministic_eval", "action_scaling"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) def _get_algorithm_class(self) -> type[SAC]: return SAC
[docs] class DiscreteSACAlgorithmFactory( ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, DiscreteSACPolicy] ): def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> DiscreteSACPolicy: return self._create_policy_from_args( DiscreteSACPolicy, params, ["deterministic_eval"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) def _get_algorithm_class(self) -> type[DiscreteSAC]: return DiscreteSAC
[docs] class TD3AlgorithmFactory( ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, ContinuousDeterministicPolicy] ): def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> ContinuousDeterministicPolicy: return self._create_policy_from_args( ContinuousDeterministicPolicy, params, ["exploration_noise", "action_scaling", "action_bound_method"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) def _get_algorithm_class(self) -> type[TD3]: return TD3