Source code for tianshou.highlevel.params.lr_scheduler
from abc import ABC, abstractmethod
from sensai.util.string import ToStringMixin
from tianshou.algorithm.optim import LRSchedulerFactory, LRSchedulerFactoryLinear
from tianshou.highlevel.config import TrainingConfig
[docs]
class LRSchedulerFactoryFactory(ToStringMixin, ABC):
"""Factory for the creation of a learning rate scheduler factory."""
[docs]
@abstractmethod
def create_lr_scheduler_factory(self) -> LRSchedulerFactory:
pass
[docs]
class LRSchedulerFactoryFactoryLinear(LRSchedulerFactoryFactory):
def __init__(self, training_config: TrainingConfig):
self.training_config = training_config
[docs]
def create_lr_scheduler_factory(self) -> LRSchedulerFactory:
if (
self.training_config.epoch_num_steps is None
or self.training_config.collection_step_num_env_steps is None
):
raise ValueError(
f"{self.__class__.__name__} requires epoch_num_steps and collection_step_num_env_steps to be set "
f"in order for the scheduling to be well-defined."
)
return LRSchedulerFactoryLinear(
max_epochs=self.training_config.max_epochs,
epoch_num_steps=self.training_config.epoch_num_steps,
collection_step_num_env_steps=self.training_config.collection_step_num_env_steps,
)