Source code for tianshou.data.types
from typing import Protocol
import numpy as np
import torch
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol, TArr, TObsArr
TObs = TObsArr | BatchProtocol
TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"]
[docs]
class ObsBatchProtocol(BatchProtocol, Protocol):
"""Observations of an environment that a policy can turn into actions.
Typically used inside a policy's forward
"""
obs: TObs
"""the observations as generated by the environment in `step`.
If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)"""
info: TArr
"""array of info dicts generated by the environment in `step`"""
[docs]
class RolloutBatchProtocol(ObsBatchProtocol, Protocol):
"""Typically, the outcome of sampling from a replay buffer."""
obs_next: TObs
"""the observations after obs as generated by the environment in `step`.
If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)"""
act: TArr
rew: np.ndarray
terminated: TArr
truncated: TArr
[docs]
class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol):
"""With added returns, usually computed with GAE."""
returns: torch.Tensor
[docs]
class PrioBatchProtocol(RolloutBatchProtocol, Protocol):
"""Contains weights that can be used for prioritized replay."""
weight: np.ndarray | torch.Tensor
"""can be used for prioritized replay."""
[docs]
class RecurrentStateBatch(BatchProtocol, Protocol):
"""Used by RNNs in policies, contains `hidden` and `cell` fields."""
hidden: torch.Tensor
cell: torch.Tensor
[docs]
class ActBatchProtocol(BatchProtocol, Protocol):
"""Simplest batch, just containing the action. Useful e.g., for random policy."""
act: TArr
[docs]
class ActStateBatchProtocol(ActBatchProtocol, Protocol):
"""Contains action and state (which can be None), useful for policies that can support RNNs."""
state: dict | BatchProtocol | np.ndarray | None
"""Hidden state of RNNs, or None if not using RNNs. Used for recurrent policies.
At the moment support for recurrent is experimental!"""
[docs]
class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol):
"""In addition to state and action, contains model output: (logits)."""
logits: torch.Tensor
[docs]
class FQFBatchProtocol(ModelOutputBatchProtocol, Protocol):
"""Model outputs, fractions and quantiles_tau - specific to the FQF model."""
fractions: torch.Tensor
quantiles_tau: torch.Tensor
[docs]
class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol, Protocol):
"""Contains estimated advantages and values.
Returns are usually computed from GAE of advantages by adding the value.
"""
adv: torch.Tensor
v_s: torch.Tensor
[docs]
class DistBatchProtocol(ModelOutputBatchProtocol, Protocol):
"""Contains dist instances for actions (created by dist_fn).
Usually categorical or normal.
"""
dist: torch.distributions.Distribution
[docs]
class DistLogProbBatchProtocol(DistBatchProtocol, Protocol):
"""Contains dist objects that can be sampled from and log_prob of taken action."""
log_prob: torch.Tensor
[docs]
class LogpOldProtocol(BatchWithAdvantagesProtocol, Protocol):
"""Contains logp_old, often needed for importance weights, in particular in PPO.
Builds on batches that contain advantages and values.
"""
logp_old: torch.Tensor
[docs]
class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol):
"""Contains taus for algorithms using quantile regression.
See e.g. https://arxiv.org/abs/1806.06923
"""
taus: torch.Tensor
[docs]
class ImitationBatchProtocol(ModelOutputBatchProtocol, Protocol):
"""Similar to other batches, but contains `imitation_logits` and `q_value` fields."""
state: dict | Batch | np.ndarray | None
q_value: torch.Tensor
imitation_logits: torch.Tensor