import difflib
import inspect
import os
import re
import time
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from io import StringIO
from pathlib import Path
from typing import Self
import torch
from sensai.util import logging
from sensai.util.git import GitStatus, git_status
from sensai.util.pickle import dump_pickle, load_pickle
[docs]
class TraceLogger:
"""Supports the collection of behavioural trace logs, which can, in particular, be used for determinism tests."""
is_enabled = False
"""
whether the trace logger is enabled.
NOTE: The preferred way to enable this is via the context manager.
"""
verbose = False
"""
whether to print trace log messages to stdout.
"""
MESSAGE_TAG = "[TRACE]"
"""
a tag which is added at the beginning of log messages generated by this logger
"""
LOG_LEVEL = logging.DEBUG
log_buffer: StringIO | None = None
log_formatter: logging.Formatter | None = None
[docs]
@classmethod
def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> None:
"""
Logs a message intended for tracing agent-env interaction, which is enabled via
`TraceAgentEnvLoggerContext`.
:param logger: the logger to use for the actual logging
:param message_generator: function which generates the log message (which may be expensive);
if logging is disabled, the function will not be called.
"""
if not cls.is_enabled:
return
msg = message_generator()
msg = cls.MESSAGE_TAG + " " + msg
# Log with caller's frame info
logger.log(logging.DEBUG, msg, stacklevel=2)
# If a dedicated memory buffer is configured, also store the message there
if cls.log_buffer is not None:
msg_formatted = format_log_message(
logger,
logging.DEBUG,
msg,
cls.log_formatter,
stacklevel=2,
)
cls.log_buffer.write(msg_formatted + "\n")
if cls.verbose:
print(msg_formatted)
[docs]
@dataclass
class TraceLog:
log_lines: list[str]
[docs]
def save_log(self, path: str) -> None:
with open(path, "w") as f:
for line in self.log_lines:
f.write(line + "\n")
[docs]
def print_log(self) -> None:
for line in self.log_lines:
print(line)
[docs]
def get_full_log(self) -> str:
return "\n".join(self.log_lines)
[docs]
def reduce_log_to_messages(self) -> "TraceLog":
"""
Removes logger names and function names from the log entries, such that each log message
contains only the main text message itself (starting with the content after the logger's tag).
:return: the result with reduced log messages
"""
lines = []
tag = re.escape(TraceLogger.MESSAGE_TAG)
for line in self.log_lines:
lines.append(re.sub(r".*" + tag, "", line))
return TraceLog(lines)
[docs]
def filter_messages(
self,
required_messages: Sequence[str] = (),
optional_messages: Sequence[str] = (),
ignored_messages: Sequence[str] = (),
) -> "TraceLog":
"""
Applies inclusion and or exclusion filtering to the log messages.
If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied.
If `ignored_messages` is empty, exclusion filtering is applied.
If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence.
:param required_messages: required message substrings to filter for; each message is required to appear at least once
(triggering exception otherwise)
:param optional_messages: additional messages fragments to filter for; these are not required
:param ignored_messages: message fragments that result in exclusion; takes precedence over
`required_messages` and `optional_messages`
:return: the result with reduced log messages
"""
import numpy as np
required_message_counters = np.zeros(len(required_messages))
def retain_line(line: str) -> bool:
for ignored_message in ignored_messages:
if ignored_message in line:
return False
if required_messages or optional_messages:
for i, main_message in enumerate(required_messages):
if main_message in line:
required_message_counters[i] += 1
return True
return any(add_message in line for add_message in optional_messages)
else:
return True
lines = []
for line in self.log_lines:
if retain_line(line):
lines.append(line)
assert np.all(
required_message_counters > 0,
), "Not all types of required messages were found in the trace. Were log messages changed?"
return TraceLog(lines)
[docs]
class TraceLoggerContext:
"""
A context manager which enables the trace logger.
Apart from enabling the logging, it can optionally create a memory log buffer, such that
getting the trace log is not strictly dependent on the logging system.
"""
def __init__(
self,
enable_log_buffer: bool = True,
log_format: str = "%(name)s:%(funcName)s - %(message)s",
) -> None:
"""
:param enable_log_buffer: whether to enable the dedicated log buffer for trace logs, whose contents
can, within the context of this manager, be accessed via method `get_log`.
:param log_format: the logger format string to use for the dedicated log buffer
"""
self._enable_log_buffer = enable_log_buffer
self._log_format: str = log_format
self._log_buffer: StringIO | None = None
def __enter__(self) -> Self:
TraceLogger.is_enabled = True
if self._enable_log_buffer:
TraceLogger.log_buffer = StringIO()
TraceLogger.log_formatter = logging.Formatter(self._log_format)
self._log_buffer = TraceLogger.log_buffer
return self
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
TraceLogger.is_enabled = False
TraceLogger.log_buffer = None
TraceLogger.log_formatter = None
[docs]
def get_log(self) -> TraceLog:
""":return: the full trace log that was captured if `enable_log_buffer` was enabled at construction"""
if self._log_buffer is None:
raise Exception(
"This method is only supported if the log buffer is enabled at construction",
)
return TraceLog(log_lines=self._log_buffer.getvalue().split("\n"))
[docs]
def torch_param_hash(module: torch.nn.Module) -> str:
"""
Computes a hash of the parameters of the given module; parameters not requiring gradients are ignored.
:param module: a torch module
:return: a hex digest of the parameters of the module
"""
import hashlib
hasher = hashlib.sha1()
for param in module.parameters():
if param.requires_grad:
np_array = param.detach().cpu().numpy()
hasher.update(np_array.tobytes())
return hasher.hexdigest()
[docs]
class TraceDeterminismTest:
def __init__(
self,
base_path: Path,
core_messages: Sequence[str] = (),
ignored_messages: Sequence[str] = (),
log_filename: str | None = None,
) -> None:
"""
:param base_path: the directory where the reference results are stored (will be created if necessary)
:param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core
:param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over
`core_messages`
:param log_filename: the name of the log file to which results are to be written (if any)
"""
base_path.mkdir(parents=True, exist_ok=True)
self.base_path = base_path
self.core_messages = core_messages
self.ignored_messages = ignored_messages
self.log_filename = log_filename
[docs]
@dataclass(kw_only=True)
class Result:
git_status: GitStatus
log: TraceLog
[docs]
def check(
self,
current_log: TraceLog,
name: str,
create_reference_result: bool = False,
pass_if_core_messages_unchanged: bool = False,
) -> None:
"""
Checks the given log against the reference result for the given name.
:param current_log: the result to check
:param name: the name of the reference result; must be unique among all tests!
:param create_reference_result: whether update the reference result with the given result
"""
import pytest
reference_result_path = self.base_path / f"{name}.pkl.bz2"
current_git_status = git_status()
if create_reference_result:
current_result = self.Result(git_status=current_git_status, log=current_log)
dump_pickle(current_result, reference_result_path)
reference_result: TraceDeterminismTest.Result = load_pickle(
reference_result_path,
)
reference_log = reference_result.log
current_log_reduced = current_log.reduce_log_to_messages().filter_messages(
ignored_messages=self.ignored_messages,
)
reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages(
ignored_messages=self.ignored_messages,
)
results: list[tuple[TraceLog, str]] = [
(reference_log_reduced, "expected"),
(current_log_reduced, "current"),
(reference_log, "expected_full"),
(current_log, "current_full"),
]
if self.core_messages:
result_main_messages = current_log_reduced.filter_messages(
required_messages=self.core_messages,
)
reference_result_main_messages = reference_log_reduced.filter_messages(
required_messages=self.core_messages,
)
results.extend(
[
(reference_result_main_messages, "expected_core"),
(result_main_messages, "current_core"),
],
)
else:
result_main_messages = current_log_reduced
reference_result_main_messages = reference_log_reduced
logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log()
if logs_equivalent:
status_passed = True
status_message = "OK"
else:
core_messages_unchanged = (
len(self.core_messages) > 0
and result_main_messages.get_full_log()
== reference_result_main_messages.get_full_log()
)
status_passed = core_messages_unchanged and pass_if_core_messages_unchanged
if status_passed:
status_message = "OK (core messages unchanged)"
else:
# save files for comparison
files = []
for r, suffix in results:
path = os.path.abspath(f"determinism_{name}_{suffix}.txt")
r.save_log(path)
files.append(path)
paths_str = "\n".join(files)
main_message = (
f"Please inspect the changes by diffing the log files:\n{paths_str}\n"
f"If the changes are OK, enable the `create_reference_result` flag temporarily, "
"rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n"
)
# compute diff and add to message
num_diff_lines_to_show = 30
for i, line in enumerate(
difflib.unified_diff(
reference_log_reduced.log_lines,
current_log_reduced.log_lines,
fromfile="expected.txt",
tofile="current.txt",
lineterm="",
),
):
if i == num_diff_lines_to_show:
break
main_message += line + "\n"
if core_messages_unchanged:
status_message = (
"The behaviour log has changed, but the core messages are still the same (so this "
f"probably isn't an issue). {main_message}"
)
else:
status_message = f"The behaviour log has changed; even the core messages are different. {main_message}"
# write log message
if self.log_filename:
with open(self.log_filename, "a") as f:
hr = "-" * 100
f.write(f"\n\n{hr}\nName: {name}\n")
f.write(f"Reference state: {reference_result.git_status}\n")
f.write(f"Current state: {current_git_status}\n")
f.write(f"Test result: {status_message}\n")
if not status_passed:
pytest.fail(status_message)