Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / torch / csrc / distributed / c10d / RankLocal.hpp
Size: Mime:

#pragma once

#include <shared_mutex>

#include <torch/csrc/autograd/function.h>

namespace c10d {

// `RankLocal` maintains a unique instance of T for each non-autograd thread.
// For non-autograd threads, `RankLocal<T>::get()` functions similar to
// thread_local. For autograd threads, `RankLocal<T>::get()` returns the
// instance of T corresponding to the enqueuing non-autograd thread. The
// mechanism allows for rank-specific context shared between forward and
// backward. It works for both the one-rank-per-process and one-rank-per-thread
// scenarios.
//
// NOTE: RankLocal doesn't make the underlying objects thread-safe.
template <typename T>
class RankLocal {
 public:
  RankLocal(const RankLocal&) = delete;
  RankLocal& operator=(const RankLocal&) = delete;

  static T& get() {
    // Fast path: non-autograd threads can simply return
    // the object reference cached in TLS.
    if (cached_ != nullptr) {
      return *cached_;
    }
    const auto node = torch::autograd::get_current_node();
    auto fwd_thread_id = node == nullptr ? at::RecordFunction::currentThreadId()
                                         : node->thread_id();
    // Optimistically acquire the read lock first, since most likely we are in
    // an autograd thread and the object has already been constructed.
    {
      std::shared_lock read_lock(lock_);
      auto it = thread_id_to_rank_local_.find(fwd_thread_id);
      if (it != thread_id_to_rank_local_.end()) {
        // Cache for non-autograd threads
        if (node == nullptr) {
          cached_ = &it->second;
        }
        return it->second;
      }
    }

    std::unique_lock write_lock(lock_);
    auto [it, _] = thread_id_to_rank_local_.try_emplace(fwd_thread_id);
    // Cache for non-autograd threads
    if (node == nullptr) {
      cached_ = &it->second;
    }
    return it->second;
  }

 private:
  RankLocal(){};
  thread_local static T* cached_;
  static std::unordered_map<uint64_t, T> thread_id_to_rank_local_;
  static std::shared_mutex lock_;
};

template <typename T>
thread_local T* RankLocal<T>::cached_ = nullptr;

template <typename T>
std::unordered_map<uint64_t, T> RankLocal<T>::thread_id_to_rank_local_;

template <typename T>
std::shared_mutex RankLocal<T>::lock_;

} // namespace c10d