"""This module implements :class:`Batch`, a flexible data structure for
handling heterogeneous data in reinforcement learning algorithms. Such a data structure
is needed since RL algorithms differ widely in the conceptual fields that they need.
`Batch` is the main data carrier in Tianshou. It bears some similarities to
`TensorDict <https://github.com/pytorch/tensordict>`_
that is used for a similar purpose in `pytorch-rl <https://github.com/pytorch/rl>`_.
The main differences between the two are that `Batch` can hold arbitrary objects (and not just torch tensors),
and that Tianshou implements `BatchProtocol` for enabling type checking and autocompletion (more on that below).
The `Batch` class is designed to store and manipulate collections of data with
varying types and structures. It strikes a balance between flexibility and type safety, the latter mainly
achieved through the use of protocols. One can thing of it as a mixture of a dictionary and an array,
as it has both key-value pairs and nesting, while also having a shape, being indexable and sliceable.
Key features of the `Batch` class include:
1. Flexible data storage: Can hold numpy arrays, torch tensors, scalars, and nested Batch objects.
2. Dynamic attribute access: Allows setting and accessing data using attribute notation (e.g., `batch.observation`).
This allows for type-safe and readable code and enables IDE autocompletion. See comments on `BatchProtocol` below.
3. Indexing and slicing: Supports numpy-like indexing and slicing operations. The slicing is extended to nested
Batch objects and torch Distributions.
4. Batch operations: Provides methods for splitting, shuffling, concatenating and stacking multiple Batch objects.
5. Data type conversion: Offers methods to convert data between numpy arrays and torch tensors.
6. Value transformations: Allows applying functions to all values in the Batch recursively.
7. Analysis utilities: Provides methods for checking for missing values, dropping entries with missing values,
and others.
Since we want to keep `Batch` flexible and not fix a specific set of fields or their types,
we don't have fixed interfaces for actual `Batch` objects that are used throughout
tianshou (such interfaces could be dataclasses, for example). However, we still want to enable
IDE autocompletion and type checking for `Batch` objects. To achieve this, we rely on dynamic duck typing
by using `Protocol`. The :class:`BatchProtocol` defines the interface that all `Batch` objects should adhere to,
and its various implementations (like :class:`~.types.ActBatchProtocol` or :class:`~.types.RolloutBatchProtocol`) define the specific
fields that are expected in the respective `Batch` objects. The protocols are then used as type hints
throughout the codebase. Protocols can't be instantiated, but we can cast to them.
For example, we "instantiate" an `ActBatchProtocol` with something like:
>>> act_batch = cast(ActBatchProtocol, Batch(act=my_action))
The users can decide for themselves how to structure their `Batch` objects, and can opt in to the
`BatchProtocol` style to enable type checking and autocompletion. Opting out will have no effect on
the functionality.
"""
import pprint
import warnings
from collections.abc import Callable, Collection, Iterable, Iterator, KeysView, Sequence
from copy import deepcopy
from numbers import Number
from types import EllipsisType
from typing import (
Any,
Literal,
Protocol,
Self,
TypeVar,
Union,
cast,
overload,
runtime_checkable,
)
import numpy as np
import pandas as pd
import torch
from deepdiff import DeepDiff
from sensai.util import logging
from torch.distributions import Categorical, Distribution, Independent, Normal
_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | Sequence[_SingleIndexType]
TBatch = TypeVar("TBatch", bound="BatchProtocol")
TDistribution = TypeVar("TDistribution", bound=Distribution)
T = TypeVar("T")
TArr = torch.Tensor | np.ndarray
TObsArr = torch.Tensor | np.ndarray
log = logging.getLogger(__name__)
def _is_batch_set(obj: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
# or 1-D np.ndarray with object type,
# where each element is a dict/Batch object
if isinstance(obj, np.ndarray): # most often case
# "for element in obj" will just unpack the first dimension,
# but obj.tolist() will flatten ndarray of objects
# so do not use obj.tolist()
if obj.shape == ():
return False
return obj.dtype == object and all(isinstance(element, dict | Batch) for element in obj)
return (
isinstance(obj, list | tuple)
and len(obj) > 0
and all(isinstance(element, dict | Batch) for element in obj)
)
def _is_scalar(value: Any) -> bool:
# check if the value is a scalar
# 1. python bool object, number object: isinstance(value, Number)
# 2. numpy scalar: isinstance(value, np.generic)
# 3. python object rather than dict / Batch / tensor
# the check of dict / Batch is omitted because this only checks a value.
# a dict / Batch will eventually check their values
if isinstance(value, torch.Tensor):
return value.numel() == 1 and not value.shape
# np.asanyarray will cause dead loop in some cases
return np.isscalar(value)
def _is_number(value: Any) -> bool:
# isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc.
# isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc.
# isinstance(value, np.bool_) checks np.bool_(True), etc.
# similar to np.isscalar but np.isscalar('st') returns True
return isinstance(value, Number | np.number | np.bool_)
def _to_array_with_correct_type(obj: Any) -> np.ndarray:
if isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, np.bool_ | np.number):
return obj # most often case
# convert the value to np.ndarray
# convert to object obj type if neither bool nor number
# raises an exception if array's elements are tensors themselves
try:
obj_array = np.asanyarray(obj)
except ValueError:
obj_array = np.asanyarray(obj, dtype=object)
if not issubclass(obj_array.dtype.type, np.bool_ | np.number):
obj_array = obj_array.astype(object)
if obj_array.dtype == object:
# scalar ndarray with object obj type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
# array([{}, array({}, dtype=object)], dtype=object)
if not obj_array.shape:
obj_array = obj_array.item(0)
elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)):
return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]])
elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)):
raise ValueError("Numpy arrays of tensors are not supported yet.")
return obj_array
[docs]
def create_value(
inst: Any,
size: int,
stack: bool = True,
) -> Union["Batch", np.ndarray, torch.Tensor]:
"""Create empty place-holders according to inst's shape.
:param stack: whether to stack or to concatenate. E.g. if inst has shape of
(3, 5), size = 10, stack=True returns an np.array with shape of (10, 3, 5),
otherwise (10, 5)
"""
has_shape = isinstance(inst, np.ndarray | torch.Tensor)
is_scalar = _is_scalar(inst)
if not stack and is_scalar:
# should never hit since it has already checked in Batch.cat_ , here we do not
# consider scalar types, following the behavior of numpy which does not support
# concatenation of zero-dimensional arrays (scalars)
raise TypeError(f"cannot concatenate with {inst} which is scalar")
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
target_type = (
inst.dtype.type if issubclass(inst.dtype.type, np.bool_ | np.number) else object
)
return np.full(shape, fill_value=None if target_type is object else 0, dtype=target_type)
if isinstance(inst, torch.Tensor):
return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype)
if isinstance(inst, dict | Batch):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = create_value(val, size, stack=stack)
return zero_batch
if is_scalar:
return create_value(np.asarray(inst), size, stack=stack)
# fall back to object
return np.array([None for _ in range(size)], object)
def _assert_type_keys(keys: Iterable[str]) -> None:
assert all(isinstance(key, str) for key in keys), f"keys should all be string, but got {keys}"
def _parse_value(obj: Any) -> Union["Batch", np.ndarray, torch.Tensor] | None:
if isinstance(obj, Batch): # most often case
return obj
if (
(isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, np.bool_ | np.number))
or isinstance(obj, torch.Tensor)
or obj is None
): # third often case
return obj
if _is_number(obj): # second often case, but it is more time-consuming
return np.asanyarray(obj)
if isinstance(obj, dict):
return Batch(obj)
if (
not isinstance(obj, np.ndarray)
and isinstance(obj, Collection)
and len(obj) > 0
and all(isinstance(element, torch.Tensor) for element in obj)
):
try:
obj = cast(list[torch.Tensor], obj)
return torch.stack(obj)
except RuntimeError as exception:
raise TypeError(
"Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.",
) from exception
if _is_batch_set(obj):
obj = Batch(obj) # list of dict / Batch
else:
# None, scalar, normal obj list (main case)
# or an actual list of objects
try:
obj = _to_array_with_correct_type(obj)
except ValueError as exception:
raise TypeError(
"Batch does not support heterogeneous list/tuple of tensors as unique value yet.",
) from exception
return obj
[docs]
def alloc_by_keys_diff(
meta: "BatchProtocol",
batch: "BatchProtocol",
size: int,
stack: bool = True,
) -> None:
"""Creates place-holders inside meta for keys that are in batch but not in meta.
This mainly is an internal method, use it only if you know what you are doing.
"""
for key in batch.get_keys():
if key in meta.get_keys():
if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
alloc_by_keys_diff(meta[key], batch[key], size, stack)
elif isinstance(meta[key], Batch) and len(meta[key].get_keys()) == 0:
meta[key] = create_value(batch[key], size, stack)
else:
meta[key] = create_value(batch[key], size, stack)
[docs]
class ProtocolCalledException(Exception):
"""The methods of a Protocol should never be called.
Currently, no static type checker actually verifies that a class that inherits
from a Protocol does in fact provide the correct interface. Thus, it may happen
that a method of the protocol is called accidentally (this is an
implementation error). The normal error for that is a somewhat cryptic
AttributeError, wherefore we instead raise this custom exception in the
BatchProtocol.
Finally and importantly: using this in BatchProtocol makes mypy verify the fields
in the various sub-protocols and thus renders is MUCH more useful!
"""
[docs]
def get_sliced_dist(dist: TDistribution, index: IndexType) -> TDistribution:
"""Slice a distribution object by the given index."""
if isinstance(dist, Categorical):
return Categorical(probs=dist.probs[index]) # type: ignore[return-value]
if isinstance(dist, Normal):
return Normal(loc=dist.loc[index], scale=dist.scale[index]) # type: ignore[return-value]
if isinstance(dist, Independent):
return Independent(
get_sliced_dist(dist.base_dist, index),
dist.reinterpreted_batch_ndims,
) # type: ignore[return-value]
else:
raise NotImplementedError(f"Unsupported distribution for slicing: {dist}")
[docs]
def get_len_of_dist(dist: Distribution) -> int:
"""Return the length (typically batch size) of a distribution object."""
if len(dist.batch_shape) == 0:
raise TypeError(f"scalar Distribution has no length: {dist=}")
return dist.batch_shape[0]
[docs]
def dist_to_atleast_2d(dist: TDistribution) -> TDistribution:
"""Convert a distribution to at least 2D, such that the `batch_shape` attribute has a len of at least 1."""
if len(dist.batch_shape) > 0:
return dist
if isinstance(dist, Categorical):
return Categorical(probs=dist.probs.unsqueeze(0)) # type: ignore[return-value]
elif isinstance(dist, Normal):
return Normal(loc=dist.loc.unsqueeze(0), scale=dist.scale.unsqueeze(0)) # type: ignore[return-value]
elif isinstance(dist, Independent):
return Independent(
dist_to_atleast_2d(dist.base_dist),
dist.reinterpreted_batch_ndims,
) # type: ignore[return-value]
else:
raise NotImplementedError(f"Unsupported distribution for conversion to 2D: {type(dist)}")
# Note: This is implemented as a protocol because the interface
# of Batch is always extended by adding new fields. Having a hierarchy of
# protocols building off this one allows for type safety and IDE support despite
# the dynamic nature of Batch
[docs]
@runtime_checkable
class BatchProtocol(Protocol):
"""The internal data structure in Tianshou.
Batch is a kind of supercharged array (of temporal data) stored individually in a
(recursive) dictionary of objects that can be either numpy arrays, torch tensors, or
batches themselves. It is designed to make it extremely easily to access, manipulate
and set partial view of the heterogeneous data conveniently.
"""
@property
def shape(self) -> list[int]:
raise ProtocolCalledException
# NOTE: even though setattr and getattr are defined for any object, we need
# to explicitly define them for the BatchProtocol, since otherwise mypy will
# complain about new fields being added dynamically. For example, things like
# `batch.new_field = ...` followed by using `batch.new_field` become type errors
# if getattr and setattr are missing in the BatchProtocol.
#
# For the moment, tianshou relies on this kind of dynamic-field-addition
# in many, many places. In principle, it would be better to construct new
# objects with new combinations of fields instead of mutating existing ones - the
# latter is error-prone and can't properly be expressed with types. May be in a
# future, rather different version of tianshou it would be feasible to have stricter
# typing. Then the need for Protocols would in fact disappear
def __setattr__(self, key: str, value: Any) -> None:
raise ProtocolCalledException
def __getattr__(self, key: str) -> Any:
raise ProtocolCalledException
def __iter__(self) -> Iterator[Self]:
raise ProtocolCalledException
@overload
def __getitem__(self, index: str) -> Any:
raise ProtocolCalledException
@overload
def __getitem__(self, index: IndexType) -> Self:
raise ProtocolCalledException
def __getitem__(self, index: str | IndexType) -> Any:
raise ProtocolCalledException
def __setitem__(self, index: str | IndexType, value: Any) -> None:
raise ProtocolCalledException
def __iadd__(self, other: Self | Number | np.number) -> Self:
raise ProtocolCalledException
def __add__(self, other: Self | Number | np.number) -> Self:
raise ProtocolCalledException
def __imul__(self, value: Number | np.number) -> Self:
raise ProtocolCalledException
def __mul__(self, value: Number | np.number) -> Self:
raise ProtocolCalledException
def __itruediv__(self, value: Number | np.number) -> Self:
raise ProtocolCalledException
def __truediv__(self, value: Number | np.number) -> Self:
raise ProtocolCalledException
def __repr__(self) -> str:
raise ProtocolCalledException
def __eq__(self, other: Any) -> bool:
raise ProtocolCalledException
[docs]
def to_numpy(self: Self) -> Self:
"""Change all torch.Tensor to numpy.ndarray and return a new Batch."""
raise ProtocolCalledException
[docs]
def to_numpy_(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
raise ProtocolCalledException
[docs]
def to_torch(
self: Self,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> Self:
"""Change all numpy.ndarray to torch.Tensor and return a new Batch."""
raise ProtocolCalledException
[docs]
def to_torch_(
self,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> None:
"""Change all numpy.ndarray to torch.Tensor in-place."""
raise ProtocolCalledException
[docs]
def cat_(self, batches: Self | Sequence[dict | Self]) -> None:
"""Concatenate a list of (or one) Batch objects into current batch."""
raise ProtocolCalledException
[docs]
@staticmethod
def cat(batches: Sequence[dict | TBatch]) -> TBatch:
"""Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not
have these keys will be padded by zeros with appropriate shapes. E.g.
::
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
"""
raise ProtocolCalledException
[docs]
def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None:
"""Stack a list of Batch object into current batch."""
raise ProtocolCalledException
[docs]
@staticmethod
def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch:
"""Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not
have these keys will be padded by zeros. E.g.
::
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)
.. note::
If there are keys that are not shared across all batches, ``stack``
with ``axis != 0`` is undefined, and will cause an exception.
"""
raise ProtocolCalledException
[docs]
def empty_(self, index: slice | IndexType | None = None) -> Self:
"""Return an empty Batch object with 0 or None filled.
If "index" is specified, it will only reset the specific indexed-data.
::
>>> data.empty_()
>>> print(data)
Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> b={'c': [2., 'st'], 'd': [1., 0.]}
>>> data = Batch(a=[False, True], b=b)
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([None, 'st']),
d: array([0., 0.]),
),
)
"""
raise ProtocolCalledException
[docs]
@staticmethod
def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
"""Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
"""
raise ProtocolCalledException
[docs]
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
"""Update this batch from another dict/Batch."""
raise ProtocolCalledException
def __len__(self) -> int:
raise ProtocolCalledException
[docs]
def split(
self,
size: int,
shuffle: bool = True,
merge_last: bool = False,
) -> Iterator[Self]:
"""Split whole data into multiple small batches.
:param size: divide the data batch with the given size, but one
batch if the length of the batch is smaller than "size". Size of -1 means
the whole batch.
:param shuffle: randomly shuffle the entire data batch if it is
True, otherwise remain in the same. Default to True.
:param merge_last: merge the last batch into the previous one.
Default to False.
"""
raise ProtocolCalledException
[docs]
def to_dict(self, recurse: bool = True) -> dict[str, Any]:
raise ProtocolCalledException
[docs]
def to_list_of_dicts(self) -> list[dict[str, Any]]:
raise ProtocolCalledException
[docs]
def get_keys(self) -> KeysView:
raise ProtocolCalledException
[docs]
def set_array_at_key(
self,
seq: np.ndarray,
key: str,
index: IndexType | None = None,
default_value: float | None = None,
) -> None:
"""Set a sequence of values at a given key.
If `index` is not passed, the sequence must have the same length as the batch.
:param seq: the array of values to set.
:param key: the key to set the sequence at.
:param index: the indices to set the sequence at. If None, the sequence must have
the same length as the batch and will be set at all indices.
:param default_value: this only applies if `index` is passed and the key does not exist yet
in the batch. In that case, entries outside the passed index will be filled
with this default value.
Note that the array at the key will be of the same dtype as the passed sequence,
so `default_value` should be such that numpy can cast it to this dtype.
"""
raise ProtocolCalledException
[docs]
def isnull(self) -> Self:
"""Return a boolean mask of the same shape, indicating missing values."""
raise ProtocolCalledException
[docs]
def hasnull(self) -> bool:
"""Return whether the batch has missing values."""
raise ProtocolCalledException
[docs]
def dropnull(self) -> Self:
"""Return a batch where all items in which any value is null are dropped.
Note that it is not the same as just dropping the entries of the sequence.
For example, with
>>> b = Batch(a=[None, 2, 3, 4], b=[4, 5, None, 7])
>>> b.dropnull()
will result in
>>> Batch(a=[2, 4], b=[5, 7])
This logic is applied recursively to all nested batches. The result is
the same as if the batch was flattened, entries were dropped,
and then the batch was reshaped back to the original nested structure.
"""
...
@overload
def apply_values_transform(
self,
values_transform: Callable[[np.ndarray | torch.Tensor], Any],
) -> Self: ...
@overload
def apply_values_transform(
self,
values_transform: Callable,
inplace: Literal[True],
) -> None: ...
@overload
def apply_values_transform(
self,
values_transform: Callable[[np.ndarray | torch.Tensor], Any],
inplace: Literal[False],
) -> Self: ...
[docs]
def get(self, key: str, default: Any | None = None) -> Any:
raise ProtocolCalledException
[docs]
def pop(self, key: str, default: Any | None = None) -> Any:
raise ProtocolCalledException
[docs]
def to_at_least_2d(self) -> Self:
"""Ensures that all arrays and dists in the batch have at least 2 dimensions.
This is useful for ensuring that all arrays in the batch can be concatenated
along a new axis.
"""
raise ProtocolCalledException
[docs]
class Batch(BatchProtocol):
"""See :class:`~tianshou.data.batch.BatchProtocol`."""
__doc__ = BatchProtocol.__doc__
def __init__(
self,
batch_dict: dict
| BatchProtocol
| Sequence[dict | BatchProtocol]
| np.ndarray
| None = None,
copy: bool = False,
**kwargs: Any,
) -> None:
if copy:
batch_dict = deepcopy(batch_dict)
if batch_dict is not None:
if isinstance(batch_dict, dict | BatchProtocol):
_assert_type_keys(batch_dict.keys())
for batch_key, obj in batch_dict.items():
self.__dict__[batch_key] = _parse_value(obj)
elif _is_batch_set(batch_dict):
batch_dict = cast(Sequence[dict | BatchProtocol], batch_dict)
self.stack_(batch_dict)
if len(kwargs) > 0:
# TODO: that's a rather weird pattern, is it really needed?
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore
[docs]
def to_dict(self, recursive: bool = True) -> dict[str, Any]:
result = {}
for k, v in self.__dict__.items():
if recursive and isinstance(v, Batch):
v = v.to_dict(recursive=recursive)
result[k] = v
return result
[docs]
def get_keys(self) -> KeysView:
return self.__dict__.keys()
[docs]
def get(self, key: str, default: Any | None = None) -> Any:
return self.__dict__.get(key, default)
[docs]
def pop(self, key: str, default: Any | None = None) -> Any:
return self.__dict__.pop(key, default)
[docs]
def to_list_of_dicts(self) -> list[dict[str, Any]]:
return [entry.to_dict() for entry in self]
def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
self.__dict__[key] = _parse_value(value)
def __getattr__(self, key: str) -> Any:
"""Return self.key. The "Any" return type is needed for mypy."""
return getattr(self.__dict__, key)
def __contains__(self, key: str) -> bool:
"""Return key in self."""
return key in self.__dict__
def __getstate__(self) -> dict[str, Any]:
"""Pickling interface.
Only the actual data are serialized for both efficiency and simplicity.
"""
state = {}
for batch_key, obj in self.items():
if isinstance(obj, Batch):
state[batch_key] = obj.__getstate__()
else:
state[batch_key] = obj
return state
def __setstate__(self, state: dict[str, Any]) -> None:
"""Unpickling interface.
At this point, self is an empty Batch instance that has not been
initialized, so it can safely be initialized by the pickle state.
"""
self.__init__(**state) # type: ignore
@overload
def __getitem__(self, index: str) -> Any: ...
@overload
def __getitem__(self, index: IndexType) -> Self: ...
def __getitem__(self, index: str | IndexType) -> Any:
"""Returns either the value of a key or a sliced Batch object."""
if isinstance(index, str):
return self.__dict__[index]
batch_items = self.items()
if len(batch_items) > 0:
new_batch = Batch()
sliced_obj: Any
for batch_key, obj in batch_items:
# None and empty Batches as values are added to any slice
if obj is None:
sliced_obj = None
elif isinstance(obj, Batch) and len(obj.get_keys()) == 0:
sliced_obj = Batch()
# We attempt slicing of a distribution. This is hacky, but presents an important special case
elif isinstance(obj, Distribution):
sliced_obj = get_sliced_dist(obj, index)
# All other objects are either array-like or Batch-like, so hopefully sliceable
# A batch should have no scalars
else:
sliced_obj = obj[index]
new_batch.__dict__[batch_key] = sliced_obj
return new_batch
raise IndexError("Cannot access item from empty Batch object.")
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
this_batch_no_torch_tensor = self.to_numpy()
other_batch_no_torch_tensor = other.to_numpy()
# DeepDiff 7.0.1 cannot compare 0-dimensional arrays
# so, we ensure with this transform that all array values have at least 1 dim
this_batch_no_torch_tensor.apply_values_transform(
values_transform=np.atleast_1d,
inplace=True,
)
other_batch_no_torch_tensor.apply_values_transform(
values_transform=np.atleast_1d,
inplace=True,
)
this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)
return not DeepDiff(this_dict, other_dict)
def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
yield from []
else:
for i in range(len(self)):
yield self[i]
def __setitem__(self, index: str | IndexType, value: Any) -> None:
"""Assign value to self[index]."""
value = _parse_value(value)
if isinstance(index, str):
self.__dict__[index] = value
return
if not isinstance(value, Batch):
raise ValueError(
"Batch does not supported tensor assignment. "
"Use a compatible Batch or dict instead.",
)
if not set(value.keys()).issubset(self.__dict__.keys()):
raise ValueError("Creating keys is not supported by item assignment.")
for key, val in self.items():
try:
self.__dict__[key][index] = value[key]
except KeyError:
if isinstance(val, Batch):
self.__dict__[key][index] = Batch()
elif isinstance(val, torch.Tensor) or (
isinstance(val, np.ndarray) and issubclass(val.dtype.type, np.bool_ | np.number)
):
self.__dict__[key][index] = 0
else:
self.__dict__[key][index] = None
def __iadd__(self, other: Self | Number | np.number) -> Self:
"""Algebraic addition with another Batch instance in-place."""
if isinstance(other, Batch):
for (batch_key, obj), value in zip(
self.__dict__.items(),
other.__dict__.values(),
strict=True,
): # TODO are keys consistent?
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += value
return self
if _is_number(other):
for batch_key, obj in self.items():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += other
return self
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, other: Self | Number | np.number) -> Self:
"""Algebraic addition with another Batch instance out-of-place."""
return deepcopy(self).__iadd__(other)
def __imul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(value), "Only multiplication by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] *= value
return self
def __mul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(value)
def __itruediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value in-place."""
assert _is_number(value), "Only division by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] /= value
return self
def __truediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(value)
def __repr__(self) -> str:
"""Return str(self)."""
self_str = self.__class__.__name__ + "(\n"
flag = False
for batch_key, obj in self.__dict__.items():
rpl = "\n" + " " * (6 + len(batch_key))
obj_name = pprint.pformat(obj).replace("\n", rpl)
self_str += f" {batch_key}: {obj_name},\n"
flag = True
if flag:
self_str += ")"
else:
self_str = self.__class__.__name__ + "()"
return self_str
[docs]
def to_numpy(self: Self) -> Self:
result = deepcopy(self)
result.to_numpy_()
return result
[docs]
def to_numpy_(self) -> None:
def arr_to_numpy(arr: TArr) -> TArr:
if isinstance(arr, torch.Tensor):
return arr.detach().cpu().numpy()
return arr
self.apply_values_transform(arr_to_numpy, inplace=True)
[docs]
def to_torch(
self: Self,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> Self:
result = deepcopy(self)
result.to_torch_(dtype=dtype, device=device)
return result
[docs]
def to_torch_(
self,
dtype: torch.dtype | None = None,
device: str | int | torch.device = "cpu",
) -> None:
if not isinstance(device, torch.device):
device = torch.device(device)
def arr_to_torch(arr: TArr) -> TArr:
if isinstance(arr, np.ndarray):
return torch.from_numpy(arr).to(device)
# TODO: simplify
if (
(dtype is not None and arr.dtype != dtype)
or arr.device.type != device.type
or device.index != arr.device.index
):
if dtype is not None:
arr = arr.type(dtype)
return arr.to(device)
return arr
self.apply_values_transform(arr_to_torch, inplace=True)
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
"""Private method for Batch.cat_.
::
>>> a = Batch(a=np.random.randn(3, 4))
>>> x = Batch(a=a, b=np.random.randn(4, 4))
>>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4))
If we want to concatenate x and y, we want to pad y.a.a with zeros.
Without ``lens`` as a hint, when we concatenate x.a and y.a, we would
not be able to know how to pad y.a. So ``Batch.cat_`` should compute
the ``lens`` to give ``Batch.__cat`` a hint.
::
>>> ans = Batch.cat([x, y])
>>> # this is equivalent to the following line
>>> ans = Batch(); ans.__cat([x, y], lens=[3, 4])
>>> # this lens is equal to [len(a), len(b)]
"""
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
sum_lens = [0]
for len_ in lens:
sum_lens.append(sum_lens[-1] + len_)
# collect non-empty keys
keys_map = [
{
batch_key
for batch_key, obj in batch.items()
if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
}
for batch in batches
]
keys_shared = set.intersection(*keys_map)
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
for key, shared_value in zip(keys_shared, values_shared, strict=True):
if all(isinstance(element, dict | Batch) for element in shared_value):
batch_holder = Batch()
batch_holder.__cat(shared_value, lens=lens)
self.__dict__[key] = batch_holder
elif all(isinstance(element, torch.Tensor) for element in shared_value):
self.__dict__[key] = torch.cat(shared_value)
else:
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
# will fail here
self.__dict__[key] = _to_array_with_correct_type(np.concatenate(shared_value))
keys_total = set.union(*[set(batch.keys()) for batch in batches])
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
for key in keys_reserve:
# reserved keys
self.__dict__[key] = Batch()
for key in keys_partial:
for i, batch in enumerate(batches):
if key not in batch.__dict__:
continue
value = batch.get(key)
if isinstance(value, Batch) and len(value.get_keys()) == 0:
continue
try:
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
except KeyError:
self.__dict__[key] = create_value(value, sum_lens[-1], stack=False)
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
[docs]
def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
if isinstance(batches, Batch | dict):
batches = [batches]
# check input format
batch_list = []
original_keys_only_batch = None
"""A batch with all values removed, just keys left. Can be considered a sort of schema.
Will be either the schema of self, or of the first non-empty batch in the sequence.
"""
if len(self) > 0:
original_keys_only_batch = self.apply_values_transform(lambda x: None)
original_keys_only_batch.replace_empty_batches_by_none()
for batch in batches:
if isinstance(batch, dict):
batch = Batch(batch)
if not isinstance(batch, Batch):
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
if len(batch.get_keys()) == 0:
continue
if original_keys_only_batch is None:
original_keys_only_batch = batch.apply_values_transform(lambda x: None)
original_keys_only_batch.replace_empty_batches_by_none()
batch_list.append(batch)
continue
cur_keys_only_batch = batch.apply_values_transform(lambda x: None)
cur_keys_only_batch.replace_empty_batches_by_none()
if original_keys_only_batch != cur_keys_only_batch:
raise ValueError(
f"Batch.cat_ only supports concatenation of batches with the same structure but got "
f"structures: \n{original_keys_only_batch}\n and\n{cur_keys_only_batch}.",
)
batch_list.append(batch)
if len(batch_list) == 0:
return
batches = batch_list
# TODO: lot's of the remaining logic is devoted to filling up remaining keys with zeros
# this should be removed, and also the check above should be extended to nested keys
try:
# len(batch) here means batch is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if len(batch) == 0 else len(batch) for batch in batches]
except TypeError as exception:
raise ValueError(
"Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the "
"concatenation of scalar.",
) from exception
if len(self.get_keys()) != 0:
batches = [self, *list(batches)]
# len of zero means that that item is Batch() and should be ignored
lens = [0 if len(self) == 0 else len(self), *lens]
self.__cat(batches, lens)
[docs]
@staticmethod
def cat(batches: Sequence[dict | TBatch]) -> TBatch:
batch = Batch()
batch.cat_(batches)
return batch # type: ignore
[docs]
def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None:
# check input format
batch_list = []
for batch in batches:
if isinstance(batch, dict):
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = batch_list
if len(self.get_keys()) != 0:
batches = [self, *batches]
# collect non-empty keys
keys_map = [
{
batch_key
for batch_key, obj in batch.items()
if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
}
for batch in batches
]
keys_shared = set.intersection(*keys_map)
values_shared = [[batch[key] for batch in batches] for key in keys_shared]
for shared_key, value in zip(keys_shared, values_shared, strict=True):
# second often
if all(isinstance(element, torch.Tensor) for element in value):
self.__dict__[shared_key] = torch.stack(value, axis)
# third often
elif all(isinstance(element, Batch | dict) for element in value):
self.__dict__[shared_key] = Batch.stack(value, axis)
else: # most often case is np.ndarray
try:
self.__dict__[shared_key] = _to_array_with_correct_type(np.stack(value, axis))
except ValueError:
warnings.warn(
"You are using tensors with different shape,"
" fallback to dtype=object by default.",
)
self.__dict__[shared_key] = np.array(value, dtype=object)
# all the keys
keys_total = set.union(*[set(batch.keys()) for batch in batches])
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that are either partial or reserved
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
# keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} is only "
f"supported with axis=0, but got axis={axis}!",
)
for key in keys_reserve:
# reserved keys
self.__dict__[key] = Batch()
for key in keys_partial:
for i, batch in enumerate(batches):
if key not in batch.__dict__:
continue
value = batch.get(key)
# TODO: fix code/annotations s.t. the ignores can be removed
if (
isinstance(value, Batch) # type: ignore
and len(value.get_keys()) == 0 # type: ignore
):
continue # type: ignore
try:
self.__dict__[key][i] = value
except KeyError:
self.__dict__[key] = create_value(value, len(batches))
self.__dict__[key][i] = value
[docs]
@staticmethod
def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch:
batch = Batch()
batch.stack_(batches, axis)
# can't cast to a generic type, so we have to ignore the type here
return batch # type: ignore
[docs]
def empty_(self, index: slice | IndexType | None = None) -> Self:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor): # most often case
self.__dict__[batch_key][index] = 0
elif obj is None:
continue
elif isinstance(obj, np.ndarray):
if obj.dtype == object:
self.__dict__[batch_key][index] = None
else:
self.__dict__[batch_key][index] = 0
elif isinstance(obj, Batch):
self.__dict__[batch_key].empty_(index=index)
else: # scalar value
warnings.warn(
"You are calling Batch.empty on a NumPy scalar, "
"which may cause undefined behaviors.",
)
if _is_number(obj):
self.__dict__[batch_key] = obj.__class__(0)
else:
self.__dict__[batch_key] = None
return self
[docs]
@staticmethod
def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
return deepcopy(batch).empty_(index)
[docs]
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
if batch is None:
self.update(kwargs)
return
for batch_key, obj in batch.items():
self.__dict__[batch_key] = _parse_value(obj)
if kwargs:
self.update(kwargs)
def __len__(self) -> int:
"""Raises `TypeError` if any value in the batch has no len(), typically meaning it's a batch of scalars."""
lens = []
for key, obj in self.__dict__.items():
# TODO: causes inconsistent behavior to batch with empty batches
# and batch with empty sequences of other type. Remove, but only after
# Buffer and Collectors have been improved to no longer rely on this
if isinstance(obj, Batch) and len(obj) == 0:
continue
if obj is None:
continue
if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
lens.append(len(obj))
continue
if isinstance(obj, Distribution):
lens.append(get_len_of_dist(obj))
continue
raise TypeError(f"Entry for {key} in {self} is {obj} has no len()")
if not lens:
return 0
return min(lens)
@property
def shape(self) -> list[int]:
"""Return self.shape."""
if len(self.get_keys()) == 0:
return []
data_shape = []
for obj in self.__dict__.values():
try:
data_shape.append(list(obj.shape))
except AttributeError:
data_shape.append([])
return (
list(map(min, zip(*data_shape, strict=False))) if len(data_shape) > 1 else data_shape[0]
)
[docs]
def split(
self,
size: int,
shuffle: bool = True,
merge_last: bool = False,
) -> Iterator[Self]:
length = len(self)
if size == -1:
size = length
assert size >= 1 # size can be greater than length, return whole batch
indices = np.random.permutation(length) if shuffle else np.arange(length)
merge_last = merge_last and length % size > 0
for idx in range(0, length, size):
if merge_last and idx + size + size >= length:
yield self[indices[idx:]]
break
yield self[indices[idx : idx + size]]
@overload
def apply_values_transform(
self,
values_transform: Callable,
) -> Self: ...
@overload
def apply_values_transform(
self,
values_transform: Callable,
inplace: Literal[True],
) -> None: ...
@overload
def apply_values_transform(
self,
values_transform: Callable,
inplace: Literal[False],
) -> Self: ...
[docs]
def set_array_at_key(
self,
arr: np.ndarray,
key: str,
index: IndexType | None = None,
default_value: float | None = None,
) -> None:
if index is not None:
if key not in self.get_keys():
log.info(
f"Key {key} not found in batch, "
f"creating a sequence of len {len(self)} with {default_value=} for it.",
)
try:
self[key] = np.array([default_value] * len(self), dtype=arr.dtype)
except TypeError as exception:
raise TypeError(
f"Cannot create a sequence of dtype {arr.dtype} with default value {default_value}. "
f"You can fix this either by passing an array with the correct dtype or by passing "
f"a different default value that can be cast to the array's dtype (or both).",
) from exception
else:
existing_entry = self[key]
if isinstance(existing_entry, Batch):
raise ValueError(
f"Cannot set sequence at key {key} because it is a nested batch, "
f"can only set a subsequence of an array.",
)
self[key][index] = arr
else:
if len(arr) != len(self):
raise ValueError(
f"Sequence length {len(arr)} does not match "
f"batch length {len(self)}. For setting a subsequence with missing "
f"entries filled up by default values, consider passing an index.",
)
self[key] = arr
[docs]
def isnull(self) -> Self:
return self.apply_values_transform(pd.isnull, inplace=False)
[docs]
def hasnull(self) -> bool:
isnan_batch = self.isnull()
is_any_null_batch = isnan_batch.apply_values_transform(np.any, inplace=False)
def is_any_true(boolean_batch: BatchProtocol) -> bool:
for val in boolean_batch.values():
if isinstance(val, Batch):
if is_any_true(val):
return True
else:
assert val.size == 1, "This shouldn't have happened, it's a bug!"
# an unsized array with a boolean, e.g. np.array(False). behaves like the boolean itself
if val:
return True
return False
return is_any_true(is_any_null_batch)
[docs]
def dropnull(self) -> Self:
# we need to use dicts since a batch retrieved for a single index has no length and cat fails
# TODO: make cat work with batches containing scalars?
sub_batches = []
for b in self:
if b.hasnull():
continue
# needed for cat to work
b = b.apply_values_transform(np.atleast_1d)
sub_batches.append(b)
return Batch.cat(sub_batches)
[docs]
def replace_empty_batches_by_none(self) -> None:
"""Goes through the batch-tree" recursively and replaces empty batches by None.
This is useful for extracting the structure of a batch without the actual data,
especially in combination with `apply_values_transform` with a
transform function a la `lambda x: None`.
"""
empty_batch = Batch()
for key, val in self.items():
if isinstance(val, Batch):
if val == empty_batch:
self[key] = None
else:
val.replace_empty_batches_by_none()
[docs]
def to_at_least_2d(self) -> Self:
"""Ensures that all arrays and dists in the batch have at least 2 dimensions.
This is useful for ensuring that all arrays in the batch can be concatenated
along a new axis.
"""
result = self.apply_values_transform(np.atleast_2d, inplace=False)
for key, val in self.items():
if isinstance(val, Distribution):
result[key] = dist_to_atleast_2d(val)
return result
def _apply_batch_values_func_recursively(
batch: TBatch,
values_transform: Callable,
inplace: bool = False,
) -> TBatch | None:
"""Applies the desired function on all values of the batch recursively.
See docstring of the corresponding method in the Batch class for more details.
"""
result = batch if inplace else deepcopy(batch)
for key, val in batch.__dict__.items():
if isinstance(val, Batch):
result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False)
else:
result[key] = values_transform(val)
if not inplace:
return result
return None