Source code for tianshou.data.stats
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import numpy as np
from tianshou.utils.print import DataclassPPrintMixin
if TYPE_CHECKING:
from tianshou.algorithm.algorithm_base import TrainingStats
from tianshou.data import CollectStats, CollectStatsBase
log = logging.getLogger(__name__)
[docs]
@dataclass(kw_only=True)
class SequenceSummaryStats(DataclassPPrintMixin):
"""A data structure for storing the statistics of a sequence."""
mean: float
std: float
max: float
min: float
[docs]
@classmethod
def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats":
if len(sequence) == 0:
return cls(mean=0.0, std=0.0, max=0.0, min=0.0)
if hasattr(sequence, "shape") and len(sequence.shape) > 1:
log.warning(
f"Sequence has shape {sequence.shape}, but only 1D sequences are supported. "
"Stats will be computed from the flattened sequence. For computing stats "
"for each dimension consider using the function `compute_dim_to_summary_stats`.",
)
return cls(
mean=float(np.mean(sequence)),
std=float(np.std(sequence)),
max=float(np.max(sequence)),
min=float(np.min(sequence)),
)
[docs]
@classmethod
def from_single_value(cls, value: float | int) -> "SequenceSummaryStats":
return cls(mean=value, std=0.0, max=value, min=value)
[docs]
def compute_dim_to_summary_stats(
arr: Sequence[Sequence[float]] | np.ndarray,
) -> dict[int, SequenceSummaryStats]:
"""Compute summary statistics for each dimension of a sequence.
:param arr: a 2-dim arr (or sequence of sequences) from which to compute the statistics.
:return: A dictionary of summary statistics for each dimension.
"""
stats = {}
for dim, seq in enumerate(arr):
stats[dim] = SequenceSummaryStats.from_sequence(seq)
return stats
[docs]
@dataclass(kw_only=True)
class TimingStats(DataclassPPrintMixin):
"""A data structure for storing timing statistics."""
total_time: float = 0.0
"""The total time elapsed."""
train_time: float = 0.0
"""The total time elapsed for training (collecting samples plus model update)."""
train_time_collect: float = 0.0
"""The total time elapsed for collecting training transitions."""
train_time_update: float = 0.0
"""The total time elapsed for updating models."""
test_time: float = 0.0
"""The total time elapsed for testing models."""
update_speed: float = 0.0
"""The speed of updating (env_step per second)."""
[docs]
@dataclass(kw_only=True)
class InfoStats(DataclassPPrintMixin):
"""A data structure for storing information about the learning process."""
update_step: int
"""The total number of update steps that have been taken."""
best_score: float
"""The best score over the test results. The one with the highest score will be considered the best model."""
best_reward: float
"""The best reward over the test results."""
best_reward_std: float
"""Standard deviation of the best reward over the test results."""
train_step: int
"""The total collected step of training collector."""
train_episode: int
"""The total collected episode of training collector."""
test_step: int
"""The total collected step of test collector."""
test_episode: int
"""The total collected episode of test collector."""
timing: TimingStats
"""The timing statistics."""
[docs]
@dataclass(kw_only=True)
class EpochStats(DataclassPPrintMixin):
"""A data structure for storing epoch statistics."""
epoch: int
"""The current epoch."""
train_collect_stat: Optional["CollectStatsBase"]
"""The statistics of the last call to the training collector."""
test_collect_stat: Optional["CollectStats"]
"""The statistics of the last call to the test collector."""
training_stat: Optional["TrainingStats"]
"""The statistics of the last model update step.
Can be None if no model update is performed, typically in the last training iteration."""
info_stat: InfoStats
"""The information of the collector."""