batch#
Source code: tianshou/data/batch.py
This module implements Batch, a flexible data structure for
handling heterogeneous data in reinforcement learning algorithms. Such a data structure
is needed since RL algorithms differ widely in the conceptual fields that they need.
Batch is the main data carrier in Tianshou. It bears some similarities to
TensorDict
that is used for a similar purpose in pytorch-rl.
The main differences between the two are that Batch can hold arbitrary objects (and not just torch tensors),
and that Tianshou implements BatchProtocol for enabling type checking and autocompletion (more on that below).
The Batch class is designed to store and manipulate collections of data with varying types and structures. It strikes a balance between flexibility and type safety, the latter mainly achieved through the use of protocols. One can thing of it as a mixture of a dictionary and an array, as it has both key-value pairs and nesting, while also having a shape, being indexable and sliceable.
Key features of the Batch class include:
Flexible data storage: Can hold numpy arrays, torch tensors, scalars, and nested Batch objects.
Dynamic attribute access: Allows setting and accessing data using attribute notation (e.g., batch.observation). This allows for type-safe and readable code and enables IDE autocompletion. See comments on BatchProtocol below.
Indexing and slicing: Supports numpy-like indexing and slicing operations. The slicing is extended to nested Batch objects and torch Distributions.
Batch operations: Provides methods for splitting, shuffling, concatenating and stacking multiple Batch objects.
Data type conversion: Offers methods to convert data between numpy arrays and torch tensors.
Value transformations: Allows applying functions to all values in the Batch recursively.
Analysis utilities: Provides methods for checking for missing values, dropping entries with missing values, and others.
Since we want to keep Batch flexible and not fix a specific set of fields or their types,
we don’t have fixed interfaces for actual Batch objects that are used throughout
tianshou (such interfaces could be dataclasses, for example). However, we still want to enable
IDE autocompletion and type checking for Batch objects. To achieve this, we rely on dynamic duck typing
by using Protocol. The BatchProtocol defines the interface that all Batch objects should adhere to,
and its various implementations (like ActBatchProtocol or RolloutBatchProtocol) define the specific
fields that are expected in the respective Batch objects. The protocols are then used as type hints
throughout the codebase. Protocols can’t be instantiated, but we can cast to them.
For example, we “instantiate” an ActBatchProtocol with something like:
>>> act_batch = cast(ActBatchProtocol, Batch(act=my_action))
The users can decide for themselves how to structure their Batch objects, and can opt in to the BatchProtocol style to enable type checking and autocompletion. Opting out will have no effect on the functionality.
- create_value(inst: Any, size: int, stack: bool = True) Batch | ndarray | Tensor[source]#
Create empty place-holders according to inst’s shape.
- Parameters:
stack – whether to stack or to concatenate. E.g. if inst has shape of (3, 5), size = 10, stack=True returns an np.array with shape of (10, 3, 5), otherwise (10, 5)
- alloc_by_keys_diff(meta: BatchProtocol, batch: BatchProtocol, size: int, stack: bool = True) None[source]#
Creates place-holders inside meta for keys that are in batch but not in meta.
This mainly is an internal method, use it only if you know what you are doing.
- exception ProtocolCalledException[source]#
Bases:
ExceptionThe methods of a Protocol should never be called.
Currently, no static type checker actually verifies that a class that inherits from a Protocol does in fact provide the correct interface. Thus, it may happen that a method of the protocol is called accidentally (this is an implementation error). The normal error for that is a somewhat cryptic AttributeError, wherefore we instead raise this custom exception in the BatchProtocol.
Finally and importantly: using this in BatchProtocol makes mypy verify the fields in the various sub-protocols and thus renders is MUCH more useful!
- get_sliced_dist(dist: TDistribution, index: ndarray | slice | int | ellipsis | Sequence[slice | int | ellipsis]) TDistribution[source]#
Slice a distribution object by the given index.
- get_len_of_dist(dist: Distribution) int[source]#
Return the length (typically batch size) of a distribution object.
- dist_to_atleast_2d(dist: TDistribution) TDistribution[source]#
Convert a distribution to at least 2D, such that the batch_shape attribute has a len of at least 1.
- class BatchProtocol(*args, **kwargs)[source]#
Bases:
ProtocolThe internal data structure in Tianshou.
Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.
- property shape: list[int]#
- to_torch(dtype: dtype | None = None, device: str | int | device = 'cpu') Self[source]#
Change all numpy.ndarray to torch.Tensor and return a new Batch.
- to_torch_(dtype: dtype | None = None, device: str | int | device = 'cpu') None[source]#
Change all numpy.ndarray to torch.Tensor in-place.
- cat_(batches: Self | Sequence[dict | Self]) None[source]#
Concatenate a list of (or one) Batch objects into current batch.
- static cat(batches: Sequence[dict | TBatch]) TBatch[source]#
Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5)
- stack_(batches: Sequence[dict | Self], axis: int = 0) None[source]#
Stack a list of Batch object into current batch.
- static stack(batches: Sequence[dict | TBatch], axis: int = 0) TBatch[source]#
Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5)
Note
If there are keys that are not shared across all batches,
stackwithaxis != 0is undefined, and will cause an exception.
- empty_(index: slice | ndarray | int | ellipsis | Sequence[slice | int | ellipsis] | None = None) Self[source]#
Return an empty Batch object with 0 or None filled.
If “index” is specified, it will only reset the specific indexed-data.
>>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), )
- static empty(batch: TBatch, index: ndarray | slice | int | ellipsis | Sequence[slice | int | ellipsis] | None = None) TBatch[source]#
Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
- update(batch: dict | Self | None = None, **kwargs: Any) None[source]#
Update this batch from another dict/Batch.
- split(size: int, shuffle: bool = True, merge_last: bool = False) Iterator[Self][source]#
Split whole data into multiple small batches.
- Parameters:
size – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”. Size of -1 means the whole batch.
shuffle – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.
merge_last – merge the last batch into the previous one. Default to False.
- set_array_at_key(seq: ndarray, key: str, index: ndarray | slice | int | ellipsis | Sequence[slice | int | ellipsis] | None = None, default_value: float | None = None) None[source]#
Set a sequence of values at a given key.
If index is not passed, the sequence must have the same length as the batch.
- Parameters:
seq – the array of values to set.
key – the key to set the sequence at.
index – the indices to set the sequence at. If None, the sequence must have the same length as the batch and will be set at all indices.
default_value – this only applies if index is passed and the key does not exist yet in the batch. In that case, entries outside the passed index will be filled with this default value. Note that the array at the key will be of the same dtype as the passed sequence, so default_value should be such that numpy can cast it to this dtype.
- dropnull() Self[source]#
Return a batch where all items in which any value is null are dropped.
Note that it is not the same as just dropping the entries of the sequence. For example, with
>>> b = Batch(a=[None, 2, 3, 4], b=[4, 5, None, 7]) >>> b.dropnull()
will result in
>>> Batch(a=[2, 4], b=[5, 7])
This logic is applied recursively to all nested batches. The result is the same as if the batch was flattened, entries were dropped, and then the batch was reshaped back to the original nested structure.
- apply_values_transform(values_transform: Callable[[ndarray | Tensor], Any]) Self[source]#
- apply_values_transform(values_transform: Callable, inplace: Literal[True]) None
- apply_values_transform(values_transform: Callable[[ndarray | Tensor], Any], inplace: Literal[False]) Self
Apply a function to all arrays in the batch, including nested ones.
- Parameters:
values_transform – the function to apply to the arrays.
inplace – whether to apply the function in-place. If False, a new batch is returned, otherwise the batch is modified in-place and None is returned.
- class Batch(batch_dict: dict | BatchProtocol | Sequence[dict | BatchProtocol] | ndarray | None = None, copy: bool = False, **kwargs: Any)[source]#
Bases:
BatchProtocolThe internal data structure in Tianshou.
Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.
- to_torch(dtype: dtype | None = None, device: str | int | device = 'cpu') Self[source]#
Change all numpy.ndarray to torch.Tensor and return a new Batch.
- to_torch_(dtype: dtype | None = None, device: str | int | device = 'cpu') None[source]#
Change all numpy.ndarray to torch.Tensor in-place.
- cat_(batches: BatchProtocol | Sequence[dict | BatchProtocol]) None[source]#
Concatenate a list of (or one) Batch objects into current batch.
- static cat(batches: Sequence[dict | TBatch]) TBatch[source]#
Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5)
- stack_(batches: Sequence[dict | BatchProtocol], axis: int = 0) None[source]#
Stack a list of Batch object into current batch.
- static stack(batches: Sequence[dict | TBatch], axis: int = 0) TBatch[source]#
Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5)
Note
If there are keys that are not shared across all batches,
stackwithaxis != 0is undefined, and will cause an exception.
- empty_(index: slice | ndarray | int | ellipsis | Sequence[slice | int | ellipsis] | None = None) Self[source]#
Return an empty Batch object with 0 or None filled.
If “index” is specified, it will only reset the specific indexed-data.
>>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), )
- static empty(batch: TBatch, index: ndarray | slice | int | ellipsis | Sequence[slice | int | ellipsis] | None = None) TBatch[source]#
Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
- update(batch: dict | Self | None = None, **kwargs: Any) None[source]#
Update this batch from another dict/Batch.
- property shape: list[int]#
Return self.shape.
- split(size: int, shuffle: bool = True, merge_last: bool = False) Iterator[Self][source]#
Split whole data into multiple small batches.
- Parameters:
size – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”. Size of -1 means the whole batch.
shuffle – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.
merge_last – merge the last batch into the previous one. Default to False.
- apply_values_transform(values_transform: Callable) Self[source]#
- apply_values_transform(values_transform: Callable, inplace: Literal[True]) None
- apply_values_transform(values_transform: Callable, inplace: Literal[False]) Self
Applies a function to all non-batch-values in the batch, including values in nested batches.
A batch with keys pointing to either batches or to non-batch values can be thought of as a tree of Batch nodes. This function traverses the tree and applies the function to all leaf nodes (i.e. values that are not batches themselves).
The values are usually arrays, but can also be scalar values of an arbitrary type since retrieving a single entry from a Batch a la batch[0] will return a batch with scalar values.
- set_array_at_key(arr: ndarray, key: str, index: ndarray | slice | int | ellipsis | Sequence[slice | int | ellipsis] | None = None, default_value: float | None = None) None[source]#
Set a sequence of values at a given key.
If index is not passed, the sequence must have the same length as the batch.
- Parameters:
seq – the array of values to set.
key – the key to set the sequence at.
index – the indices to set the sequence at. If None, the sequence must have the same length as the batch and will be set at all indices.
default_value – this only applies if index is passed and the key does not exist yet in the batch. In that case, entries outside the passed index will be filled with this default value. Note that the array at the key will be of the same dtype as the passed sequence, so default_value should be such that numpy can cast it to this dtype.
- dropnull() Self[source]#
Return a batch where all items in which any value is null are dropped.
Note that it is not the same as just dropping the entries of the sequence. For example, with
>>> b = Batch(a=[None, 2, 3, 4], b=[4, 5, None, 7]) >>> b.dropnull()
will result in
>>> Batch(a=[2, 4], b=[5, 7])
This logic is applied recursively to all nested batches. The result is the same as if the batch was flattened, entries were dropped, and then the batch was reshaped back to the original nested structure.
- replace_empty_batches_by_none() None[source]#
Goes through the batch-tree” recursively and replaces empty batches by None.
This is useful for extracting the structure of a batch without the actual data, especially in combination with apply_values_transform with a transform function a la lambda x: None.