Source code for tianshou.policy.modelfree.rainbow
from dataclasses import dataclass
from typing import Any, TypeVar
from torch import nn
from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy import C51Policy
from tianshou.policy.modelfree.c51 import C51TrainingStats
from tianshou.utils.net.discrete import NoisyLinear
# TODO: this is a hacky thing interviewing side-effects and a return. Should improve.
def _sample_noise(model: nn.Module) -> bool:
"""Sample the random noises of NoisyLinear modules in the model.
Returns True if at least one NoisyLinear submodule was found.
:param model: a PyTorch module which may have NoisyLinear submodules.
:returns: True if model has at least one NoisyLinear submodule;
otherwise, False.
"""
sampled_any_noise = False
for m in model.modules():
if isinstance(m, NoisyLinear):
m.sample()
sampled_any_noise = True
return sampled_any_noise
[docs]
@dataclass(kw_only=True)
class RainbowTrainingStats(C51TrainingStats):
loss: float
TRainbowTrainingStats = TypeVar("TRainbowTrainingStats", bound=RainbowTrainingStats)
# TODO: is this class worth keeping? It barely does anything
[docs]
class RainbowPolicy(C51Policy[TRainbowTrainingStats]):
"""Implementation of Rainbow DQN. arXiv:1710.02298.
Same parameters as :class:`~tianshou.policy.C51Policy`.
.. seealso::
Please refer to :class:`~tianshou.policy.C51Policy` for more detailed
explanation.
"""
[docs]
def learn(
self,
batch: RolloutBatchProtocol,
*args: Any,
**kwargs: Any,
) -> TRainbowTrainingStats:
_sample_noise(self.model)
if self._target and _sample_noise(self.model_old):
self.model_old.train() # so that NoisyLinear takes effect
return super().learn(batch, **kwargs)