Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / c10 / util / sparse_bitset.h

//===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
 //
 // This file defines the SparseBitVector class.  See the doxygen comment for
 // SparseBitVector for more details on the algorithm used.
 //
 //===----------------------------------------------------------------------===//

#pragma once
 #include <cassert>
 #include <climits>
 #include <cstring>
 #include <iterator>
 #include <list>
 #include <c10/util/llvmMathExtras.h>

 namespace c10 {

 /// SparseBitVector is an implementation of a bitvector that is sparse by only
 /// storing the elements that have non-zero bits set.  In order to make this
 /// fast for the most common cases, SparseBitVector is implemented as a linked
 /// list of SparseBitVectorElements.  We maintain a pointer to the last
 /// SparseBitVectorElement accessed (in the form of a list iterator), in order
 /// to make multiple in-order test/set constant time after the first one is
 /// executed.  Note that using vectors to store SparseBitVectorElement's does
 /// not work out very well because it causes insertion in the middle to take
 /// enormous amounts of time with a large amount of bits.  Other structures that
 /// have better worst cases for insertion in the middle (various balanced trees,
 /// etc) do not perform as well in practice as a linked list with this iterator
 /// kept up to date.  They are also significantly more memory intensive.

 template <unsigned ElementSize = 128> struct SparseBitVectorElement {
 public:
   using BitWord = unsigned long;
   using size_type = unsigned;
   enum {
     BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT,
     BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE,
     BITS_PER_ELEMENT = ElementSize
   };

 private:
   // Index of Element in terms of where first bit starts.
   unsigned ElementIndex;
   BitWord Bits[BITWORDS_PER_ELEMENT];

   SparseBitVectorElement() {
     ElementIndex = ~0U;
     memset(&Bits[0], 0, sizeof (BitWord) * BITWORDS_PER_ELEMENT);
   }

 public:
   explicit SparseBitVectorElement(unsigned Idx) {
     ElementIndex = Idx;
     memset(&Bits[0], 0, sizeof (BitWord) * BITWORDS_PER_ELEMENT);
   }

   // Comparison.
   bool operator==(const SparseBitVectorElement &RHS) const {
     if (ElementIndex != RHS.ElementIndex)
       return false;
     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
       if (Bits[i] != RHS.Bits[i])
         return false;
     return true;
   }

   bool operator!=(const SparseBitVectorElement &RHS) const {
     return !(*this == RHS);
   }

   // Return the bits that make up word Idx in our element.
   BitWord word(unsigned Idx) const {
     assert(Idx < BITWORDS_PER_ELEMENT);
     return Bits[Idx];
   }

   unsigned index() const {
     return ElementIndex;
   }

   bool empty() const {
     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
       if (Bits[i])
         return false;
     return true;
   }

   void set(unsigned Idx) {
     Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE);
   }

   bool test_and_set(unsigned Idx) {
     bool old = test(Idx);
     if (!old) {
       set(Idx);
       return true;
     }
     return false;
   }

   void reset(unsigned Idx) {
     Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE));
   }

   bool test(unsigned Idx) const {
     return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE));
   }

   size_type count() const {
     unsigned NumBits = 0;
     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
       NumBits += llvm::countPopulation(Bits[i]);
     return NumBits;
   }

   /// find_first - Returns the index of the first set bit.
   int find_first() const {
     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
       if (Bits[i] != 0)
         return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
      throw std::runtime_error("Illegal empty element");
   }

   /// find_last - Returns the index of the last set bit.
   int find_last() const {
     for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) {
       unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
       if (Bits[Idx] != 0)
         return Idx * BITWORD_SIZE + BITWORD_SIZE -
                 llvm::countLeadingZeros(Bits[Idx]);
     }
      throw std::runtime_error("Illegal empty element");
   }

   /// find_next - Returns the index of the next set bit starting from the
   /// "Curr" bit. Returns -1 if the next set bit is not found.
   int find_next(unsigned Curr) const {
     if (Curr >= BITS_PER_ELEMENT)
       return -1;

     unsigned WordPos = Curr / BITWORD_SIZE;
     unsigned BitPos = Curr % BITWORD_SIZE;
     BitWord Copy = Bits[WordPos];
     assert(WordPos <= BITWORDS_PER_ELEMENT
            && "Word Position outside of element");

     // Mask off previous bits.
     Copy &= ~0UL << BitPos;

     if (Copy != 0)
       return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);

     // Check subsequent words.
     for (unsigned i = WordPos+1; i < BITWORDS_PER_ELEMENT; ++i)
       if (Bits[i] != 0)
         return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
     return -1;
   }

   // Union this element with RHS and return true if this one changed.
   bool unionWith(const SparseBitVectorElement &RHS) {
     bool changed = false;
     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
       BitWord old = changed ? 0 : Bits[i];

       Bits[i] |= RHS.Bits[i];
       if (!changed && old != Bits[i])
         changed = true;
     }
     return changed;
   }

   // Return true if we have any bits in common with RHS
   bool intersects(const SparseBitVectorElement &RHS) const {
     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
       if (RHS.Bits[i] & Bits[i])
         return true;
     }
     return false;
   }

   // Intersect this Element with RHS and return true if this one changed.
   // BecameZero is set to true if this element became all-zero bits.
   bool intersectWith(const SparseBitVectorElement &RHS,
                      bool &BecameZero) {
     bool changed = false;
     bool allzero = true;

     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
       BitWord old = changed ? 0 : Bits[i];

       Bits[i] &= RHS.Bits[i];
       if (Bits[i] != 0)
         allzero = false;

       if (!changed && old != Bits[i])
         changed = true;
     }
     BecameZero = allzero;
     return changed;
   }

   // Intersect this Element with the complement of RHS and return true if this
   // one changed.  BecameZero is set to true if this element became all-zero
   // bits.
   bool intersectWithComplement(const SparseBitVectorElement &RHS,
                                bool &BecameZero) {
     bool changed = false;
     bool allzero = true;

     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
       BitWord old = changed ? 0 : Bits[i];

       Bits[i] &= ~RHS.Bits[i];
       if (Bits[i] != 0)
         allzero = false;

       if (!changed && old != Bits[i])
         changed = true;
     }
     BecameZero = allzero;
     return changed;
   }

   // Three argument version of intersectWithComplement that intersects
   // RHS1 & ~RHS2 into this element
   void intersectWithComplement(const SparseBitVectorElement &RHS1,
                                const SparseBitVectorElement &RHS2,
                                bool &BecameZero) {
     bool allzero = true;

     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
       Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i];
       if (Bits[i] != 0)
         allzero = false;
     }
     BecameZero = allzero;
   }
 };

 template <unsigned ElementSize = 128>
 class SparseBitVector {
   using ElementList = std::list<SparseBitVectorElement<ElementSize>>;
   using ElementListIter = typename ElementList::iterator;
   using ElementListConstIter = typename ElementList::const_iterator;
   enum {
     BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE
   };

   ElementList Elements;
   // Pointer to our current Element. This has no visible effect on the external
   // state of a SparseBitVector, it's just used to improve performance in the
   // common case of testing/modifying bits with similar indices.
   mutable ElementListIter CurrElementIter;

   // This is like std::lower_bound, except we do linear searching from the
   // current position.
   ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const {

     // We cache a non-const iterator so we're forced to resort to const_cast to
     // get the begin/end in the case where 'this' is const. To avoid duplication
     // of code with the only difference being whether the const cast is present
     // 'this' is always const in this particular function and we sort out the
     // difference in FindLowerBound and FindLowerBoundConst.
     ElementListIter Begin =
         const_cast<SparseBitVector<ElementSize> *>(this)->Elements.begin();
     ElementListIter End =
         const_cast<SparseBitVector<ElementSize> *>(this)->Elements.end();

     if (Elements.empty()) {
       CurrElementIter = Begin;
       return CurrElementIter;
     }

     // Make sure our current iterator is valid.
     if (CurrElementIter == End)
       --CurrElementIter;

     // Search from our current iterator, either backwards or forwards,
     // depending on what element we are looking for.
     ElementListIter ElementIter = CurrElementIter;
     if (CurrElementIter->index() == ElementIndex) {
       return ElementIter;
     } else if (CurrElementIter->index() > ElementIndex) {
       while (ElementIter != Begin
              && ElementIter->index() > ElementIndex)
         --ElementIter;
     } else {
       while (ElementIter != End &&
              ElementIter->index() < ElementIndex)
         ++ElementIter;
     }
     CurrElementIter = ElementIter;
     return ElementIter;
   }
   ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const {
     return FindLowerBoundImpl(ElementIndex);
   }
   ElementListIter FindLowerBound(unsigned ElementIndex) {
     return FindLowerBoundImpl(ElementIndex);
   }

   // Iterator to walk set bits in the bitmap.  This iterator is a lot uglier
   // than it would be, in order to be efficient.
   class SparseBitVectorIterator {
   private:
     bool AtEnd;

     const SparseBitVector<ElementSize> *BitVector = nullptr;

     // Current element inside of bitmap.
     ElementListConstIter Iter;

     // Current bit number inside of our bitmap.
     unsigned BitNumber;

     // Current word number inside of our element.
     unsigned WordNumber;

     // Current bits from the element.
     typename SparseBitVectorElement<ElementSize>::BitWord Bits;

     // Move our iterator to the first non-zero bit in the bitmap.
     void AdvanceToFirstNonZero() {
       if (AtEnd)
         return;
       if (BitVector->Elements.empty()) {
         AtEnd = true;
         return;
       }
       Iter = BitVector->Elements.begin();
       BitNumber = Iter->index() * ElementSize;
       unsigned BitPos = Iter->find_first();
       BitNumber += BitPos;
       WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
       Bits = Iter->word(WordNumber);
       Bits >>= BitPos % BITWORD_SIZE;
     }
Loading ...