Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / event.h

#ifndef CAFFE2_CORE_EVENT_H_
#define CAFFE2_CORE_EVENT_H_

#include <chrono>

#include <c10/core/DeviceType.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"

namespace caffe2 {

constexpr int MaxDeviceTypes =
    DeviceTypeProto::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
class Event;

enum EventStatus {
  EVENT_INITIALIZED = 0,
  EVENT_SCHEDULED = 1,
  EVENT_SUCCESS = 2,
  EVENT_FAILED = 3,
};

// For the following functions, void* shall be interpreted as the corresponding
// context object corresponding to the device type associated with the
// functions.

// Initializes event
typedef void (*EventCreateFunction)(const DeviceOption& option, Event*);

// Called on event to signal that CPU part of operation is finished,
// Optionally accepts error message from CPU part.
// Should be called no more than once per event
typedef void (*EventRecordFunction)(Event*, const void*, const char*);

// Waits and returns as soon as possible in order schedule next operation,
// e.g. for CUDA->CUDA waits only for CPU part of CUDA op,
// for CUDA->CPU waits till the CUDA op is fully completed.
// Prepares context to synchronize device part of operation.
// Can be called concurrently from multiple threads
typedef void (*EventWaitFunction)(const Event*, void*);

// Waits till operation is fully finished,
// can be called concurrently from multiple threads
typedef void (*EventFinishFunction)(const Event*);

// Queries current status of operation,
// can be called concurrently from multiple threads
typedef EventStatus (*EventQueryFunction)(const Event*);
typedef const std::string& (*EventErrorMessageFunction)(const Event*);
typedef void (*EventSetFinishedFunction)(const Event*, const char*);
typedef void (*EventResetFunction)(Event*);

// Sets callback that is called when event is finished
typedef std::function<void()> EventCallbackFunction;
typedef void (*EventSetCallbackFunction)(Event*, EventCallbackFunction);

class TORCH_API Event {
 public:
  explicit Event(const DeviceOption& option)
      : event_(), type_(option.device_type()), option_(option) {
    CAFFE_ENFORCE_LT(type_, MaxDeviceTypes);
    CAFFE_ENFORCE(event_creator_[type_]);
    event_creator_[type_](option, this);
  }

  // Nothing needs to be done in the destructor, as the event creator should
  // set the proper destruction process for the unique_ptr.
  ~Event() {}

  void Record(
      DeviceType recorder_type,
      const void* context,
      const char* err_msg = nullptr) {
    auto recorder_index = TypeToProto(recorder_type);
    CAFFE_ENFORCE_EQ(
        recorder_index,
        type_,
        "You are trying to record with a wrong device type.");
    CAFFE_ENFORCE(event_recorder_[recorder_index]);
    event_recorder_[recorder_index](this, context, err_msg);
  }

  void Wait(DeviceType waiter_type, void* context) const {
    auto waiter_index = TypeToProto(waiter_type);
    CAFFE_ENFORCE(event_waiter_[waiter_index][type_]);
    event_waiter_[waiter_index][type_](this, context);
  }

  void Finish() const {
    CAFFE_ENFORCE(event_finisher_[type_]);
    event_finisher_[type_](this);
  }

  EventStatus Query() const {
    CAFFE_ENFORCE(event_querier_[type_]);
    return event_querier_[type_](this);
  }

  const std::string& ErrorMessage() const {
    CAFFE_ENFORCE(event_err_msg_getter_[type_]);
    return event_err_msg_getter_[type_](this);
  }

  void Reset() {
    CAFFE_ENFORCE(event_resetter_[type_]);
    event_resetter_[type_](this);
#ifdef CAFFE2_USE_EXCEPTION_PTR
    caught_exception_ = nullptr;
#endif // CAFFE2_USE_EXCEPTION_PTR
    error_timestamp_ = 0;
  }

  const DeviceOption& GetDeviceOption() const {
    return option_;
  }

  bool IsScheduled() const {
    return Query() == EventStatus::EVENT_SCHEDULED;
  }

  bool IsFinished() const {
    auto status = Query();
    return status == EventStatus::EVENT_SUCCESS ||
        status == EventStatus::EVENT_FAILED;
  }

  void SetFinished(const char* err_msg = nullptr) {
    typedef std::chrono::high_resolution_clock clock;
    error_timestamp_ = std::chrono::duration_cast<std::chrono::nanoseconds>(
                           clock::now().time_since_epoch())
                           .count();

    CAFFE_ENFORCE(event_finished_setter_[type_]);
    return event_finished_setter_[type_](this, err_msg);
  }

  bool SupportsCallback() const {
    return event_callback_setter_[type_] != nullptr;
  }

  void SetCallback(EventCallbackFunction callback) {
    CAFFE_ENFORCE(
        event_callback_setter_[type_], "Event does not support callbacks");
    event_callback_setter_[type_](this, callback);
  }

  // If parent op has succeeded, then we can run any child op;
  // If parent op is in scheduled state, we need to check that:
  //  - child op supports async scheduling
  //  - there's a way to setup synchronization between async parent and
  //    child - both child and parent should use the same type of device,
  //    non-blocking synchronization between different device types is not
  //    supported
  // If parent op is in another state (initialized or failed) then scheduling
  // is not possible
  bool CanSchedule(const Event& child_event, bool supports_async) const {
    return CanSchedule(type_, Query(), child_event.GetType(), supports_async);
  }

  static bool CanSchedule(
      int parent_type,
      EventStatus parent_status,
      int child_type,
      bool child_supports_async) {
    if (parent_status == EventStatus::EVENT_SUCCESS) {
      return true;
    }
    if (parent_status == EventStatus::EVENT_SCHEDULED) {
      return (parent_type == child_type) && child_supports_async;
    }
    return false;
  }

  int GetType() const {
    return type_;
  }

  void SetFinishedWithException(const char* err_msg = nullptr) {
#ifdef CAFFE2_USE_EXCEPTION_PTR
    if (!caught_exception_) {
      caught_exception_ = std::current_exception();
    }
    CAFFE_ENFORCE(caught_exception_, "No exception found");
#else
    VLOG(1) << "No support for exceptions in Event";
#endif // CAFFE2_USE_EXCEPTION_PTR
    if (err_msg) {
      SetFinished(err_msg);
    } else {
      SetFinished("Error happened during an operator run");
    }
  }

  bool HasException() const {
#ifdef CAFFE2_USE_EXCEPTION_PTR
    return (bool)caught_exception_;
#else
    VLOG(1) << "No support for exceptions in Event";
    return false;
#endif // CAFFE2_USE_EXCEPTION_PTR
  }

  int64_t ErrorTimestamp() const {
    return error_timestamp_;
  }

  void RethrowException() const {
#ifdef CAFFE2_USE_EXCEPTION_PTR
    if (caught_exception_) {
      std::rethrow_exception(caught_exception_);
    }
#else
    VLOG(1) << "No support for exceptions in Event";
#endif // CAFFE2_USE_EXCEPTION_PTR
  }

  // event_ is going to be accessed by the EventCreate/Record/Wait/Finish
  // functions, but one should not use it outside the own Event functionalities.
  // In the future we may move it to a private member.
  std::shared_ptr<void> event_;

 private:
  int type_;
  DeviceOption option_;

#ifdef CAFFE2_USE_EXCEPTION_PTR
  std::exception_ptr caught_exception_;
#endif // CAFFE2_USE_EXCEPTION_PTR
  int64_t error_timestamp_{};

  static EventCreateFunction event_creator_[MaxDeviceTypes];
  static EventRecordFunction event_recorder_[MaxDeviceTypes];
  static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
  static EventFinishFunction event_finisher_[MaxDeviceTypes];

  static EventQueryFunction event_querier_[MaxDeviceTypes];
  static EventErrorMessageFunction event_err_msg_getter_[MaxDeviceTypes];
  static EventSetFinishedFunction event_finished_setter_[MaxDeviceTypes];
  static EventResetFunction event_resetter_[MaxDeviceTypes];

  static EventSetCallbackFunction event_callback_setter_[MaxDeviceTypes];

  template <DeviceType t>
  friend struct EventCreateFunctionRegisterer;
  template <DeviceType t>
  friend struct EventRecordFunctionRegisterer;
  template <DeviceType w, DeviceType d>
  friend struct EventWaitFunctionRegisterer;
  template <DeviceType t>
  friend struct EventFinishFunctionRegisterer;

  template <DeviceType t>
  friend struct EventQueryFunctionRegisterer;
  template <DeviceType t>
  friend struct EventErrorMessageFunctionRegisterer;
  template <DeviceType t>
  friend struct EventSetFinishedFunctionRegisterer;
  template <DeviceType t>
  friend struct EventSetCallbackFunctionRegisterer;
  template <DeviceType t>
  friend struct EventResetFunctionRegisterer;
};

template <DeviceType t>
struct EventCreateFunctionRegisterer {
  explicit EventCreateFunctionRegisterer(EventCreateFunction f) {
    auto d = TypeToProto(t);
    Event::event_creator_[d] = f;
  }
};
#define REGISTER_EVENT_CREATE_FUNCTION(t, f)                     \
  namespace {                                                    \
  static EventCreateFunctionRegisterer<t> g_event_create_##d(f); \
  }

template <DeviceType t>
struct EventRecordFunctionRegisterer {
  explicit EventRecordFunctionRegisterer(EventRecordFunction f) {
    auto d = TypeToProto(t);
    Event::event_recorder_[d] = f;
  }
};
#define REGISTER_EVENT_RECORD_FUNCTION(t, f)                     \
  namespace {                                                    \
  static EventRecordFunctionRegisterer<t> g_event_record_##d(f); \
  }

template <DeviceType waiter_type, DeviceType event_type>
struct EventWaitFunctionRegisterer {
  explicit EventWaitFunctionRegisterer(EventWaitFunction f) {
    auto waiter_index = TypeToProto(waiter_type);
    auto event_index = TypeToProto(event_type);
    Event::event_waiter_[waiter_index][event_index] = f;
  }
};
#define REGISTER_EVENT_WAIT_FUNCTION(w, d, f)                         \
  namespace {                                                         \
  static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \
  }

template <DeviceType t>
struct EventQueryFunctionRegisterer {
  explicit EventQueryFunctionRegisterer(EventQueryFunction f) {
    auto d = TypeToProto(t);
    Event::event_querier_[d] = f;
  }
};
#define REGISTER_EVENT_QUERY_FUNCTION(t, f)                    \
  namespace {                                                  \
  static EventQueryFunctionRegisterer<t> g_event_query_##d(f); \
  }

template <DeviceType t>
struct EventErrorMessageFunctionRegisterer {
  explicit EventErrorMessageFunctionRegisterer(EventErrorMessageFunction f) {
    auto d = TypeToProto(t);
    Event::event_err_msg_getter_[d] = f;
  }
};
#define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(t, f)                     \
  namespace {                                                           \
  static EventErrorMessageFunctionRegisterer<t> g_event_err_msg_##d(f); \
  }

template <DeviceType t>
struct EventSetFinishedFunctionRegisterer {
  explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction f) {
    auto d = TypeToProto(t);
    Event::event_finished_setter_[d] = f;
  }
};
#define REGISTER_EVENT_SET_FINISHED_FUNCTION(t, f)                          \
  namespace {                                                               \
  static EventSetFinishedFunctionRegisterer<t> g_event_set_finished_##d(f); \
  }

template <DeviceType t>
struct EventSetCallbackFunctionRegisterer {
  explicit EventSetCallbackFunctionRegisterer(EventSetCallbackFunction f) {
    auto d = TypeToProto(t);
    Event::event_callback_setter_[d] = f;
  }
};
#define REGISTER_EVENT_SET_CALLBACK_FUNCTION(t, f)                          \
Loading ...