Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ include / c10 / util / sparse_bitset.h

//===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the SparseBitVector class.  See the doxygen comment for
// SparseBitVector for more details on the algorithm used.
//
//===----------------------------------------------------------------------===//

#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/llvmMathExtras.h>
#include <cassert>
#include <climits>
#include <cstring>
#include <iterator>
#include <list>

C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
#endif

namespace c10 {

/// SparseBitVector is an implementation of a bitvector that is sparse by only
/// storing the elements that have non-zero bits set.  In order to make this
/// fast for the most common cases, SparseBitVector is implemented as a linked
/// list of SparseBitVectorElements.  We maintain a pointer to the last
/// SparseBitVectorElement accessed (in the form of a list iterator), in order
/// to make multiple in-order test/set constant time after the first one is
/// executed.  Note that using vectors to store SparseBitVectorElement's does
/// not work out very well because it causes insertion in the middle to take
/// enormous amounts of time with a large amount of bits.  Other structures that
/// have better worst cases for insertion in the middle (various balanced trees,
/// etc) do not perform as well in practice as a linked list with this iterator
/// kept up to date.  They are also significantly more memory intensive.

template <unsigned ElementSize = 128>
struct SparseBitVectorElement {
 public:
  using BitWord = unsigned long;
  using size_type = unsigned;
  enum {
    BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT,
    BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE,
    BITS_PER_ELEMENT = ElementSize
  };

 private:
  // Index of Element in terms of where first bit starts.
  unsigned ElementIndex;
  BitWord Bits[BITWORDS_PER_ELEMENT];

  SparseBitVectorElement() {
    ElementIndex = ~0U;
    memset(&Bits[0], 0, sizeof(BitWord) * BITWORDS_PER_ELEMENT);
  }

 public:
  explicit SparseBitVectorElement(unsigned Idx) {
    ElementIndex = Idx;
    memset(&Bits[0], 0, sizeof(BitWord) * BITWORDS_PER_ELEMENT);
  }

  // Comparison.
  bool operator==(const SparseBitVectorElement& RHS) const {
    if (ElementIndex != RHS.ElementIndex)
      return false;
    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
      if (Bits[i] != RHS.Bits[i])
        return false;
    return true;
  }

  bool operator!=(const SparseBitVectorElement& RHS) const {
    return !(*this == RHS);
  }

  // Return the bits that make up word Idx in our element.
  BitWord word(unsigned Idx) const {
    assert(Idx < BITWORDS_PER_ELEMENT);
    return Bits[Idx];
  }

  unsigned index() const {
    return ElementIndex;
  }

  bool empty() const {
    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
      if (Bits[i])
        return false;
    return true;
  }

  void set(unsigned Idx) {
    Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE);
  }

  bool test_and_set(unsigned Idx) {
    bool old = test(Idx);
    if (!old) {
      set(Idx);
      return true;
    }
    return false;
  }

  void reset(unsigned Idx) {
    Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE));
  }

  bool test(unsigned Idx) const {
    return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE));
  }

  size_type count() const {
    unsigned NumBits = 0;
    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
      NumBits += llvm::countPopulation(Bits[i]);
    return NumBits;
  }

  /// find_first - Returns the index of the first set bit.
  int find_first() const {
    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
      if (Bits[i] != 0)
        return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
    throw std::runtime_error("Illegal empty element");
  }

  /// find_last - Returns the index of the last set bit.
  int find_last() const {
    for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) {
      unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
      if (Bits[Idx] != 0)
        return Idx * BITWORD_SIZE + BITWORD_SIZE -
            llvm::countLeadingZeros(Bits[Idx]);
    }
    throw std::runtime_error("Illegal empty element");
  }

  /// find_next - Returns the index of the next set bit starting from the
  /// "Curr" bit. Returns -1 if the next set bit is not found.
  int find_next(unsigned Curr) const {
    if (Curr >= BITS_PER_ELEMENT)
      return -1;

    unsigned WordPos = Curr / BITWORD_SIZE;
    unsigned BitPos = Curr % BITWORD_SIZE;
    BitWord Copy = Bits[WordPos];
    assert(
        WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element");

    // Mask off previous bits.
    Copy &= ~0UL << BitPos;

    if (Copy != 0)
      return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);

    // Check subsequent words.
    for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i)
      if (Bits[i] != 0)
        return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
    return -1;
  }

  // Union this element with RHS and return true if this one changed.
  bool unionWith(const SparseBitVectorElement& RHS) {
    bool changed = false;
    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
      BitWord old = changed ? 0 : Bits[i];

      Bits[i] |= RHS.Bits[i];
      if (!changed && old != Bits[i])
        changed = true;
    }
    return changed;
  }

  // Return true if we have any bits in common with RHS
  bool intersects(const SparseBitVectorElement& RHS) const {
    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
      if (RHS.Bits[i] & Bits[i])
        return true;
    }
    return false;
  }

  // Intersect this Element with RHS and return true if this one changed.
  // BecameZero is set to true if this element became all-zero bits.
  bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) {
    bool changed = false;
    bool allzero = true;

    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
      BitWord old = changed ? 0 : Bits[i];

      Bits[i] &= RHS.Bits[i];
      if (Bits[i] != 0)
        allzero = false;

      if (!changed && old != Bits[i])
        changed = true;
    }
    BecameZero = allzero;
    return changed;
  }

  // Intersect this Element with the complement of RHS and return true if this
  // one changed.  BecameZero is set to true if this element became all-zero
  // bits.
  bool intersectWithComplement(
      const SparseBitVectorElement& RHS,
      bool& BecameZero) {
    bool changed = false;
    bool allzero = true;

    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
      BitWord old = changed ? 0 : Bits[i];

      Bits[i] &= ~RHS.Bits[i];
      if (Bits[i] != 0)
        allzero = false;

      if (!changed && old != Bits[i])
        changed = true;
    }
    BecameZero = allzero;
    return changed;
  }

  // Three argument version of intersectWithComplement that intersects
  // RHS1 & ~RHS2 into this element
  void intersectWithComplement(
      const SparseBitVectorElement& RHS1,
      const SparseBitVectorElement& RHS2,
      bool& BecameZero) {
    bool allzero = true;

    for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
      Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i];
      if (Bits[i] != 0)
        allzero = false;
    }
    BecameZero = allzero;
  }
};

template <unsigned ElementSize = 128>
class SparseBitVector {
  using ElementList = std::list<SparseBitVectorElement<ElementSize>>;
  using ElementListIter = typename ElementList::iterator;
  using ElementListConstIter = typename ElementList::const_iterator;
  enum { BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE };

  ElementList Elements;
  // Pointer to our current Element. This has no visible effect on the external
  // state of a SparseBitVector, it's just used to improve performance in the
  // common case of testing/modifying bits with similar indices.
  mutable ElementListIter CurrElementIter;

  // This is like std::lower_bound, except we do linear searching from the
  // current position.
  ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const {
    // We cache a non-const iterator so we're forced to resort to const_cast to
    // get the begin/end in the case where 'this' is const. To avoid duplication
    // of code with the only difference being whether the const cast is present
    // 'this' is always const in this particular function and we sort out the
    // difference in FindLowerBound and FindLowerBoundConst.
    ElementListIter Begin =
        const_cast<SparseBitVector<ElementSize>*>(this)->Elements.begin();
    ElementListIter End =
        const_cast<SparseBitVector<ElementSize>*>(this)->Elements.end();

    if (Elements.empty()) {
      CurrElementIter = Begin;
      return CurrElementIter;
    }

    // Make sure our current iterator is valid.
    if (CurrElementIter == End)
      --CurrElementIter;

    // Search from our current iterator, either backwards or forwards,
    // depending on what element we are looking for.
    ElementListIter ElementIter = CurrElementIter;
    if (CurrElementIter->index() == ElementIndex) {
      return ElementIter;
    } else if (CurrElementIter->index() > ElementIndex) {
      while (ElementIter != Begin && ElementIter->index() > ElementIndex)
        --ElementIter;
    } else {
      while (ElementIter != End && ElementIter->index() < ElementIndex)
        ++ElementIter;
    }
    CurrElementIter = ElementIter;
    return ElementIter;
  }
  ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const {
    return FindLowerBoundImpl(ElementIndex);
  }
  ElementListIter FindLowerBound(unsigned ElementIndex) {
    return FindLowerBoundImpl(ElementIndex);
  }

  // Iterator to walk set bits in the bitmap.  This iterator is a lot uglier
  // than it would be, in order to be efficient.
  class SparseBitVectorIterator {
   private:
    bool AtEnd;

    const SparseBitVector<ElementSize>* BitVector = nullptr;

    // Current element inside of bitmap.
    ElementListConstIter Iter;

    // Current bit number inside of our bitmap.
    unsigned BitNumber;

    // Current word number inside of our element.
    unsigned WordNumber;

    // Current bits from the element.
    typename SparseBitVectorElement<ElementSize>::BitWord Bits;

    // Move our iterator to the first non-zero bit in the bitmap.
    void AdvanceToFirstNonZero() {
      if (AtEnd)
        return;
      if (BitVector->Elements.empty()) {
        AtEnd = true;
        return;
      }
      Iter = BitVector->Elements.begin();
      BitNumber = Iter->index() * ElementSize;
      unsigned BitPos = Iter->find_first();
      BitNumber += BitPos;
      WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
      Bits = Iter->word(WordNumber);
Loading ...