Source code for tianshou.policy.modelfree.ppo

from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeVar

import gymnasium as gym
import numpy as np
import torch
from torch import nn

from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils.net.common import ActorCritic


[docs] @dataclass(kw_only=True) class PPOTrainingStats(TrainingStats): loss: SequenceSummaryStats clip_loss: SequenceSummaryStats vf_loss: SequenceSummaryStats ent_loss: SequenceSummaryStats
TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure.
[docs] class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. :param actor: the actor network following the rules in BasePolicy. (s -> logits) :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. :param action_space: env's action space :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original paper. :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, where c > 1 is a constant indicating the lower bound. Set to None to disable dual-clip PPO. :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param recompute_advantage: whether to recompute advantage every update repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. :param max_grad_norm: clipping gradients in back propagation. :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. :param deterministic_eval: if True, use deterministic evaluation. :param observation_space: the space of the observation. :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. :param lr_scheduler: if not None, will be called in `policy.update()`. .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ def __init__( self, *, actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: TDistributionFunction, action_space: gym.Space, eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, advantage_normalization: bool = True, recompute_advantage: bool = False, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, deterministic_eval: bool = False, observation_space: gym.Space | None = None, action_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: assert ( dual_clip is None or dual_clip > 1.0 ), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" super().__init__( actor=actor, critic=critic, optim=optim, dist_fn=dist_fn, action_space=action_space, vf_coef=vf_coef, ent_coef=ent_coef, max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, deterministic_eval=deterministic_eval, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) self.eps_clip = eps_clip self.dual_clip = dual_clip self.value_clip = value_clip self.norm_adv = advantage_normalization self.recompute_adv = recompute_advantage self._actor_critic: ActorCritic
[docs] def process_fn( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> LogpOldProtocol: if self.recompute_adv: # buffer input `buffer` and `indices` to be used in `learn()`. self._buffer, self._indices = buffer, indices batch = self._compute_returns(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) with torch.no_grad(): batch.logp_old = self(batch).dist.log_prob(batch.act) batch: LogpOldProtocol return batch
# TODO: why does mypy complain?
[docs] def learn( # type: ignore self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, *args: Any, **kwargs: Any, ) -> TPPOTrainingStats: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for step in range(repeat): if self.recompute_adv and step > 0: batch = self._compute_returns(batch, self._buffer, self._indices) for minibatch in batch.split(split_batch_size, merge_last=True): # calculate loss for actor dist = self(minibatch).dist if self.norm_adv: mean, std = minibatch.adv.mean(), minibatch.adv.std() minibatch.adv = (minibatch.adv - mean) / (std + self._eps) # per-batch norm ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * minibatch.adv surr2 = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * minibatch.adv if self.dual_clip: clip1 = torch.min(surr1, surr2) clip2 = torch.max(clip1, self.dual_clip * minibatch.adv) clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean() else: clip_loss = -torch.min(surr1, surr2).mean() # calculate loss for critic value = self.critic(minibatch.obs).flatten() if self.value_clip: v_clip = minibatch.v_s + (value - minibatch.v_s).clamp( -self.eps_clip, self.eps_clip, ) vf1 = (minibatch.returns - value).pow(2) vf2 = (minibatch.returns - v_clip).pow(2) vf_loss = torch.max(vf1, vf2).mean() else: vf_loss = (minibatch.returns - value).pow(2).mean() # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss self.optim.zero_grad() loss.backward() if self.max_grad_norm: # clip large gradient nn.utils.clip_grad_norm_( self._actor_critic.parameters(), max_norm=self.max_grad_norm, ) self.optim.step() clip_losses.append(clip_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) losses_summary = SequenceSummaryStats.from_sequence(losses) clip_losses_summary = SequenceSummaryStats.from_sequence(clip_losses) vf_losses_summary = SequenceSummaryStats.from_sequence(vf_losses) ent_losses_summary = SequenceSummaryStats.from_sequence(ent_losses) return PPOTrainingStats( # type: ignore[return-value] loss=losses_summary, clip_loss=clip_losses_summary, vf_loss=vf_losses_summary, ent_loss=ent_losses_summary, )