collector#


class CollectActionBatchProtocol(*args, **kwargs)[source]#

Bases: Protocol

A protocol for results of computing actions from a batch of observations within a single collect step.

All fields all have length R (the dist is a Distribution of batch size R), where R is the number of ready envs.

act: ndarray | Tensor#
act_normalized: ndarray | Tensor#
policy_entry: Batch#
dist: Distribution | None#
hidden_state: ndarray | Tensor | Batch | None#
class CollectStepBatchProtocol(*args, **kwargs)[source]#

Bases: RolloutBatchProtocol

A batch of steps collected from a single collect step from multiple envs in parallel.

All fields have length R (the dist is a Distribution of batch size R), where R is the number of ready envs. This is essentially the response of the vectorized environment to making a step with CollectActionBatchProtocol.

dist: Distribution | None#
class EpisodeBatchProtocol(*args, **kwargs)[source]#

Bases: RolloutBatchProtocol

Marker interface for a batch containing a single episode.

Instances are created by retrieving an episode from the buffer when the Collector encounters done=True.

get_stddev_from_dist(dist: Distribution) Tensor[source]#

Return the standard deviation of the given distribution.

Same as dist.stddev for all distributions except Categorical, where it is computed by assuming that the output values 0, …, K have the corresponding numerical meaning. See here for a discussion on stddev and mean of Categorical.

class CollectStatsBase(*, n_collected_episodes: int = 0, n_collected_steps: int = 0)[source]#

Bases: DataclassPPrintMixin

The most basic stats, often used for offline learning.

n_collected_episodes: int = 0#

The number of collected episodes.

n_collected_steps: int = 0#

The number of collected steps.

class CollectStats(*, n_collected_episodes: int = 0, n_collected_steps: int = 0, collect_time: float = 0.0, collect_speed: float = 0.0, returns: ~numpy.ndarray = <factory>, returns_stat: ~tianshou.data.stats.SequenceSummaryStats | None = None, lens: ~numpy.ndarray = <factory>, lens_stat: ~tianshou.data.stats.SequenceSummaryStats | None = None, pred_dist_std_array: ~numpy.ndarray | None = None, pred_dist_std_array_stat: dict[int, ~tianshou.data.stats.SequenceSummaryStats] | None = None)[source]#

Bases: CollectStatsBase

A data structure for storing the statistics of rollouts.

Custom stats collection logic can be implemented by subclassing this class and overriding the update_ methods.

Ideally, it is instantiated once with correct values and then never modified. However, during the collection process instances of modified using the update_ methods. Then the arrays and their corresponding _stats fields may become out of sync (we don’t update the stats after each update for performance reasons, only at the end of the collection). The same for the collect_time and collect_speed. In the Collector, refresh_sequence_stats() and set_collect_time() are is called at the end of the collection to update the stats. But for other use cases, the users should keep in mind to call this method manually if using the update_ methods.

collect_time: float = 0.0#

The time for collecting transitions.

collect_speed: float = 0.0#

The speed of collecting (env_step per second).

returns: ndarray#

The collected episode returns.

returns_stat: SequenceSummaryStats | None = None#

Stats of the collected returns.

lens: ndarray#

The collected episode lengths.

lens_stat: SequenceSummaryStats | None = None#

Stats of the collected episode lengths.

pred_dist_std_array: ndarray | None = None#

The standard deviations of the predicted distributions.

pred_dist_std_array_stat: dict[int, SequenceSummaryStats] | None = None#

Stats of the standard deviations of the predicted distributions (maps action dim to stats)

classmethod with_autogenerated_stats(returns: ndarray, lens: ndarray, n_collected_episodes: int = 0, n_collected_steps: int = 0, collect_time: float = 0.0, collect_speed: float = 0.0) Self[source]#

Return a new instance with the stats autogenerated from the given lists.

update_at_step_batch(step_batch: CollectStepBatchProtocol, refresh_sequence_stats: bool = False) None[source]#
update_at_episode_done(episode_batch: EpisodeBatchProtocol, episode_return: float, refresh_sequence_stats: bool = False) None[source]#
set_collect_time(collect_time: float, update_collect_speed: bool = True) None[source]#
refresh_return_stats() None[source]#
refresh_len_stats() None[source]#
refresh_std_array_stats() None[source]#
refresh_all_sequence_stats() None[source]#
class BaseCollector(policy: ~tianshou.policy.base.BasePolicy, env: ~tianshou.env.venvs.BaseVectorEnv | ~gymnasium.core.Env, buffer: ~tianshou.data.buffer.base.ReplayBuffer | None = None, exploration_noise: bool = False, collect_stats_class: type[~tianshou.data.collector.TCollectStats] = <class 'tianshou.data.collector.CollectStats'>, raise_on_nan_in_buffer: bool = True)[source]#

Bases: Generic[TCollectStats], ABC

Used to collect data from a vector environment into a buffer using a given policy.

Note

Please make sure the given environment has a time limitation if using n_episode collect option.

Note

In past versions of Tianshou, the replay buffer passed to __init__ was automatically reset. This is not done in the current implementation.

property env_num: int#
property action_space: Space#
close() None[source]#

Close the collector and the environment.

reset(reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None) tuple[ndarray, ndarray][source]#

Reset the environment, statistics, and data needed to start the collection.

Parameters:
  • reset_buffer – if true, reset the replay buffer attached to the collector.

  • reset_stats – if true, reset the statistics attached to the collector.

  • gym_reset_kwargs – extra keyword arguments to pass into the environment’s reset function. Defaults to None (extra keyword arguments)

Returns:

The initial observation and info from the environment.

reset_stat() None[source]#

Reset the statistic variables.

reset_buffer(keep_statistics: bool = False) None[source]#

Reset the data buffer.

reset_env(gym_reset_kwargs: dict[str, Any] | None = None) tuple[ndarray, ndarray][source]#

Reset the environments and the initial obs, info, and hidden state of the collector.

collect(n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) TCollectStats[source]#

Collect the specified number of steps or episodes to the buffer.

Note

One and only one collection specification is permitted, either n_step or n_episode.

To ensure an unbiased sampling result with the n_episode option, this function will first collect n_episode - env_num episodes, then for the last env_num episodes, they will be collected evenly from each env.

Parameters:
  • n_step – how many steps to collect.

  • n_episode – how many episodes to collect.

  • random – whether to sample randomly from the action space instead of using the policy for collecting data.

  • render – the sleep time between rendering consecutive frames.

  • reset_before_collect – whether to reset the environment before collecting data. (The collector needs the initial obs and info to function properly.)

  • gym_reset_kwargs – extra keyword arguments to pass into the environment’s reset function. Only used if reset_before_collect is True.

Returns:

The collected stats

class Collector(policy: ~tianshou.policy.base.BasePolicy, env: ~gymnasium.core.Env | ~tianshou.env.venvs.BaseVectorEnv, buffer: ~tianshou.data.buffer.base.ReplayBuffer | None = None, exploration_noise: bool = False, on_episode_done_hook: ~tianshou.data.collector.EpisodeRolloutHookProtocol | None = None, on_step_hook: ~tianshou.data.collector.StepHookProtocol | None = None, raise_on_nan_in_buffer: bool = True, collect_stats_class: type[~tianshou.data.collector.TCollectStats] = <class 'tianshou.data.collector.CollectStats'>)[source]#

Bases: BaseCollector[TCollectStats], Generic[TCollectStats]

Collects transitions from a vectorized env by computing and applying actions batch-wise.

Parameters:
  • policy – a tianshou policy, each BasePolicy is capable of computing a batch of actions from a batch of observations.

  • env – a gymnasium.Env environment or a vectorized instance of the BaseVectorEnv class. The latter is strongly recommended, as with a gymnasium env the collection will not happen in parallel (a DummyVectorEnv will be constructed internally from the passed env)

  • buffer – an instance of the ReplayBuffer class. If set to None, will instantiate a VectorReplayBuffer of size DEFAULT_BUFFER_MAXSIZE * (number of envs) as the default buffer.

  • exploration_noise – determine whether the action needs to be modified with the corresponding policy’s exploration noise. If so, “policy. exploration_noise(act, batch)” will be called automatically to add the exploration noise into action..

  • on_episode_done_hook

    if passed will be executed when an episode is done. The input to the hook will be a RolloutBatch that contains the entire episode (and nothing else). If a dict is returned by the hook it will be used to add new entries to the buffer for the episode that just ended. The values of the dict should be arrays with floats of the same length as the input rollout batch. Note that multiple hooks can be combined using EpisodeRolloutHookMerged. A typical example of a hook is EpisodeRolloutHookMCReturn which adds the Monte Carlo return as a field to the buffer.

    Care must be taken when using such hook, as for unfinished episodes one can easily end up with NaNs in the buffer. It is recommended to use the hooks only with the n_episode option in collect, or to strip the buffer of NaNs after the collection.

  • on_step_hook – if passed will be executed after each step of the collection but before the resulting rollout batch is added to the buffer. The inputs to the hook will be the action distributions computed from the previous observations (following the CollectActionBatchProtocol) using the policy, and the resulting rollout batch (following the RolloutBatchProtocol). Note that modifying the rollout batch with this hook also modifies the data that is collected to the buffer!

  • raise_on_nan_in_buffer – whether to raise a RuntimeError if NaNs are found in the buffer after a collection step. Especially useful when episode-level hooks are passed for making sure that nothing is broken during the collection. Consider setting to False if the NaN-check becomes a bottleneck.

  • collect_stats_class – the class to use for collecting statistics. Allows customizing the stats collection logic by passing a subclass of CollectStats. Changing this is rarely necessary and is mainly done by “power users”.

set_on_episode_done_hook(hook: EpisodeRolloutHookProtocol | None) None[source]#
set_on_step_hook(hook: StepHookProtocol | None) None[source]#
get_on_episode_done_hook() EpisodeRolloutHookProtocol | None[source]#
get_on_step_hook() StepHookProtocol | None[source]#
close() None[source]#

Close the collector and the environment.

run_on_episode_done(episode_batch: EpisodeBatchProtocol) dict[str, ndarray] | None[source]#

Executes the on_episode_done_hook that was passed on init.

One of the main uses of this public method is to allow users to override it in custom subclasses of Collector. This way, they can override the init to no longer accept the on_episode_done provider.

run_on_step_hook(action_batch: CollectActionBatchProtocol, rollout_batch: RolloutBatchProtocol) None[source]#

Executes the instance’s on_step_hook.

One of the main uses of this public method is to allow users to override it in custom subclasses of the Collector. This way, they can override the init to no longer accept the on_step_hook provider.

reset_env(gym_reset_kwargs: dict[str, Any] | None = None) tuple[ndarray, ndarray][source]#

Reset the environments and the initial obs, info, and hidden state of the collector.

class AsyncCollector(policy: BasePolicy, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, raise_on_nan_in_buffer: bool = True)[source]#

Bases: Collector[CollectStats]

Async Collector handles async vector environment.

Please refer to Collector for a more detailed explanation.

Parameters:
  • policy – a tianshou policy, each BasePolicy is capable of computing a batch of actions from a batch of observations.

  • env – a gymnasium.Env environment or a vectorized instance of the BaseVectorEnv class. The latter is strongly recommended, as with a gymnasium env the collection will not happen in parallel (a DummyVectorEnv will be constructed internally from the passed env)

  • buffer – an instance of the ReplayBuffer class. If set to None, will instantiate a VectorReplayBuffer of size DEFAULT_BUFFER_MAXSIZE * (number of envs) as the default buffer.

  • exploration_noise – determine whether the action needs to be modified with the corresponding policy’s exploration noise. If so, “policy. exploration_noise(act, batch)” will be called automatically to add the exploration noise into action..

  • on_episode_done_hook

    if passed will be executed when an episode is done. The input to the hook will be a RolloutBatch that contains the entire episode (and nothing else). If a dict is returned by the hook it will be used to add new entries to the buffer for the episode that just ended. The values of the dict should be arrays with floats of the same length as the input rollout batch. Note that multiple hooks can be combined using EpisodeRolloutHookMerged. A typical example of a hook is EpisodeRolloutHookMCReturn which adds the Monte Carlo return as a field to the buffer.

    Care must be taken when using such hook, as for unfinished episodes one can easily end up with NaNs in the buffer. It is recommended to use the hooks only with the n_episode option in collect, or to strip the buffer of NaNs after the collection.

  • on_step_hook – if passed will be executed after each step of the collection but before the resulting rollout batch is added to the buffer. The inputs to the hook will be the action distributions computed from the previous observations (following the CollectActionBatchProtocol) using the policy, and the resulting rollout batch (following the RolloutBatchProtocol). Note that modifying the rollout batch with this hook also modifies the data that is collected to the buffer!

  • raise_on_nan_in_buffer – whether to raise a RuntimeError if NaNs are found in the buffer after a collection step. Especially useful when episode-level hooks are passed for making sure that nothing is broken during the collection. Consider setting to False if the NaN-check becomes a bottleneck.

  • collect_stats_class – the class to use for collecting statistics. Allows customizing the stats collection logic by passing a subclass of CollectStats. Changing this is rarely necessary and is mainly done by “power users”.

reset(reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None) tuple[ndarray, ndarray][source]#

Reset the environment, statistics, and data needed to start the collection.

Parameters:
  • reset_buffer – if true, reset the replay buffer attached to the collector.

  • reset_stats – if true, reset the statistics attached to the collector.

  • gym_reset_kwargs – extra keyword arguments to pass into the environment’s reset function. Defaults to None (extra keyword arguments)

Returns:

The initial observation and info from the environment.

reset_env(gym_reset_kwargs: dict[str, Any] | None = None) tuple[ndarray, ndarray][source]#

Reset the environments and the initial obs, info, and hidden state of the collector.

class StepHookProtocol(*args, **kwargs)[source]#

Bases: Protocol

A protocol for step hooks.

class StepHook(*args, **kwargs)[source]#

Bases: StepHookProtocol, ABC

Marker interface for step hooks.

All step hooks in Tianshou will inherit from it, but only the corresponding protocol will be used in type hints. This makes it possible to discover all hooks in the codebase by looking up the hierarchy of this class (or the protocol itself) while still allowing the user to pass something like a lambda function as a hook.

class StepHookAddActionDistribution(*args, **kwargs)[source]#

Bases: StepHook

Adds the action distribution to the collected rollout batch under the field “action_dist”.

The field is also accessible as class variable ACTION_DIST_KEY. This hook be useful for algorithms that need the previously taken actions for training, like variants of imitation learning or DAGGER.

ACTION_DIST_KEY = 'action_dist'#
class EpisodeRolloutHookProtocol(*args, **kwargs)[source]#

Bases: Protocol

A protocol for hooks (functions) that act on an entire collected episode.

Can be used to add values to the buffer that are only known after the episode is finished. A prime example is something like the MC return to go.

class EpisodeRolloutHook(*args, **kwargs)[source]#

Bases: EpisodeRolloutHookProtocol, ABC

Marker interface for episode hooks.

All episode hooks in Tianshou will inherit from it, but only the corresponding protocol will be used in type hints. This makes it possible to discover all hooks in the codebase by looking up the hierarchy of this class (or the protocol itself) while still allowing the user to pass something like a lambda function as a hook.

class EpisodeRolloutHookMCReturn(gamma: float = 0.99)[source]#

Bases: EpisodeRolloutHook

Adds the MC return to go as well as the full episode MC return to the transitions in the buffer.

The latter will be constant for all transitions in the same episode and simply corresponds to the initial MC return to go. Useful for algorithms that rely on the monte carlo returns during training.

MC_RETURN_TO_GO_KEY = 'mc_return_to_go'#
FULL_EPISODE_MC_RETURN_KEY = 'full_episode_mc_return'#
class OutputDict[source]#

Bases: TypedDict

mc_return_to_go: ndarray#
full_episode_mc_return: ndarray#
class EpisodeRolloutHookMerged(*episode_rollout_hooks: EpisodeRolloutHookProtocol, check_overlapping_keys: bool = True)[source]#

Bases: EpisodeRolloutHook

Combines multiple episode hooks into a single one.

If all hooks return None, this hook will also return None.

Parameters:
  • episode_rollout_hooks – the hooks to combine

  • check_overlapping_keys – whether to check for overlapping keys in the output of the hooks and raise a KeyError if any are found. Set to False to disable this check (can be useful if this becomes a performance bottleneck).