class CQLPolicy(*, actor: ActorProb, actor_optim: Optimizer, critic: Module, critic_optim: Optimizer, action_space: Box, critic2: Module | None = None, critic2_optim: Optimizer | None = None, cql_alpha_lr: float = 0.0001, cql_weight: float = 1.0, tau: float = 0.005, gamma: float = 0.99, alpha: float | tuple[float, Tensor, Optimizer] = 0.2, temperature: float = 1.0, with_lagrange: bool = True, lagrange_threshold: float = 10.0, min_action: float = -1.0, max_action: float = 1.0, num_repeat_actions: int = 10, alpha_min: float = 0.0, alpha_max: float = 1000000.0, clip_grad: float = 1.0, calibrated: bool = True, device: str | device = 'cpu', estimation_step: int = 1, exploration_noise: BaseNoise | Literal['default'] | None = None, deterministic_eval: bool = True, action_scaling: bool = True, action_bound_method: Literal['clip'] | None = 'clip', observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Implementation of CQL algorithm. arXiv:2006.04779.

  • actor – the actor network following the rules in BasePolicy. (s -> a)

  • actor_optim – The optimizer for actor network.

  • critic – The first critic network.

  • critic_optim – The optimizer for the first critic network.

  • action_space – Env’s action space.

  • critic2 – the second critic network. (s, a -> Q(s, a)). If None, use the same network as critic (via deepcopy).

  • critic2_optim – the optimizer for the second critic network. If None, clone critic_optim to use for critic2.parameters().

  • cql_alpha_lr – The learning rate of cql_log_alpha.

  • cql_weight

  • tau – Parameter for soft update of the target network.

  • gamma – Discount factor, in [0, 1].

  • alpha – Entropy regularization coefficient or a tuple (target_entropy, log_alpha, alpha_optim) for automatic tuning.

  • temperature

  • with_lagrange – Whether to use Lagrange. TODO: extend documentation - what does this mean?

  • lagrange_threshold – The value of tau in CQL(Lagrange).

  • min_action – The minimum value of each dimension of action.

  • max_action – The maximum value of each dimension of action.

  • num_repeat_actions – The number of times the action is repeated when calculating log-sum-exp.

  • alpha_min – Lower bound for clipping cql_alpha.

  • alpha_max – Upper bound for clipping cql_alpha.

  • clip_grad – Clip_grad for updating critic network.

  • calibrated – calibrate Q-values as in CalQL paper arXiv:2303.05479. Useful for offline pre-training followed by online training, and also was observed to achieve better results than vanilla cql.

  • device – Which device to create this model on.

  • estimation_step – Estimation steps.

  • exploration_noise – Type of exploration noise.

  • deterministic_eval – Flag for deterministic evaluation.

  • action_scaling – Flag for action scaling.

  • action_bound_method – Method for action bounding. Only used if the action_space is continuous.

  • observation_space – Env’s Observation space.

  • lr_scheduler – a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update().

See also

Please refer to BasePolicy for more detailed explanation.

actor_pred(obs: Tensor) tuple[Tensor, Tensor][source]#
calc_actor_loss(obs: Tensor) tuple[Tensor, Tensor][source]#
calc_pi_values(obs_pi: Tensor, obs_to_pred: Tensor) tuple[Tensor, Tensor][source]#
calc_random_values(obs: Tensor, act: Tensor) tuple[Tensor, Tensor][source]#
learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TCQLTrainingStats[source]#

Update policy with a given batch of data.


A dataclass object, including the data needed to be logged (e.g., loss).


In order to distinguish the collecting state, updating state and testing state, you can check the policy state by self.training and self.updating. Please refer to States for policy for more detailed explanation.


If you use torch.distributions.Normal and torch.distributions.Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.

process_buffer(buffer: TBuffer) TBuffer[source]#

If self.calibrated = True, adds calibration_returns to buffer._meta.




process_fn(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray) RolloutBatchProtocol[source]#

Pre-process the data from the provided replay buffer.

Meant to be overridden by subclasses. Typical usage is to add new keys to the batch, e.g., to add the value function of the next state. Used in update(), which is usually called repeatedly during training.

For modifying the replay buffer only once at the beginning (e.g., for offline learning) see process_buffer().

sync_weight() None[source]#

Soft-update the weight for the target network.

train(mode: bool = True) Self[source]#

Set the module in training mode, except for the target network.

class CQLTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, actor_loss: float, critic1_loss: float, critic2_loss: float, alpha: float | None = None, alpha_loss: float | None = None, cql_alpha: float | None = None, cql_alpha_loss: float | None = None)[source]#

A data structure for storing loss statistics of the CQL learn step.

cql_alpha: float | None = None#
cql_alpha_loss: float | None = None#