optim#


class LRSchedulerFactory[source]#

Bases: ToStringMixin, ABC

Factory for the creation of a learning rate scheduler.

abstract create_scheduler(optim: Optimizer) LRScheduler[source]#
class LRSchedulerFactoryLinear(max_epochs: int, epoch_num_steps: int, collection_step_num_env_steps: int)[source]#

Bases: LRSchedulerFactory

Factory for a learning rate scheduler where the learning rate linearly decays towards zero for the given trainer parameters.

create_scheduler(optim: Optimizer) LRScheduler[source]#
class OptimizerFactory[source]#

Bases: ABC, ToStringMixin

with_lr_scheduler_factory(lr_scheduler_factory: LRSchedulerFactory) Self[source]#
create_instances(module: Module) tuple[Optimizer, LRScheduler | None][source]#
class TorchOptimizerFactory(optim_class: Callable[[...], Optimizer], **kwargs: Any)[source]#

Bases: OptimizerFactory

General factory for arbitrary torch optimizers.

Parameters:
  • optim_class – the optimizer class (e.g. subclass of torch.optim.Optimizer), which will be passed the module parameters, the learning rate as lr and the kwargs provided.

  • kwargs – keyword arguments to provide at optimizer construction

class AdamOptimizerFactory(lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0)[source]#

Bases: OptimizerFactory

class RMSpropOptimizerFactory(lr: float = 0.01, alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0, centered: bool = False)[source]#

Bases: OptimizerFactory