Source code for tianshou.highlevel.params.dist_fn

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any

import torch
from sensai.util.string import ToStringMixin

from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.highlevel.env import Environments


[docs] class DistributionFunctionFactory(ToStringMixin, ABC): # True return type defined in subclasses
[docs] @abstractmethod def create_dist_fn( self, envs: Environments, ) -> Callable[[Any], torch.distributions.Distribution]: pass
[docs] class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): def __init__(self, is_probs_input: bool = True): """ :param is_probs_input: If True, the distribution function shall create a categorical distribution from a tensor containing probabilities; otherwise the tensor is assumed to contain logits. """ self.is_probs_input = is_probs_input
[docs] def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete: envs.get_type().assert_discrete(self) if self.is_probs_input: return self._dist_fn_probs else: return self._dist_fn
# NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod def _dist_fn(logits: torch.Tensor) -> torch.distributions.Categorical: return torch.distributions.Categorical(logits=logits) # NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod def _dist_fn_probs(probs: torch.Tensor) -> torch.distributions.Categorical: return torch.distributions.Categorical(probs=probs)
[docs] class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
[docs] def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: envs.get_type().assert_continuous(self) return self._dist_fn
# NOTE: Do not move/rename because a reference to the function can appear in persisted policies @staticmethod def _dist_fn( loc_scale: tuple[torch.Tensor, torch.Tensor], ) -> torch.distributions.Distribution: loc, scale = loc_scale return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)