class GAILPolicy(*, actor: Module | ActorProb | Actor, critic: Module | Critic | Critic, optim: Optimizer, dist_fn: Callable[[tuple[Tensor, Tensor]], Distribution] | Callable[[Tensor], Categorical], action_space: Space, expert_buffer: ReplayBuffer, disc_net: Module, disc_optim: Optimizer, disc_update_num: int = 4, eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, advantage_normalization: bool = True, recompute_advantage: bool = False, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, discount_factor: float = 0.99, reward_normalization: bool = False, deterministic_eval: bool = False, observation_space: Space | None = None, action_scaling: bool = True, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.

  • actor – the actor network following the rules: If self.action_type == “discrete”: (s_B ->`action_values_BA`). If self.action_type == “continuous”: (s_B -> dist_input_BD).

  • critic – the critic network. (s -> V(s))

  • optim – the optimizer for actor and critic network.

  • dist_fn – distribution class for computing the action.

  • action_space – env’s action space

  • expert_buffer – the replay buffer containing expert experience.

  • disc_net – the discriminator network with input dim equals state dim plus action dim and output dim equals 1.

  • disc_optim – the optimizer for the discriminator network.

  • disc_update_num – the number of discriminator grad steps per model grad step.

  • eps_clip\(\epsilon\) in \(L_{CLIP}\) in the original paper.

  • dual_clip – a parameter c mentioned in arXiv:1912.09729 Equ. 5, where c > 1 is a constant indicating the lower bound. Set to None to disable dual-clip PPO.

  • value_clip – a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1.

  • advantage_normalization – whether to do per mini-batch advantage normalization.

  • recompute_advantage – whether to recompute advantage every update repeat according to Sec. 3.5.

  • vf_coef – weight for value loss.

  • ent_coef – weight for entropy loss.

  • max_grad_norm – clipping gradients in back propagation.

  • gae_lambda – in [0, 1], param for Generalized Advantage Estimation.

  • max_batchsize – the maximum size of the batch when computing GAE.

  • discount_factor – in [0, 1].

  • reward_normalization – normalize estimated values to have std close to 1.

  • deterministic_eval – if True, use deterministic evaluation.

  • observation_space – the space of the observation.

  • action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.

  • action_bound_method – method to bound action to range [-1, 1].

  • lr_scheduler – if not None, will be called in policy.update().

See also

Please refer to PPOPolicy for more detailed explanation.

disc(batch: RolloutBatchProtocol) Tensor[source]#
learn(batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, **kwargs: Any) TGailTrainingStats[source]#

Update policy with a given batch of data.


A dataclass object, including the data needed to be logged (e.g., loss).


In order to distinguish the collecting state, updating state and testing state, you can check the policy state by and self.updating. Please refer to States for policy for more detailed explanation.


If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

process_fn(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray) LogpOldProtocol[source]#

Pre-process the data from the provided replay buffer.

Used in update(). Check out policy.process_fn for more information.

class GailTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss:, clip_loss:, vf_loss:, ent_loss:, disc_loss:, acc_pi:, acc_exp:[source]#
acc_exp: SequenceSummaryStats#
acc_pi: SequenceSummaryStats#
disc_loss: SequenceSummaryStats#