bcq#
Source code: tianshou/policy/imitation/bcq.py
- class BCQPolicy(*, actor_perturbation: Module, actor_perturbation_optim: Optimizer, critic: Module, critic_optim: Optimizer, action_space: Space, vae: VAE, vae_optim: Optimizer, critic2: Module | None = None, critic2_optim: Optimizer | None = None, device: str | device = 'cpu', gamma: float = 0.99, tau: float = 0.005, lmbda: float = 0.75, forward_sampled_times: int = 100, num_sampled_action: int = 10, observation_space: Space | None = None, action_scaling: bool = False, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
Implementation of BCQ algorithm. arXiv:1812.02900.
- Parameters:
actor_perturbation – the actor perturbation. (s, a -> perturbed a)
actor_perturbation_optim – the optimizer for actor network.
critic – the first critic network.
critic_optim – the optimizer for the first critic network.
critic2 – the second critic network.
critic2_optim – the optimizer for the second critic network.
vae – the VAE network, generating actions similar to those in batch.
vae_optim – the optimizer for the VAE network.
device – which device to create this model on.
gamma – discount factor, in [0, 1].
tau – param for soft update of the target network.
lmbda – param for Clipped Double Q-learning.
forward_sampled_times – the number of sampled actions in forward function. The policy samples many actions and takes the action with the max value.
num_sampled_action – the number of sampled actions in calculating target Q. The algorithm samples several actions using VAE, and perturbs each action to get the target Q.
observation_space – Env’s observation space.
action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.
action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.
lr_scheduler – if not None, will be called in policy.update().
See also
Please refer to
BasePolicy
for more detailed explanation.- forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) ActBatchProtocol [source]#
Compute action over the given batch data.
- learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TBCQTrainingStats [source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
Note
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
self.training
andself.updating
. Please refer to States for policy for more detailed explanation.Warning
If you use
torch.distributions.Normal
andtorch.distributions.Categorical
to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.