base#


class TrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>)[source]#

Bases: DataclassPPrintMixin

train_time: float = 0.0#

The time for learning models.

smoothed_loss: dict#

The smoothed loss statistics of the policy learn step.

get_loss_stats_dict() dict[str, float][source]#

Return loss statistics as a dict for logging.

Returns a dict with all fields except train_time and smoothed_loss. Moreover, fields with value None excluded, and instances of SequenceSummaryStats are replaced by their mean.

class TrainingStatsWrapper(wrapped_stats: TrainingStats)[source]#

Bases: TrainingStats

In this particular case, super().__init__() should be called LAST in the subclass init.

property wrapped_stats: TrainingStats#
class BasePolicy(*, action_space: Space, observation_space: Space | None = None, action_scaling: bool = False, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Bases: Module, Generic[TTrainingStats], ABC

The base class for any RL policy.

Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All policy classes must inherit from BasePolicy.

A policy class typically has the following parts:

  • __init__(): initialize the policy, including coping the target network and so on;

  • forward(): compute action with given observation;

  • process_fn(): pre-process data from the replay buffer (this function can interact with replay buffer);

  • learn(): update policy with a given batch of data.

  • post_process_fn(): update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);

  • update(): the main interface for training, i.e., process_fn -> learn -> post_process_fn.

Most of the policy needs a neural network to predict the action and an optimizer to optimize the policy. The rules of self-defined networks are:

  1. Input: observation “obs” (may be a numpy.ndarray, a torch.Tensor, a dict or any others), hidden state “state” (for RNN usage), and other information “info” provided by the environment.

  2. Output: some “logits”, the next hidden state “state”, and the intermediate result during policy forwarding procedure “policy”. The “logits” could be a tuple instead of a torch.Tensor. It depends on how the policy process the network output. For example, in PPO, the return of the network might be (mu, sigma), state for Gaussian policy. The “policy” can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in “policy.learn()”, the “batch.policy” is what you need).

Since BasePolicy inherits torch.nn.Module, you can use BasePolicy almost the same as torch.nn.Module, for instance, loading and saving the model:

torch.save(policy.state_dict(), "policy.pth")
policy.load_state_dict(torch.load("policy.pth"))
Parameters:
  • action_space – Env’s action_space.

  • observation_space – Env’s observation space. TODO: appears unused…

  • action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.

  • action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.

  • lr_scheduler – if not None, will be called in policy.update().

Initializes internal Module state, shared by both nn.Module and ScriptModule.

is_within_training_step#

flag indicating whether we are currently within a training step, which encompasses data collection for training (in online RL algorithms) and the policy update (gradient steps).

It can be used, for example, to control whether a flag controlling deterministic evaluation should indeed be applied, because within a training step, we typically always want to apply stochastic evaluation (even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC based algorithms).

This flag should normally remain False and should be set to True only by the algorithm which performs training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, the user should ensure that this flag is set correctly before calling update or learn.

property action_type: Literal['discrete', 'continuous']#
set_agent_id(agent_id: int) None[source]#

Set self.agent_id = agent_id, for MARL.

exploration_noise(act: _TArrOrActBatch, batch: ObsBatchProtocol) _TArrOrActBatch[source]#

Modify the action from policy.forward with exploration noise.

NOTE: currently does not add any noise! Needs to be overridden by subclasses to actually do something.

Parameters:
  • act – a data batch or numpy.ndarray which is the action taken by policy.forward.

  • batch – the input batch for policy.forward, kept for advanced usage.

Returns:

action in the same form of input “act” but with added exploration noise.

soft_update(tgt: Module, src: Module, tau: float) None[source]#

Softly update the parameters of target module towards the parameters of source module.

compute_action(obs: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], info: dict[str, Any] | None = None, state: dict | BatchProtocol | ndarray | None = None) ndarray | int[source]#

Get action as int (for discrete env’s) or array (for continuous ones) from an env’s observation and info.

Parameters:
  • obs – observation from the gym’s env.

  • info – information given by the gym’s env.

  • state – the hidden state of RNN policy, used for recurrent policy.

Returns:

action as int (for discrete env’s) or array (for continuous ones).

abstract forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) ActBatchProtocol | ActStateBatchProtocol[source]#

Compute action over the given batch data.

Returns:

A Batch which MUST have the following keys:

  • act a numpy.ndarray or a torch.Tensor, the action over given batch data.

  • state a dict, a numpy.ndarray or a torch.Tensor, the internal state of the policy, None as default.

Other keys are user-defined. It depends on the algorithm. For example,

# some code
return Batch(logits=..., act=..., state=None, dist=...)

The keyword policy is reserved and the corresponding data will be stored into the replay buffer. For instance,

# some code
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data.

Note

In continuous action space, you should do another step “map_action” to get the real action:

act = policy(batch).act  # doesn't map to the target action range
act = policy.map_action(act, batch)
map_action(act: Tensor | ndarray) ndarray[source]#

Map raw network output to action range in gym’s env.action_space.

This function is called in collect() and only affects action sending to env. Remapped action will not be stored in buffer and thus can be viewed as a part of env (a black box action transformation).

Action mapping includes 2 standard procedures: bounding and scaling. Bounding procedure expects original action range is (-inf, inf) and maps it to [-1, 1], while scaling procedure expects original action range is (-1, 1) and maps it to [action_space.low, action_space.high]. Bounding procedure is applied first.

Parameters:

act – a data batch or numpy.ndarray which is the action taken by policy.forward.

Returns:

action in the same form of input “act” but remap to the target action space.

map_action_inverse(act: Tensor | ndarray) ndarray[source]#

Inverse operation to map_action().

This function is called in collect() for random initial steps. It scales [action_space.low, action_space.high] to the value ranges of policy.forward.

Parameters:

act – a data batch, list or numpy.ndarray which is the action taken by gym.spaces.Box.sample().

Returns:

action remapped.

process_buffer(buffer: TBuffer) TBuffer[source]#

Pre-process the replay buffer, e.g., to add new keys.

Used in BaseTrainer initialization method, usually used by offline trainers.

Note: this will only be called once, when the trainer is initialized!

If the buffer is empty by then, there will be nothing to process. This method is meant to be overridden by policies which will be trained offline at some stage, e.g., in a pre-training step.

process_fn(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray) RolloutBatchProtocol[source]#

Pre-process the data from the provided replay buffer.

Meant to be overridden by subclasses. Typical usage is to add new keys to the batch, e.g., to add the value function of the next state. Used in update(), which is usually called repeatedly during training.

For modifying the replay buffer only once at the beginning (e.g., for offline learning) see process_buffer().

abstract learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TTrainingStats[source]#

Update policy with a given batch of data.

Returns:

A dataclass object, including the data needed to be logged (e.g., loss).

Note

In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.

Warning

If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

post_process_fn(batch: BatchProtocol, buffer: ReplayBuffer, indices: ndarray) None[source]#

Post-process the data from the provided replay buffer.

This will only have an effect if the buffer has the method update_weight and the batch has the attribute weight.

Typical usage is to update the sampling weight in prioritized experience replay. Used in update().

update(sample_size: int | None, buffer: ReplayBuffer | None, **kwargs: Any) TTrainingStats[source]#

Update the policy network and replay buffer.

It includes 3 function steps: process_fn, learn, and post_process_fn. In addition, this function will change the value of self.updating: it will be False before this function and will be True when executing update(). Please refer to States for policy for more detailed explanation. The return value of learn is augmented with the training time within update, while smoothed loss values are computed in the trainer.

Parameters:
  • sample_size – 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. None also means it will extract all the data from the buffer, but it will be shuffled first. TODO: remove the option for 0?

  • buffer – the corresponding replay buffer.

Returns:

A dataclass object containing the data needed to be logged (e.g., loss) from policy.learn().

static value_mask(buffer: ReplayBuffer, indices: ndarray) ndarray[source]#

Value mask determines whether the obs_next of buffer[indices] is valid.

For instance, usually “obs_next” after “done” flag is considered to be invalid, and its q/advantage value can provide meaningless (even misleading) information, and should be set to 0 by hand. But if “done” flag is generated because timelimit of game length (info[“TimeLimit.truncated”] is set to True in gym’s settings), “obs_next” will instead be valid. Value mask is typically used for assisting in calculating the correct q/advantage value.

Parameters:
  • buffer – the corresponding replay buffer.

  • indices (numpy.ndarray) – indices of replay buffer whose “obs_next” will be judged.

Returns:

A bool type numpy.ndarray in the same shape with indices. “True” means “obs_next” of that buffer[indices] is valid.

static compute_episodic_return(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray, v_s_: ndarray | Tensor | None = None, v_s: ndarray | Tensor | None = None, gamma: float = 0.99, gae_lambda: float = 0.95) tuple[ndarray, ndarray][source]#

Compute returns over given batch.

Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q/advantage value of given batch. Returns are calculated as advantage + value, which is exactly equivalent to using \(TD(\lambda)\) for estimating returns.

Setting v_s_ and v_s to None (or all zeros) and gae_lambda to 1.0 calculates the discounted return-to-go/ Monte-Carlo return.

Parameters:
  • batch – a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recognized by buffer.unfinished_index().

  • buffer – the corresponding replay buffer.

  • indices – tells the batch’s location in buffer, batch is equal to buffer[indices].

  • v_s – the value function of all next states \(V(s')\). If None, it will be set to an array of 0.

  • v_s – the value function of all current states \(V(s)\). If None, it is set based upon v_s_ rolled by 1.

  • gamma – the discount factor, should be in [0, 1].

  • gae_lambda – the parameter for Generalized Advantage Estimation, should be in [0, 1].

Returns:

two numpy arrays (returns, advantage) with each shape (bsz, ).

static compute_nstep_return(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray, target_q_fn: Callable[[ReplayBuffer, ndarray], Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False) BatchWithReturnsProtocol[source]#

Compute n-step return for Q-learning targets.

\[G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})\]

where \(\gamma\) is the discount factor, \(\gamma \in [0, 1]\), \(d_t\) is the done flag of step \(t\).

Parameters:
  • batch – a data batch, which is equal to buffer[indices].

  • buffer – the data buffer.

  • indices – tell batch’s location in buffer

  • target_q_fn (function) – a function which compute target Q value of “obs_next” given data buffer and wanted indices.

  • gamma – the discount factor, should be in [0, 1].

  • n_step – the number of estimation step, should be an int greater than 0.

  • rew_norm – normalize the reward to Normal(0, 1). TODO: passing True is not supported and will cause an error!

Returns:

a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn’s return tensor.

class RandomActionPolicy(action_space: Space)[source]#

Bases: BasePolicy

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) ActStateBatchProtocol[source]#

Compute action over the given batch data.

Returns:

A Batch which MUST have the following keys:

  • act a numpy.ndarray or a torch.Tensor, the action over given batch data.

  • state a dict, a numpy.ndarray or a torch.Tensor, the internal state of the policy, None as default.

Other keys are user-defined. It depends on the algorithm. For example,

# some code
return Batch(logits=..., act=..., state=None, dist=...)

The keyword policy is reserved and the corresponding data will be stored into the replay buffer. For instance,

# some code
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data.

Note

In continuous action space, you should do another step “map_action” to get the real action:

act = policy(batch).act  # doesn't map to the target action range
act = policy.map_action(act, batch)
learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TrainingStats[source]#

Update policy with a given batch of data.

Returns:

A dataclass object, including the data needed to be logged (e.g., loss).

Note

In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.

Warning

If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

episode_mc_return_to_go(rewards: ndarray, gamma: float = 0.99) ndarray[source]#

Calculates discounted monte-carlo returns to go from rewards of a single episode.

Parameters:
  • rewards – rewards of a single episode. Assumed to be a 1-dim array from reset till the end of the episode.

  • gamma – discount factor

Returns:

a numpy array of shape (len(rewards), ).