Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ include / c10 / util / LeftRight.h

#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/Synchronized.h>
#include <array>
#include <atomic>
#include <functional>
#include <mutex>
#include <shared_mutex>
#include <thread>

namespace c10 {

namespace detail {

struct IncrementRAII final {
 public:
  explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) {
    _counter->fetch_add(1);
  }

  ~IncrementRAII() {
    _counter->fetch_sub(1);
  }

 private:
  std::atomic<int32_t>* _counter;

  C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII);
};

} // namespace detail

// LeftRight wait-free readers synchronization primitive
// https://hal.archives-ouvertes.fr/hal-01207881/document
//
// LeftRight is quite easy to use (it can make an arbitrary
// data structure permit wait-free reads), but it has some
// particular performance characteristics you should be aware
// of if you're deciding to use it:
//
//  - Reads still incur an atomic write (this is how LeftRight
//    keeps track of how long it needs to keep around the old
//    data structure)
//
//  - Writes get executed twice, to keep both the left and right
//    versions up to date.  So if your write is expensive or
//    nondeterministic, this is also an inappropriate structure
//
// LeftRight is used fairly rarely in PyTorch's codebase.  If you
// are still not sure if you need it or not, consult your local
// C++ expert.
//
template <class T>
class LeftRight final {
 public:
  template <class... Args>
  explicit LeftRight(const Args&... args)
      : _counters{{{0}, {0}}},
        _foregroundCounterIndex(0),
        _foregroundDataIndex(0),
        _data{{T{args...}, T{args...}}},
        _writeMutex() {}

  // Copying and moving would not be threadsafe.
  // Needs more thought and careful design to make that work.
  LeftRight(const LeftRight&) = delete;
  LeftRight(LeftRight&&) noexcept = delete;
  LeftRight& operator=(const LeftRight&) = delete;
  LeftRight& operator=(LeftRight&&) noexcept = delete;

  ~LeftRight() {
    // wait until any potentially running writers are finished
    { std::unique_lock<std::mutex> lock(_writeMutex); }

    // wait until any potentially running readers are finished
    while (_counters[0].load() != 0 || _counters[1].load() != 0) {
      std::this_thread::yield();
    }
  }

  template <typename F>
  auto read(F&& readFunc) const -> typename c10::invoke_result_t<F, const T&> {
    detail::IncrementRAII _increment_counter(
        &_counters[_foregroundCounterIndex.load()]);

    return readFunc(_data[_foregroundDataIndex.load()]);
  }

  // Throwing an exception in writeFunc is ok but causes the state to be either
  // the old or the new state, depending on if the first or the second call to
  // writeFunc threw.
  template <typename F>
  auto write(F&& writeFunc) -> typename c10::invoke_result_t<F, T&> {
    std::unique_lock<std::mutex> lock(_writeMutex);

    return _write(writeFunc);
  }

 private:
  template <class F>
  auto _write(const F& writeFunc) -> typename c10::invoke_result_t<F, T&> {
    /*
     * Assume, A is in background and B in foreground. In simplified terms, we
     * want to do the following:
     * 1. Write to A (old background)
     * 2. Switch A/B
     * 3. Write to B (new background)
     *
     * More detailed algorithm (explanations on why this is important are below
     * in code):
     * 1. Write to A
     * 2. Switch A/B data pointers
     * 3. Wait until A counter is zero
     * 4. Switch A/B counters
     * 5. Wait until B counter is zero
     * 6. Write to B
     */

    auto localDataIndex = _foregroundDataIndex.load();

    // 1. Write to A
    _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);

    // 2. Switch A/B data pointers
    localDataIndex = localDataIndex ^ 1;
    _foregroundDataIndex = localDataIndex;

    /*
     * 3. Wait until A counter is zero
     *
     * In the previous write run, A was foreground and B was background.
     * There was a time after switching _foregroundDataIndex (B to foreground)
     * and before switching _foregroundCounterIndex, in which new readers could
     * have read B but incremented A's counter.
     *
     * In this current run, we just switched _foregroundDataIndex (A back to
     * foreground), but before writing to the new background B, we have to make
     * sure A's counter was zero briefly, so all these old readers are gone.
     */
    auto localCounterIndex = _foregroundCounterIndex.load();
    _waitForBackgroundCounterToBeZero(localCounterIndex);

    /*
     * 4. Switch A/B counters
     *
     * Now that we know all readers on B are really gone, we can switch the
     * counters and have new readers increment A's counter again, which is the
     * correct counter since they're reading A.
     */
    localCounterIndex = localCounterIndex ^ 1;
    _foregroundCounterIndex = localCounterIndex;

    /*
     * 5. Wait until B counter is zero
     *
     * This waits for all the readers on B that came in while both data and
     * counter for B was in foreground, i.e. normal readers that happened
     * outside of that brief gap between switching data and counter.
     */
    _waitForBackgroundCounterToBeZero(localCounterIndex);

    // 6. Write to B
    return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
  }

  template <class F>
  auto _callWriteFuncOnBackgroundInstance(
      const F& writeFunc,
      uint8_t localDataIndex) -> typename c10::invoke_result_t<F, T&> {
    try {
      return writeFunc(_data[localDataIndex ^ 1]);
    } catch (...) {
      // recover invariant by copying from the foreground instance
      _data[localDataIndex ^ 1] = _data[localDataIndex];
      // rethrow
      throw;
    }
  }

  void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
    while (_counters[counterIndex ^ 1].load() != 0) {
      std::this_thread::yield();
    }
  }

  mutable std::array<std::atomic<int32_t>, 2> _counters;
  std::atomic<uint8_t> _foregroundCounterIndex;
  std::atomic<uint8_t> _foregroundDataIndex;
  std::array<T, 2> _data;
  std::mutex _writeMutex;
};

// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a
// read-write lock to protect T (data).
template <class T>
class RWSafeLeftRightWrapper final {
 public:
  template <class... Args>
  explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {}

  // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight
  // is not copyable or moveable.
  RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete;
  RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete;
  RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete;
  RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete;

  template <typename F>
  auto read(F&& readFunc) const -> typename c10::invoke_result_t<F, const T&> {
    return data_.withLock(
        [&readFunc](T const& data) { return readFunc(data); });
  }

  template <typename F>
  auto write(F&& writeFunc) -> typename c10::invoke_result_t<F, T&> {
    return data_.withLock([&writeFunc](T& data) { return writeFunc(data); });
  }

 private:
  c10::Synchronized<T> data_;
};

} // namespace c10