Source code for tianshou.policy.random
from typing import Any, TypeVar, cast
import numpy as np
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import TrainingStats
[docs]
class RandomTrainingStats(TrainingStats):
pass
TRandomTrainingStats = TypeVar("TRandomTrainingStats", bound=RandomTrainingStats)
[docs]
class RandomPolicy(BasePolicy[TRandomTrainingStats]):
"""A random agent used in multi-agent learning.
It randomly chooses an action from the legal action.
"""
[docs]
def forward(
self,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ActBatchProtocol:
"""Compute the random action over the given batch data.
The input should contain a mask in batch.obs, with "True" to be
available and "False" to be unavailable. For example,
``batch.obs.mask == np.array([[False, True, False]])`` means with batch
size 1, action "1" is available but action "0" and "2" are unavailable.
:return: A :class:`~tianshou.data.Batch` with "act" key, containing
the random action.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
mask = batch.obs.mask # type: ignore
logits = np.random.rand(*mask.shape)
logits[~mask] = -np.inf
result = Batch(act=logits.argmax(axis=-1))
return cast(ActBatchProtocol, result)
[docs]
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TRandomTrainingStats: # type: ignore
"""Since a random agent learns nothing, it returns an empty dict."""
return RandomTrainingStats() # type: ignore[return-value]