batch#
Source code: tianshou/data/batch.py
- class Batch(batch_dict: dict | BatchProtocol | Sequence[dict | BatchProtocol] | ndarray | None = None, copy: bool = False, **kwargs: Any)[source]#
The internal data structure in Tianshou.
Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.
For a detailed description, please refer to Understand Batch.
- static cat(batches: Sequence[dict | TBatch]) TBatch [source]#
Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5)
- cat_(batches: BatchProtocol | Sequence[dict | BatchProtocol]) None [source]#
Concatenate a list of (or one) Batch objects into current batch.
- static empty(batch: TBatch, index: slice | int | ndarray | list[int] | None = None) TBatch [source]#
Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
- empty_(index: slice | int | ndarray | list[int] | None = None) Self [source]#
Return an empty Batch object with 0 or None filled.
If “index” is specified, it will only reset the specific indexed-data.
>>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), )
- is_empty(recurse: bool = False) bool [source]#
Test if a Batch is empty.
If
recurse=True
, it further tests the values of the object; else it only tests the existence of any key.b.is_empty(recurse=True)
is mainly used to distinguishBatch(a=Batch(a=Batch()))
andBatch(a=1)
. They both raise exceptions when applied tolen()
, but the former can be used incat
, while the latter is a scalar and cannot be used incat
.Another usage is in
__len__
, where we have to skip checking the length of recursively empty Batch.>>> Batch().is_empty() True >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() False >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) True >>> Batch(d=1).is_empty() False >>> Batch(a=np.float64(1.0)).is_empty() False
- property shape: list[int]#
Return self.shape.
- split(size: int, shuffle: bool = True, merge_last: bool = False) Iterator[Self] [source]#
Split whole data into multiple small batches.
- Parameters:
size – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”. Size of -1 means the whole batch.
shuffle – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.
merge_last – merge the last batch into the previous one. Default to False.
- static stack(batches: Sequence[dict | TBatch], axis: int = 0) TBatch [source]#
Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5)
Note
If there are keys that are not shared across all batches,
stack
withaxis != 0
is undefined, and will cause an exception.
- stack_(batches: Sequence[dict | BatchProtocol], axis: int = 0) None [source]#
Stack a list of Batch object into current batch.
- class BatchProtocol(*args, **kwargs)[source]#
The internal data structure in Tianshou.
Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.
For a detailed description, please refer to Understand Batch.
- static cat(batches: Sequence[dict | TBatch]) TBatch [source]#
Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5)
- cat_(batches: Self | Sequence[dict | Self]) None [source]#
Concatenate a list of (or one) Batch objects into current batch.
- static empty(batch: TBatch, index: slice | int | ndarray | list[int] | None = None) TBatch [source]#
Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
- empty_(index: slice | int | ndarray | list[int] | None = None) Self [source]#
Return an empty Batch object with 0 or None filled.
If “index” is specified, it will only reset the specific indexed-data.
>>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), )
- property shape: list[int]#
- split(size: int, shuffle: bool = True, merge_last: bool = False) Iterator[Self] [source]#
Split whole data into multiple small batches.
- Parameters:
size – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”. Size of -1 means the whole batch.
shuffle – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.
merge_last – merge the last batch into the previous one. Default to False.
- static stack(batches: Sequence[dict | TBatch], axis: int = 0) TBatch [source]#
Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5)
Note
If there are keys that are not shared across all batches,
stack
withaxis != 0
is undefined, and will cause an exception.
- stack_(batches: Sequence[dict | Self], axis: int = 0) None [source]#
Stack a list of Batch object into current batch.
- alloc_by_keys_diff(meta: BatchProtocol, batch: BatchProtocol, size: int, stack: bool = True) None [source]#
Creates place-holders inside meta for keys that are in batch but not in meta.
This mainly is an internal method, use it only if you know what you are doing.
- create_value(inst: Any, size: int, stack: bool = True) Batch | ndarray | Tensor [source]#
Create empty place-holders according to inst’s shape.
- Parameters:
stack – whether to stack or to concatenate. E.g. if inst has shape of (3, 5), size = 10, stack=True returns an np.array with shape of (10, 3, 5), otherwise (10, 5)