Source code for tianshou.utils.determinism

import difflib
import inspect
import os
import re
import time
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from io import StringIO
from pathlib import Path
from typing import Self

import torch
from sensai.util import logging
from sensai.util.git import GitStatus, git_status
from sensai.util.pickle import dump_pickle, load_pickle


[docs] def format_log_message( logger: logging.Logger, level: int, msg: str, formatter: logging.Formatter, stacklevel: int = 1, ) -> str: """ Formats a log message as it would have been created by `logger.log(level, msg)` with the given formatter. :param logger: the logger :param level: the log level :param msg: the message :param formatter: the formatter :param stacklevel: the stack level of the function to report as the generator :return: the formatted log message (not including trailing newline) """ frame_info = inspect.stack()[stacklevel] pathname = frame_info.filename lineno = frame_info.lineno func = frame_info.function record = logger.makeRecord( name=logger.name, level=level, fn=pathname, lno=lineno, msg=msg, args=(), exc_info=None, func=func, extra=None, ) record.created = time.time() record.asctime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created)) return formatter.format(record)
[docs] class TraceLogger: """Supports the collection of behavioural trace logs, which can, in particular, be used for determinism tests.""" is_enabled = False """ whether the trace logger is enabled. NOTE: The preferred way to enable this is via the context manager. """ verbose = False """ whether to print trace log messages to stdout. """ MESSAGE_TAG = "[TRACE]" """ a tag which is added at the beginning of log messages generated by this logger """ LOG_LEVEL = logging.DEBUG log_buffer: StringIO | None = None log_formatter: logging.Formatter | None = None
[docs] @classmethod def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> None: """ Logs a message intended for tracing agent-env interaction, which is enabled via `TraceAgentEnvLoggerContext`. :param logger: the logger to use for the actual logging :param message_generator: function which generates the log message (which may be expensive); if logging is disabled, the function will not be called. """ if not cls.is_enabled: return msg = message_generator() msg = cls.MESSAGE_TAG + " " + msg # Log with caller's frame info logger.log(logging.DEBUG, msg, stacklevel=2) # If a dedicated memory buffer is configured, also store the message there if cls.log_buffer is not None: msg_formatted = format_log_message( logger, logging.DEBUG, msg, cls.log_formatter, stacklevel=2, ) cls.log_buffer.write(msg_formatted + "\n") if cls.verbose: print(msg_formatted)
[docs] @dataclass class TraceLog: log_lines: list[str]
[docs] def save_log(self, path: str) -> None: with open(path, "w") as f: for line in self.log_lines: f.write(line + "\n")
[docs] def print_log(self) -> None: for line in self.log_lines: print(line)
[docs] def get_full_log(self) -> str: return "\n".join(self.log_lines)
[docs] def reduce_log_to_messages(self) -> "TraceLog": """ Removes logger names and function names from the log entries, such that each log message contains only the main text message itself (starting with the content after the logger's tag). :return: the result with reduced log messages """ lines = [] tag = re.escape(TraceLogger.MESSAGE_TAG) for line in self.log_lines: lines.append(re.sub(r".*" + tag, "", line)) return TraceLog(lines)
[docs] def filter_messages( self, required_messages: Sequence[str] = (), optional_messages: Sequence[str] = (), ignored_messages: Sequence[str] = (), ) -> "TraceLog": """ Applies inclusion and or exclusion filtering to the log messages. If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied. If `ignored_messages` is empty, exclusion filtering is applied. If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence. :param required_messages: required message substrings to filter for; each message is required to appear at least once (triggering exception otherwise) :param optional_messages: additional messages fragments to filter for; these are not required :param ignored_messages: message fragments that result in exclusion; takes precedence over `required_messages` and `optional_messages` :return: the result with reduced log messages """ import numpy as np required_message_counters = np.zeros(len(required_messages)) def retain_line(line: str) -> bool: for ignored_message in ignored_messages: if ignored_message in line: return False if required_messages or optional_messages: for i, main_message in enumerate(required_messages): if main_message in line: required_message_counters[i] += 1 return True return any(add_message in line for add_message in optional_messages) else: return True lines = [] for line in self.log_lines: if retain_line(line): lines.append(line) assert np.all( required_message_counters > 0, ), "Not all types of required messages were found in the trace. Were log messages changed?" return TraceLog(lines)
[docs] class TraceLoggerContext: """ A context manager which enables the trace logger. Apart from enabling the logging, it can optionally create a memory log buffer, such that getting the trace log is not strictly dependent on the logging system. """ def __init__( self, enable_log_buffer: bool = True, log_format: str = "%(name)s:%(funcName)s - %(message)s", ) -> None: """ :param enable_log_buffer: whether to enable the dedicated log buffer for trace logs, whose contents can, within the context of this manager, be accessed via method `get_log`. :param log_format: the logger format string to use for the dedicated log buffer """ self._enable_log_buffer = enable_log_buffer self._log_format: str = log_format self._log_buffer: StringIO | None = None def __enter__(self) -> Self: TraceLogger.is_enabled = True if self._enable_log_buffer: TraceLogger.log_buffer = StringIO() TraceLogger.log_formatter = logging.Formatter(self._log_format) self._log_buffer = TraceLogger.log_buffer return self def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore TraceLogger.is_enabled = False TraceLogger.log_buffer = None TraceLogger.log_formatter = None
[docs] def get_log(self) -> TraceLog: """:return: the full trace log that was captured if `enable_log_buffer` was enabled at construction""" if self._log_buffer is None: raise Exception( "This method is only supported if the log buffer is enabled at construction", ) return TraceLog(log_lines=self._log_buffer.getvalue().split("\n"))
[docs] def torch_param_hash(module: torch.nn.Module) -> str: """ Computes a hash of the parameters of the given module; parameters not requiring gradients are ignored. :param module: a torch module :return: a hex digest of the parameters of the module """ import hashlib hasher = hashlib.sha1() for param in module.parameters(): if param.requires_grad: np_array = param.detach().cpu().numpy() hasher.update(np_array.tobytes()) return hasher.hexdigest()
[docs] class TraceDeterminismTest: def __init__( self, base_path: Path, core_messages: Sequence[str] = (), ignored_messages: Sequence[str] = (), log_filename: str | None = None, ) -> None: """ :param base_path: the directory where the reference results are stored (will be created if necessary) :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core :param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over `core_messages` :param log_filename: the name of the log file to which results are to be written (if any) """ base_path.mkdir(parents=True, exist_ok=True) self.base_path = base_path self.core_messages = core_messages self.ignored_messages = ignored_messages self.log_filename = log_filename
[docs] @dataclass(kw_only=True) class Result: git_status: GitStatus log: TraceLog
[docs] def check( self, current_log: TraceLog, name: str, create_reference_result: bool = False, pass_if_core_messages_unchanged: bool = False, ) -> None: """ Checks the given log against the reference result for the given name. :param current_log: the result to check :param name: the name of the reference result; must be unique among all tests! :param create_reference_result: whether update the reference result with the given result """ import pytest reference_result_path = self.base_path / f"{name}.pkl.bz2" current_git_status = git_status() if create_reference_result: current_result = self.Result(git_status=current_git_status, log=current_log) dump_pickle(current_result, reference_result_path) reference_result: TraceDeterminismTest.Result = load_pickle( reference_result_path, ) reference_log = reference_result.log current_log_reduced = current_log.reduce_log_to_messages().filter_messages( ignored_messages=self.ignored_messages, ) reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages( ignored_messages=self.ignored_messages, ) results: list[tuple[TraceLog, str]] = [ (reference_log_reduced, "expected"), (current_log_reduced, "current"), (reference_log, "expected_full"), (current_log, "current_full"), ] if self.core_messages: result_main_messages = current_log_reduced.filter_messages( required_messages=self.core_messages, ) reference_result_main_messages = reference_log_reduced.filter_messages( required_messages=self.core_messages, ) results.extend( [ (reference_result_main_messages, "expected_core"), (result_main_messages, "current_core"), ], ) else: result_main_messages = current_log_reduced reference_result_main_messages = reference_log_reduced logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log() if logs_equivalent: status_passed = True status_message = "OK" else: core_messages_unchanged = ( len(self.core_messages) > 0 and result_main_messages.get_full_log() == reference_result_main_messages.get_full_log() ) status_passed = core_messages_unchanged and pass_if_core_messages_unchanged if status_passed: status_message = "OK (core messages unchanged)" else: # save files for comparison files = [] for r, suffix in results: path = os.path.abspath(f"determinism_{name}_{suffix}.txt") r.save_log(path) files.append(path) paths_str = "\n".join(files) main_message = ( f"Please inspect the changes by diffing the log files:\n{paths_str}\n" f"If the changes are OK, enable the `create_reference_result` flag temporarily, " "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" ) # compute diff and add to message num_diff_lines_to_show = 30 for i, line in enumerate( difflib.unified_diff( reference_log_reduced.log_lines, current_log_reduced.log_lines, fromfile="expected.txt", tofile="current.txt", lineterm="", ), ): if i == num_diff_lines_to_show: break main_message += line + "\n" if core_messages_unchanged: status_message = ( "The behaviour log has changed, but the core messages are still the same (so this " f"probably isn't an issue). {main_message}" ) else: status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" # write log message if self.log_filename: with open(self.log_filename, "a") as f: hr = "-" * 100 f.write(f"\n\n{hr}\nName: {name}\n") f.write(f"Reference state: {reference_result.git_status}\n") f.write(f"Current state: {current_git_status}\n") f.write(f"Test result: {status_message}\n") if not status_passed: pytest.fail(status_message)