optim#


class OptimizerWithLearningRateProtocol(*args, **kwargs)[source]#

Bases: Protocol

class OptimizerFactoryFactory[source]#

Bases: ABC, ToStringMixin

static default() OptimizerFactoryFactory[source]#
abstract create_optimizer_factory(lr: float) OptimizerFactory[source]#
class OptimizerFactoryFactoryTorch(optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any)[source]#

Bases: OptimizerFactoryFactory

Factory for torch optimizers.

Parameters:
  • optim_class – the optimizer class (e.g. subclass of torch.optim.Optimizer), which will be passed the module parameters, the learning rate as lr and the kwargs provided.

  • kwargs – keyword arguments to provide at optimizer construction

create_optimizer_factory(lr: float) OptimizerFactory[source]#
class OptimizerFactoryFactoryAdam(betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0)[source]#

Bases: OptimizerFactoryFactory

create_optimizer_factory(lr: float) AdamOptimizerFactory[source]#
class OptimizerFactoryFactoryRMSprop(alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0, centered: bool = False)[source]#

Bases: OptimizerFactoryFactory

create_optimizer_factory(lr: float) RMSpropOptimizerFactory[source]#