Source code for tianshou.utils.torch_utils
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, overload
import torch
import torch.distributions as dist
from gymnasium import spaces
from torch import nn
if TYPE_CHECKING:
from tianshou.algorithm import algorithm_base
[docs]
@contextmanager
def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]:
"""Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`."""
original_mode = module.training
try:
module.train(enabled)
yield
finally:
module.train(original_mode)
[docs]
@contextmanager
def policy_within_training_step(
policy: "algorithm_base.Policy", enabled: bool = True
) -> Iterator[None]:
"""Temporarily switch to `policy.is_within_training_step=enabled`.
Enabling this ensures that the policy is able to adapt its behavior,
allowing it to differentiate between training and inference/evaluation,
e.g., to sample actions instead of using the most probable action (where applicable)
Note that for rollout, which also happens within a training step, one would usually want
the wrapped torch module to be in evaluation mode, which can be achieved using
`with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both
within training step and in torch train mode.
"""
original_mode = policy.is_within_training_step
try:
policy.is_within_training_step = enabled
yield
finally:
policy.is_within_training_step = original_mode
@overload
def create_uniform_action_dist(action_space: spaces.Box, batch_size: int = 1) -> dist.Uniform: ...
@overload
def create_uniform_action_dist(
action_space: spaces.Discrete,
batch_size: int = 1,
) -> dist.Categorical: ...
[docs]
def torch_device(module: torch.nn.Module) -> torch.device:
"""Gets the device of a torch module by retrieving the device of the parameters.
If parameters are empty, it returns the CPU device as a fallback.
"""
try:
return next(module.parameters()).device
except StopIteration:
return torch.device("cpu")