class RainbowPolicy(*, model: Module, optim: Optimizer, action_space: Discrete, discount_factor: float = 0.99, num_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of Rainbow DQN. arXiv:1710.02298.

Same parameters as C51Policy.

See also

Please refer to C51Policy for more detailed explanation.

learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TRainbowTrainingStats[source]#

Update policy with a given batch of data.


A dataclass object, including the data needed to be logged (e.g., loss).


In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.


If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

class RainbowTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: float)[source]#
loss: float#