Source code for tianshou.highlevel.logger

import os
from abc import ABC, abstractmethod
from typing import Literal, TypeAlias

from sensai.util.string import ToStringMixin
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger

TLogger: TypeAlias = BaseLogger


[docs] class LoggerFactory(ToStringMixin, ABC):
[docs] @abstractmethod def create_logger( self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict | None = None, ) -> TLogger: """Creates the logger. :param log_dir: path to the directory in which log data is to be stored :param experiment_name: the name of the job, which may contain `os.path.delimiter` :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger :param config_dict: a dictionary with data that is to be logged :return: the logger """
[docs] @abstractmethod def get_logger_class(self) -> type[TLogger]: """Returns the class of the logger that is to be created."""
[docs] class LoggerFactoryDefault(LoggerFactory): def __init__( self, logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard", wand_entity: str | None = None, wandb_project: str | None = None, group: str | None = None, job_type: str | None = None, save_interval: int = 1, ): if logger_type == "wandb" and wandb_project is None: raise ValueError("Must provide 'wandb_project'") self.logger_type = logger_type self.wandb_entity = wand_entity self.wandb_project = wandb_project self.group = group self.job_type = job_type self.save_interval = save_interval
[docs] def create_logger( self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict | None = None, ) -> TLogger: match self.logger_type: case "wandb": logger = WandbLogger( save_interval=self.save_interval, name=experiment_name.replace(os.path.sep, "__"), run_id=run_id, config=config_dict, entity=self.wandb_entity, project=self.wandb_project, group=self.group, job_type=self.job_type, log_dir=log_dir, ) writer = self._create_writer(log_dir) # writer has to be created after wandb.init! logger.load(writer) return logger case "tensorboard": writer = self._create_writer(log_dir) return TensorboardLogger(writer) case _: raise ValueError(f"Unknown logger type '{self.logger_type}'")
def _create_writer(self, log_dir: str) -> SummaryWriter: """Creates a tensorboard writer and adds a text artifact.""" writer = SummaryWriter(log_dir) writer.add_text( "args", str( dict( log_dir=log_dir, logger_type=self.logger_type, wandb_project=self.wandb_project, ), ), ) return writer
[docs] def get_logger_class(self) -> type[TLogger]: match self.logger_type: case "wandb": return WandbLogger case "tensorboard": return TensorboardLogger case _: raise ValueError(f"Unknown logger type '{self.logger_type}'")