Overview#
To begin, ensure you have Tianshou and the Gym environment installed by executing the following commands. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. For users on older versions of Tianshou, please consult the documentation corresponding to your version..
Run the code#
Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v1 problem in Gym. Simply run it and don’t worry if you can’t understand the code very well. That is exactly what this tutorial is for.
If the script ends normally, you will see the evaluation result printed out before the first epoch is finished.
Show code cell content
%%capture
import gymnasium as gym
import torch
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic
device = "cuda" if torch.cuda.is_available() else "cpu"
# environments
env = gym.make("CartPole-v1")
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(20)])
test_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(10)])
# model & optimizer
assert env.observation_space.shape is not None # for mypy
net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)
assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy
actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)
critic = Critic(preprocess_net=net, device=device).to(device)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)
# PPO policy
dist = torch.distributions.Categorical
policy: PPOPolicy = PPOPolicy(
actor=actor,
critic=critic,
optim=optim,
dist_fn=dist,
action_space=env.action_space,
action_scaling=False,
)
# collector
train_collector = Collector[CollectStats](
policy,
train_envs,
VectorReplayBuffer(20000, len(train_envs)),
)
test_collector = Collector[CollectStats](policy, test_envs)
# trainer
train_result = OnpolicyTrainer(
policy=policy,
batch_size=256,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=10,
step_per_epoch=50000,
repeat_per_collect=10,
episode_per_test=10,
step_per_collect=2000,
stop_fn=lambda mean_reward: mean_reward >= 195,
).run()
Show code cell output
Epoch #1: 0%| | 0/50000 [00:00<?, ?it/s]
Epoch #1: 4%|4 | 2000/50000 [00:00<00:06, 6963.26it/s]
Epoch #1: 4%|4 | 2000/50000 [00:00<00:06, 6963.26it/s, env_episode=87, env_step=2000, gradient_step=8, len=20, n/ep=87, n/st=2000, rew=20.01]
Epoch #1: 8%|8 | 4000/50000 [00:00<00:06, 7219.01it/s, env_episode=87, env_step=2000, gradient_step=8, len=20, n/ep=87, n/st=2000, rew=20.01]
Epoch #1: 8%|8 | 4000/50000 [00:00<00:06, 7219.01it/s, env_episode=174, env_step=4000, gradient_step=16, len=19, n/ep=87, n/st=2000, rew=22.39]
Epoch #1: 12%|#2 | 6000/50000 [00:00<00:06, 7289.91it/s, env_episode=174, env_step=4000, gradient_step=16, len=19, n/ep=87, n/st=2000, rew=22.39]
Epoch #1: 12%|#2 | 6000/50000 [00:00<00:06, 7289.91it/s, env_episode=248, env_step=6000, gradient_step=24, len=21, n/ep=74, n/st=2000, rew=25.78]
Epoch #1: 16%|#6 | 8000/50000 [00:01<00:05, 7349.50it/s, env_episode=248, env_step=6000, gradient_step=24, len=21, n/ep=74, n/st=2000, rew=25.78]
Epoch #1: 16%|#6 | 8000/50000 [00:01<00:05, 7349.50it/s, env_episode=325, env_step=8000, gradient_step=32, len=20, n/ep=77, n/st=2000, rew=25.49]
Epoch #1: 20%|## | 10000/50000 [00:01<00:05, 7375.64it/s, env_episode=325, env_step=8000, gradient_step=32, len=20, n/ep=77, n/st=2000, rew=25.49]
Epoch #1: 20%|## | 10000/50000 [00:01<00:05, 7375.64it/s, env_episode=403, env_step=10000, gradient_step=40, len=21, n/ep=78, n/st=2000, rew=26.86]
Epoch #1: 24%|##4 | 12000/50000 [00:01<00:05, 7389.62it/s, env_episode=403, env_step=10000, gradient_step=40, len=21, n/ep=78, n/st=2000, rew=26.86]
Epoch #1: 24%|##4 | 12000/50000 [00:01<00:05, 7389.62it/s, env_episode=474, env_step=12000, gradient_step=48, len=22, n/ep=71, n/st=2000, rew=27.14]
Epoch #1: 28%|##8 | 14000/50000 [00:01<00:04, 7401.46it/s, env_episode=474, env_step=12000, gradient_step=48, len=22, n/ep=71, n/st=2000, rew=27.14]
Epoch #1: 28%|##8 | 14000/50000 [00:01<00:04, 7401.46it/s, env_episode=540, env_step=14000, gradient_step=56, len=24, n/ep=66, n/st=2000, rew=30.73]
Epoch #1: 32%|###2 | 16000/50000 [00:02<00:04, 7433.86it/s, env_episode=540, env_step=14000, gradient_step=56, len=24, n/ep=66, n/st=2000, rew=30.73]
Epoch #1: 32%|###2 | 16000/50000 [00:02<00:04, 7433.86it/s, env_episode=600, env_step=16000, gradient_step=64, len=27, n/ep=60, n/st=2000, rew=33.57]
Epoch #1: 36%|###6 | 18000/50000 [00:02<00:04, 7457.15it/s, env_episode=600, env_step=16000, gradient_step=64, len=27, n/ep=60, n/st=2000, rew=33.57]
Epoch #1: 36%|###6 | 18000/50000 [00:02<00:04, 7457.15it/s, env_episode=655, env_step=18000, gradient_step=72, len=27, n/ep=55, n/st=2000, rew=33.87]
Epoch #1: 40%|#### | 20000/50000 [00:02<00:04, 7473.92it/s, env_episode=655, env_step=18000, gradient_step=72, len=27, n/ep=55, n/st=2000, rew=33.87]
Epoch #1: 40%|#### | 20000/50000 [00:02<00:04, 7473.92it/s, env_episode=696, env_step=20000, gradient_step=80, len=35, n/ep=41, n/st=2000, rew=46.76]
Epoch #1: 44%|####4 | 22000/50000 [00:02<00:03, 7500.97it/s, env_episode=696, env_step=20000, gradient_step=80, len=35, n/ep=41, n/st=2000, rew=46.76]
Epoch #1: 44%|####4 | 22000/50000 [00:02<00:03, 7500.97it/s, env_episode=732, env_step=22000, gradient_step=88, len=36, n/ep=36, n/st=2000, rew=48.50]
Epoch #1: 48%|####8 | 24000/50000 [00:03<00:03, 7526.55it/s, env_episode=732, env_step=22000, gradient_step=88, len=36, n/ep=36, n/st=2000, rew=48.50]
Epoch #1: 48%|####8 | 24000/50000 [00:03<00:03, 7526.55it/s, env_episode=756, env_step=24000, gradient_step=96, len=37, n/ep=24, n/st=2000, rew=59.42]
Epoch #1: 52%|#####2 | 26000/50000 [00:03<00:03, 7541.94it/s, env_episode=756, env_step=24000, gradient_step=96, len=37, n/ep=24, n/st=2000, rew=59.42]
Epoch #1: 52%|#####2 | 26000/50000 [00:03<00:03, 7541.94it/s, env_episode=777, env_step=26000, gradient_step=104, len=42, n/ep=21, n/st=2000, rew=98.76]
Epoch #1: 56%|#####6 | 28000/50000 [00:03<00:02, 7537.38it/s, env_episode=777, env_step=26000, gradient_step=104, len=42, n/ep=21, n/st=2000, rew=98.76]
Epoch #1: 56%|#####6 | 28000/50000 [00:03<00:02, 7537.38it/s, env_episode=804, env_step=28000, gradient_step=112, len=42, n/ep=27, n/st=2000, rew=88.78]
Epoch #1: 60%|###### | 30000/50000 [00:04<00:02, 7554.89it/s, env_episode=804, env_step=28000, gradient_step=112, len=42, n/ep=27, n/st=2000, rew=88.78]
Epoch #1: 60%|###### | 30000/50000 [00:04<00:02, 7554.89it/s, env_episode=823, env_step=30000, gradient_step=120, len=48, n/ep=19, n/st=2000, rew=81.63]
Epoch #1: 64%|######4 | 32000/50000 [00:04<00:02, 7576.57it/s, env_episode=823, env_step=30000, gradient_step=120, len=48, n/ep=19, n/st=2000, rew=81.63]
Epoch #1: 64%|######4 | 32000/50000 [00:04<00:02, 7576.57it/s, env_episode=838, env_step=32000, gradient_step=128, len=46, n/ep=15, n/st=2000, rew=99.80]
Epoch #1: 68%|######8 | 34000/50000 [00:04<00:02, 7594.29it/s, env_episode=838, env_step=32000, gradient_step=128, len=46, n/ep=15, n/st=2000, rew=99.80]
Epoch #1: 68%|######8 | 34000/50000 [00:04<00:02, 7594.29it/s, env_episode=854, env_step=34000, gradient_step=136, len=44, n/ep=16, n/st=2000, rew=147.75]
Epoch #1: 72%|#######2 | 36000/50000 [00:04<00:01, 7588.55it/s, env_episode=854, env_step=34000, gradient_step=136, len=44, n/ep=16, n/st=2000, rew=147.75]
Epoch #1: 72%|#######2 | 36000/50000 [00:04<00:01, 7588.55it/s, env_episode=874, env_step=36000, gradient_step=144, len=45, n/ep=20, n/st=2000, rew=114.80]
Epoch #1: 76%|#######6 | 38000/50000 [00:05<00:01, 7605.80it/s, env_episode=874, env_step=36000, gradient_step=144, len=45, n/ep=20, n/st=2000, rew=114.80]
Epoch #1: 76%|#######6 | 38000/50000 [00:05<00:01, 7605.80it/s, env_episode=889, env_step=38000, gradient_step=152, len=65, n/ep=15, n/st=2000, rew=133.53]
Epoch #1: 80%|######## | 40000/50000 [00:05<00:01, 7580.89it/s, env_episode=889, env_step=38000, gradient_step=152, len=65, n/ep=15, n/st=2000, rew=133.53]
Epoch #1: 80%|######## | 40000/50000 [00:05<00:01, 7580.89it/s, env_episode=912, env_step=40000, gradient_step=160, len=44, n/ep=23, n/st=2000, rew=94.43]
Epoch #1: 84%|########4 | 42000/50000 [00:05<00:01, 7557.17it/s, env_episode=912, env_step=40000, gradient_step=160, len=44, n/ep=23, n/st=2000, rew=94.43]
Epoch #1: 84%|########4 | 42000/50000 [00:05<00:01, 7557.17it/s, env_episode=925, env_step=42000, gradient_step=168, len=50, n/ep=13, n/st=2000, rew=110.08]
Epoch #1: 88%|########8 | 44000/50000 [00:05<00:00, 7550.64it/s, env_episode=925, env_step=42000, gradient_step=168, len=50, n/ep=13, n/st=2000, rew=110.08]
Epoch #1: 88%|########8 | 44000/50000 [00:05<00:00, 7550.64it/s, env_episode=941, env_step=44000, gradient_step=176, len=52, n/ep=16, n/st=2000, rew=136.00]
Epoch #1: 92%|#########2| 46000/50000 [00:06<00:00, 7554.36it/s, env_episode=941, env_step=44000, gradient_step=176, len=52, n/ep=16, n/st=2000, rew=136.00]
Epoch #1: 92%|#########2| 46000/50000 [00:06<00:00, 7554.36it/s, env_episode=952, env_step=46000, gradient_step=184, len=29, n/ep=11, n/st=2000, rew=126.27]
Epoch #1: 96%|#########6| 48000/50000 [00:06<00:00, 7550.06it/s, env_episode=952, env_step=46000, gradient_step=184, len=29, n/ep=11, n/st=2000, rew=126.27]
Epoch #1: 96%|#########6| 48000/50000 [00:06<00:00, 7550.06it/s, env_episode=963, env_step=48000, gradient_step=192, len=62, n/ep=11, n/st=2000, rew=182.36]
Epoch #1: 100%|##########| 50000/50000 [00:06<00:00, 7550.26it/s, env_episode=963, env_step=48000, gradient_step=192, len=62, n/ep=11, n/st=2000, rew=182.36]
Epoch #1: 100%|##########| 50000/50000 [00:06<00:00, 7550.26it/s, env_episode=976, env_step=50000, gradient_step=200, len=49, n/ep=13, n/st=2000, rew=166.38]
Epoch #1: 50001it [00:06, 7496.31it/s, env_episode=976, env_step=50000, gradient_step=200, len=49, n/ep=13, n/st=2000, rew=166.38]
Epoch #1: test_reward: 118.100000 ± 55.646114, best_reward: 118.100000 ± 55.646114 in #1
Epoch #2: 0%| | 0/50000 [00:00<?, ?it/s]
Epoch #2: 4%|4 | 2000/50000 [00:00<00:06, 7673.65it/s]
Epoch #2: 4%|4 | 2000/50000 [00:00<00:06, 7673.65it/s, env_episode=987, env_step=52000, gradient_step=208, len=52, n/ep=11, n/st=2000, rew=149.73]
Epoch #2: 8%|8 | 4000/50000 [00:00<00:06, 7568.39it/s, env_episode=987, env_step=52000, gradient_step=208, len=52, n/ep=11, n/st=2000, rew=149.73]
Epoch #2: 8%|8 | 4000/50000 [00:00<00:06, 7568.39it/s, env_episode=1001, env_step=54000, gradient_step=216, len=62, n/ep=14, n/st=2000, rew=187.43]
Epoch #2: 12%|#2 | 6000/50000 [00:00<00:05, 7553.12it/s, env_episode=1001, env_step=54000, gradient_step=216, len=62, n/ep=14, n/st=2000, rew=187.43]
Epoch #2: 12%|#2 | 6000/50000 [00:00<00:05, 7553.12it/s, env_episode=1011, env_step=56000, gradient_step=224, len=39, n/ep=10, n/st=2000, rew=160.50]
Epoch #2: 16%|#6 | 8000/50000 [00:01<00:05, 7550.92it/s, env_episode=1011, env_step=56000, gradient_step=224, len=39, n/ep=10, n/st=2000, rew=160.50]
Epoch #2: 16%|#6 | 8000/50000 [00:01<00:05, 7550.92it/s, env_episode=1022, env_step=58000, gradient_step=232, len=32, n/ep=11, n/st=2000, rew=157.73]
Epoch #2: 20%|## | 10000/50000 [00:01<00:05, 7578.38it/s, env_episode=1022, env_step=58000, gradient_step=232, len=32, n/ep=11, n/st=2000, rew=157.73]
Epoch #2: 20%|## | 10000/50000 [00:01<00:05, 7578.38it/s, env_episode=1034, env_step=60000, gradient_step=240, len=38, n/ep=12, n/st=2000, rew=167.25]
Epoch #2: 24%|##4 | 12000/50000 [00:01<00:05, 7557.25it/s, env_episode=1034, env_step=60000, gradient_step=240, len=38, n/ep=12, n/st=2000, rew=167.25]
Epoch #2: 24%|##4 | 12000/50000 [00:01<00:05, 7557.25it/s, env_episode=1051, env_step=62000, gradient_step=248, len=54, n/ep=17, n/st=2000, rew=158.35]
Epoch #2: 28%|##8 | 14000/50000 [00:01<00:04, 7551.68it/s, env_episode=1051, env_step=62000, gradient_step=248, len=54, n/ep=17, n/st=2000, rew=158.35]
Epoch #2: 28%|##8 | 14000/50000 [00:01<00:04, 7551.68it/s, env_episode=1068, env_step=64000, gradient_step=256, len=48, n/ep=17, n/st=2000, rew=125.24]
Epoch #2: 32%|###2 | 16000/50000 [00:02<00:04, 7545.95it/s, env_episode=1068, env_step=64000, gradient_step=256, len=48, n/ep=17, n/st=2000, rew=125.24]
Epoch #2: 32%|###2 | 16000/50000 [00:02<00:04, 7545.95it/s, env_episode=1091, env_step=66000, gradient_step=264, len=47, n/ep=23, n/st=2000, rew=100.83]
Epoch #2: 36%|###6 | 18000/50000 [00:02<00:04, 7543.68it/s, env_episode=1091, env_step=66000, gradient_step=264, len=47, n/ep=23, n/st=2000, rew=100.83]
Epoch #2: 36%|###6 | 18000/50000 [00:02<00:04, 7543.68it/s, env_episode=1119, env_step=68000, gradient_step=272, len=41, n/ep=28, n/st=2000, rew=77.93]
Epoch #2: 40%|#### | 20000/50000 [00:02<00:03, 7540.85it/s, env_episode=1119, env_step=68000, gradient_step=272, len=41, n/ep=28, n/st=2000, rew=77.93]
Epoch #2: 40%|#### | 20000/50000 [00:02<00:03, 7540.85it/s, env_episode=1124, env_step=70000, gradient_step=280, len=28, n/ep=5, n/st=2000, rew=93.40]
Epoch #2: 44%|####4 | 22000/50000 [00:02<00:03, 7528.21it/s, env_episode=1124, env_step=70000, gradient_step=280, len=28, n/ep=5, n/st=2000, rew=93.40]
Epoch #2: 44%|####4 | 22000/50000 [00:02<00:03, 7528.21it/s, env_episode=1137, env_step=72000, gradient_step=288, len=52, n/ep=13, n/st=2000, rew=169.38]
Epoch #2: 48%|####8 | 24000/50000 [00:03<00:04, 5817.28it/s, env_episode=1137, env_step=72000, gradient_step=288, len=52, n/ep=13, n/st=2000, rew=169.38]
Epoch #2: 48%|####8 | 24000/50000 [00:03<00:04, 5817.28it/s, env_episode=1147, env_step=74000, gradient_step=296, len=39, n/ep=10, n/st=2000, rew=208.60]
Epoch #2: 52%|#####2 | 26000/50000 [00:03<00:03, 6268.76it/s, env_episode=1147, env_step=74000, gradient_step=296, len=39, n/ep=10, n/st=2000, rew=208.60]
Epoch #2: 52%|#####2 | 26000/50000 [00:03<00:03, 6268.76it/s, env_episode=1158, env_step=76000, gradient_step=304, len=59, n/ep=11, n/st=2000, rew=165.00]
Epoch #2: 56%|#####6 | 28000/50000 [00:03<00:03, 6598.77it/s, env_episode=1158, env_step=76000, gradient_step=304, len=59, n/ep=11, n/st=2000, rew=165.00]
Epoch #2: 56%|#####6 | 28000/50000 [00:03<00:03, 6598.77it/s, env_episode=1172, env_step=78000, gradient_step=312, len=41, n/ep=14, n/st=2000, rew=183.43]
Epoch #2: 60%|###### | 30000/50000 [00:04<00:02, 6865.13it/s, env_episode=1172, env_step=78000, gradient_step=312, len=41, n/ep=14, n/st=2000, rew=183.43]
Epoch #2: 60%|###### | 30000/50000 [00:04<00:02, 6865.13it/s, env_episode=1183, env_step=80000, gradient_step=320, len=43, n/ep=11, n/st=2000, rew=140.73]
Epoch #2: 64%|######4 | 32000/50000 [00:04<00:03, 5711.58it/s, env_episode=1183, env_step=80000, gradient_step=320, len=43, n/ep=11, n/st=2000, rew=140.73]
Epoch #2: 64%|######4 | 32000/50000 [00:04<00:03, 5711.58it/s, env_episode=1196, env_step=82000, gradient_step=328, len=51, n/ep=13, n/st=2000, rew=196.46]
Epoch #2: 68%|######8 | 34000/50000 [00:04<00:02, 6163.77it/s, env_episode=1196, env_step=82000, gradient_step=328, len=51, n/ep=13, n/st=2000, rew=196.46]
Epoch #2: 68%|######8 | 34000/50000 [00:04<00:02, 6163.77it/s, env_episode=1206, env_step=84000, gradient_step=336, len=48, n/ep=10, n/st=2000, rew=165.60]
Epoch #2: 72%|#######2 | 36000/50000 [00:05<00:02, 6531.70it/s, env_episode=1206, env_step=84000, gradient_step=336, len=48, n/ep=10, n/st=2000, rew=165.60]
Epoch #2: 72%|#######2 | 36000/50000 [00:05<00:02, 6531.70it/s, env_episode=1220, env_step=86000, gradient_step=344, len=52, n/ep=14, n/st=2000, rew=159.71]
Epoch #2: 76%|#######6 | 38000/50000 [00:05<00:01, 6794.90it/s, env_episode=1220, env_step=86000, gradient_step=344, len=52, n/ep=14, n/st=2000, rew=159.71]
Epoch #2: 76%|#######6 | 38000/50000 [00:05<00:01, 6794.90it/s, env_episode=1231, env_step=88000, gradient_step=352, len=47, n/ep=11, n/st=2000, rew=175.36]
Epoch #2: 80%|######## | 40000/50000 [00:05<00:01, 7016.90it/s, env_episode=1231, env_step=88000, gradient_step=352, len=47, n/ep=11, n/st=2000, rew=175.36]
Epoch #2: 80%|######## | 40000/50000 [00:05<00:01, 7016.90it/s, env_episode=1241, env_step=90000, gradient_step=360, len=57, n/ep=10, n/st=2000, rew=164.70]
Epoch #2: 84%|########4 | 42000/50000 [00:06<00:01, 5778.82it/s, env_episode=1241, env_step=90000, gradient_step=360, len=57, n/ep=10, n/st=2000, rew=164.70]
Epoch #2: 84%|########4 | 42000/50000 [00:06<00:01, 5778.82it/s, env_episode=1254, env_step=92000, gradient_step=368, len=57, n/ep=13, n/st=2000, rew=198.38]
Epoch #2: 88%|########8 | 44000/50000 [00:06<00:00, 6209.59it/s, env_episode=1254, env_step=92000, gradient_step=368, len=57, n/ep=13, n/st=2000, rew=198.38]
Epoch #2: 88%|########8 | 44000/50000 [00:06<00:00, 6209.59it/s, env_episode=1268, env_step=94000, gradient_step=376, len=50, n/ep=14, n/st=2000, rew=140.07]
Epoch #2: 92%|#########2| 46000/50000 [00:06<00:00, 6543.06it/s, env_episode=1268, env_step=94000, gradient_step=376, len=50, n/ep=14, n/st=2000, rew=140.07]
Epoch #2: 92%|#########2| 46000/50000 [00:06<00:00, 6543.06it/s, env_episode=1289, env_step=96000, gradient_step=384, len=49, n/ep=21, n/st=2000, rew=113.57]
Epoch #2: 96%|#########6| 48000/50000 [00:07<00:00, 6827.17it/s, env_episode=1289, env_step=96000, gradient_step=384, len=49, n/ep=21, n/st=2000, rew=113.57]
Epoch #2: 96%|#########6| 48000/50000 [00:07<00:00, 6827.17it/s, env_episode=1305, env_step=98000, gradient_step=392, len=36, n/ep=16, n/st=2000, rew=97.38]
Epoch #2: 100%|##########| 50000/50000 [00:07<00:00, 7052.33it/s, env_episode=1305, env_step=98000, gradient_step=392, len=36, n/ep=16, n/st=2000, rew=97.38]
Epoch #2: 100%|##########| 50000/50000 [00:07<00:00, 7052.33it/s, env_episode=1317, env_step=100000, gradient_step=400, len=43, n/ep=12, n/st=2000, rew=146.17]
Epoch #2: 50001it [00:07, 6827.29it/s, env_episode=1317, env_step=100000, gradient_step=400, len=43, n/ep=12, n/st=2000, rew=146.17]
Epoch #2: test_reward: 115.000000 ± 70.366185, best_reward: 118.100000 ± 55.646114 in #1
Epoch #3: 0%| | 0/50000 [00:00<?, ?it/s]
Epoch #3: 4%|4 | 2000/50000 [00:00<00:06, 7679.55it/s]
Epoch #3: 4%|4 | 2000/50000 [00:00<00:06, 7679.55it/s, env_episode=1327, env_step=102000, gradient_step=408, len=41, n/ep=10, n/st=2000, rew=118.00]
Epoch #3: 8%|8 | 4000/50000 [00:00<00:06, 7596.56it/s, env_episode=1327, env_step=102000, gradient_step=408, len=41, n/ep=10, n/st=2000, rew=118.00]
Epoch #3: 8%|8 | 4000/50000 [00:00<00:06, 7596.56it/s, env_episode=1356, env_step=104000, gradient_step=416, len=34, n/ep=29, n/st=2000, rew=104.62]
Epoch #3: 12%|#2 | 6000/50000 [00:00<00:05, 7592.44it/s, env_episode=1356, env_step=104000, gradient_step=416, len=34, n/ep=29, n/st=2000, rew=104.62]
Epoch #3: 12%|#2 | 6000/50000 [00:00<00:05, 7592.44it/s, env_episode=1376, env_step=106000, gradient_step=424, len=33, n/ep=20, n/st=2000, rew=106.55]
Epoch #3: 16%|#6 | 8000/50000 [00:01<00:05, 7573.87it/s, env_episode=1376, env_step=106000, gradient_step=424, len=33, n/ep=20, n/st=2000, rew=106.55]
Epoch #3: 16%|#6 | 8000/50000 [00:01<00:05, 7573.87it/s, env_episode=1398, env_step=108000, gradient_step=432, len=38, n/ep=22, n/st=2000, rew=97.86]
Epoch #3: 20%|## | 10000/50000 [00:01<00:05, 7457.01it/s, env_episode=1398, env_step=108000, gradient_step=432, len=38, n/ep=22, n/st=2000, rew=97.86]
Epoch #3: 20%|## | 10000/50000 [00:01<00:05, 7457.01it/s, env_episode=1426, env_step=110000, gradient_step=440, len=33, n/ep=28, n/st=2000, rew=55.11]
Epoch #3: 24%|##4 | 12000/50000 [00:01<00:05, 7476.53it/s, env_episode=1426, env_step=110000, gradient_step=440, len=33, n/ep=28, n/st=2000, rew=55.11]
Epoch #3: 24%|##4 | 12000/50000 [00:01<00:05, 7476.53it/s, env_episode=1436, env_step=112000, gradient_step=448, len=46, n/ep=10, n/st=2000, rew=147.30]
Epoch #3: 28%|##8 | 14000/50000 [00:01<00:04, 7485.91it/s, env_episode=1436, env_step=112000, gradient_step=448, len=46, n/ep=10, n/st=2000, rew=147.30]
Epoch #3: 28%|##8 | 14000/50000 [00:01<00:04, 7485.91it/s, env_episode=1451, env_step=114000, gradient_step=456, len=57, n/ep=15, n/st=2000, rew=175.93]
Epoch #3: 32%|###2 | 16000/50000 [00:02<00:04, 7487.74it/s, env_episode=1451, env_step=114000, gradient_step=456, len=57, n/ep=15, n/st=2000, rew=175.93]
Epoch #3: 32%|###2 | 16000/50000 [00:02<00:04, 7487.74it/s, env_episode=1473, env_step=116000, gradient_step=464, len=37, n/ep=22, n/st=2000, rew=107.64]
Epoch #3: 36%|###6 | 18000/50000 [00:02<00:04, 7507.15it/s, env_episode=1473, env_step=116000, gradient_step=464, len=37, n/ep=22, n/st=2000, rew=107.64]
Epoch #3: 36%|###6 | 18000/50000 [00:02<00:04, 7507.15it/s, env_episode=1483, env_step=118000, gradient_step=472, len=52, n/ep=10, n/st=2000, rew=171.40]
Epoch #3: 40%|#### | 20000/50000 [00:02<00:03, 7507.24it/s, env_episode=1483, env_step=118000, gradient_step=472, len=52, n/ep=10, n/st=2000, rew=171.40]
Epoch #3: 40%|#### | 20000/50000 [00:02<00:03, 7507.24it/s, env_episode=1504, env_step=120000, gradient_step=480, len=47, n/ep=21, n/st=2000, rew=107.19]
Epoch #3: 44%|####4 | 22000/50000 [00:02<00:03, 7519.32it/s, env_episode=1504, env_step=120000, gradient_step=480, len=47, n/ep=21, n/st=2000, rew=107.19]
Epoch #3: 44%|####4 | 22000/50000 [00:02<00:03, 7519.32it/s, env_episode=1530, env_step=122000, gradient_step=488, len=36, n/ep=26, n/st=2000, rew=94.08]
Epoch #3: 48%|####8 | 24000/50000 [00:03<00:03, 7498.73it/s, env_episode=1530, env_step=122000, gradient_step=488, len=36, n/ep=26, n/st=2000, rew=94.08]
Epoch #3: 48%|####8 | 24000/50000 [00:03<00:03, 7498.73it/s, env_episode=1551, env_step=124000, gradient_step=496, len=43, n/ep=21, n/st=2000, rew=81.62]
Epoch #3: 52%|#####2 | 26000/50000 [00:03<00:03, 7471.58it/s, env_episode=1551, env_step=124000, gradient_step=496, len=43, n/ep=21, n/st=2000, rew=81.62]
Epoch #3: 52%|#####2 | 26000/50000 [00:03<00:03, 7471.58it/s, env_episode=1561, env_step=126000, gradient_step=504, len=54, n/ep=10, n/st=2000, rew=132.20]
Epoch #3: 56%|#####6 | 28000/50000 [00:03<00:02, 7496.64it/s, env_episode=1561, env_step=126000, gradient_step=504, len=54, n/ep=10, n/st=2000, rew=132.20]
Epoch #3: 56%|#####6 | 28000/50000 [00:03<00:02, 7496.64it/s, env_episode=1583, env_step=128000, gradient_step=512, len=37, n/ep=22, n/st=2000, rew=94.50]
Epoch #3: 60%|###### | 30000/50000 [00:03<00:02, 7495.92it/s, env_episode=1583, env_step=128000, gradient_step=512, len=37, n/ep=22, n/st=2000, rew=94.50]
Epoch #3: 60%|###### | 30000/50000 [00:03<00:02, 7495.92it/s, env_episode=1600, env_step=130000, gradient_step=520, len=49, n/ep=17, n/st=2000, rew=131.24]
Epoch #3: 64%|######4 | 32000/50000 [00:04<00:02, 7512.02it/s, env_episode=1600, env_step=130000, gradient_step=520, len=49, n/ep=17, n/st=2000, rew=131.24]
Epoch #3: 64%|######4 | 32000/50000 [00:04<00:02, 7512.02it/s, env_episode=1624, env_step=132000, gradient_step=528, len=37, n/ep=24, n/st=2000, rew=99.21]
Epoch #3: 68%|######8 | 34000/50000 [00:04<00:02, 7516.85it/s, env_episode=1624, env_step=132000, gradient_step=528, len=37, n/ep=24, n/st=2000, rew=99.21]
Epoch #3: 68%|######8 | 34000/50000 [00:04<00:02, 7516.85it/s, env_episode=1644, env_step=134000, gradient_step=536, len=36, n/ep=20, n/st=2000, rew=90.25]
Epoch #3: 72%|#######2 | 36000/50000 [00:04<00:01, 7516.07it/s, env_episode=1644, env_step=134000, gradient_step=536, len=36, n/ep=20, n/st=2000, rew=90.25]
Epoch #3: 72%|#######2 | 36000/50000 [00:04<00:01, 7516.07it/s, env_episode=1650, env_step=136000, gradient_step=544, len=63, n/ep=6, n/st=2000, rew=168.83]
Epoch #3: 76%|#######6 | 38000/50000 [00:05<00:01, 7503.63it/s, env_episode=1650, env_step=136000, gradient_step=544, len=63, n/ep=6, n/st=2000, rew=168.83]
Epoch #3: 76%|#######6 | 38000/50000 [00:05<00:01, 7503.63it/s, env_episode=1665, env_step=138000, gradient_step=552, len=39, n/ep=15, n/st=2000, rew=154.73]
Epoch #3: 80%|######## | 40000/50000 [00:05<00:01, 7510.07it/s, env_episode=1665, env_step=138000, gradient_step=552, len=39, n/ep=15, n/st=2000, rew=154.73]
Epoch #3: 80%|######## | 40000/50000 [00:05<00:01, 7510.07it/s, env_episode=1676, env_step=140000, gradient_step=560, len=53, n/ep=11, n/st=2000, rew=106.91]
Epoch #3: 84%|########4 | 42000/50000 [00:05<00:01, 7512.44it/s, env_episode=1676, env_step=140000, gradient_step=560, len=53, n/ep=11, n/st=2000, rew=106.91]
Epoch #3: 84%|########4 | 42000/50000 [00:05<00:01, 7512.44it/s, env_episode=1690, env_step=142000, gradient_step=568, len=38, n/ep=14, n/st=2000, rew=174.07]
Epoch #3: 88%|########8 | 44000/50000 [00:05<00:00, 7529.88it/s, env_episode=1690, env_step=142000, gradient_step=568, len=38, n/ep=14, n/st=2000, rew=174.07]
Epoch #3: 88%|########8 | 44000/50000 [00:05<00:00, 7529.88it/s, env_episode=1709, env_step=144000, gradient_step=576, len=30, n/ep=19, n/st=2000, rew=145.26]
Epoch #3: 92%|#########2| 46000/50000 [00:06<00:00, 7528.30it/s, env_episode=1709, env_step=144000, gradient_step=576, len=30, n/ep=19, n/st=2000, rew=145.26]
Epoch #3: 92%|#########2| 46000/50000 [00:06<00:00, 7528.30it/s, env_episode=1721, env_step=146000, gradient_step=584, len=58, n/ep=12, n/st=2000, rew=141.67]
Epoch #3: 96%|#########6| 48000/50000 [00:06<00:00, 7481.39it/s, env_episode=1721, env_step=146000, gradient_step=584, len=58, n/ep=12, n/st=2000, rew=141.67]
Epoch #3: 96%|#########6| 48000/50000 [00:06<00:00, 7481.39it/s, env_episode=1760, env_step=148000, gradient_step=592, len=36, n/ep=39, n/st=2000, rew=89.85]
Epoch #3: 100%|##########| 50000/50000 [00:06<00:00, 7499.57it/s, env_episode=1760, env_step=148000, gradient_step=592, len=36, n/ep=39, n/st=2000, rew=89.85]
Epoch #3: 100%|##########| 50000/50000 [00:06<00:00, 7499.57it/s, env_episode=1772, env_step=150000, gradient_step=600, len=29, n/ep=12, n/st=2000, rew=58.17]
Epoch #3: 50001it [00:06, 7507.21it/s, env_episode=1772, env_step=150000, gradient_step=600, len=29, n/ep=12, n/st=2000, rew=58.17]
Epoch #3: test_reward: 252.800000 ± 172.729731, best_reward: 252.800000 ± 172.729731 in #3
train_result.pprint_asdict()
InfoStats
----------------------------------------
{ 'best_reward': 252.8,
'best_reward_std': 172.7297310829841,
'best_score': 252.8,
'gradient_step': 600,
'test_episode': 70,
'test_step': 9460,
'timing': { 'test_time': 0.0,
'total_time': 21.52031397819519,
'train_time': 21.52031397819519,
'train_time_collect': 0.0,
'train_time_update': 9.637651681900024,
'update_speed': 6970.158527983513},
'train_episode': 1772,
'train_step': 150000}
# Let's watch its performance!
policy.eval()
eval_result = test_collector.collect(n_episode=3, render=False)
print(f"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}")
/home/docs/checkouts/readthedocs.org/user_builds/tianshou/checkouts/v1.2.0/tianshou/data/collector.py:542: UserWarning: n_episode=3 should be larger than self.env_num=10 to collect at least one trajectory in each environment.
warnings.warn(
Number of episodes (3) is smaller than the number of environments (10). This means that 7 environments (or, equivalently, parallel workers) will not be used!
Final reward: 500.0, length: 500.0
Tutorial Introduction#
A common DRL experiment as is shown above may require many components to work together. The agent, the environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a training task.
In Tianshou, all of these main components are factored out as different building blocks, which you can use to create your own algorithm and finish your own experiment.
Building blocks may include:
Batch
Replay Buffer
Vectorized Environment Wrapper
Policy (the agent and the training algorithm)
Data Collector
Trainer
Logger
These notebooks tutorials will guide you through all the modules one by one.
Further reading#
What if I am not familiar with the PPO algorithm itself?#
As for the DRL algorithms themselves, we will refer you to the Spinning up documentation, where they provide plenty of resources and guides if you want to study the DRL algorithms. In Tianshou’s tutorials, we will focus on the usages of different modules, but not the algorithms themselves.