lagged_network#
Source code: tianshou/utils/lagged_network.py
- polyak_parameter_update(tgt: Module, src: Module, tau: float) None[source]#
Softly updates the parameters of a target network tgt with the parameters of a source network src using Polyak averaging: tau * src + (1 - tau) * tgt.
- Parameters:
tgt – the target network that receives the parameter update
src – the source network whose parameters are used for the update
tau – the fraction with which to use the source network’s parameters, the inverse 1-tau being the fraction with which to retain the target network’s parameters.
- class EvalModeModuleWrapper(m: Module)[source]#
Bases:
ModuleA wrapper around a torch.nn.Module that forces the module to eval mode.
The wrapped module supports only the forward method, attribute access is not supported. NOTE: It is not recommended to support attribute/method access beyond this via __getattr__, because torch.nn.Module already heavily relies on __getattr__ to provides its own attribute access. Overriding it naively will cause problems! But it’s also not necessary for our use cases; forward is enough.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(*args, **kwargs)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- train(mode: bool = True) Self[source]#
Sets the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout,BatchNorm, etc.- Args:
- mode (bool): whether to set training mode (
True) or evaluation mode (
False). Default:True.
- mode (bool): whether to set training mode (
- Returns:
Module: self
- class LaggedNetworkPair(target: torch.nn.modules.module.Module, source: torch.nn.modules.module.Module)[source]#
Bases:
object- target: Module#
- source: Module#
- class LaggedNetworkCollection[source]#
Bases:
object- add_lagged_network(source: Module) EvalModeModuleWrapper[source]#
Adds a lagged network to the collection, returning the target network, which is forced to eval mode. The target network is a copy of the source network, which, however, supports only the forward method (hence the type torch.nn.Module); attribute access is not supported.
- Parameters:
source – the source network whose parameters are to be copied to the target network
- Returns:
the target network, which supports only the forward method and is forced to eval mode
- polyak_parameter_update(tau: float) None[source]#
Softly updates the parameters of each target network tgt with the parameters of a source network src using Polyak averaging: tau * src + (1 - tau) * tgt.
- Parameters:
tau – the fraction with which to use the source network’s parameters, the inverse 1-tau being the fraction with which to retain the target network’s parameters.