Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / predictor / ThreadLocalPtr.h

#pragma once

#include <mutex>
#include <unordered_map>
#include <unordered_set>
#include "caffe2/core/logging.h"

namespace caffe2 {

/**
 * thread_local pointer in C++ is a per thread pointer. However, sometimes
 * we want to have a thread local state that is per thread and also per
 * instance. e.g. we have the following class:
 * class A {
 *   ThreadLocalPtr<int> x;
 * }
 * We would like to have a copy of x per thread and also per instance of class A
 * This can be applied to storing per instance thread local state of some class,
 * when we could have multiple instances of the class in the same thread.
 * We implemented a subset of functions in folly::ThreadLocalPtr that's enough
 * to support BlackBoxPredictor.
 */

class ThreadLocalPtrImpl;
class ThreadLocalHelper;

/**
 * Map of object pointer to instance in each thread
 * to achieve per thread(using thread_local) per object(using the map)
 * thread local pointer
 */
typedef std::unordered_map<ThreadLocalPtrImpl*, std::shared_ptr<void>>
    UnsafeThreadLocalMap;

ThreadLocalHelper* getThreadLocalHelper();

typedef std::vector<ThreadLocalHelper*> UnsafeAllThreadLocalHelperVector;

/**
 * A thread safe vector of all ThreadLocalHelper, this will be used
 * to encapuslate the locking in the APIs for the changes to the global
 * AllThreadLocalHelperVector instance.
 */
class AllThreadLocalHelperVector {
 public:
  AllThreadLocalHelperVector() {}

  // Add a new ThreadLocalHelper to the vector
  void push_back(ThreadLocalHelper* helper);

  // Erase a ThreadLocalHelper to the vector
  void erase(ThreadLocalHelper* helper);

  // Erase object in all the helpers stored in vector
  // Called during destructor of a ThreadLocalPtrImpl
  void erase_tlp(ThreadLocalPtrImpl* ptr);

 private:
  UnsafeAllThreadLocalHelperVector vector_;
  std::mutex mutex_;
};

/**
 * ThreadLocalHelper is per thread
 */
class ThreadLocalHelper {
 public:
  ThreadLocalHelper();

  // When the thread dies, we want to clean up *this*
  // in AllThreadLocalHelperVector
  ~ThreadLocalHelper();

  // Insert a (object, ptr) pair into the thread local map
  void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr<void> ptr);
  // Get the ptr by object
  void* get(ThreadLocalPtrImpl* key);
  // Erase the ptr associated with the object in the map
  void erase(ThreadLocalPtrImpl* key);

 private:
  // mapping of object -> ptr in each thread
  UnsafeThreadLocalMap mapping_;
  std::mutex mutex_;
}; // ThreadLocalHelper

/** ThreadLocalPtrImpl is per object
 */
class ThreadLocalPtrImpl {
 public:
  ThreadLocalPtrImpl() {}
  // Delete copy and move constructors
  ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete;
  ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete;
  ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete;
  ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete;

  // In the case when object dies first, we want to
  // clean up the states in all child threads
  ~ThreadLocalPtrImpl();

  template <typename T>
  T* get() {
    return static_cast<T*>(getThreadLocalHelper()->get(this));
  }

  template <typename T>
  void reset(T* newPtr = nullptr) {
    VLOG(2) << "In Reset(" << newPtr << ")";
    auto* wrapper = getThreadLocalHelper();
    // Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread
    wrapper->erase(this);
    if (newPtr != nullptr) {
      std::shared_ptr<void> sharedPtr(newPtr);
      // Deletion of newPtr is handled by shared_ptr
      // as it implements type erasure
      wrapper->insert(this, std::move(sharedPtr));
    }
  }

}; // ThreadLocalPtrImpl

template <typename T>
class ThreadLocalPtr {
 public:
  auto* operator-> () {
    return get();
  }

  auto& operator*() {
    return *get();
  }

  auto* get() {
    return impl_.get<T>();
  }

  auto* operator-> () const {
    return get();
  }

  auto& operator*() const {
    return *get();
  }

  auto* get() const {
    return impl_.get<T>();
  }

  void reset(unique_ptr<T> ptr = nullptr) {
    impl_.reset<T>(ptr.release());
  }

 private:
  ThreadLocalPtrImpl impl_;
};

} // namespace caffe2