from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import Any, Self, TypeAlias
import numpy as np
import torch
from sensai.util.string import ToStringMixin
from torch.optim import Adam, RMSprop
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
ParamsType: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
[docs]
class LRSchedulerFactory(ToStringMixin, ABC):
"""Factory for the creation of a learning rate scheduler."""
[docs]
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
[docs]
class LRSchedulerFactoryLinear(LRSchedulerFactory):
"""
Factory for a learning rate scheduler where the learning rate linearly decays towards
zero for the given trainer parameters.
"""
def __init__(self, max_epochs: int, epoch_num_steps: int, collection_step_num_env_steps: int):
self.num_epochs = max_epochs
self.epoch_num_steps = epoch_num_steps
self.collection_step_num_env_steps = collection_step_num_env_steps
[docs]
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
return LambdaLR(optim, lr_lambda=self._LRLambda(self).compute)
class _LRLambda:
def __init__(self, parent: "LRSchedulerFactoryLinear"):
self.max_update_num = (
np.ceil(parent.epoch_num_steps / parent.collection_step_num_env_steps)
* parent.num_epochs
)
def compute(self, epoch: int) -> float:
return 1.0 - epoch / self.max_update_num
[docs]
class OptimizerFactory(ABC, ToStringMixin):
def __init__(self) -> None:
self.lr_scheduler_factory: LRSchedulerFactory | None = None
[docs]
def with_lr_scheduler_factory(self, lr_scheduler_factory: LRSchedulerFactory) -> Self:
self.lr_scheduler_factory = lr_scheduler_factory
return self
[docs]
def create_instances(
self,
module: torch.nn.Module,
) -> tuple[torch.optim.Optimizer, LRScheduler | None]:
optimizer = self._create_optimizer_for_params(module.parameters())
lr_scheduler = None
if self.lr_scheduler_factory is not None:
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optimizer)
return optimizer, lr_scheduler
@abstractmethod
def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer:
pass
[docs]
class TorchOptimizerFactory(OptimizerFactory):
"""General factory for arbitrary torch optimizers."""
def __init__(self, optim_class: Callable[..., torch.optim.Optimizer], **kwargs: Any):
"""
:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
which will be passed the module parameters, the learning rate as `lr` and the
kwargs provided.
:param kwargs: keyword arguments to provide at optimizer construction
"""
super().__init__()
self.optim_class = optim_class
self.kwargs = kwargs
def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer:
return self.optim_class(params, **self.kwargs)
[docs]
class AdamOptimizerFactory(OptimizerFactory):
def __init__(
self,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0,
):
super().__init__()
self.lr = lr
self.weight_decay = weight_decay
self.eps = eps
self.betas = betas
def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer:
return Adam(
params,
lr=self.lr,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
)
[docs]
class RMSpropOptimizerFactory(OptimizerFactory):
def __init__(
self,
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-08,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
):
super().__init__()
self.lr = lr
self.alpha = alpha
self.momentum = momentum
self.centered = centered
self.weight_decay = weight_decay
self.eps = eps
def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer:
return RMSprop(
params,
lr=self.lr,
alpha=self.alpha,
eps=self.eps,
weight_decay=self.weight_decay,
momentum=self.momentum,
centered=self.centered,
)