icm#
Source code: tianshou/policy/modelbased/icm.py
- class ICMTrainingStats(wrapped_stats: TrainingStats, *, icm_loss: float, icm_forward_loss: float, icm_inverse_loss: float)[source]#
Bases:
TrainingStatsWrapperIn this particular case, super().__init__() should be called LAST in the subclass init.
- class ICMPolicy(*, policy: BasePolicy[TTrainingStats], model: IntrinsicCuriosityModule, optim: Optimizer, lr_scale: float, reward_scale: float, forward_loss_weight: float, action_space: Space, observation_space: Space | None = None, action_scaling: bool = False, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
Bases:
BasePolicy[ICMTrainingStats]Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
- Parameters:
policy – a base policy to add ICM to.
model – the ICM model.
optim – a torch.optim for optimizing the model.
lr_scale – the scaling factor for ICM learning.
forward_loss_weight – the weight for forward model loss.
observation_space – Env’s observation space.
action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.
action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.
lr_scheduler – if not None, will be called in policy.update().
See also
Please refer to
BasePolicyfor more detailed explanation.Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) ActBatchProtocol[source]#
Compute action over the given batch data by inner policy.
See also
Please refer to
forward()for more detailed explanation.
- exploration_noise(act: _TArrOrActBatch, batch: ObsBatchProtocol) _TArrOrActBatch[source]#
Modify the action from policy.forward with exploration noise.
NOTE: currently does not add any noise! Needs to be overridden by subclasses to actually do something.
- Parameters:
act – a data batch or numpy.ndarray which is the action taken by policy.forward.
batch – the input batch for policy.forward, kept for advanced usage.
- Returns:
action in the same form of input “act” but with added exploration noise.
- process_fn(batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: ndarray) RolloutBatchProtocol[source]#
Pre-process the data from the provided replay buffer.
Used in
update(). Check out policy.process_fn for more information.
- post_process_fn(batch: BatchProtocol, buffer: ReplayBuffer, indices: ndarray) None[source]#
Post-process the data from the provided replay buffer.
Typical usage is to update the sampling weight in prioritized experience replay. Used in
update().
- learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) ICMTrainingStats[source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
Note
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
self.trainingandself.updating. Please refer to States for policy for more detailed explanation.Warning
If you use
torch.distributions.Normalandtorch.distributions.Categoricalto calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.