import time
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, cast
import gymnasium as gym
import numpy as np
import torch
from tianshou.data import (
Batch,
CachedReplayBuffer,
PrioritizedReplayBuffer,
ReplayBuffer,
ReplayBufferManager,
SequenceSummaryStats,
VectorReplayBuffer,
to_numpy,
)
from tianshou.data.batch import alloc_by_keys_diff
from tianshou.data.types import RolloutBatchProtocol
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy
from tianshou.utils.print import DataclassPPrintMixin
[docs]
@dataclass(kw_only=True)
class CollectStatsBase(DataclassPPrintMixin):
"""The most basic stats, often used for offline learning."""
n_collected_episodes: int = 0
"""The number of collected episodes."""
n_collected_steps: int = 0
"""The number of collected steps."""
[docs]
@dataclass(kw_only=True)
class CollectStats(CollectStatsBase):
"""A data structure for storing the statistics of rollouts."""
collect_time: float = 0.0
"""The time for collecting transitions."""
collect_speed: float = 0.0
"""The speed of collecting (env_step per second)."""
returns: np.ndarray
"""The collected episode returns."""
returns_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step
"""Stats of the collected returns."""
lens: np.ndarray
"""The collected episode lengths."""
lens_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step
"""Stats of the collected episode lengths."""
[docs]
class Collector:
"""Collector enables the policy to interact with different types of envs with exact number of steps or episodes.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
If set to None, it will not store the data. Default to None.
:param function preprocess_fn: a function called before the data has been added to
the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
:param exploration_noise: determine whether the action needs to be modified
with corresponding policy's exploration noise. If so, "policy.
exploration_noise(act, batch)" will be called automatically to add the
exploration noise into action. Default to False.
The "preprocess_fn" is a function called before the data has been added to the
buffer with batch format. It will receive only "obs" and "env_id" when the
collector resets the environment, and will receive the keys "obs_next", "rew",
"terminated", "truncated, "info", "policy" and "env_id" in a normal env step.
Alternatively, it may also accept the keys "obs_next", "rew", "done", "info",
"policy" and "env_id".
It returns either a dict or a :class:`~tianshou.data.Batch` with the modified
keys and values. Examples are in "test/base/test_collector.py".
.. note::
Please make sure the given environment has a time limitation if using n_episode
collect option.
.. note::
In past versions of Tianshou, the replay buffer that was passed to `__init__`
was automatically reset. This is not done in the current implementation.
"""
def __init__(
self,
policy: BasePolicy,
env: gym.Env | BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
# Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
self.env = DummyVectorEnv([lambda: env])
else:
self.env = env # type: ignore
self.env_num = len(self.env)
self.exploration_noise = exploration_noise
self.buffer: ReplayBuffer
self._assign_buffer(buffer)
self.policy = policy
self.preprocess_fn = preprocess_fn
self._action_space = self.env.action_space
self.data: RolloutBatchProtocol
# avoid creating attribute outside __init__
self.reset(False)
def _assign_buffer(self, buffer: ReplayBuffer | None) -> None:
"""Check if the buffer matches the constraint."""
if buffer is None:
buffer = VectorReplayBuffer(self.env_num, self.env_num)
elif isinstance(buffer, ReplayBufferManager):
assert buffer.buffer_num >= self.env_num
if isinstance(buffer, CachedReplayBuffer):
assert buffer.cached_buffer_num >= self.env_num
else: # ReplayBuffer or PrioritizedReplayBuffer
assert buffer.maxsize > 0
if self.env_num > 1:
if isinstance(buffer, ReplayBuffer):
buffer_type = "ReplayBuffer"
vector_type = "VectorReplayBuffer"
if isinstance(buffer, PrioritizedReplayBuffer):
buffer_type = "PrioritizedReplayBuffer"
vector_type = "PrioritizedVectorReplayBuffer"
raise TypeError(
f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.",
)
self.buffer = buffer
[docs]
def reset(
self,
reset_buffer: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> None:
"""Reset the environment, statistics, current data and possibly replay memory.
:param reset_buffer: if true, reset the replay buffer that is attached
to the collector.
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
"""
# use empty Batch for "state" so that self.data supports slicing
# convert empty Batch to None when passing data to policy
data = Batch(
obs={},
act={},
rew={},
terminated={},
truncated={},
done={},
obs_next={},
info={},
policy={},
)
self.data = cast(RolloutBatchProtocol, data)
self.reset_env(gym_reset_kwargs)
if reset_buffer:
self.reset_buffer()
self.reset_stat()
[docs]
def reset_stat(self) -> None:
"""Reset the statistic variables."""
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
[docs]
def reset_buffer(self, keep_statistics: bool = False) -> None:
"""Reset the data buffer."""
self.buffer.reset(keep_statistics=keep_statistics)
[docs]
def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None:
"""Reset all of the environments."""
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
obs, info = self.env.reset(**gym_reset_kwargs)
if self.preprocess_fn:
processed_data = self.preprocess_fn(obs=obs, info=info, env_id=np.arange(self.env_num))
obs = processed_data.get("obs", obs)
info = processed_data.get("info", info)
self.data.info = info # type: ignore
self.data.obs = obs
def _reset_state(self, id: int | list[int]) -> None:
"""Reset the hidden state: self.data.state[id]."""
if hasattr(self.data.policy, "hidden_state"):
state = self.data.policy.hidden_state # it is a reference
if isinstance(state, torch.Tensor):
state[id].zero_()
elif isinstance(state, np.ndarray):
state[id] = None if state.dtype == object else 0
elif isinstance(state, Batch):
state.empty_(id)
def _reset_env_with_ids(
self,
local_ids: list[int] | np.ndarray,
global_ids: list[int] | np.ndarray,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> None:
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs)
if self.preprocess_fn:
processed_data = self.preprocess_fn(obs=obs_reset, info=info, env_id=global_ids)
obs_reset = processed_data.get("obs", obs_reset)
info = processed_data.get("info", info)
self.data.info[local_ids] = info # type: ignore
self.data.obs_next[local_ids] = obs_reset # type: ignore
[docs]
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
"""Collect a specified number of step or episode.
To ensure unbiased sampling result with n_episode option, this function will
first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
episodes, they will be collected evenly from each env.
:param n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data. Default
to False.
:param render: the sleep time between rendering consecutive frames.
Default to None (no rendering).
:param no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: A dataclass object
"""
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None:
assert n_episode is None, (
f"Only one of n_step or n_episode is allowed in Collector."
f"collect, got n_step={n_step}, n_episode={n_episode}."
)
assert n_step > 0
if n_step % self.env_num != 0:
warnings.warn(
f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
"which may cause extra transitions collected into the buffer.",
)
ready_env_ids = np.arange(self.env_num)
elif n_episode is not None:
assert n_episode > 0
ready_env_ids = np.arange(min(self.env_num, n_episode))
self.data = self.data[: min(self.env_num, n_episode)]
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect().",
)
start_time = time.time()
step_count = 0
episode_count = 0
episode_returns: list[float] = []
episode_lens: list[int] = []
episode_start_indices: list[int] = []
while True:
assert len(self.data) == len(ready_env_ids)
# restore the state: if the last state is None, it won't store
last_state = self.data.policy.pop("hidden_state", None)
# get the next action
if random:
try:
act_sample = [self._action_space[i].sample() for i in ready_env_ids]
except TypeError: # envpool's action space is not for per-env
act_sample = [self._action_space.sample() for _ in ready_env_ids]
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
self.data.update(act=act_sample)
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
# update state / act / policy into self.data
policy = result.get("policy", Batch())
assert isinstance(policy, Batch)
state = result.get("state", None)
if state is not None:
policy.hidden_state = state # save state into buffer
act = to_numpy(result.act)
if self.exploration_noise:
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, terminated, truncated, info = self.env.step(
action_remap,
ready_env_ids,
)
done = np.logical_or(terminated, truncated)
self.data.update(
obs_next=obs_next,
rew=rew,
terminated=terminated,
truncated=truncated,
done=done,
info=info,
)
if self.preprocess_fn:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
policy=self.data.policy,
env_id=ready_env_ids,
act=self.data.act,
),
)
if render:
self.env.render()
if render > 0 and not np.isclose(render, 0):
time.sleep(render)
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids)
# collect statistics
step_count += len(ready_env_ids)
if np.any(done):
env_ind_local = np.where(done)[0]
env_ind_global = ready_env_ids[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.extend(ep_len[env_ind_local])
episode_returns.extend(ep_rew[env_ind_local])
episode_start_indices.extend(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs)
for i in env_ind_local:
self._reset_state(i)
# remove surplus env id from ready_env_ids
# to avoid bias in selecting environments
if n_episode:
surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
if surplus_env_num > 0:
mask = np.ones_like(ready_env_ids, dtype=bool)
mask[env_ind_local[:surplus_env_num]] = False
ready_env_ids = ready_env_ids[mask]
self.data = self.data[mask]
self.data.obs = self.data.obs_next
if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode):
break
# generate statistics
self.collect_step += step_count
self.collect_episode += episode_count
collect_time = max(time.time() - start_time, 1e-9)
self.collect_time += collect_time
if n_episode:
data = Batch(
obs={},
act={},
rew={},
terminated={},
truncated={},
done={},
obs_next={},
info={},
policy={},
)
self.data = cast(RolloutBatchProtocol, data)
self.reset_env()
return CollectStats(
n_collected_episodes=episode_count,
n_collected_steps=step_count,
collect_time=collect_time,
collect_speed=step_count / collect_time,
returns=np.array(episode_returns),
returns_stat=SequenceSummaryStats.from_sequence(episode_returns)
if len(episode_returns) > 0
else None,
lens=np.array(episode_lens, int),
lens_stat=SequenceSummaryStats.from_sequence(episode_lens)
if len(episode_lens) > 0
else None,
)
[docs]
class AsyncCollector(Collector):
"""Async Collector handles async vector environment.
The arguments are exactly the same as :class:`~tianshou.data.Collector`, please
refer to :class:`~tianshou.data.Collector` for more detailed explanation.
"""
def __init__(
self,
policy: BasePolicy,
env: BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
# assert env.is_async
warnings.warn("Using async setting may collect extra transitions into buffer.")
super().__init__(
policy,
env,
buffer,
preprocess_fn,
exploration_noise,
)
[docs]
def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None:
super().reset_env(gym_reset_kwargs)
self._ready_env_ids = np.arange(self.env_num)
[docs]
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
"""Collect a specified number of step or episode with async env setting.
This function doesn't collect exactly n_step or n_episode number of
transitions. Instead, in order to support async setting, it may collect more
than given n_step or n_episode transitions and save into buffer.
:param n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data. Default
to False.
:param render: the sleep time between rendering consecutive frames.
Default to None (no rendering).
:param no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: A dataclass object
"""
# collect at least n_step or n_episode
if n_step is not None:
assert n_episode is None, (
"Only one of n_step or n_episode is allowed in Collector."
f"collect, got n_step={n_step}, n_episode={n_episode}."
)
assert n_step > 0
elif n_episode is not None:
assert n_episode > 0
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect().",
)
ready_env_ids = self._ready_env_ids
start_time = time.time()
step_count = 0
episode_count = 0
episode_returns: list[float] = []
episode_lens: list[int] = []
episode_start_indices: list[int] = []
while True:
whole_data = self.data
self.data = self.data[ready_env_ids]
assert len(whole_data) == self.env_num # major difference
# restore the state: if the last state is None, it won't store
last_state = self.data.policy.pop("hidden_state", None)
# get the next action
if random:
try:
act_sample = [self._action_space[i].sample() for i in ready_env_ids]
except TypeError: # envpool's action space is not for per-env
act_sample = [self._action_space.sample() for _ in ready_env_ids]
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
self.data.update(act=act_sample)
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
# update state / act / policy into self.data
policy = result.get("policy", Batch())
assert isinstance(policy, Batch)
state = result.get("state", None)
if state is not None:
policy.hidden_state = state # save state into buffer
act = to_numpy(result.act)
if self.exploration_noise:
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)
# save act/policy before env.step
try:
whole_data.act[ready_env_ids] = self.data.act # type: ignore
whole_data.policy[ready_env_ids] = self.data.policy
except ValueError:
alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
whole_data[ready_env_ids] = self.data # lots of overhead
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, terminated, truncated, info = self.env.step(
action_remap,
ready_env_ids,
)
done = np.logical_or(terminated, truncated)
# change self.data here because ready_env_ids has changed
try:
ready_env_ids = info["env_id"]
except Exception:
ready_env_ids = np.array([i["env_id"] for i in info])
self.data = whole_data[ready_env_ids]
self.data.update(
obs_next=obs_next,
rew=rew,
terminated=terminated,
truncated=truncated,
info=info,
)
if self.preprocess_fn:
try:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
terminated=self.data.terminated,
truncated=self.data.truncated,
info=self.data.info,
env_id=ready_env_ids,
act=self.data.act,
),
)
except TypeError:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
env_id=ready_env_ids,
act=self.data.act,
),
)
if render:
self.env.render()
if render > 0 and not np.isclose(render, 0):
time.sleep(render)
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids)
# collect statistics
step_count += len(ready_env_ids)
if np.any(done):
env_ind_local = np.where(done)[0]
env_ind_global = ready_env_ids[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.extend(ep_len[env_ind_local])
episode_returns.extend(ep_rew[env_ind_local])
episode_start_indices.extend(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs)
for i in env_ind_local:
self._reset_state(i)
try:
# Need to ignore types b/c according to mypy Tensors cannot be indexed
# by arrays (which they can...)
whole_data.obs[ready_env_ids] = self.data.obs_next # type: ignore
whole_data.rew[ready_env_ids] = self.data.rew
whole_data.done[ready_env_ids] = self.data.done
whole_data.info[ready_env_ids] = self.data.info # type: ignore
except ValueError:
alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
self.data.obs = self.data.obs_next
# lots of overhead
whole_data[ready_env_ids] = self.data
self.data = whole_data
if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode):
break
self._ready_env_ids = ready_env_ids
# generate statistics
self.collect_step += step_count
self.collect_episode += episode_count
collect_time = max(time.time() - start_time, 1e-9)
self.collect_time += collect_time
return CollectStats(
n_collected_episodes=episode_count,
n_collected_steps=step_count,
collect_time=collect_time,
collect_speed=step_count / collect_time,
returns=np.array(episode_returns),
returns_stat=SequenceSummaryStats.from_sequence(episode_returns)
if len(episode_returns) > 0
else None,
lens=np.array(episode_lens, int),
lens_stat=SequenceSummaryStats.from_sequence(episode_lens)
if len(episode_lens) > 0
else None,
)