Multi-Agent Reinforcement Learning (MARL)#

This tutorial demonstrates how to use Tianshou for multi-agent reinforcement learning scenarios. We’ll explore different MARL paradigms and implement a practical example using the Tic-Tac-Toe game.

MARL Paradigms#

Tianshou supports three fundamental types of multi-agent reinforcement learning paradigms:

  1. Simultaneous move: All agents take their actions at each timestep simultaneously (e.g., MOBA games)

  2. Cyclic move: Agents take actions sequentially in turns (e.g., Go)

  3. Conditional move: The environment conditionally selects which agent acts at each timestep (e.g., Pig Game)

Our approach addresses these multi-agent RL problems by converting them into traditional single-agent RL formulations.

Converting MARL to Single-Agent RL#

Simultaneous Move#

For simultaneous-move scenarios, the solution is straightforward: we add an extra num_agents dimension to the state, action, and reward tensors. No other modifications are necessary.

Cyclic and Conditional Move#

Both cyclic and conditional move scenarios can be unified into a single framework. At each timestep, the environment selects an agent identified by agent_id to act. Since multiple agents are typically wrapped into a single object (the “abstract agent”), we pass the agent_id to this abstract agent, which then delegates the action to the appropriate specific agent.

Additionally, in multi-agent RL, the set of legal actions often varies across timesteps (as in Go). Therefore, the environment must also provide a legal action mask to the abstract agent. This mask is a boolean array where True indicates available actions and False indicates illegal actions at the current timestep.


The abstract agent framework for multi-agent RL

Unified Formulation#

This architecture leads to the following formulation of multi-agent RL:

act = policy(state, agent_id, mask)
(next_state, next_agent_id, next_mask), reward = env.step(act)

By constructing an augmented state state_ = (state, agent_id, mask), we can reduce this to the standard single-agent RL formulation:

act = policy(state_)
next_state_, reward = env.step(act)

Following this principle, we’ll implement a Q-learning algorithm to play Tic-Tac-Toe against a random opponent.

PettingZoo Integration#

Tianshou is fully compatible with PettingZoo environments for multi-agent RL. While Tianshou doesn’t directly provide specialized MARL facilities, it offers a flexible framework that can be adapted to various MARL scenarios.

For comprehensive tutorials on using Tianshou with PettingZoo, refer to:

In this tutorial, we’ll demonstrate how to use Tianshou in a multi-agent setting where only one agent is trained while the other uses a fixed random policy. You can then use this as a blueprint to replace the random policy with another trainable agent.

Specifically, we’ll train an agent to play Tic-Tac-Toe against a random opponent:


Tic-Tac-Toe game board

Exploring the Tic-Tac-Toe Environment#

The complete scripts are located in test/pettingzoo/. Tianshou provides the PettingZooEnv wrapper class that can wrap any PettingZoo environment. Let’s explore the 3×3 Tic-Tac-Toe environment provided by PettingZoo.

from pettingzoo.classic import tictactoe_v3  # the Tic-Tac-Toe environment

from tianshou.env import PettingZooEnv  # wrapper for PettingZoo environments

# Initialize the environment
# The board has 3 rows and 3 columns (9 positions total)
# Players place 'X' and 'O' alternately on the board
# The first player to get 3 consecutive marks wins
env = PettingZooEnv(tictactoe_v3.env(render_mode="human"))
obs = env.reset()
env.render()  # render the empty board
error: XDG_RUNTIME_DIR not set in the environment.

The output shows an empty 3×3 board:

board (step 0):
     |     |
  -  |  -  |  -
_____|_____|_____
     |     |
  -  |  -  |  -
_____|_____|_____
     |     |
  -  |  -  |  -
     |     |
# Examine the observation structure
print(obs)
({'agent_id': 'player_1', 'obs': array([[[0, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0]]], dtype=int8), 'mask': [True, True, True, True, True, True, True, True, True]}, {})

Understanding the Observation Space#

The observation returned by the environment is a dictionary with three keys:

  • agent_id: The identifier of the currently acting agent (e.g., 'player_1' or 'player_2')

  • obs: The actual environment observation. For Tic-Tac-Toe, this is a numpy array with shape (3, 3, 2):

    • For player_1: The first 3×3 plane represents X placements, the second plane represents O placements

    • For player_2: The planes are swapped (O in first plane, X in second)

    • Each cell contains either 0 (empty/not placed) or 1 (mark placed)

  • mask: A boolean array indicating legal actions at the current timestep. For Tic-Tac-Toe, index i corresponds to position (i // 3, i % 3) on the board. If mask[i] == True, the player can place their mark at that position. Initially, all positions are available, so all mask values are True.

Note: The mask representation is flexible and works for both discrete and continuous action spaces. While we use a boolean array here, you could also use action spaces like gymnasium.spaces.Discrete or gymnasium.spaces.Box to represent available actions.

Playing a Few Steps#

Let’s play a couple of moves to understand the environment dynamics better.

import numpy as np

# Take an action (place mark at position 0 - top-left corner)
action = 0  # action can be an integer or a numpy array with one element
obs, reward, done, truncated, info = env.step(action)  # follows the Gymnasium API

print("Observation after first move:")
print(obs)

# Examine the reward structure
# Reward has two items (one for each player): 1 for win, -1 for loss, 0 otherwise
print(f"\nReward: {reward}")

# Check if the game is over
print(f"Done: {done}")

# Info is typically an empty dict in Tic-Tac-Toe but may contain useful information in other environments
print(f"Info: {info}")
Observation after first move:
{'agent_id': 'player_2', 'obs': array([[[0, 1],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0]]], dtype=int8), 'mask': [False, True, True, True, True, True, True, True, True]}

Reward: [0, 0]
Done: False
Info: {}

Notice that after the first move:

  • The agent_id switches to 'player_2'

  • The observation array shows the X placement in the first position

  • The mask now has False at index 0 (that position is occupied)

  • The reward is [0, 0] (no winner yet)

  • The game continues (done = False)

Note: If we continue playing, the game terminates when only one empty position remains, rather than when the board is completely full. This is because a player with only one available position has no meaningful choice.

Random Agents#

Now that we understand the environment, let’s start by watching two random agents play against each other.

Tianshou provides built-in classes for multi-agent learning. The key components are:

  • RandomPolicy: A policy that randomly selects actions

  • MultiAgentPolicyManager: Manages multiple agent policies and delegates actions to the appropriate agent based on agent_id


The relationship between MultiAgentPolicyManager and individual agent policies
from tianshou.algorithm.multiagent.marl import MultiAgentOffPolicyAlgorithm
from tianshou.algorithm.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv

# Create a multi-agent algorithm with two random agents
policy = MultiAgentOffPolicyAlgorithm(
    algorithms=[
        MARLRandomDiscreteMaskedOffPolicyAlgorithm(action_space=env.action_space),
        MARLRandomDiscreteMaskedOffPolicyAlgorithm(action_space=env.action_space),
    ],
    env=env,
)

# Vectorize the environment for the collector
env = DummyVectorEnv([lambda: env])

# Create a collector to gather trajectories
collector = Collector(policy, env)

# Collect and visualize one episode
result = collector.collect(n_episode=1, render=0.1, reset_before_collect=True)
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.

You’ll see the game progress step by step. Here’s an example of the final moves:

     |     |
  X  |  X  |  -
_____|_____|_____
     |     |
  X  |  O  |  -
_____|_____|_____
     |     |
  O  |  -  |  -
     |     |
     |     |
  X  |  X  |  -
_____|_____|_____
     |     |
  X  |  O  |  -
_____|_____|_____
     |     |
  O  |  -  |  O
     |     |
     |     |
  X  |  X  |  X
_____|_____|_____
     |     |
  X  |  O  |  -
_____|_____|_____
     |     |
  O  |  -  |  O
     |     |

Random agents perform poorly. In the game above, although agent 2 eventually wins, a smart agent 1 would have won immediately by placing an X at position (1, 1) (center of middle row).

Training an Agent Against a Random Opponent#

Now let’s train an intelligent agent! We’ll use Deep Q-Network (DQN) to learn optimal play against a random opponent.

Imports and Setup#

First, let’s import all necessary modules:

import os
from copy import deepcopy
from functools import partial

import gymnasium
import torch
from pettingzoo.classic import tictactoe_v3
from torch.utils.tensorboard import SummaryWriter

from tianshou.algorithm import (
    DQN,
    Algorithm,
    MARLRandomDiscreteMaskedOffPolicyAlgorithm,
    MultiAgentOffPolicyAlgorithm,
)
from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm
from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy
from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.data.stats import InfoStats
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.trainer import OffPolicyTrainerParams
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net

Hyperparameters#

Let’s define the hyperparameters for our training experiment directly (no argparse needed in notebooks!):

# Define hyperparameters
class Args:
    seed = 1626
    eps_test = 0.05
    eps_train = 0.1
    buffer_size = 20000
    lr = 1e-4
    gamma = 0.9  # A smaller gamma favors earlier wins
    n_step = 3
    target_update_freq = 320
    epoch = 50
    epoch_num_steps = 1000
    collection_step_num_env_steps = 10
    update_per_step = 0.1
    batch_size = 64
    hidden_sizes = [128, 128, 128, 128]  # noqa: RUF012
    num_train_envs = 10
    num_test_envs = 10
    logdir = "log"
    render = 0.1
    win_rate = 0.6  # Target winning rate (optimal policy can get ~0.7)
    watch = False  # Set to True to skip training and watch pre-trained models
    agent_id = 2  # The learned agent plays as player 2
    resume_path = ""  # Path to pre-trained agent .pth file
    opponent_path = ""  # Path to pre-trained opponent .pth file
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_save_path = None  # Will be set in save_best_fn


args = Args()

Agent Setup#

The get_agents function creates and configures our agents:

  • Neural Network: We use Net, a multi-layer perceptron with ReLU activations

  • Learning Algorithm: A DiscreteQLearningPolicy combined with DQN for Q-learning updates

  • Opponent: Either a MARLRandomDiscreteMaskedOffPolicyAlgorithm that randomly chooses legal actions, or a pre-trained agent for self-play

Both agents are managed by MultiAgentOffPolicyAlgorithm, which:

  • Calls the correct agent based on agent_id in the observation

  • Dispatches data to each agent according to their agent_id

  • Makes each agent perceive the environment as a single-agent problem


How MultiAgentOffPolicyAlgorithm coordinates agent algorithms
def get_env(render_mode: str | None = None) -> PettingZooEnv:
    return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))


def get_agents(
    args,
    agent_learn: OffPolicyAlgorithm | None = None,
    agent_opponent: OffPolicyAlgorithm | None = None,
    optim: OptimizerFactory | None = None,
) -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]:
    """Create or load agents for training."""
    env = get_env()
    observation_space = (
        env.observation_space.spaces["observation"]
        if isinstance(env.observation_space, gymnasium.spaces.Dict)
        else env.observation_space
    )
    args.state_shape = observation_space.shape or int(observation_space.n)
    args.action_shape = env.action_space.shape or int(env.action_space.n)

    if agent_learn is None:
        # Create the neural network model
        net = Net(
            state_shape=args.state_shape,
            action_shape=args.action_shape,
            hidden_sizes=args.hidden_sizes,
        ).to(args.device)

        if optim is None:
            optim = AdamOptimizerFactory(lr=args.lr)

        # Create Q-learning policy for the learning agent
        algorithm = DiscreteQLearningPolicy(
            model=net,
            action_space=env.action_space,
            eps_training=args.eps_train,
            eps_inference=args.eps_test,
        )

        # Wrap in DQN algorithm
        agent_learn = DQN(
            policy=algorithm,
            optim=optim,
            n_step_return_horizon=args.n_step,
            gamma=args.gamma,
            target_update_freq=args.target_update_freq,
        )

        if args.resume_path:
            agent_learn.load_state_dict(torch.load(args.resume_path))

    if agent_opponent is None:
        if args.opponent_path:
            # Load a pre-trained opponent for self-play
            agent_opponent = deepcopy(agent_learn)
            agent_opponent.load_state_dict(torch.load(args.opponent_path))
        else:
            # Use a random opponent
            agent_opponent = MARLRandomDiscreteMaskedOffPolicyAlgorithm(
                action_space=env.action_space
            )

    # Arrange agents based on which player position the learning agent takes
    if args.agent_id == 1:
        agents = [agent_learn, agent_opponent]
    else:
        agents = [agent_opponent, agent_learn]

    ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env)
    return ma_algorithm, optim, env.agents

Training Loop#

The training procedure follows the standard Tianshou workflow, similar to single-agent DQN training:

def train_agent(
    args,
    agent_learn: OffPolicyAlgorithm | None = None,
    agent_opponent: OffPolicyAlgorithm | None = None,
    optim: OptimizerFactory | None = None,
) -> tuple[InfoStats, OffPolicyAlgorithm]:
    """Train the agent using DQN."""
    # ======== Environment Setup =========
    train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)])
    test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)])

    # Set random seeds for reproducibility
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)

    # ======== Agent Setup =========
    marl_algorithm, optim, agents = get_agents(
        args,
        agent_learn=agent_learn,
        agent_opponent=agent_opponent,
        optim=optim,
    )

    # ======== Collector Setup =========
    training_collector = Collector[CollectStats](
        marl_algorithm,
        train_envs,
        VectorReplayBuffer(args.buffer_size, len(train_envs)),
        exploration_noise=True,
    )
    test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True)

    # Collect initial random samples
    training_collector.reset()
    training_collector.collect(n_step=args.batch_size * args.num_train_envs)

    # ======== Logging Setup =========
    log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn")
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer)

    player_agent_id = agents[args.agent_id - 1]

    # ======== Callback Functions =========
    def save_best_fn(policy: Algorithm) -> None:
        """Save the best performing policy."""
        if hasattr(args, "model_save_path") and args.model_save_path:
            model_save_path = args.model_save_path
        else:
            model_save_path = os.path.join(args.logdir, "tic_tac_toe", "dqn", "policy.pth")
        torch.save(policy.get_algorithm(player_agent_id).state_dict(), model_save_path)

    def stop_fn(mean_rewards: float) -> bool:
        """Stop training when target win rate is achieved."""
        return mean_rewards >= args.win_rate

    def reward_metric(rews: np.ndarray) -> np.ndarray:
        """Extract the reward for our learning agent."""
        return rews[:, args.agent_id - 1]

    # ======== Trainer =========
    result = marl_algorithm.run_training(
        OffPolicyTrainerParams(
            training_collector=training_collector,
            test_collector=test_collector,
            max_epochs=args.epoch,
            epoch_num_steps=args.epoch_num_steps,
            collection_step_num_env_steps=args.collection_step_num_env_steps,
            test_step_num_episodes=args.num_test_envs,
            batch_size=args.batch_size,
            stop_fn=stop_fn,
            save_best_fn=save_best_fn,
            update_step_num_gradient_steps_per_sample=args.update_per_step,
            logger=logger,
            test_in_training=False,
            multi_agent_return_reduction=reward_metric,
            show_progress=False,
        )
    )

    return result, marl_algorithm.get_algorithm(player_agent_id)

Evaluation Function#

This function allows us to watch a trained agent play:

def watch(
    args,
    agent_learn: OffPolicyAlgorithm | None = None,
    agent_opponent: OffPolicyAlgorithm | None = None,
) -> None:
    """Watch a pre-trained agent play."""
    env = DummyVectorEnv([partial(get_env, render_mode="human")])
    policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
    collector = Collector[CollectStats](policy, env, exploration_noise=True)
    result = collector.collect(n_episode=1, render=args.render, reset_before_collect=True)
    result.pprint_asdict()

Running the Training#

Now let’s train the agent and watch it play!

# Train the agent
result, agent = train_agent(args)

# Watch the trained agent play
watch(args, agent)
Sequence has shape (79, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (10, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Initial test step: test_reward: -0.800000 ± 0.600000, best_reward: -0.800000 ± 0.600000 in #0
Sequence has shape (4, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (4, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (4, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (4, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (2, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (3, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Sequence has shape (10, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
Epoch #1: test_reward: 0.900000 ± 0.300000, best_reward: 0.900000 ± 0.300000 in #1
Sequence has shape (1, 2), but only 1D sequences are supported. Stats will be computed from the flattened sequence. For computing stats for each dimension consider using the function `compute_dim_to_summary_stats`.
CollectStats
----------------------------------------
{   'collect_speed': 0.49564255535677326,
    'collect_time': 12.10549807548523,
    'lens': array([6]),
    'lens_stat': {'max': 6.0, 'mean': 6.0, 'min': 6.0, 'std': 0.0},
    'n_collected_episodes': 1,
    'n_collected_steps': 6,
    'pred_dist_std_array': None,
    'pred_dist_std_array_stat': None,
    'returns': array([[-1.,  1.]]),
    'returns_stat': {'max': 1.0, 'mean': 0.0, 'min': -1.0, 'std': 1.0}}

Training Results#

After training for less than a minute, you’ll see the agent play against the random opponent. Here’s an example game:

Example: Trained Agent vs Random Opponent
     |     |
  -  |  -  |  -
_____|_____|_____
     |     |
  -  |  -  |  X
_____|_____|_____
     |     |
  -  |  -  |  -
     |     |
     |     |
  -  |  -  |  -
_____|_____|_____
     |     |
  -  |  O  |  X
_____|_____|_____
     |     |
  -  |  -  |  -
     |     |
     |     |
  -  |  -  |  -
_____|_____|_____
     |     |
  X  |  O  |  X
_____|_____|_____
     |     |
  -  |  -  |  -
     |     |
     |     |
  -  |  O  |  -
_____|_____|_____
     |     |
  X  |  O  |  X
_____|_____|_____
     |     |
  -  |  -  |  -
     |     |
     |     |
  -  |  O  |  -
_____|_____|_____
     |     |
  X  |  O  |  X
_____|_____|_____
     |     |
  -  |  X  |  -
     |     |
     |     |
  O  |  O  |  -
_____|_____|_____
     |     |
  X  |  O  |  X
_____|_____|_____
     |     |
  -  |  X  |  -
     |     |
     |     |
  O  |  O  |  X
_____|_____|_____
     |     |
  X  |  O  |  X
_____|_____|_____
     |     |
  -  |  X  |  -
     |     |
     |     |
  O  |  O  |  X
_____|_____|_____
     |     |
  X  |  O  |  X
_____|_____|_____
     |     |
  -  |  X  |  O
     |     |

Final reward: 1.0, length: 8.0

Notice that our trained agent plays as player 2 (O) and wins! The agent has learned the game rules through trial and error, understanding that three consecutive O marks lead to victory.

It is easily possible to make the trained agent play against itself. Try this as an exercise!

While the trained agent plays well against a random opponent, it’s still far from perfect play. The next step would be to implement self-play training, similar to AlphaZero, where the agent continuously improves by playing against increasingly stronger versions of itself.

Summary#

In this tutorial, we demonstrated how to use Tianshou for training a single agent in a multi-agent reinforcement learning setting. Key takeaways:

  1. MARL Paradigms: Tianshou supports simultaneous, cyclic, and conditional move scenarios

  2. Abstraction: Multi-agent problems can be converted to single-agent RL through clever state augmentation

  3. PettingZoo Integration: Seamless compatibility with PettingZoo environments via PettingZooEnv

  4. Algorithm Management: MultiAgentOffPolicyAlgorithm handles agent coordination and data distribution

  5. Flexible Framework: Easy to extend from single-agent training to more complex multi-agent scenarios

Tianshou provides a flexible and intuitive framework for reinforcement learning. Experiment with different architectures, training regimes, and opponent strategies to build even more capable agents!