Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / rllib / execution / segment_tree.py
Size: Mime:
import operator
from typing import Any, Optional


class SegmentTree:
    """A Segment Tree data structure.

    https://en.wikipedia.org/wiki/Segment_tree

    Can be used as regular array, but with two important differences:

      a) Setting an item's value is slightly slower. It is O(lg capacity),
         instead of O(1).
      b) Offers efficient `reduce` operation which reduces the tree's values
         over some specified contiguous subsequence of items in the array.
         Operation could be e.g. min/max/sum.

    The data is stored in a list, where the length is 2 * capacity.
    The second half of the list stores the actual values for each index, so if
    capacity=8, values are stored at indices 8 to 15. The first half of the
    array contains the reduced-values of the different (binary divided)
    segments, e.g. (capacity=4):
    0=not used
    1=reduced-value over all elements (array indices 4 to 7).
    2=reduced-value over array indices (4 and 5).
    3=reduced-value over array indices (6 and 7).
    4-7: values of the tree.
    NOTE that the values of the tree are accessed by indices starting at 0, so
    `tree[0]` accesses `internal_array[4]` in the above example.
    """

    def __init__(
        self, capacity: int, operation: Any, neutral_element: Optional[Any] = None
    ):
        """Initializes a Segment Tree object.

        Args:
            capacity: Total size of the array - must be a power of two.
            operation: Lambda obj, obj -> obj
                The operation for combining elements (eg. sum, max).
                Must be a mathematical group together with the set of
                possible values for array elements.
            neutral_element (Optional[obj]): The neutral element for
                `operation`. Use None for automatically finding a value:
                max: float("-inf"), min: float("inf"), sum: 0.0.
        """

        assert (
            capacity > 0 and capacity & (capacity - 1) == 0
        ), "Capacity must be positive and a power of 2!"
        self.capacity = capacity
        if neutral_element is None:
            neutral_element = (
                0.0
                if operation is operator.add
                else float("-inf")
                if operation is max
                else float("inf")
            )
        self.neutral_element = neutral_element
        self.value = [self.neutral_element for _ in range(2 * capacity)]
        self.operation = operation

    def reduce(self, start: int = 0, end: Optional[int] = None) -> Any:
        """Applies `self.operation` to subsequence of our values.

        Subsequence is contiguous, includes `start` and excludes `end`.

          self.operation(
              arr[start], operation(arr[start+1], operation(... arr[end])))

        Args:
            start: Start index to apply reduction to.
            end (Optional[int]): End index to apply reduction to (excluded).

        Returns:
            any: The result of reducing self.operation over the specified
                range of `self._value` elements.
        """
        if end is None:
            end = self.capacity
        elif end < 0:
            end += self.capacity

        # Init result with neutral element.
        result = self.neutral_element
        # Map start/end to our actual index space (second half of array).
        start += self.capacity
        end += self.capacity

        # Example:
        # internal-array (first half=sums, second half=actual values):
        # 0 1 2 3 | 4 5 6 7
        # - 6 1 5 | 1 0 2 3

        # tree.sum(0, 3) = 3
        # internally: start=4, end=7 -> sum values 1 0 2 = 3.

        # Iterate over tree starting in the actual-values (second half)
        # section.
        # 1) start=4 is even -> do nothing.
        # 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
        # 3) int-divide start and end by 2: start=2, end=3
        # 4) start still smaller end -> iterate once more.
        # 5) start=2 is even -> do nothing.
        # 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
        #    NOTE: This adds the sum of indices 4 and 5 to the result.

        # Iterate as long as start != end.
        while start < end:

            # If start is odd: Add its value to result and move start to
            # next even value.
            if start & 1:
                result = self.operation(result, self.value[start])
                start += 1

            # If end is odd: Move end to previous even value, then add its
            # value to result. NOTE: This takes care of excluding `end` in any
            # situation.
            if end & 1:
                end -= 1
                result = self.operation(result, self.value[end])

            # Divide both start and end by 2 to make them "jump" into the
            # next upper level reduce-index space.
            start //= 2
            end //= 2

            # Then repeat till start == end.

        return result

    def __setitem__(self, idx: int, val: float) -> None:
        """
        Inserts/overwrites a value in/into the tree.

        Args:
            idx: The index to insert to. Must be in [0, `self.capacity`[
            val: The value to insert.
        """
        assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}"

        # Index of the leaf to insert into (always insert in "second half"
        # of the tree, the first half is reserved for already calculated
        # reduction-values).
        idx += self.capacity
        self.value[idx] = val

        # Recalculate all affected reduction values (in "first half" of tree).
        idx = idx >> 1  # Divide by 2 (faster than division).
        while idx >= 1:
            update_idx = 2 * idx  # calculate only once
            # Update the reduction value at the correct "first half" idx.
            self.value[idx] = self.operation(
                self.value[update_idx], self.value[update_idx + 1]
            )
            idx = idx >> 1  # Divide by 2 (faster than division).

    def __getitem__(self, idx: int) -> Any:
        assert 0 <= idx < self.capacity
        return self.value[idx + self.capacity]

    def get_state(self):
        return self.value

    def set_state(self, state):
        assert len(state) == self.capacity * 2
        self.value = state


class SumSegmentTree(SegmentTree):
    """A SegmentTree with the reduction `operation`=operator.add."""

    def __init__(self, capacity: int):
        super(SumSegmentTree, self).__init__(capacity=capacity, operation=operator.add)

    def sum(self, start: int = 0, end: Optional[Any] = None) -> Any:
        """Returns the sum over a sub-segment of the tree."""
        return self.reduce(start, end)

    def find_prefixsum_idx(self, prefixsum: float) -> int:
        """Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.

        Args:
            prefixsum: `prefixsum` upper bound in above constraint.

        Returns:
            int: Largest possible index (i) satisfying above constraint.
        """
        assert 0 <= prefixsum <= self.sum() + 1e-5
        # Global sum node.
        idx = 1

        # While non-leaf (first half of tree).
        while idx < self.capacity:
            update_idx = 2 * idx
            if self.value[update_idx] > prefixsum:
                idx = update_idx
            else:
                prefixsum -= self.value[update_idx]
                idx = update_idx + 1
        return idx - self.capacity


class MinSegmentTree(SegmentTree):
    def __init__(self, capacity: int):
        super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)

    def min(self, start: int = 0, end: Optional[Any] = None) -> Any:
        """Returns min(arr[start], ...,  arr[end])"""
        return self.reduce(start, end)