#ifndef CAFFE2_CORE_EVENT_H_
#define CAFFE2_CORE_EVENT_H_
#include <chrono>
#include <c10/core/DeviceType.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
constexpr int MaxDeviceTypes =
DeviceTypeProto::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
class Event;
enum EventStatus {
EVENT_INITIALIZED = 0,
EVENT_SCHEDULED = 1,
EVENT_SUCCESS = 2,
EVENT_FAILED = 3,
};
// For the following functions, void* shall be interpreted as the corresponding
// context object corresponding to the device type associated with the
// functions.
// Initializes event
typedef void (*EventCreateFunction)(const DeviceOption& option, Event*);
// Called on event to signal that CPU part of operation is finished,
// Optionally accepts error message from CPU part.
// Should be called no more than once per event
typedef void (*EventRecordFunction)(Event*, const void*, const char*);
// Waits and returns as soon as possible in order schedule next operation,
// e.g. for CUDA->CUDA waits only for CPU part of CUDA op,
// for CUDA->CPU waits till the CUDA op is fully completed.
// Prepares context to synchronize device part of operation.
// Can be called concurrently from multiple threads
typedef void (*EventWaitFunction)(const Event*, void*);
// Waits till operation is fully finished,
// can be called concurrently from multiple threads
typedef void (*EventFinishFunction)(const Event*);
// Queries current status of operation,
// can be called concurrently from multiple threads
typedef EventStatus (*EventQueryFunction)(const Event*);
typedef const std::string& (*EventErrorMessageFunction)(const Event*);
typedef void (*EventSetFinishedFunction)(const Event*, const char*);
typedef void (*EventResetFunction)(Event*);
// Sets callback that is called when event is finished
typedef std::function<void()> EventCallbackFunction;
typedef void (*EventSetCallbackFunction)(Event*, EventCallbackFunction);
class TORCH_API Event {
public:
explicit Event(const DeviceOption& option)
: event_(), type_(option.device_type()), option_(option) {
CAFFE_ENFORCE_LT(type_, MaxDeviceTypes);
CAFFE_ENFORCE(event_creator_[type_]);
event_creator_[type_](option, this);
}
// Nothing needs to be done in the destructor, as the event creator should
// set the proper destruction process for the unique_ptr.
~Event() {}
void Record(
DeviceType recorder_type,
const void* context,
const char* err_msg = nullptr) {
auto recorder_index = TypeToProto(recorder_type);
CAFFE_ENFORCE_EQ(
recorder_index,
type_,
"You are trying to record with a wrong device type.");
CAFFE_ENFORCE(event_recorder_[recorder_index]);
event_recorder_[recorder_index](this, context, err_msg);
}
void Wait(DeviceType waiter_type, void* context) const {
auto waiter_index = TypeToProto(waiter_type);
CAFFE_ENFORCE(event_waiter_[waiter_index][type_]);
event_waiter_[waiter_index][type_](this, context);
}
void Finish() const {
CAFFE_ENFORCE(event_finisher_[type_]);
event_finisher_[type_](this);
}
EventStatus Query() const {
CAFFE_ENFORCE(event_querier_[type_]);
return event_querier_[type_](this);
}
const std::string& ErrorMessage() const {
CAFFE_ENFORCE(event_err_msg_getter_[type_]);
return event_err_msg_getter_[type_](this);
}
void Reset() {
CAFFE_ENFORCE(event_resetter_[type_]);
event_resetter_[type_](this);
#ifdef CAFFE2_USE_EXCEPTION_PTR
caught_exception_ = nullptr;
#endif // CAFFE2_USE_EXCEPTION_PTR
error_timestamp_ = 0;
}
const DeviceOption& GetDeviceOption() const {
return option_;
}
bool IsScheduled() const {
return Query() == EventStatus::EVENT_SCHEDULED;
}
bool IsFinished() const {
auto status = Query();
return status == EventStatus::EVENT_SUCCESS ||
status == EventStatus::EVENT_FAILED;
}
void SetFinished(const char* err_msg = nullptr) {
typedef std::chrono::high_resolution_clock clock;
error_timestamp_ = std::chrono::duration_cast<std::chrono::nanoseconds>(
clock::now().time_since_epoch())
.count();
CAFFE_ENFORCE(event_finished_setter_[type_]);
return event_finished_setter_[type_](this, err_msg);
}
bool SupportsCallback() const {
return event_callback_setter_[type_] != nullptr;
}
void SetCallback(EventCallbackFunction callback) {
CAFFE_ENFORCE(
event_callback_setter_[type_], "Event does not support callbacks");
event_callback_setter_[type_](this, callback);
}
// If parent op has succeeded, then we can run any child op;
// If parent op is in scheduled state, we need to check that:
// - child op supports async scheduling
// - there's a way to setup synchronization between async parent and
// child - both child and parent should use the same type of device,
// non-blocking synchronization between different device types is not
// supported
// If parent op is in another state (initialized or failed) then scheduling
// is not possible
bool CanSchedule(const Event& child_event, bool supports_async) const {
return CanSchedule(type_, Query(), child_event.GetType(), supports_async);
}
static bool CanSchedule(
int parent_type,
EventStatus parent_status,
int child_type,
bool child_supports_async) {
if (parent_status == EventStatus::EVENT_SUCCESS) {
return true;
}
if (parent_status == EventStatus::EVENT_SCHEDULED) {
return (parent_type == child_type) && child_supports_async;
}
return false;
}
int GetType() const {
return type_;
}
void SetFinishedWithException(const char* err_msg = nullptr) {
#ifdef CAFFE2_USE_EXCEPTION_PTR
if (!caught_exception_) {
caught_exception_ = std::current_exception();
}
CAFFE_ENFORCE(caught_exception_, "No exception found");
#else
VLOG(1) << "No support for exceptions in Event";
#endif // CAFFE2_USE_EXCEPTION_PTR
if (err_msg) {
SetFinished(err_msg);
} else {
SetFinished("Error happened during an operator run");
}
}
bool HasException() const {
#ifdef CAFFE2_USE_EXCEPTION_PTR
return (bool)caught_exception_;
#else
VLOG(1) << "No support for exceptions in Event";
return false;
#endif // CAFFE2_USE_EXCEPTION_PTR
}
int64_t ErrorTimestamp() const {
return error_timestamp_;
}
void RethrowException() const {
#ifdef CAFFE2_USE_EXCEPTION_PTR
if (caught_exception_) {
std::rethrow_exception(caught_exception_);
}
#else
VLOG(1) << "No support for exceptions in Event";
#endif // CAFFE2_USE_EXCEPTION_PTR
}
// event_ is going to be accessed by the EventCreate/Record/Wait/Finish
// functions, but one should not use it outside the own Event functionalities.
// In the future we may move it to a private member.
std::shared_ptr<void> event_;
private:
int type_;
DeviceOption option_;
#ifdef CAFFE2_USE_EXCEPTION_PTR
std::exception_ptr caught_exception_;
#endif // CAFFE2_USE_EXCEPTION_PTR
int64_t error_timestamp_{};
static EventCreateFunction event_creator_[MaxDeviceTypes];
static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
static EventFinishFunction event_finisher_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes];
static EventErrorMessageFunction event_err_msg_getter_[MaxDeviceTypes];
static EventSetFinishedFunction event_finished_setter_[MaxDeviceTypes];
static EventResetFunction event_resetter_[MaxDeviceTypes];
static EventSetCallbackFunction event_callback_setter_[MaxDeviceTypes];
template <DeviceType t>
friend struct EventCreateFunctionRegisterer;
template <DeviceType t>
friend struct EventRecordFunctionRegisterer;
template <DeviceType w, DeviceType d>
friend struct EventWaitFunctionRegisterer;
template <DeviceType t>
friend struct EventFinishFunctionRegisterer;
template <DeviceType t>
friend struct EventQueryFunctionRegisterer;
template <DeviceType t>
friend struct EventErrorMessageFunctionRegisterer;
template <DeviceType t>
friend struct EventSetFinishedFunctionRegisterer;
template <DeviceType t>
friend struct EventSetCallbackFunctionRegisterer;
template <DeviceType t>
friend struct EventResetFunctionRegisterer;
};
template <DeviceType t>
struct EventCreateFunctionRegisterer {
explicit EventCreateFunctionRegisterer(EventCreateFunction f) {
auto d = TypeToProto(t);
Event::event_creator_[d] = f;
}
};
#define REGISTER_EVENT_CREATE_FUNCTION(t, f) \
namespace { \
static EventCreateFunctionRegisterer<t> g_event_create_##d(f); \
}
template <DeviceType t>
struct EventRecordFunctionRegisterer {
explicit EventRecordFunctionRegisterer(EventRecordFunction f) {
auto d = TypeToProto(t);
Event::event_recorder_[d] = f;
}
};
#define REGISTER_EVENT_RECORD_FUNCTION(t, f) \
namespace { \
static EventRecordFunctionRegisterer<t> g_event_record_##d(f); \
}
template <DeviceType waiter_type, DeviceType event_type>
struct EventWaitFunctionRegisterer {
explicit EventWaitFunctionRegisterer(EventWaitFunction f) {
auto waiter_index = TypeToProto(waiter_type);
auto event_index = TypeToProto(event_type);
Event::event_waiter_[waiter_index][event_index] = f;
}
};
#define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \
namespace { \
static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \
}
template <DeviceType t>
struct EventQueryFunctionRegisterer {
explicit EventQueryFunctionRegisterer(EventQueryFunction f) {
auto d = TypeToProto(t);
Event::event_querier_[d] = f;
}
};
#define REGISTER_EVENT_QUERY_FUNCTION(t, f) \
namespace { \
static EventQueryFunctionRegisterer<t> g_event_query_##d(f); \
}
template <DeviceType t>
struct EventErrorMessageFunctionRegisterer {
explicit EventErrorMessageFunctionRegisterer(EventErrorMessageFunction f) {
auto d = TypeToProto(t);
Event::event_err_msg_getter_[d] = f;
}
};
#define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(t, f) \
namespace { \
static EventErrorMessageFunctionRegisterer<t> g_event_err_msg_##d(f); \
}
template <DeviceType t>
struct EventSetFinishedFunctionRegisterer {
explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction f) {
auto d = TypeToProto(t);
Event::event_finished_setter_[d] = f;
}
};
#define REGISTER_EVENT_SET_FINISHED_FUNCTION(t, f) \
namespace { \
static EventSetFinishedFunctionRegisterer<t> g_event_set_finished_##d(f); \
}
template <DeviceType t>
struct EventSetCallbackFunctionRegisterer {
explicit EventSetCallbackFunctionRegisterer(EventSetCallbackFunction f) {
auto d = TypeToProto(t);
Event::event_callback_setter_[d] = f;
}
};
#define REGISTER_EVENT_SET_CALLBACK_FUNCTION(t, f) \
Loading ...