class C51Policy(*, model: Module, optim: Optimizer, action_space: Discrete, discount_factor: float = 0.99, num_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of Categorical Deep Q-Network. arXiv:1707.06887.

  • model – a model following the rules (s_B -> action_values_BA)

  • optim – a torch.optim for optimizing the model.

  • discount_factor – in [0, 1].

  • num_atoms – the number of atoms in the support set of the value distribution. Default to 51.

  • v_min – the value of the smallest atom in the support set. Default to -10.0.

  • v_max – the value of the largest atom in the support set. Default to 10.0.

  • estimation_step – the number of steps to look ahead.

  • target_update_freq – the target network update frequency (0 if you do not use the target network).

  • reward_normalization – normalize the returns to Normal(0, 1). TODO: rename to return_normalization?

  • is_double – use double dqn.

  • clip_loss_grad – clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss.

  • observation_space – Env’s observation space.

  • lr_scheduler – if not None, will be called in policy.update().

See also

Please refer to DQNPolicy for more detailed explanation.

compute_q_value(logits: Tensor, mask: ndarray | None) Tensor[source]#

Compute the q value based on the network’s raw output and action mask.

learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TC51TrainingStats[source]#

Update policy with a given batch of data.


A dataclass object, including the data needed to be logged (e.g., loss).


In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.


If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

class C51TrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: float)[source]#