class MAPRolloutBatchProtocol(*args, **kwargs)[source]#
class MapTrainingStats(agent_id_to_stats: dict[str | int, TrainingStats], train_time_aggregator: Literal['min', 'max', 'mean'] = 'max')[source]#
get_loss_stats_dict() dict[str, float][source]#

Collects loss_stats_dicts from all agents, prepends agent_id to all keys, and joins results.

class MultiAgentPolicyManager(*, policies: list[BasePolicy], env: PettingZooEnv, action_scaling: bool = False, action_bound_method: Literal['clip', 'tanh'] | None = 'clip', lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#

Multi-agent policy manager for MARL.

This multi-agent policy manager accepts a list of BasePolicy. It dispatches the batch data to each of these policies when the “forward” is called. The same as “process_fn” and “learn”: it splits the data and feeds them to each policy. A figure in Multi-Agent Reinforcement Learning can help you better understand this procedure.

  • policies – a list of policies.

  • env – a PettingZooEnv.

  • action_scaling – if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous.

  • action_bound_method – method to bound action to range [-1, 1]. Only used if the action_space is continuous.

  • lr_scheduler – if not None, will be called in policy.update().

exploration_noise(act: _TArrOrActBatch, batch: ObsBatchProtocol) _TArrOrActBatch[source]#

Add exploration noise from sub-policy onto act.

forward(batch: Batch, state: dict | Batch | None = None, **kwargs: Any) Batch[source]#

Dispatch batch data from obs.agent_id to every policy’s forward.

  • batch – TODO: document what is expected at input and make a BatchProtocol for it

  • state – if None, it means all agents have no state. If not None, it should contain keys of “agent_1”, “agent_2”, …


a Batch with the following contents: TODO: establish a BatcProtocol for this

    "act": actions corresponding to the input
    "state": {
        "agent_1": output state of agent_1's policy for the state
        "agent_2": xxx
        "agent_n": xxx}
    "out": {
        "agent_1": output of agent_1's policy for the input
        "agent_2": xxx
        "agent_n": xxx}
learn(batch: MAPRolloutBatchProtocol, *args: Any, **kwargs: Any) MapTrainingStats[source]#

Dispatch the data to all policies for learning.


batch – must map agent_ids to rollout batches

policies: dict[str | int, BasePolicy]#

Maps agent_id to policy.

process_fn(batch: MAPRolloutBatchProtocol, buffer: ReplayBuffer, indice: ndarray) MAPRolloutBatchProtocol[source]#

Dispatch batch data from obs.agent_id to every policy’s process_fn.

Save original multi-dimensional rew in “save_rew”, set rew to the reward of each agent during their “process_fn”, and restore the original reward afterwards.

replace_policy(policy: BasePolicy, agent_id: int) None[source]#

Replace the “agent_id”th policy in this manager.

train(mode: bool = True) Self[source]#

Set each internal policy in training mode.