lagged_network#


polyak_parameter_update(tgt: Module, src: Module, tau: float) None[source]#

Softly updates the parameters of a target network tgt with the parameters of a source network src using Polyak averaging: tau * src + (1 - tau) * tgt.

Parameters:
  • tgt – the target network that receives the parameter update

  • src – the source network whose parameters are used for the update

  • tau – the fraction with which to use the source network’s parameters, the inverse 1-tau being the fraction with which to retain the target network’s parameters.

class EvalModeModuleWrapper(m: Module)[source]#

Bases: Module

A wrapper around a torch.nn.Module that forces the module to eval mode.

The wrapped module supports only the forward method, attribute access is not supported. NOTE: It is not recommended to support attribute/method access beyond this via __getattr__, because torch.nn.Module already heavily relies on __getattr__ to provides its own attribute access. Overriding it naively will cause problems! But it’s also not necessary for our use cases; forward is enough.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(*args, **kwargs)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

train(mode: bool = True) Self[source]#

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Args:
mode (bool): whether to set training mode (True) or evaluation

mode (False). Default: True.

Returns:

Module: self

class LaggedNetworkPair(target: torch.nn.modules.module.Module, source: torch.nn.modules.module.Module)[source]#

Bases: object

target: Module#
source: Module#
class LaggedNetworkCollection[source]#

Bases: object

add_lagged_network(source: Module) EvalModeModuleWrapper[source]#

Adds a lagged network to the collection, returning the target network, which is forced to eval mode. The target network is a copy of the source network, which, however, supports only the forward method (hence the type torch.nn.Module); attribute access is not supported.

Parameters:

source – the source network whose parameters are to be copied to the target network

Returns:

the target network, which supports only the forward method and is forced to eval mode

polyak_parameter_update(tau: float) None[source]#

Softly updates the parameters of each target network tgt with the parameters of a source network src using Polyak averaging: tau * src + (1 - tau) * tgt.

Parameters:

tau – the fraction with which to use the source network’s parameters, the inverse 1-tau being the fraction with which to retain the target network’s parameters.

full_parameter_update() None[source]#

Fully updates the target networks with the source networks’ parameters (exact copy).