Source code for tianshou.utils.lagged_network
from copy import deepcopy
from dataclasses import dataclass
from typing import Self
import torch
[docs]
def polyak_parameter_update(tgt: torch.nn.Module, src: torch.nn.Module, tau: float) -> None:
"""Softly updates the parameters of a target network `tgt` with the parameters of a source network `src`
using Polyak averaging: `tau * src + (1 - tau) * tgt`.
:param tgt: the target network that receives the parameter update
:param src: the source network whose parameters are used for the update
:param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being
the fraction with which to retain the target network's parameters.
"""
for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True):
tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)
[docs]
class EvalModeModuleWrapper(torch.nn.Module):
"""
A wrapper around a torch.nn.Module that forces the module to eval mode.
The wrapped module supports only the forward method, attribute access is not supported.
**NOTE**: It is *not* recommended to support attribute/method access beyond this via `__getattr__`,
because torch.nn.Module already heavily relies on `__getattr__` to provides its own attribute access.
Overriding it naively will cause problems!
But it's also not necessary for our use cases; forward is enough.
"""
def __init__(self, m: torch.nn.Module):
super().__init__()
m.eval()
self.module = m
[docs]
def forward(self, *args, **kwargs): # type: ignore
self.module.eval()
return self.module(*args, **kwargs)
[docs]
def train(self, mode: bool = True) -> Self:
super().train(mode=mode)
self.module.eval() # force eval mode
return self
[docs]
@dataclass
class LaggedNetworkPair:
target: torch.nn.Module
source: torch.nn.Module
[docs]
class LaggedNetworkCollection:
def __init__(self) -> None:
self._lagged_network_pairs: list[LaggedNetworkPair] = []
[docs]
def add_lagged_network(self, source: torch.nn.Module) -> EvalModeModuleWrapper:
"""
Adds a lagged network to the collection, returning the target network, which
is forced to eval mode. The target network is a copy of the source network,
which, however, supports only the forward method (hence the type torch.nn.Module);
attribute access is not supported.
:param source: the source network whose parameters are to be copied to the target network
:return: the target network, which supports only the forward method and is forced to eval mode
"""
target = deepcopy(source)
self._lagged_network_pairs.append(LaggedNetworkPair(target, source))
return EvalModeModuleWrapper(target)
[docs]
def polyak_parameter_update(self, tau: float) -> None:
"""Softly updates the parameters of each target network `tgt` with the parameters of a source network `src`
using Polyak averaging: `tau * src + (1 - tau) * tgt`.
:param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being
the fraction with which to retain the target network's parameters.
"""
for pair in self._lagged_network_pairs:
polyak_parameter_update(pair.target, pair.source, tau)
[docs]
def full_parameter_update(self) -> None:
"""Fully updates the target networks with the source networks' parameters (exact copy)."""
for pair in self._lagged_network_pairs:
for tgt_param, src_param in zip(
pair.target.parameters(), pair.source.parameters(), strict=True
):
tgt_param.data.copy_(src_param.data)