#ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_
#define CAFFE2_OPERATORS_PREFETCH_OP_H_
#include <condition_variable>
#include <mutex>
#include <thread> // NOLINT
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
// PrefetchOperator is an operator that prefetches the next batch. It should
// almost always be used to read things from disk, so I am setting the input to
// zero blobs.
//
// For any operator that is derived from PrefetchOperator, it should
// explicitly call the Finalize() function in its destructor, so that the
// prefetching thread is properly destructed.
// Note: We inherit from OperatorBase since we control the
// synchronization properties of this operator ourselves (we inform
// the waiting producer after we synchronize). This is a special-case
// - you should generally inherit from Operator<Context> directly.
template <class Context>
class PrefetchOperator : public OperatorBase {
public:
PrefetchOperator(const OperatorDef& operator_def, Workspace* ws)
: OperatorBase(operator_def, ws),
context_(operator_def.device_option()),
prefetched_(false),
prefetch_success_(true),
finalize_(false),
no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
context_.SwitchToDevice();
}
virtual ~PrefetchOperator() noexcept {
CHECK(finalize_ || !prefetch_thread_.get())
<< "YOU MADE A PROGRAMING ERROR: derived class of PrefetchOperator "
"should call Finalize() in its destructor so the prefetching "
"thread is joined. ";
}
void Finalize() {
if (prefetch_thread_.get()) {
{
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (!prefetched_)
consumer_.wait(lock);
finalize_ = true;
prefetched_ = false;
}
producer_.notify_one();
prefetch_thread_->join();
prefetch_thread_.reset();
} else {
// If we never initialized the prefetch thread, just set
// finalize anyway.
finalize_ = true;
}
}
bool Run(int /* unused */ /*stream_id*/) override {
if (no_prefetch_) {
context_.SwitchToDevice();
bool result = Prefetch() && CopyPrefetched();
context_.FinishDeviceComputation();
return result;
}
// Note(jiayq): We only start the prefetch_thread at the Run() function
// instead of in the constructor, because the prefetch_thread needs to start
// after all derived classes' constructors finish.
if (!prefetch_thread_) {
prefetch_thread_.reset(
new std::thread([this] { this->PrefetchWorker(); }));
}
context_.SwitchToDevice();
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (!prefetched_)
consumer_.wait(lock);
if (!prefetch_success_) {
LOG(ERROR) << "Prefetching failed.";
return false;
}
if (!CopyPrefetched()) {
LOG(ERROR) << "Error when copying prefetched data.";
return false;
}
prefetched_ = false;
context_.FinishDeviceComputation();
producer_.notify_one();
return true;
}
void PrefetchWorker() {
context_.SwitchToDevice();
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (prefetched_)
producer_.wait(lock);
while (!finalize_) {
// We will need to run a FinishDeviceComputation() call because the
// prefetcher thread and the main thread are potentially using different
// streams (like on GPU).
try {
prefetch_success_ = Prefetch();
context_.FinishDeviceComputation();
} catch (const std::exception& e) {
// TODO: propagate exception_ptr to the caller side
LOG(ERROR) << "Prefetching error " << e.what();
prefetch_success_ = false;
}
prefetched_ = true;
consumer_.notify_one();
while (prefetched_)
producer_.wait(lock);
}
}
// You will need to implement this instead of the Run function.
virtual bool Prefetch() = 0;
virtual bool CopyPrefetched() = 0;
protected:
Context context_;
std::mutex prefetch_access_mutex_;
std::condition_variable producer_, consumer_;
// prefetched_ is used to tell the operator that it is done.
std::atomic<bool> prefetched_;
// prefetch_success_ is used to see if prefetching failed or not.
std::atomic<bool> prefetch_success_;
// finalize_ is used to tell the prefetcher to quit.
std::atomic<bool> finalize_;
unique_ptr<std::thread> prefetch_thread_;
// Whether to do prefetching or run this as a normal operator
const bool no_prefetch_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_PREFETCH_OP_H_