import argparse
import logging
import os
from collections.abc import Callable
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger
from tianshou.utils.logger.logger_base import VALID_LOG_VALS_TYPE, TRestoredData
log = logging.getLogger(__name__)
[docs]
class WandbLogger(BaseLogger):
"""Weights and Biases logger that sends data to https://wandb.ai/.
This logger creates three panels with plots: train, test, and update.
Make sure to select the correct access for each panel in weights and biases:
Example of usage:
::
logger = WandbLogger()
logger.load(SummaryWriter(log_path))
:param training_interval: the log interval in log_training_data().
:param test_interval: the log interval in log_test_data().
:param update_interval: the log interval in log_update_data().
:param info_interval: the log interval in log_info_data().
:param save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
:param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None.
:param str run_id: run id of W&B run to be resumed. Default to None.
:param argparse.Namespace config: experiment configurations. Default to None.
"""
def __init__(
self,
training_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
info_interval: int = 1,
save_interval: int | None = None,
write_flush: bool = True,
project: str | None = None,
name: str | None = None,
entity: str | None = None,
run_id: str | None = None,
group: str | None = None,
job_type: str | None = None,
config: argparse.Namespace | dict | None = None,
monitor_gym: bool = True,
disable_stats: bool = False,
log_dir: str | None = None,
) -> None:
import wandb
super().__init__(
training_interval, test_interval, update_interval, info_interval, save_interval
)
self.last_save_step = -1
self.write_flush = write_flush
self.restored = False
if project is None:
project = os.getenv("WANDB_PROJECT", "tianshou")
wandb_run = (
wandb.init(
project=project,
group=group,
job_type=job_type,
name=name,
id=run_id,
resume="allow",
entity=entity,
sync_tensorboard=True,
# monitor_gym=monitor_gym, # currently disabled until gymnasium version is bumped to >1.0.0 https://github.com/wandb/wandb/issues/7047
dir=log_dir,
config=config, # type: ignore
settings=wandb.Settings(x_disable_stats=disable_stats),
)
if not wandb.run
else wandb.run
)
assert wandb_run is not None
self.wandb_run = wandb_run
self.wandb_run._label(repo="tianshou")
self.tensorboard_logger: TensorboardLogger | None = None
self.writer: SummaryWriter | None = None
[docs]
def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
if self.tensorboard_logger is None:
raise Exception(
"`logger` needs to load the Tensorboard Writer before "
"preparing data for logging. Try `logger.load(SummaryWriter(log_path))`",
)
return self.tensorboard_logger.prepare_dict_for_logging(log_data)
[docs]
def load(self, writer: SummaryWriter) -> None:
self.writer = writer
self.tensorboard_logger = TensorboardLogger(
writer,
self.training_interval,
self.test_interval,
self.update_interval,
self.info_interval,
self.save_interval,
self.write_flush,
)
[docs]
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
if self.tensorboard_logger is None:
raise RuntimeError(
"`logger` needs to load the Tensorboard Writer before "
"writing data. Try `logger.load(SummaryWriter(log_path))`",
)
self.tensorboard_logger.write(step_type, step, data)
[docs]
def finalize(self) -> None:
if self.wandb_run is not None:
self.wandb_run.finish()
if self.tensorboard_logger is not None:
self.tensorboard_logger.finalize()
[docs]
def save_data(
self,
epoch: int,
env_step: int,
update_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param epoch: the epoch in trainer.
:param env_step: the env_step in trainer.
:param update_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
import wandb
if (
self.save_interval is not None
and save_checkpoint_fn
and epoch - self.last_save_step >= self.save_interval
):
self.last_save_step = epoch
checkpoint_path = save_checkpoint_fn(epoch, env_step, update_step)
checkpoint_artifact = wandb.Artifact(
"run_" + self.wandb_run.id + "_checkpoint",
type="model",
metadata={
"save/epoch": epoch,
"save/env_step": env_step,
"save/gradient_step": update_step,
"checkpoint_path": str(checkpoint_path),
},
)
checkpoint_artifact.add_file(str(checkpoint_path))
self.wandb_run.log_artifact(checkpoint_artifact)
[docs]
def restore_data(self) -> tuple[int, int, int]:
checkpoint_artifact = self.wandb_run.use_artifact(
f"run_{self.wandb_run.id}_checkpoint:latest",
)
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
checkpoint_artifact.download(
os.path.dirname(checkpoint_artifact.metadata["checkpoint_path"]),
)
try: # epoch / gradient_step
epoch = checkpoint_artifact.metadata["save/epoch"]
self.last_save_step = self.last_log_test_step = epoch
gradient_step = checkpoint_artifact.metadata["save/gradient_step"]
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = checkpoint_artifact.metadata["save/env_step"]
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step
[docs]
@staticmethod
def restore_logged_data(log_path: str) -> TRestoredData:
log.warning(
"Logging data directly from W&B is not yet implemented, will use the "
"TensorboardLogger to restore it from disc instead.",
)
return TensorboardLogger.restore_logged_data(log_path)