#pragma once
#include <ATen/core/ivalue.h>
#include <ATen/core/operator_name.h>
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <memory>
#include <functional>
namespace c10 {
class TORCH_API OperatorHandle;
}
namespace at {
// Kind of record function scope;
enum class C10_API_ENUM RecordScope : uint8_t {
// c10/ATen ops, autograd nodes
FUNCTION = 0,
// Functions/nodes called from the autograd
BACKWARD_FUNCTION,
// TorchScript functions, methods
TORCHSCRIPT_FUNCTION,
// Kernel Function dtype Tag
KERNEL_FUNCTION_DTYPE,
// User defined scope (e.g. with record_function())
USER_SCOPE,
NUM_SCOPES, // must be the last in the list
};
} // namespace at
namespace std {
template <>
struct hash<at::RecordScope> {
inline size_t operator()(
const at::RecordScope& sc) const {
return static_cast<std::size_t>(sc);
}
};
} // namespace std
namespace at {
struct TORCH_API StringView {
StringView() : StringView(nullptr) {}
explicit StringView(const char* str_ptr)
: owned_str_ptr_(nullptr), str_ptr_(str_ptr) {}
explicit StringView(std::string str)
: owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
str_ptr_(owned_str_ptr_->c_str()) {}
inline const char* str() const {
return str_ptr_;
}
friend std::ostream& operator<<(std::ostream& os, const StringView& dt) {
os << dt.str();
return os;
}
friend bool operator==(const StringView& lhs, const StringView& rhs) {
return strcmp(lhs.str(), rhs.str()) == 0;
}
friend bool operator!=(const StringView& lhs, const StringView& rhs) {
return !(lhs == rhs);
}
private:
std::shared_ptr<std::string> owned_str_ptr_;
const char* str_ptr_;
};
// Soft limit on the number of callbacks to use;
constexpr std::size_t kSoftLimitCallbacks = 4;
// An abstract base class for various observer contexts that can be attached to
// the RecordFunction.
struct ObserverContext {
virtual ~ObserverContext() {}
protected:
ObserverContext() {}
};
typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
typedef uint64_t RecordFunctionHandle;
struct TORCH_API RecordFunction {
// Default constructor is used with before function called afterwards:
// scope - record scope that this function tracks
// pre_sampled - whether this RecordFunction was already pre-sampled with
// kLowProb probability
RecordFunction(
RecordScope scope = RecordScope::FUNCTION,
bool pre_sampled = false);
template <typename F>
void before(
F fn,
const std::vector<c10::IValue>* args,
int64_t current_sequence_nr = -1) {
if (!isActive()) {
return;
}
state_->inputs_ = *args;
before(fn, current_sequence_nr);
}
// Destructor calls end callbacks
virtual ~RecordFunction();
RecordFunction(const RecordFunction&) = delete;
RecordFunction& operator=(const RecordFunction&) = delete;
inline const StringView& name() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called name() on inactive RecordFunction");
return state_->name_;
}
inline int64_t seqNr() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called seqNr() on inactive RecordFunction");
return state_->sequence_nr_;
}
const std::vector<c10::IValue>& inputs() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called inputs() on inactive RecordFunction");
return state_->inputs_;
}
// Retrieves the thread_id that this RecordFunction ran start callbacks with.
// Useful for writing thread safe end callbacks that may be potentially
// executed in a different thread (async ops)
inline uint64_t threadId() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called threadId() on inactive RecordFunction");
return state_->thread_id_;
}
// For backward functions - thread id of the corresponding forward function,
// or zero otherwise;
// used alongside with sequence number to correlate backward functions with
// the forward ones
inline uint64_t forwardThreadId() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called forwardThreadId() on inactive RecordFunction");
return state_->fwd_thread_id_;
}
inline void setForwardThreadId(uint64_t thread_id) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setForwardThreadId() on inactive RecordFunction");
state_->fwd_thread_id_ = thread_id;
}
inline RecordScope scope() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called scope() on inactive RecordFunction");
return state_->scope_;
}
// Returns logical thread_id for the current thread
static uint64_t currentThreadId();
// Internal functions, do not use directly;
// used in python's context manager
// before functions initialize RecordFunction members and call
// start callbacks
void before(const char* name, int64_t sequence_nr = -1);
void before(std::string name, int64_t sequence_nr = -1);
void before(c10::OperatorHandle const& op, int64_t sequence_nr = -1);
// Sets node ID for distributed profiling
static void setDefaultNodeId(int64_t defaultNodeId);
// Gets node ID for distributed profiling
static int64_t getDefaultNodeId();
template<typename F>
void before(
F fn,
c10::ArrayRef<c10::IValue> args,
int64_t current_sequence_nr = -1) {
if (!isActive()) {
return;
}
state_->inputs_ = args.vec();
before(fn, current_sequence_nr);
}
template<typename F>
void before(
F fn,
std::vector<c10::IValue>&& args,
int64_t current_sequence_nr = -1) {
if (!isActive()) {
return;
}
state_->inputs_ = std::move(args);
before(fn, current_sequence_nr);
}
// Calls end callbacks. After end(), accessors will no longer provide useful results.
void end();
inline RecordFunctionHandle handle() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called handle() on inactive RecordFunction");
return state_->handle_;
}
inline c10::optional<OperatorName> operator_name() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called operator_name() on inactive RecordFunction");
return state_->operator_name_;
}
inline void setHandle(RecordFunctionHandle handle) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setHandle() on inactive RecordFunction");
state_->handle_ = handle;
}
// Whether this RecordFunction runs any callbacks.
bool isActive() const {
return state_ != nullptr;
}
bool needsInputs() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called needsInputs() on inactive RecordFunction");
return state_->needs_inputs;
}
private:
// Allows the modification of some internal states for callbacks.
friend class CallbackManager;
struct State {
explicit State(RecordScope scope) : scope_(scope) {}
// Whether any of the picked callbacks require inputs
bool needs_inputs = false;
// In cases when RecordFunction might be active but we chose not to
// use the observers (e.g. operator is not observed), this boolean
// flag is used to check whether the start callbacks were called
bool called_start_callbacks_ = false;
// Whether the RecordFunction is pre-sampled
bool pre_sampled_ = false;
// Used internally to keep track of thread local and global callbacks
// that were picked to run; must be sorted;
CallbackHandles sorted_active_tls_handles_;
CallbackHandles sorted_active_global_handles_;
// Stores various ObserverContext objects with event metadata for thread local
// callbacks.
ObserverContextList tls_ctx_;
// Stores various ObserverContext objects with event metadata for global
// callbacks.
ObserverContextList global_ctx_;
StringView name_;
int64_t sequence_nr_ = -1;
std::vector<c10::IValue> inputs_;
c10::optional<c10::OperatorName> operator_name_;
// Kind of scope this RecordFunction is observing
const RecordScope scope_;
// The logical thread_id that this RecordFunction was created with
uint64_t thread_id_ = 0;
// For backward functions - thread id of the the forward function
uint64_t fwd_thread_id_ = 0;
// Unique id for this RecordFunction, used in callbacks to track start
// and end of ranges
RecordFunctionHandle handle_ {0};
};
std::unique_ptr<State> state_;
};
//
// PyTorch callbacks/observers API:
//
/**
* RecordFunctionCallback represents a pair of callbacks to be used with
* RecordFunction, members:
* start, end - the callbacks to run when entering and exiting the scope;
* optionally, the start callback may return an ObserverContext which will
* be passed to the end callback, use appropriate constructor accordingly.
* needs_inputs - whether the callbacks need the inputs passed from the observed
* function/range; NOTE: passing the inputs incurs an additional overhead;
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
* to run; NOTE: start and end callbacks always run as a pair and are sampled
* together;
* scopes - types of scopes to execute the callbacks on (see RecordScope);
* passing empty set means the callbacks will be executed for all possible
* scope types
* should_run - optional function that returns whether this callback should run;
* overwrites the effect of setting sampling_probability
*/
class TORCH_API RecordFunctionCallback {
public:
using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
using EndCallback = void (*)(const RecordFunction&, ObserverContext*);
// This interface supports observers that require passing an ObserverContext
// between start and end callbacks.
explicit RecordFunctionCallback(
StartCallback start,
EndCallback end = nullptr) :
start_(start),
end_(end) {
scopes_.fill(true);
}
RecordFunctionCallback& needsInputs(bool needs_inputs) {
needs_inputs_ = needs_inputs;
return *this;
}
RecordFunctionCallback& needsIds(bool needs_ids) {
needs_ids_ = needs_ids;
return *this;
}
RecordFunctionCallback& samplingProb(double sampling_prob) {
TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob <= 1.0,
"Invalid sampling probability");
sampling_prob_ = sampling_prob;
return *this;
}
RecordFunctionCallback& scopes(
const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) {
if (!scopes.empty()) {
scopes_.fill(false);
for (auto sc : scopes) {
scopes_[static_cast<size_t>(sc)] = true;
}
} else {
Loading ...