#pragma once
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/graph_node_list.h>
#include <torch/csrc/jit/ir/named_value.h>
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/utils/disallow_copy.h>
#include <torch/csrc/utils/python_stub.h>
#include <ATen/ATen.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/functional.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <functional>
#include <iostream>
#include <unordered_set>
#include <vector>
// Forward declare, the real meat is in python_ir.cpp
template <class T>
class THPPointer;
using THPObjectPtr = THPPointer<PyObject>;
using pyobj_list = std::vector<THPObjectPtr>;
namespace torch {
namespace jit {
class AliasDb;
using ::c10::Argument;
using ::c10::FunctionSchema;
using ::c10::Symbol;
using ::c10::ivalue::Shared;
using ::c10::IValue;
using ::c10::ivalue::Future;
using ::c10::ivalue::ConstantString;
#define C10_USING(T) using ::c10::T;
C10_FORALL_TYPES(C10_USING)
#undef C10_USING
#define C10_USING(T) using ::c10::T##Ptr;
C10_FORALL_TYPES(C10_USING)
#undef C10_USING
using ::c10::Type;
using ::c10::TypeEnv;
using ::c10::TypePtr;
using ::c10::getTypePtr;
using ::c10::MatchTypeReturn;
using ::c10::TypeKind;
using ::c10::fmap;
namespace prim {
using namespace ::c10::prim;
}
namespace attr {
using namespace ::c10::attr;
}
namespace aten {
using namespace ::c10::aten;
}
namespace cuda {
#ifndef __HIP_PLATFORM_HCC__
using namespace ::c10::cuda;
#endif
} // namespace cuda
struct Function;
struct MatchedSchema;
// Graph represents one "function" of computation.
// It uses a simple ownership model where the graph owns all the nodes inside
// it. All references inside the graph are raw pointers. Destroying the Graph
// will invalidate any pointers to nodes in the graph.
struct Graph;
// Node is the base class of the IR graph. It represents one computation
// and dependencies on a list of Values. The "prim-ops", so to speak.
struct Node;
// A Value represents an input or output to node that is either a
// Tensor or an opaque Handle object, as determined by type().
struct Value;
TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
// A list of nodes, with inputs and outputs
struct Block;
// Each use is represented by this type, see Node::uses()
// 'user' is the consumer of the value, offset is the index into
// 'user's input this where the produces will be found.
struct Use {
Use(Node* user, size_t offset) : user(user), offset(offset) {}
Node* user;
size_t offset;
bool operator==(const Use& b) {
return user == b.user && offset == b.offset;
}
};
// Note [User node does not uniquely identify use]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// A while back, we wrote some code manipulating uses that looked like this:
//
// for (auto& use : used_val->uses_) {
// if (use.user == this_node) {
// use.offset += 1;
// break;
// }
// }
//
// This code is trying to find a particular use (our node's use) to update it.
// However, it's wrong: there may be *multiple* uses of a value %x in a node,
// as might be the case in this IR:
//
// %y = Add %x %x
//
// In this case, there are two uses of %x whose user is the node 'Add %x %x'.
// So, "use induced by this node" is not a well-formed concept.
//
// If you are looking for "use induced by an input", it's best to use
// findUseForInput() to get it.
// the list types are intentionally simple, but we type-def
// them here so if we need to change them, refactoring will be easier
using node_list = std::vector<Node*>;
using value_list = std::vector<Value*>;
using use_list = std::vector<Use>;
template <typename T>
using ArrayRef = at::ArrayRef<T>;
using NodeKind = Symbol;
using topo_position_t = int64_t;
using ValueSet = std::unordered_set<const Value*>;
struct OperatorSet;
// This is a wrapper to allow invalidating the Python object
// safely when the C++ object for a Node/Value/Block is deleted
// like much of graph, it isn't safe for different threads to
// access the same graph
template <typename T>
struct Wrap {
explicit Wrap(T* p) : elem(p), clear_cb(nullptr) {}
void clear() {
if (clear_cb) {
clear_cb(elem);
}
elem = nullptr;
}
T* elem;
void (*clear_cb)(void*);
};
struct Value {
TH_DISALLOW_COPY_AND_ASSIGN(Value);
Value(Node* node_, size_t offset_);
private:
friend struct Node;
friend struct Graph;
Node* node_;
size_t offset_;
size_t unique_ = 0; // unique id
use_list uses_;
std::string unique_name_;
TypePtr type_;
// a managing wrapper for Python to allow invalidation
std::shared_ptr<Wrap<Value>> wrap_;
public:
Value* setType(TypePtr type);
TORCH_API void inferTypeFrom(const at::Tensor& output);
TORCH_API void inferTypeFrom(
const c10::intrusive_ptr<c10::ivalue::Object>& output);
const TypePtr& type() const {
AT_ASSERT(type_ != nullptr);
return type_;
}
bool requires_grad() const {
return type()->requires_grad();
}
bool isCompleteTensor() const {
if (auto pt = type()->cast<TensorType>()) {
return pt->isComplete();
}
return false;
}
TORCH_API bool mustBeNone() const;
TORCH_API bool mustNotBeNone() const;
size_t unique() const {
return unique_;
}
bool hasDebugName() const {
return !unique_name_.empty();
}
static bool isValidName(const std::string& name);
TORCH_API Value* setDebugName(const std::string& name);
std::string debugName() const {
if (hasDebugName()) {
return unique_name_;
}
return c10::to_string(unique());
}
TORCH_API std::string debugNameBase() const;
Node* node() {
return node_;
}
size_t offset() const {
return offset_;
}
void setOffset(size_t offset) {
offset_ = offset;
}
const Node* node() const {
return node_;
}
Graph* owningGraph();
const Graph* owningGraph() const;
// TODO: make this more const correct
const use_list& uses() const {
return uses_;
}
bool hasUses() const {
return !uses().empty();
}
TORCH_API void replaceFirstUseWith(Value* newValue);
// Replaces all uses of this value with 'newValue'.
//
// Given: %3 = f(%1, %2)
// %4 = g(%3)
// %5 = h(%3, %3)
// Execute: %3.replaceAllUsesWith(%6)
// Result: %3 = f(%1, %2)
// %4 = g(%6)
// %5 = h(%6, %6)
TORCH_API void replaceAllUsesWith(Value* newValue);
// Replaces all uses of this value with 'newValue' after 'node'.
// Given: %3 = f(%1, %2)
// %4 = g(%3)
// %5 = inplace_(%3)
// %6 = h(%3, %3)
// Execute: %3.replaceAllUsesAfterNodeWith(%5.node(), %5)
// Result: %3 = f(%1, %2)
// %4 = g(%3)
// %5 = inplace_(%3)
// %6 = h(%5, %5)
TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue);
TORCH_API Value* copyMetadata(Value* from);
TORCH_API std::shared_ptr<Wrap<Value>> wrap() {
if (!wrap_) {
wrap_ = std::make_shared<Wrap<Value>>(this);
}
return wrap_;
}
virtual ~Value() {
if (wrap_) {
wrap_->clear();
}
}
};
struct TORCH_API Node {
TH_DISALLOW_COPY_AND_ASSIGN(Node);
friend struct Graph;
friend struct Block;
friend struct Value;
friend graph_node_list;
friend const_graph_node_list;
friend graph_node_list_iterator;
friend const_graph_node_list_iterator;
private:
const NodeKind kind_;
std::vector<Value*> inputs_;
std::vector<Value*> outputs_;
// subblocks
std::vector<Block*> blocks_;
Graph* graph_;
Block* owning_block_;
c10::optional<SourceRange> source_range_;
ScopePtr scope_;
c10::optional<InlinedCallStackPtr> callstack_;
// Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
// This field is effective a cache that's populated on attribute lookups and
// invalidated every time we perform an operation that could potentially
// change the schema. note: mutable because schema_ is effectively a cache
mutable const Operator* op_;
topo_position_t topo_position_ = 0;
// a managing wrapper for Python to allow invalidation
std::shared_ptr<Wrap<Node>> wrap_;
protected:
Node(Graph* graph_, NodeKind kind_); // defined after graph
public:
// each node but Return/Param
// is associated with exactly one place in the node list...
// of the graph_
// this circular is a doubly-linked list, the Return node is used as the
// sentinel for the beginning and end of the list such that the list never has
// null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev
// pointer using an array to allow the same iterator class for forward and
// reverse node lists This list represents a topological sort
Node* next_in_graph[2] = {nullptr, nullptr};
std::shared_ptr<Wrap<Node>> wrap() {
if (!wrap_) {
wrap_ = std::make_shared<Wrap<Node>>(this);
}
return wrap_;
}
Node*& next() {
return next_in_graph[kNextDirection];
}
Node*& prev() {
return next_in_graph[kPrevDirection];
}
Node* const& next() const {
return next_in_graph[kNextDirection];
}
Node* const& prev() const {
return next_in_graph[kPrevDirection];
Loading ...