types#


class ObsBatchProtocol(*args, **kwargs)[source]#

Bases: BatchProtocol, Protocol

Observations of an environment that a policy can turn into actions.

Typically used inside a policy’s forward

obs: Tensor | ndarray | BatchProtocol#
info: Tensor | ndarray | BatchProtocol#
class RolloutBatchProtocol(*args, **kwargs)[source]#

Bases: ObsBatchProtocol, Protocol

Typically, the outcome of sampling from a replay buffer.

obs_next: Tensor | ndarray | BatchProtocol#
act: Tensor | ndarray#
rew: ndarray#
terminated: Tensor | ndarray#
truncated: Tensor | ndarray#
class BatchWithReturnsProtocol(*args, **kwargs)[source]#

Bases: RolloutBatchProtocol, Protocol

With added returns, usually computed with GAE.

returns: Tensor | ndarray#
class PrioBatchProtocol(*args, **kwargs)[source]#

Bases: RolloutBatchProtocol, Protocol

Contains weights that can be used for prioritized replay.

weight: ndarray | Tensor#
class RecurrentStateBatch(*args, **kwargs)[source]#

Bases: BatchProtocol, Protocol

Used by RNNs in policies, contains hidden and cell fields.

hidden: Tensor#
cell: Tensor#
class ActBatchProtocol(*args, **kwargs)[source]#

Bases: BatchProtocol, Protocol

Simplest batch, just containing the action. Useful e.g., for random policy.

act: Tensor | ndarray#
class ActStateBatchProtocol(*args, **kwargs)[source]#

Bases: ActBatchProtocol, Protocol

Contains action and state (which can be None), useful for policies that can support RNNs.

state: dict | BatchProtocol | ndarray | None#

Hidden state of RNNs, or None if not using RNNs. Used for recurrent policies. At the moment support for recurrent is experimental!

class ModelOutputBatchProtocol(*args, **kwargs)[source]#

Bases: ActStateBatchProtocol, Protocol

In addition to state and action, contains model output: (logits).

logits: Tensor#
class FQFBatchProtocol(*args, **kwargs)[source]#

Bases: ModelOutputBatchProtocol, Protocol

Model outputs, fractions and quantiles_tau - specific to the FQF model.

fractions: Tensor#
quantiles_tau: Tensor#
class BatchWithAdvantagesProtocol(*args, **kwargs)[source]#

Bases: BatchWithReturnsProtocol, Protocol

Contains estimated advantages and values.

Returns are usually computed from GAE of advantages by adding the value.

adv: Tensor#
v_s: Tensor#
class DistBatchProtocol(*args, **kwargs)[source]#

Bases: ModelOutputBatchProtocol, Protocol

Contains dist instances for actions (created by dist_fn).

Usually categorical or normal.

dist: Distribution#
class DistLogProbBatchProtocol(*args, **kwargs)[source]#

Bases: DistBatchProtocol, Protocol

Contains dist objects that can be sampled from and log_prob of taken action.

log_prob: Tensor#
class LogpOldProtocol(*args, **kwargs)[source]#

Bases: BatchWithAdvantagesProtocol, Protocol

Contains logp_old, often needed for importance weights, in particular in PPO.

Builds on batches that contain advantages and values.

logp_old: Tensor#
class QuantileRegressionBatchProtocol(*args, **kwargs)[source]#

Bases: ModelOutputBatchProtocol, Protocol

Contains taus for algorithms using quantile regression.

See e.g. https://arxiv.org/abs/1806.06923

taus: Tensor#
class ImitationBatchProtocol(*args, **kwargs)[source]#

Bases: ActBatchProtocol, Protocol

Similar to other batches, but contains imitation_logits and q_value fields.

state: dict | Batch | ndarray | None#
q_value: Tensor#
imitation_logits: Tensor#