discrete_cql#
Source code: tianshou/policy/imitation/discrete_cql.py
- class DiscreteCQLTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: float, cql_loss: float, qr_loss: float)[source]#
Bases:
QRDQNTrainingStats
- cql_loss: float#
- qr_loss: float#
- class DiscreteCQLPolicy(*, model: Module, optim: Optimizer, action_space: Discrete, min_q_weight: float = 10.0, discount_factor: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
Bases:
QRDQNPolicy
[TDiscreteCQLTrainingStats
]Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
- Parameters:
model – a model following the rules (s_B -> action_values_BA)
optim – a torch.optim for optimizing the model.
action_space – Env’s action space.
min_q_weight – the weight for the cql loss.
discount_factor – in [0, 1].
num_quantiles – the number of quantile midpoints in the inverse cumulative distribution function of the value.
estimation_step – the number of steps to look ahead.
target_update_freq – the target network update frequency (0 if you do not use the target network).
reward_normalization – normalize the returns to Normal(0, 1). TODO: rename to return_normalization?
is_double – use double dqn.
clip_loss_grad – clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss.
observation_space – Env’s observation space.
lr_scheduler – if not None, will be called in policy.update().
See also
Please refer to
QRDQNPolicy
for more detailed explanation.Initializes internal Module state, shared by both nn.Module and ScriptModule.
- learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TDiscreteCQLTrainingStats [source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
Note
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
self.training
andself.updating
. Please refer to States for policy for more detailed explanation.Warning
If you use
torch.distributions.Normal
andtorch.distributions.Categorical
to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.