Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / ATen / record_function.h

#pragma once

#include <ATen/core/ivalue.h>
#include <ATen/core/operator_name.h>
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <memory>

#include <functional>

namespace c10 {
class TORCH_API OperatorHandle;
}

namespace at {

// Kind of record function scope;
enum class C10_API_ENUM RecordScope : uint8_t {
  // c10/ATen ops, autograd nodes
  FUNCTION = 0,
  // Functions/nodes called from the autograd
  BACKWARD_FUNCTION,
  // TorchScript functions, methods
  TORCHSCRIPT_FUNCTION,
  // Kernel Function dtype Tag
  KERNEL_FUNCTION_DTYPE,
  // User defined scope (e.g. with record_function())
  USER_SCOPE,
  NUM_SCOPES, // must be the last in the list
};

} // namespace at

namespace std {
template <>
struct hash<at::RecordScope> {
  inline size_t operator()(
      const at::RecordScope& sc) const {
    return static_cast<std::size_t>(sc);
  }
};
} // namespace std

namespace at {

struct TORCH_API StringView {
  StringView() : StringView(nullptr) {}
  explicit StringView(const char* str_ptr)
    : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {}
  explicit StringView(std::string str)
    : owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
      str_ptr_(owned_str_ptr_->c_str()) {}

  inline const char* str() const {
    return str_ptr_;
  }

  friend std::ostream& operator<<(std::ostream& os, const StringView& dt) {
    os << dt.str();
    return os;
  }

  friend bool operator==(const StringView& lhs, const StringView& rhs) {
    return strcmp(lhs.str(), rhs.str()) == 0;
  }

  friend bool operator!=(const StringView& lhs, const StringView& rhs) {
    return !(lhs == rhs);
  }

 private:
  std::shared_ptr<std::string> owned_str_ptr_;
  const char* str_ptr_;
};

// Soft limit on the number of callbacks to use;
constexpr std::size_t kSoftLimitCallbacks = 4;

// An abstract base class for various observer contexts that can be attached to
// the RecordFunction.
struct ObserverContext {
  virtual ~ObserverContext() {}
 protected:
  ObserverContext() {}
};

typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
typedef uint64_t RecordFunctionHandle;

struct TORCH_API RecordFunction {
  // Default constructor is used with before function called afterwards:
  //  scope - record scope that this function tracks
  //  pre_sampled - whether this RecordFunction was already pre-sampled with
  //    kLowProb probability
  RecordFunction(
      RecordScope scope = RecordScope::FUNCTION,
      bool pre_sampled = false);

  template <typename F>
  void before(
      F fn,
      const std::vector<c10::IValue>* args,
      int64_t current_sequence_nr = -1) {
    if (!isActive()) {
      return;
    }
    state_->inputs_ = *args;
    before(fn, current_sequence_nr);
  }

  // Destructor calls end callbacks
  virtual ~RecordFunction();

  RecordFunction(const RecordFunction&) = delete;
  RecordFunction& operator=(const RecordFunction&) = delete;

  inline const StringView& name() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called name() on inactive RecordFunction");
    return state_->name_;
  }

  inline int64_t seqNr() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called seqNr() on inactive RecordFunction");
    return state_->sequence_nr_;
  }

  const std::vector<c10::IValue>& inputs() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called inputs() on inactive RecordFunction");
    return state_->inputs_;
  }

  // Retrieves the thread_id that this RecordFunction ran start callbacks with.
  // Useful for writing thread safe end callbacks that may be potentially
  // executed in a different thread (async ops)
  inline uint64_t threadId() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called threadId() on inactive RecordFunction");
    return state_->thread_id_;
  }

  // For backward functions - thread id of the corresponding forward function,
  // or zero otherwise;
  // used alongside with sequence number to correlate backward functions with
  // the forward ones
  inline uint64_t forwardThreadId() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called forwardThreadId() on inactive RecordFunction");
    return state_->fwd_thread_id_;
  }

  inline void setForwardThreadId(uint64_t thread_id) {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setForwardThreadId() on inactive RecordFunction");
    state_->fwd_thread_id_ = thread_id;
  }

  inline RecordScope scope() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called scope() on inactive RecordFunction");
    return state_->scope_;
  }

  // Returns logical thread_id for the current thread
  static uint64_t currentThreadId();

  // Internal functions, do not use directly;
  // used in python's context manager

  // before functions initialize RecordFunction members and call
  // start callbacks
  void before(const char* name, int64_t sequence_nr = -1);
  void before(std::string name, int64_t sequence_nr = -1);
  void before(c10::OperatorHandle const& op, int64_t sequence_nr = -1);

  // Sets node ID for distributed profiling
  static void setDefaultNodeId(int64_t defaultNodeId);
  // Gets node ID for distributed profiling
  static int64_t getDefaultNodeId();

  template<typename F>
  void before(
      F fn,
      c10::ArrayRef<c10::IValue> args,
      int64_t current_sequence_nr = -1) {
    if (!isActive()) {
      return;
    }
    state_->inputs_ = args.vec();
    before(fn, current_sequence_nr);
  }

  template<typename F>
  void before(
      F fn,
      std::vector<c10::IValue>&& args,
      int64_t current_sequence_nr = -1) {
    if (!isActive()) {
      return;
    }
    state_->inputs_ = std::move(args);
    before(fn, current_sequence_nr);
  }

  // Calls end callbacks. After end(), accessors will no longer provide useful results.
  void end();

  inline RecordFunctionHandle handle() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called handle() on inactive RecordFunction");
    return state_->handle_;
  }

  inline c10::optional<OperatorName> operator_name() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called operator_name() on inactive RecordFunction");
    return state_->operator_name_;
  }

  inline void setHandle(RecordFunctionHandle handle) {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setHandle() on inactive RecordFunction");
    state_->handle_ = handle;
  }

  // Whether this RecordFunction runs any callbacks.
  bool isActive() const {
    return state_ != nullptr;
  }

  bool needsInputs() const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called needsInputs() on inactive RecordFunction");
    return state_->needs_inputs;
  }

 private:

  // Allows the modification of some internal states for callbacks.
  friend class CallbackManager;

  struct State {
    explicit State(RecordScope scope) : scope_(scope) {}

    // Whether any of the picked callbacks require inputs
    bool needs_inputs = false;

    // In cases when RecordFunction might be active but we chose not to
    // use the observers (e.g. operator is not observed), this boolean
    // flag is used to check whether the start callbacks were called
    bool called_start_callbacks_ = false;

    // Whether the RecordFunction is pre-sampled
    bool pre_sampled_ = false;

    // Used internally to keep track of thread local and global callbacks
    // that were picked to run; must be sorted;
    CallbackHandles sorted_active_tls_handles_;
    CallbackHandles sorted_active_global_handles_;

    // Stores various ObserverContext objects with event metadata for thread local
    // callbacks.
    ObserverContextList tls_ctx_;

    // Stores various ObserverContext objects with event metadata for global
    // callbacks.
    ObserverContextList global_ctx_;

    StringView name_;
    int64_t sequence_nr_ = -1;
    std::vector<c10::IValue> inputs_;

    c10::optional<c10::OperatorName> operator_name_;

    // Kind of scope this RecordFunction is observing
    const RecordScope scope_;

    // The logical thread_id that this RecordFunction was created with
    uint64_t thread_id_ = 0;

    // For backward functions - thread id of the the forward function
    uint64_t fwd_thread_id_ = 0;

    // Unique id for this RecordFunction, used in callbacks to track start
    // and end of ranges
    RecordFunctionHandle handle_ {0};
  };

  std::unique_ptr<State> state_;
};

//
// PyTorch callbacks/observers API:
//

/**
 * RecordFunctionCallback represents a pair of callbacks to be used with
 * RecordFunction, members:
 *   start, end - the callbacks to run when entering and exiting the scope;
 *     optionally, the start callback may return an ObserverContext which will
 *     be passed to the end callback, use appropriate constructor accordingly.
 *   needs_inputs - whether the callbacks need the inputs passed from the observed
 *     function/range; NOTE: passing the inputs incurs an additional overhead;
 *   sampling_probability - if not 1.0, then the callback is probabilistically sampled
 *     to run; NOTE: start and end callbacks always run as a pair and are sampled
 *     together;
 *   scopes - types of scopes to execute the callbacks on (see RecordScope);
 *     passing empty set means the callbacks will be executed for all possible
 *     scope types
 *   should_run - optional function that returns whether this callback should run;
 *     overwrites the effect of setting sampling_probability
 */
class TORCH_API RecordFunctionCallback {
 public:
  using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
  using EndCallback = void (*)(const RecordFunction&, ObserverContext*);

  // This interface supports observers that require passing an ObserverContext
  // between start and end callbacks.
  explicit RecordFunctionCallback(
      StartCallback start,
      EndCallback end = nullptr) :
      start_(start),
      end_(end) {
    scopes_.fill(true);
  }

  RecordFunctionCallback& needsInputs(bool needs_inputs) {
    needs_inputs_ = needs_inputs;
    return *this;
  }

  RecordFunctionCallback& needsIds(bool needs_ids) {
    needs_ids_ = needs_ids;
    return *this;
  }

  RecordFunctionCallback& samplingProb(double sampling_prob) {
    TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob <= 1.0,
        "Invalid sampling probability");
    sampling_prob_ = sampling_prob;
    return *this;
  }

  RecordFunctionCallback& scopes(
      const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) {
    if (!scopes.empty()) {
      scopes_.fill(false);
      for (auto sc : scopes) {
        scopes_[static_cast<size_t>(sc)] = true;
      }
    } else {
Loading ...