Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / prefetch_op.h

#ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_
#define CAFFE2_OPERATORS_PREFETCH_OP_H_

#include <condition_variable>
#include <mutex>
#include <thread> // NOLINT

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"

namespace caffe2 {

// PrefetchOperator is an operator that prefetches the next batch. It should
// almost always be used to read things from disk, so I am setting the input to
// zero blobs.
//
// For any operator that is derived from PrefetchOperator, it should
// explicitly call the Finalize() function in its destructor, so that the
// prefetching thread is properly destructed.

// Note: We inherit from OperatorBase since we control the
// synchronization properties of this operator ourselves (we inform
// the waiting producer after we synchronize). This is a special-case
// - you should generally inherit from Operator<Context> directly.
template <class Context>
class PrefetchOperator : public OperatorBase {
 public:
  PrefetchOperator(const OperatorDef& operator_def, Workspace* ws)
      : OperatorBase(operator_def, ws),
        context_(operator_def.device_option()),
        prefetched_(false),
        prefetch_success_(true),
        finalize_(false),
        no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
    context_.SwitchToDevice();
  }

  virtual ~PrefetchOperator() noexcept {
    CHECK(finalize_ || !prefetch_thread_.get())
        << "YOU MADE A PROGRAMING ERROR: derived class of PrefetchOperator "
           "should call Finalize() in its destructor so the prefetching "
           "thread is joined. ";
  }

  void Finalize() {
    if (prefetch_thread_.get()) {
      {
        std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
        while (!prefetched_)
          consumer_.wait(lock);
        finalize_ = true;
        prefetched_ = false;
      }
      producer_.notify_one();
      prefetch_thread_->join();
      prefetch_thread_.reset();
    } else {
      // If we never initialized the prefetch thread, just set
      // finalize anyway.
      finalize_ = true;
    }
  }

  bool Run(int /* unused */ /*stream_id*/) override {
    if (no_prefetch_) {
      context_.SwitchToDevice();
      bool result = Prefetch() && CopyPrefetched();
      context_.FinishDeviceComputation();
      return result;
    }
    // Note(jiayq): We only start the prefetch_thread at the Run() function
    // instead of in the constructor, because the prefetch_thread needs to start
    // after all derived classes' constructors finish.
    if (!prefetch_thread_) {
      prefetch_thread_.reset(
          new std::thread([this] { this->PrefetchWorker(); }));
    }
    context_.SwitchToDevice();
    std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
    while (!prefetched_)
      consumer_.wait(lock);
    if (!prefetch_success_) {
      LOG(ERROR) << "Prefetching failed.";
      return false;
    }
    if (!CopyPrefetched()) {
      LOG(ERROR) << "Error when copying prefetched data.";
      return false;
    }
    prefetched_ = false;
    context_.FinishDeviceComputation();
    producer_.notify_one();
    return true;
  }

  void PrefetchWorker() {
    context_.SwitchToDevice();
    std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
    while (prefetched_)
      producer_.wait(lock);
    while (!finalize_) {
      // We will need to run a FinishDeviceComputation() call because the
      // prefetcher thread and the main thread are potentially using different
      // streams (like on GPU).
      try {
        prefetch_success_ = Prefetch();
        context_.FinishDeviceComputation();
      } catch (const std::exception& e) {
        // TODO: propagate exception_ptr to the caller side
        LOG(ERROR) << "Prefetching error " << e.what();
        prefetch_success_ = false;
      }
      prefetched_ = true;
      consumer_.notify_one();
      while (prefetched_)
        producer_.wait(lock);
    }
  }

  // You will need to implement this instead of the Run function.
  virtual bool Prefetch() = 0;
  virtual bool CopyPrefetched() = 0;

 protected:
  Context context_;
  std::mutex prefetch_access_mutex_;
  std::condition_variable producer_, consumer_;
  // prefetched_ is used to tell the operator that it is done.
  std::atomic<bool> prefetched_;
  // prefetch_success_ is used to see if prefetching failed or not.
  std::atomic<bool> prefetch_success_;
  // finalize_ is used to tell the prefetcher to quit.
  std::atomic<bool> finalize_;
  unique_ptr<std::thread> prefetch_thread_;

  // Whether to do prefetching or run this as a normal operator
  const bool no_prefetch_;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_PREFETCH_OP_H_