#pragma once
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <ATen/core/interned_strings.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/versioned_symbols.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
// The AST can contain nodes like `self`, `self.b` or `python_fn` that
// are not first-class values in the graph representation, but instead
// will be desugared based on how they are used in the AST.
// SugaredValue is used to temporarily represent these values in a way
// that separates their behavior from the AST -> IR converter itself.
// This allows us to keep dependencies on python minimal.
struct TORCH_API SugaredValue
: public std::enable_shared_from_this<SugaredValue> {
// what is this node? for error reporting (e.g. Module, python function)
virtual std::string kind() const = 0;
// what can we do with this thing?
// use it as a value e.g. `this + 4`
virtual Value* asValue(const SourceRange& loc, Function& m) {
throw ErrorReport(loc) << kind() << " cannot be used as a value";
}
// select an attribute on it, e.g. `this.field`
virtual std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
virtual bool hasAttr(
const SourceRange& loc,
Function& m,
const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
// assign an attribute on it, e.g. `this.field = newValue`
virtual void setAttr(
const SourceRange& loc,
Function& m,
const std::string& field,
Value* newValue) {
throw ErrorReport(loc) << "attribute assignment is not defined on "
<< kind();
}
// use it as a vector of values, e.g. a tuple of values as return value from
// a method invocation
virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint = {}) {
throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
}
// TODO @wconstab refactor to use ModuleValue::asTuple instead of new API
virtual SugaredValuePtr asTupleValue(const SourceRange& loc, Function& m) {
throw ErrorReport(loc) << kind() << " cannot be used as a tuplevalue";
}
virtual std::vector<std::shared_ptr<SugaredValue>> asType(
const SourceRange& loc,
Method& m) {
throw ErrorReport(loc) << kind() << " cannot be used as a type";
}
// call it like a function, e.g. `outputs = this(inputs)`
virtual std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
// n_binders is always set to the number of variables an expression is
// syntactically bound to:
// a = foo() # 1 binder (note in this case the single binder might be a
// tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0
// binders
//
// In subexpressions, like bar() in foo(bar()), n_binders is always set to
// 1. n_binders is used as a hint to subexpressions to determine how many
// values they should return when that number is ambiguous statically. In
// particular it is currently used to decide how many tensors a call to a
// python function will return. It is only a hint, functions do not have to
// check that n_binders match the number of things they are returning, the
// assignment logic will do that anyway.
throw ErrorReport(loc) << "cannot call a " << kind();
}
// This function is called when to convert a SugaredValue to its iterator.
// For example, when iterating through a Dict we iterate over its keys
virtual std::shared_ptr<SugaredValue> iter(
const SourceRange& loc,
Function& m) {
throw ErrorReport(loc) << kind() << " cannot be used as an iterable";
}
// If we are iterating over a Sugared Value and it returns a value from this
// function, then we emit an unrolled loop over the variable. This allows us
// to support containers of Heterogenous types, like Module Containers &
// Tuples
virtual c10::optional<int64_t> staticLen() {
return c10::nullopt;
}
// When iterating over this SugaredValue, should we emit the for loop as an
// unrolled loop.
bool shouldEmitUnrolled() {
return staticLen() != c10::nullopt;
}
// return length of this thing, if not then it can't be iterated.
// If it does not have a statically-determinable length, then it cannot
// be iterated over with a modulelist. If it does it must return a constant
// Value *
virtual Value* len(const SourceRange& loc, Function& m) {
throw ErrorReport(loc) << "'" << kind() << "'"
<< " object is not iterable";
}
// expression for ith elemement for iterable value
virtual std::shared_ptr<SugaredValue> getitem(
const SourceRange& loc,
Function& m,
Value* idx,
TypePtr type_hint = nullptr) {
throw ErrorReport(loc) << "'" << kind() << "'"
<< " object is not subscriptable";
}
virtual ~SugaredValue() = default;
};
// most things in the environment are just simple value types
// and not special python syntax sugar types
struct TORCH_API SimpleValue : public SugaredValue {
SimpleValue(Value* value) : value_(value) {}
std::string kind() const override {
std::stringstream ss;
ss << "value of type '" << value_->type()->annotation_str() << "'";
return ss.str();
}
Value* asValue(const SourceRange& range, Function& m) override {
return value_;
}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint = {}) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override;
bool hasAttr(const SourceRange& loc, Function& m, const std::string& field)
override;
void setAttr(
const SourceRange& loc,
Function& m,
const std::string& field,
Value* newValue) override;
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
override;
Value* getValue() const {
return value_;
}
Value* len(const SourceRange& loc, Function& m) override;
SugaredValuePtr getitem(
const SourceRange& loc,
Function& m,
Value* idx,
TypePtr type_hint = nullptr) override;
private:
Value* value_;
};
struct TORCH_API BuiltinFunction : public SugaredValue {
BuiltinFunction(Symbol symbol, c10::optional<NamedValue> self)
: symbol(symbol), self(std::move(self)) {}
// The symbol of the function (e.g. `aten::relu`).
Symbol symbol;
// if this is method, then this is the self argument.
c10::optional<NamedValue> self;
std::string kind() const override {
return "builtin";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
// try to create this builtin but if it doesn't exist or the self argument
// cannot possibly match, then return nullptr. Use in situations where it is
// not clear if it is a valid builtin
static std::shared_ptr<BuiltinFunction> tryCreate(
Symbol symbol,
c10::optional<NamedValue> self);
};
struct TORCH_API SugaredTupleValue : public SugaredValue {
explicit SugaredTupleValue(std::vector<std::shared_ptr<SugaredValue>> tup)
: tup_(std::move(tup)){};
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint = {}) override {
return tup_;
};
Value* asValue(const SourceRange& loc, Function& m) override {
std::vector<Value*> vec;
for (const auto& sv : tup_) {
vec.push_back(sv->asValue(loc, m));
}
Graph& g = *m.graph();
return g.insertNode(g.createTuple(vec))->output();
}
std::string kind() const override {
return "Tuple";
}
SugaredValuePtr getitem(
const SourceRange& loc,
Function& m,
Value* idx,
TypePtr type_hint = nullptr) override {
if (!(idx->type()->cast<IntType>() && toIValue(idx))) {
throw ErrorReport(loc)
<< "Expected integer literal for index. "
<< "ModuleList/Sequential indexing is only supported with integer literals. "
<< "Enumeration is supported, e.g. 'for index, v in enumerate(self): ...'";
}
auto index = toIValue(idx)->toInt();
int64_t adj_index =
(index < 0) ? index + static_cast<int64_t>(tup_.size()) : index;
if (!(adj_index >= 0 && adj_index < static_cast<int64_t>(tup_.size()))) {
throw ErrorReport(loc)
<< "Index " << index << " out of range of length " << tup_.size();
}
return tup_.at(adj_index);
}
// This function is called when a SugaredValue is used to convert a
// SugaredValue to its iterator. For example, when iterating through a Dict we
// iterate over its keys
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
override {
return shared_from_this();
};
// Because this is used to contain SugaredValues of Heterogenous types,
// we define staticLen() so that when this is iterated over it is emitted
// as an unrolled loop.
c10::optional<int64_t> staticLen() override {
return static_cast<int64_t>(tup_.size());
}
std::vector<std::shared_ptr<SugaredValue>> tup_;
};
struct TORCH_API BuiltinModule : public SugaredValue {
BuiltinModule(std::string name, c10::optional<int64_t> version = at::nullopt)
: name(std::move(name)), version(version) {}
std::string kind() const override {
return "builtin module";
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override {
if (field == "autograd") {
// When refering torch.autograd, it is also considered to be a
// BuiltinModule and we will dispatch to the aten operators for the
// methods under its module.
return std::make_shared<BuiltinModule>("aten", version);
}
auto sym = Symbol::fromQualString(name + "::" + field);
if (version.has_value()) {
// Possibly replaces symbol with another that implements its
// historic behavior.
// See note [Versioned Symbols]
sym = get_symbol_for_version(sym, *version);
}
return std::make_shared<BuiltinFunction>(sym, c10::nullopt);
}
private:
std::string name;
// when we add operator versioning, emit this op as it exising at 'version'
// if not set, use the latest version
c10::optional<int64_t> version;
};
// Represents a class, analagous to `int` or `dict`. Instances of classes,
// like `1` or `{"foo": 5}`, are represented as SimpleValues
struct TORCH_API ClassValue : public SugaredValue {
explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {}
// Call the type's constructor, as in:
// n = Foo(constructor_arg)
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
Loading ...