Source code: tianshou/policy/modelfree/
- class TRPOTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, actor_loss:, vf_loss:, kl:, step_size:[source]#
- step_size: SequenceSummaryStats#
- class TRPOPolicy(*, actor: Module | ActorProb | Actor, critic: Module | Critic | Critic, optim: Optimizer, dist_fn: Callable[[tuple[Tensor, Tensor]], Distribution] | Callable[[Tensor], Categorical], action_space: Space, max_kl: float = 0.01, backtrack_coeff: float = 0.8, max_backtracks: int = 10, optim_critic_iters: int = 5, actor_step_size: float = 0.5, advantage_normalization: bool = True, gae_lambda: float = 0.95, max_batchsize: int = 256, discount_factor: float = 0.99, reward_normalization: bool = False, deterministic_eval: bool = False, observation_space: Space | None = None, action_scaling: bool = True, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
]Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
- Parameters:
actor – the actor network following the rules: If self.action_type == “discrete”: (s_B ->`action_values_BA`). If self.action_type == “continuous”: (s_B -> dist_input_BD).
critic – the critic network. (s -> V(s))
optim – the optimizer for actor and critic network.
dist_fn – distribution class for computing the action.
action_space – env’s action space
max_kl – max kl-divergence used to constrain each actor network update.
backtrack_coeff – Coefficient to be multiplied by step size when constraints are not met.
max_backtracks – Max number of backtracking times in linesearch.
optim_critic_iters – Number of times to optimize critic network per update.
actor_step_size – step size for actor update in natural gradient direction.
advantage_normalization – whether to do per mini-batch advantage normalization.
gae_lambda – in [0, 1], param for Generalized Advantage Estimation.
max_batchsize – the maximum size of the batch when computing GAE.
discount_factor – in [0, 1].
reward_normalization – normalize estimated values to have std close to 1.
deterministic_eval – if True, use deterministic evaluation.
observation_space – the space of the observation.
action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.
action_bound_method – method to bound action to range [-1, 1].
lr_scheduler – if not None, will be called in policy.update().
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- learn(batch: Batch, batch_size: int | None, repeat: int, **kwargs: Any) TTRPOTrainingStats [source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
. Please refer to States for policy for more detailed explanation.Warning
If you use
to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.