from collections.abc import Callable
from typing import Any, Generic, Literal, Protocol, Self, TypeVar, cast, overload
import numpy as np
from overrides import override
from sensai.util.helper import mark_used
from torch.nn import ModuleList
from tianshou.algorithm import Algorithm
from tianshou.algorithm.algorithm_base import (
OffPolicyAlgorithm,
OnPolicyAlgorithm,
Policy,
TrainingStats,
)
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.batch import BatchProtocol, IndexType
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
try:
from tianshou.env.pettingzoo_env import PettingZooEnv
except ImportError:
PettingZooEnv = None # type: ignore
mark_used(ActBatchProtocol)
[docs]
class MapTrainingStats(TrainingStats):
def __init__(
self,
agent_id_to_stats: dict[str | int, TrainingStats],
train_time_aggregator: Literal["min", "max", "mean"] = "max",
) -> None:
self._agent_id_to_stats = agent_id_to_stats
train_times = [agent_stats.train_time for agent_stats in agent_id_to_stats.values()]
match train_time_aggregator:
case "max":
aggr_function = max
case "min":
aggr_function = min
case "mean":
aggr_function = np.mean # type: ignore
case _:
raise ValueError(
f"Unknown {train_time_aggregator=}",
)
self.train_time = aggr_function(train_times)
self.smoothed_loss = {}
[docs]
@override
def get_loss_stats_dict(self) -> dict[str, float]:
"""Collects loss_stats_dicts from all agents, prepends agent_id to all keys, and joins results."""
result_dict = {}
for agent_id, stats in self._agent_id_to_stats.items():
agent_loss_stats_dict = stats.get_loss_stats_dict()
for k, v in agent_loss_stats_dict.items():
result_dict[f"{agent_id}/" + k] = v
return result_dict
[docs]
class MAPRolloutBatchProtocol(RolloutBatchProtocol, Protocol):
# TODO: this might not be entirely correct.
# The whole MAP data processing pipeline needs more documentation and possibly some refactoring
@overload
def __getitem__(self, index: str) -> RolloutBatchProtocol: ...
@overload
def __getitem__(self, index: IndexType) -> Self: ...
def __getitem__(self, index: str | IndexType) -> Any: ...
[docs]
class MultiAgentPolicy(Policy):
def __init__(self, policies: dict[str | int, Policy]):
p0 = next(iter(policies.values()))
super().__init__(
action_space=p0.action_space,
observation_space=p0.observation_space,
action_scaling=False,
action_bound_method=None,
)
self.policies = policies
self._submodules = ModuleList(policies.values())
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
[docs]
def add_exploration_noise(
self,
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
"""Add exploration noise from sub-policy onto act."""
if not isinstance(batch.obs, Batch):
raise TypeError(
f"here only observations of type Batch are permitted, but got {type(batch.obs)}",
)
for agent_id, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
continue
act[agent_index] = policy.add_exploration_noise(act[agent_index], batch[agent_index])
return act
[docs]
def forward( # type: ignore
self,
batch: Batch,
state: dict | Batch | None = None,
**kwargs: Any,
) -> Batch:
"""Dispatch batch data from obs.agent_id to every policy's forward.
:param batch: TODO: document what is expected at input and make a BatchProtocol for it
:param state: if None, it means all agents have no state. If not
None, it should contain keys of "agent_1", "agent_2", ...
:return: a Batch with the following contents:
TODO: establish a BatcProtocol for this
::
{
"act": actions corresponding to the input
"state": {
"agent_1": output state of agent_1's policy for the state
"agent_2": xxx
...
"agent_n": xxx}
"out": {
"agent_1": output of agent_1's policy for the input
"agent_2": xxx
...
"agent_n": xxx}
}
"""
results: list[tuple[bool, np.ndarray, Batch, np.ndarray | Batch, Batch]] = []
for agent_id, policy in self.policies.items():
# This part of code is difficult to understand.
# Let's follow an example with two agents
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
# each agent plays for three transitions
# agent_index for agent 1 is [0, 2, 4]
# agent_index for agent 2 is [1, 3, 5]
# we separate the transition of each agent according to agent_id
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
# (has_data, agent_index, out, act, state)
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
continue
tmp_batch = batch[agent_index]
if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
if not hasattr(tmp_batch.obs, "mask"):
if hasattr(tmp_batch.obs, "obs"):
tmp_batch.obs = tmp_batch.obs.obs
if hasattr(tmp_batch.obs_next, "obs"):
tmp_batch.obs_next = tmp_batch.obs_next.obs
out = policy(
batch=tmp_batch,
state=None if state is None else state[agent_id],
**kwargs,
)
act = out.act
each_state = out.state if (hasattr(out, "state") and out.state is not None) else Batch()
results.append((True, agent_index, out, act, each_state))
holder: Batch = Batch.cat(
[{"act": act} for (has_data, agent_index, out, act, each_state) in results if has_data],
)
state_dict, out_dict = {}, {}
for (agent_id, _), (has_data, agent_index, out, act, state) in zip(
self.policies.items(),
results,
strict=True,
):
if has_data:
holder.act[agent_index] = act
state_dict[agent_id] = state
out_dict[agent_id] = out
holder["out"] = out_dict
holder["state"] = state_dict
return holder
TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm)
[docs]
class MARLDispatcher(Generic[TAlgorithm]):
"""
Supports multi-agent learning by dispatching calls to the corresponding
algorithm for each agent.
"""
def __init__(self, algorithms: list[TAlgorithm], env: PettingZooEnv):
agent_ids = env.agents
assert len(algorithms) == len(agent_ids), "One policy must be assigned for each agent."
self.algorithms: dict[str | int, TAlgorithm] = dict(zip(agent_ids, algorithms, strict=True))
"""maps agent_id to the corresponding algorithm."""
self.agent_idx = env.agent_idx
"""maps agent_id to 0-based index."""
[docs]
def create_policy(self) -> MultiAgentPolicy:
return MultiAgentPolicy({agent_id: a.policy for agent_id, a in self.algorithms.items()})
[docs]
def dispatch_process_fn(
self,
batch: MAPRolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> MAPRolloutBatchProtocol:
"""Dispatch batch data from `obs.agent_id` to every algorithm's processing function.
Save original multi-dimensional rew in "save_rew", set rew to the
reward of each agent during their "process_fn", and restore the
original reward afterwards.
"""
# TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol
results: dict[str | int, RolloutBatchProtocol] = {}
assert isinstance(
batch.obs,
BatchProtocol,
), f"here only observations of type Batch are permitted, but got {type(batch.obs)}"
# reward can be empty Batch (after initial reset) or nparray.
has_rew = isinstance(buffer.rew, np.ndarray)
if has_rew: # save the original reward in save_rew
# Since we do not override buffer.__setattr__, here we use _meta to
# change buffer.rew, otherwise buffer.rew = Batch() has no effect.
save_rew, buffer._meta.rew = buffer.rew, Batch() # type: ignore
for agent, algorithm in self.algorithms.items():
agent_index = np.nonzero(batch.obs.agent_id == agent)[0]
if len(agent_index) == 0:
results[agent] = cast(RolloutBatchProtocol, Batch())
continue
tmp_batch, tmp_indice = batch[agent_index], indices[agent_index]
if has_rew:
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]]
buffer._meta.rew = save_rew[:, self.agent_idx[agent]]
if not hasattr(tmp_batch.obs, "mask"):
if hasattr(tmp_batch.obs, "obs"):
tmp_batch.obs = tmp_batch.obs.obs
if hasattr(tmp_batch.obs_next, "obs"):
tmp_batch.obs_next = tmp_batch.obs_next.obs
results[agent] = algorithm._preprocess_batch(tmp_batch, buffer, tmp_indice)
if has_rew: # restore from save_rew
buffer._meta.rew = save_rew
return cast(MAPRolloutBatchProtocol, Batch(results))
[docs]
def dispatch_update_with_batch(
self,
batch: MAPRolloutBatchProtocol,
algorithm_update_with_batch_fn: Callable[[TAlgorithm, RolloutBatchProtocol], TrainingStats],
) -> MapTrainingStats:
"""Dispatch the respective subset of the batch data to each algorithm.
:param batch: must map agent_ids to rollout batches
:param algorithm_update_with_batch_fn: a function that performs the algorithm-specific
update with the given agent-specific batch data
"""
agent_id_to_stats = {}
for agent_id, algorithm in self.algorithms.items():
data = batch[agent_id]
if len(data.get_keys()) != 0:
train_stats = algorithm_update_with_batch_fn(algorithm, data)
agent_id_to_stats[agent_id] = train_stats
return MapTrainingStats(agent_id_to_stats)
[docs]
class MultiAgentOffPolicyAlgorithm(OffPolicyAlgorithm[MultiAgentPolicy]):
"""Multi-agent reinforcement learning where each agent uses off-policy learning."""
def __init__(
self,
*,
algorithms: list[OffPolicyAlgorithm],
env: PettingZooEnv,
) -> None:
"""
:param algorithms: a list of off-policy algorithms.
:param env: the multi-agent RL environment
"""
self._dispatcher: MARLDispatcher[OffPolicyAlgorithm] = MARLDispatcher(algorithms, env)
super().__init__(
policy=self._dispatcher.create_policy(),
)
self._submodules = ModuleList(algorithms)
[docs]
def get_algorithm(self, agent_id: str | int) -> OffPolicyAlgorithm:
return self._dispatcher.algorithms[agent_id]
def _preprocess_batch(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> RolloutBatchProtocol:
batch = cast(MAPRolloutBatchProtocol, batch)
return self._dispatcher.dispatch_process_fn(batch, buffer, indices)
def _update_with_batch(
self,
batch: RolloutBatchProtocol,
) -> MapTrainingStats:
batch = cast(MAPRolloutBatchProtocol, batch)
def update(algorithm: OffPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats:
return algorithm._update_with_batch(data)
return self._dispatcher.dispatch_update_with_batch(batch, update)
[docs]
class MultiAgentOnPolicyAlgorithm(OnPolicyAlgorithm[MultiAgentPolicy]):
"""Multi-agent reinforcement learning where each agent uses on-policy learning."""
def __init__(
self,
*,
algorithms: list[OnPolicyAlgorithm],
env: PettingZooEnv,
) -> None:
"""
:param algorithms: a list of off-policy algorithms.
:param env: the multi-agent RL environment
"""
self._dispatcher: MARLDispatcher[OnPolicyAlgorithm] = MARLDispatcher(algorithms, env)
super().__init__(
policy=self._dispatcher.create_policy(),
)
self._submodules = ModuleList(algorithms)
[docs]
def get_algorithm(self, agent_id: str | int) -> OnPolicyAlgorithm:
return self._dispatcher.algorithms[agent_id]
def _preprocess_batch(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> RolloutBatchProtocol:
batch = cast(MAPRolloutBatchProtocol, batch)
return self._dispatcher.dispatch_process_fn(batch, buffer, indices)
def _update_with_batch(
self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int
) -> MapTrainingStats:
batch = cast(MAPRolloutBatchProtocol, batch)
def update(algorithm: OnPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats:
return algorithm._update_with_batch(data, batch_size, repeat)
return self._dispatcher.dispatch_update_with_batch(batch, update)