Source code for tianshou.highlevel.params.optim

from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any, Protocol, TypeAlias

import torch
from sensai.util.string import ToStringMixin

from tianshou.algorithm.optim import (
    AdamOptimizerFactory,
    OptimizerFactory,
    RMSpropOptimizerFactory,
    TorchOptimizerFactory,
)

TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]


[docs] class OptimizerWithLearningRateProtocol(Protocol): def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: pass
[docs] class OptimizerFactoryFactory(ABC, ToStringMixin):
[docs] @staticmethod def default() -> "OptimizerFactoryFactory": return OptimizerFactoryFactoryAdam()
[docs] @abstractmethod def create_optimizer_factory(self, lr: float) -> OptimizerFactory: pass
[docs] class OptimizerFactoryFactoryTorch(OptimizerFactoryFactory): def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any): """Factory for torch optimizers. :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), which will be passed the module parameters, the learning rate as `lr` and the kwargs provided. :param kwargs: keyword arguments to provide at optimizer construction """ self.optim_class = optim_class self.kwargs = kwargs
[docs] def create_optimizer_factory(self, lr: float) -> OptimizerFactory: return TorchOptimizerFactory(optim_class=self.optim_class, lr=lr)
[docs] class OptimizerFactoryFactoryAdam(OptimizerFactoryFactory): def __init__( self, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, ): self.weight_decay = weight_decay self.eps = eps self.betas = betas
[docs] def create_optimizer_factory(self, lr: float) -> AdamOptimizerFactory: return AdamOptimizerFactory( lr=lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, )
[docs] class OptimizerFactoryFactoryRMSprop(OptimizerFactoryFactory): def __init__( self, alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0, centered: bool = False, ): self.alpha = alpha self.momentum = momentum self.centered = centered self.weight_decay = weight_decay self.eps = eps
[docs] def create_optimizer_factory(self, lr: float) -> RMSpropOptimizerFactory: return RMSpropOptimizerFactory( lr=lr, alpha=self.alpha, eps=self.eps, weight_decay=self.weight_decay, momentum=self.momentum, centered=self.centered, )