Source code: tianshou/policy/imitation/discrete_cql.py
- class DiscreteCQLTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: float, cql_loss: float, qr_loss: float)[source]#
- cql_loss: float#
- qr_loss: float#
- class DiscreteCQLPolicy(*, model: Module, optim: Optimizer, action_space: Discrete, min_q_weight: float = 10.0, discount_factor: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
]Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
- Parameters:
model – a model following the rules (s_B -> action_values_BA)
optim – a torch.optim for optimizing the model.
action_space – Env’s action space.
min_q_weight – the weight for the cql loss.
discount_factor – in [0, 1].
num_quantiles – the number of quantile midpoints in the inverse cumulative distribution function of the value.
estimation_step – the number of steps to look ahead.
target_update_freq – the target network update frequency (0 if you do not use the target network).
reward_normalization – normalize the returns to Normal(0, 1). TODO: rename to return_normalization?
is_double – use double dqn.
clip_loss_grad – clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss.
observation_space – Env’s observation space.
lr_scheduler – if not None, will be called in policy.update().
See also
Please refer to
for more detailed explanation.Initializes internal Module state, shared by both nn.Module and ScriptModule.
- learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TDiscreteCQLTrainingStats [source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
. Please refer to States for policy for more detailed explanation.Warning
If you use
to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.