base#
Source code: tianshou/policy/imitation/base.py
- class ImitationTrainingStats(*, train_time: float = 0.0, smoothed_loss: dict = <factory>, loss: float = 0.0)[source]#
Bases:
TrainingStats
- loss: float = 0.0#
- class ImitationPolicy(*, actor: Module, optim: Optimizer, action_space: Space, observation_space: Space | None = None, action_scaling: bool = False, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
Bases:
BasePolicy
[TImitationTrainingStats
],Generic
[TImitationTrainingStats
]Implementation of vanilla imitation learning.
- Parameters:
actor – a model following the rules in
BasePolicy
. (s -> a)optim – for optimizing the model.
action_space – Env’s action_space.
observation_space – Env’s observation space.
action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.
action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.
lr_scheduler – if not None, will be called in policy.update().
See also
Please refer to
BasePolicy
for more detailed explanation.Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) ModelOutputBatchProtocol [source]#
Compute action over the given batch data.
- Returns:
A
Batch
which MUST have the following keys:act
a numpy.ndarray or a torch.Tensor, the action over given batch data.state
a dict, a numpy.ndarray or a torch.Tensor, the internal state of the policy,None
as default.
Other keys are user-defined. It depends on the algorithm. For example,
# some code return Batch(logits=..., act=..., state=None, dist=...)
The keyword
policy
is reserved and the corresponding data will be stored into the replay buffer. For instance,# some code return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) # and in the sampled data batch, you can directly use # batch.policy.log_prob to get your data.
Note
In continuous action space, you should do another step “map_action” to get the real action:
act = policy(batch).act # doesn't map to the target action range act = policy.map_action(act, batch)
- learn(batch: RolloutBatchProtocol, *ags: Any, **kwargs: Any) TImitationTrainingStats [source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
Note
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
self.training
andself.updating
. Please refer to States for policy for more detailed explanation.Warning
If you use
torch.distributions.Normal
andtorch.distributions.Categorical
to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.