Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / ir / ir.h

#pragma once

#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/graph_node_list.h>
#include <torch/csrc/jit/ir/named_value.h>
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/runtime/operator.h>

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/utils/disallow_copy.h>
#include <torch/csrc/utils/python_stub.h>

#include <ATen/ATen.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/functional.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>

#include <functional>
#include <iostream>
#include <unordered_set>
#include <vector>

// Forward declare, the real meat is in python_ir.cpp
template <class T>
class THPPointer;
using THPObjectPtr = THPPointer<PyObject>;
using pyobj_list = std::vector<THPObjectPtr>;

namespace torch {
namespace jit {
class AliasDb;

using ::c10::Argument;
using ::c10::FunctionSchema;
using ::c10::Symbol;

using ::c10::ivalue::Shared;

using ::c10::IValue;
using ::c10::ivalue::Future;

using ::c10::ivalue::ConstantString;

#define C10_USING(T) using ::c10::T;
C10_FORALL_TYPES(C10_USING)
#undef C10_USING

#define C10_USING(T) using ::c10::T##Ptr;
C10_FORALL_TYPES(C10_USING)
#undef C10_USING

using ::c10::Type;
using ::c10::TypeEnv;
using ::c10::TypePtr;

using ::c10::getTypePtr;
using ::c10::MatchTypeReturn;
using ::c10::TypeKind;

using ::c10::fmap;

namespace prim {
using namespace ::c10::prim;
}
namespace attr {
using namespace ::c10::attr;
}
namespace aten {
using namespace ::c10::aten;
}
namespace cuda {
#ifndef __HIP_PLATFORM_HCC__
using namespace ::c10::cuda;
#endif
} // namespace cuda

struct Function;
struct MatchedSchema;

// Graph represents one "function" of computation.
// It uses a simple ownership model where the graph owns all the nodes inside
// it. All references inside the graph are raw pointers. Destroying the Graph
// will invalidate any pointers to nodes in the graph.
struct Graph;

// Node is the base class of the IR graph. It represents one computation
// and dependencies on a list of Values. The "prim-ops", so to speak.
struct Node;

// A Value represents an input or output to node that is either a
// Tensor or an opaque Handle object, as determined by type().
struct Value;

TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);

// A list of nodes, with inputs and outputs
struct Block;

// Each use is represented by this type, see Node::uses()
// 'user' is the consumer of the value, offset is the index into
// 'user's input this where the produces will be found.
struct Use {
  Use(Node* user, size_t offset) : user(user), offset(offset) {}
  Node* user;
  size_t offset;

  bool operator==(const Use& b) {
    return user == b.user && offset == b.offset;
  }
};

// Note [User node does not uniquely identify use]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// A while back, we wrote some code manipulating uses that looked like this:
//
//    for (auto& use : used_val->uses_) {
//      if (use.user == this_node) {
//        use.offset += 1;
//        break;
//      }
//    }
//
// This code is trying to find a particular use (our node's use) to update it.
// However, it's wrong: there may be *multiple* uses of a value %x in a node,
// as might be the case in this IR:
//
//    %y = Add %x %x
//
// In this case, there are two uses of %x whose user is the node 'Add %x %x'.
// So, "use induced by this node" is not a well-formed concept.
//
// If you are looking for "use induced by an input", it's best to use
// findUseForInput() to get it.

// the list types are intentionally simple, but we type-def
// them here so if we need to change them, refactoring will be easier
using node_list = std::vector<Node*>;
using value_list = std::vector<Value*>;
using use_list = std::vector<Use>;
template <typename T>
using ArrayRef = at::ArrayRef<T>;
using NodeKind = Symbol;
using topo_position_t = int64_t;
using ValueSet = std::unordered_set<const Value*>;

struct OperatorSet;

// This is a wrapper to allow invalidating the Python object
// safely when the C++ object for a Node/Value/Block is deleted
// like much of graph, it isn't safe for different threads to
// access the same graph
template <typename T>
struct Wrap {
  explicit Wrap(T* p) : elem(p), clear_cb(nullptr) {}
  void clear() {
    if (clear_cb) {
      clear_cb(elem);
    }
    elem = nullptr;
  }
  T* elem;
  void (*clear_cb)(void*);
};

struct Value {
  TH_DISALLOW_COPY_AND_ASSIGN(Value);
  Value(Node* node_, size_t offset_);

 private:
  friend struct Node;
  friend struct Graph;
  Node* node_;
  size_t offset_;
  size_t unique_ = 0; // unique id
  use_list uses_;
  std::string unique_name_;
  TypePtr type_;
  // a managing wrapper for Python to allow invalidation
  std::shared_ptr<Wrap<Value>> wrap_;

 public:
  Value* setType(TypePtr type);
  TORCH_API void inferTypeFrom(const at::Tensor& output);
  TORCH_API void inferTypeFrom(
      const c10::intrusive_ptr<c10::ivalue::Object>& output);
  const TypePtr& type() const {
    AT_ASSERT(type_ != nullptr);
    return type_;
  }
  bool requires_grad() const {
    return type()->requires_grad();
  }
  bool isCompleteTensor() const {
    if (auto pt = type()->cast<TensorType>()) {
      return pt->isComplete();
    }
    return false;
  }
  TORCH_API bool mustBeNone() const;
  TORCH_API bool mustNotBeNone() const;
  size_t unique() const {
    return unique_;
  }
  bool hasDebugName() const {
    return !unique_name_.empty();
  }
  static bool isValidName(const std::string& name);
  TORCH_API Value* setDebugName(const std::string& name);
  std::string debugName() const {
    if (hasDebugName()) {
      return unique_name_;
    }
    return c10::to_string(unique());
  }
  TORCH_API std::string debugNameBase() const;
  Node* node() {
    return node_;
  }
  size_t offset() const {
    return offset_;
  }
  void setOffset(size_t offset) {
    offset_ = offset;
  }
  const Node* node() const {
    return node_;
  }
  Graph* owningGraph();
  const Graph* owningGraph() const;
  // TODO: make this more const correct
  const use_list& uses() const {
    return uses_;
  }

  bool hasUses() const {
    return !uses().empty();
  }

  TORCH_API void replaceFirstUseWith(Value* newValue);

  // Replaces all uses of this value with 'newValue'.
  //
  // Given:   %3 = f(%1, %2)
  //          %4 = g(%3)
  //          %5 = h(%3, %3)
  // Execute: %3.replaceAllUsesWith(%6)
  // Result:  %3 = f(%1, %2)
  //          %4 = g(%6)
  //          %5 = h(%6, %6)
  TORCH_API void replaceAllUsesWith(Value* newValue);

  // Replaces all uses of this value with 'newValue' after 'node'.
  // Given:   %3 = f(%1, %2)
  //          %4 = g(%3)
  //          %5 = inplace_(%3)
  //          %6 = h(%3, %3)
  // Execute: %3.replaceAllUsesAfterNodeWith(%5.node(), %5)
  // Result:  %3 = f(%1, %2)
  //          %4 = g(%3)
  //          %5 = inplace_(%3)
  //          %6 = h(%5, %5)
  TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue);

  TORCH_API Value* copyMetadata(Value* from);

  TORCH_API std::shared_ptr<Wrap<Value>> wrap() {
    if (!wrap_) {
      wrap_ = std::make_shared<Wrap<Value>>(this);
    }
    return wrap_;
  }

  virtual ~Value() {
    if (wrap_) {
      wrap_->clear();
    }
  }
};

struct TORCH_API Node {
  TH_DISALLOW_COPY_AND_ASSIGN(Node);
  friend struct Graph;
  friend struct Block;
  friend struct Value;
  friend graph_node_list;
  friend const_graph_node_list;
  friend graph_node_list_iterator;
  friend const_graph_node_list_iterator;

 private:
  const NodeKind kind_;
  std::vector<Value*> inputs_;
  std::vector<Value*> outputs_;
  // subblocks
  std::vector<Block*> blocks_;
  Graph* graph_;
  Block* owning_block_;
  c10::optional<SourceRange> source_range_;
  ScopePtr scope_;
  c10::optional<InlinedCallStackPtr> callstack_;
  // Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
  // This field is effective a cache that's populated on attribute lookups and
  // invalidated every time we perform an operation that could potentially
  // change the schema. note: mutable because schema_ is effectively a cache
  mutable const Operator* op_;
  topo_position_t topo_position_ = 0;
  // a managing wrapper for Python to allow invalidation
  std::shared_ptr<Wrap<Node>> wrap_;

 protected:
  Node(Graph* graph_, NodeKind kind_); // defined after graph
 public:
  // each node but Return/Param
  // is associated with exactly one place in the node list...
  // of the graph_
  // this circular is a doubly-linked list, the Return node is used as the
  // sentinel for the beginning and end of the list such that the list never has
  // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev
  // pointer using an array to allow the same iterator class for forward and
  // reverse node lists This list represents a topological sort
  Node* next_in_graph[2] = {nullptr, nullptr};

  std::shared_ptr<Wrap<Node>> wrap() {
    if (!wrap_) {
      wrap_ = std::make_shared<Wrap<Node>>(this);
    }
    return wrap_;
  }

  Node*& next() {
    return next_in_graph[kNextDirection];
  }
  Node*& prev() {
    return next_in_graph[kPrevDirection];
  }
  Node* const& next() const {
    return next_in_graph[kNextDirection];
  }
  Node* const& prev() const {
    return next_in_graph[kPrevDirection];
Loading ...