import numpy as np
from numba import njit
[docs]
class SegmentTree:
"""Implementation of Segment Tree.
The segment tree stores an array ``arr`` with size ``n``. It supports value
update and fast query of the sum for the interval ``[left, right)`` in
O(log n) time. The detailed procedure is as follows:
1. Pad the array to have length of power of 2, so that leaf nodes in the \
segment tree have the same depth.
2. Store the segment tree in a binary heap.
:param size: the size of segment tree.
"""
def __init__(self, size: int) -> None:
bound = 1
while bound < size:
bound *= 2
self._size = size
self._bound = bound
self._value = np.zeros([bound * 2])
self._compile()
def __len__(self) -> int:
return self._size
def __getitem__(self, index: int | np.ndarray) -> float | np.ndarray:
"""Return self[index]."""
return self._value[index + self._bound]
def __setitem__(self, index: int | np.ndarray, value: float | np.ndarray) -> None:
"""Update values in segment tree.
Duplicate values in ``index`` are handled by numpy: later index
overwrites previous ones.
::
>>> a = np.array([1, 2, 3, 4])
>>> a[[0, 1, 0, 1]] = [4, 5, 6, 7]
>>> print(a)
[6 7 3 4]
"""
if isinstance(index, int):
index, value = np.array([index]), np.array([value])
assert np.all(index >= 0)
assert np.all(index < self._size)
_setitem(self._value, index + self._bound, value)
[docs]
def reduce(self, start: int = 0, end: int | None = None) -> float:
"""Return operation(value[start:end])."""
if start == 0 and end is None:
return self._value[1]
if end is None:
end = self._size
if end < 0:
end += self._size
return _reduce(self._value, start + self._bound - 1, end + self._bound)
[docs]
def get_prefix_sum_idx(self, value: float | np.ndarray) -> int | np.ndarray:
r"""Find the index with given value.
Return the minimum index for each ``v`` in ``value`` so that
:math:`v \le \mathrm{sums}_i`, where
:math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`.
.. warning::
Please make sure all of the values inside the segment tree are
non-negative when using this function.
"""
assert np.all(value >= 0.0)
assert np.all(value < self._value[1])
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
single = True
index = _get_prefix_sum_idx(value, self._bound, self._value)
return index.item() if single else index
def _compile(self) -> None:
f64 = np.array([0, 1], dtype=np.float64)
f32 = np.array([0, 1], dtype=np.float32)
i64 = np.array([0, 1], dtype=np.int64)
_setitem(f64, i64, f64)
_setitem(f64, i64, f32)
_reduce(f64, 0, 1)
_get_prefix_sum_idx(f64, 1, f64)
_get_prefix_sum_idx(f32, 1, f64)
@njit
def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
"""Numba version, 4x faster: 0.1 -> 0.024."""
tree[index] = value
while index[0] > 1:
index //= 2
tree[index] = tree[index * 2] + tree[index * 2 + 1]
@njit
def _reduce(tree: np.ndarray, start: int, end: int) -> float:
"""Numba version, 2x faster: 0.009 -> 0.005."""
# nodes in (start, end) should be aggregated
result = 0.0
while end - start > 1: # (start, end) interval is not empty
if start % 2 == 0:
result += tree[start + 1]
start //= 2
if end % 2 == 1:
result += tree[end - 1]
end //= 2
return result
@njit
def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray:
"""Numba version (v0.51), 5x speed up with size=100000 and bsz=64.
vectorized np: 0.0923 (numpy best) -> 0.024 (now)
for-loop: 0.2914 -> 0.019 (but not so stable)
"""
index = np.ones(value.shape, dtype=np.int64)
while index[0] < bound:
index *= 2
lsons = sums[index]
direct = lsons < value
value -= lsons * direct
index += direct
index -= bound
return index