#pragma once
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/strtod.h>
#include <torch/csrc/jit/frontend/tree.h>
#include <functional>
#include <iostream>
#include <string>
namespace torch {
namespace jit {
// clang-format off
// TreeView provides a statically-typed way to traverse the tree, which should
// be formed according to the grammar below.
//
// A few notes on types and their aliases:
// - List<T> is really a Tree with kind TK_LIST and elements as subtrees
// - Maybe<T> is really a Tree with kind TK_OPTION that has 0 or 1 subtree of type T
// - Builtin types are: Ident (TK_IDENT), String (TK_STRING)
//
// Param = Param(Maybe<Expr> type, Ident name) TK_PARAM
//
// Decl = Decl(List<Param> params, Maybe<Expr> return_type) TK_DECL
// Def = Def(Ident name, Decl decl, List<Stmt> body) TK_DEF
// ClassDef = ClassDef(Ident name, TK_CLASS_DEF
// Maybe<Expr> superclass,
// List<Stmt> body)
//
// Stmt = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body) TK_IF
// | For(List<Expr> targets, List<Expr> iters, List<Stmt> body) TK_FOR
// | While(Expr cond, List<Stmt> body) TK_WHILE
// | Global(List<Ident> idents) TK_GLOBAL
// -- NB: the only type of Expr's allowed on lhs are Var
// Or a tuple containing Var with an optional terminating Starred
// | Assign(Expr lhs, Maybe<Expr> rhs, Maybe<Expr> type) TK_ASSIGN
// | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs) TK_AUG_ASSIGN
// | Return(List<Expr> values) TK_RETURN
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT
// | Raise(Expr expr) TK_RAISE
// | Def TK_DEF
// | With(List<WithItem> targets, List<Stmt> body) TK_WITH
//
// Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
// | BinOp(Expr lhs, Expr rhs)
// | And TK_AND
// | Or TK_OR
// | Lt '<'
// | Gt '>'
// | Eq TK_EQ
// | Le TK_LE
// | Ge TK_GE
// | Ne TK_NE
// | Is TK_IS
// | IsNot TK_ISNOT
// | Add '+'
// | Sub '-'
// | Mul '*'
// | Div '/'
// | Mod '%'
// | MatMult '@'
// | Pow TK_POW
// | UnaryOp(Expr expr)
// | Not TK_NOT
// | USub '-'
// | Const(String value) TK_CONST
// -- NB: x.name(y) is desugared into name(x, y)
// | Apply(Ident name, List<Expr> args, List<Attribute> kwargs) TK_APPLY
// | Select(Expr value, Ident selector) '.'
// | Subscript(Expr value, List<Expr> subscript_exprs) TK_SUBSCRIPT
// | SliceExpr(Maybe<Expr> start, Maybe<Expr> end) TK_SLICE_EXPR
// | Var(Ident name) TK_VAR
// | ListLiteral(List<Expr> inputs) TK_LIST_LITERAL
// | TupleLiteral(List<Expr> inputs) TK_TUPLE_LITERAL
// | Starred(Expr expr) TK_STARRED
// | WithItem(Expr target, Maybe<Var> var) TK_WITH_ITEM
// -- NB: only allowed expressions are Const or List(Const)
// (List as a value, not type constructor)
// Attribute = Attribute(Ident name, Expr value) TK_ATTRIBUTE
//
// AugAssignKind =
// | Add() TK_PLUS_EQ
// | Sub() TK_MINUS_EQ
// | Mul() TK_TIMES_EQ
// | Div() TK_DIV_EQ
// | Mod() TK_MOD_EQ
//
// Each subclass of TreeView should provide:
// 1. Constructor that takes a TreeRef, and checks that it's of the right type.
// 2. Accessors that get underlying information out of the object. If they
// return subtrees, they should wrap them in appropriate views too.
// 3. Static method 'create' that creates the underlying TreeRef object
// for every TreeRef kind that has a TreeView, the parser always uses
// (e.g.) Ident::create rather than Compound::Create, this means that
// changes to the structure of Ident are always made right here rather
// than both in the parser and in this code.
// XXX: these structs should have no fields to prevent slicing when passing by value
// clang-format on
struct TreeView {
explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
TreeRef tree() const {
return tree_;
}
const SourceRange& range() const {
return tree_->range();
}
operator TreeRef() const {
return tree_;
}
const TreeRef& get() const {
return tree_;
}
int kind() const {
return tree_->kind();
}
void dump() const {
std::cout << tree_;
}
protected:
const TreeRef& subtree(size_t i) const {
return tree_->trees().at(i);
}
TreeRef tree_;
};
template <typename T>
struct ListIterator {
ListIterator(TreeList::const_iterator it) : it(it) {}
bool operator!=(const ListIterator& rhs) const {
return it != rhs.it;
}
bool operator==(const ListIterator& rhs) const {
return it == rhs.it;
}
T operator*() const {
return T(*it);
}
ListIterator& operator+=(std::ptrdiff_t n) {
it += n;
return *this;
}
ListIterator& operator++() {
++it;
return *this;
}
ListIterator& operator--() {
--it;
return *this;
}
private:
TreeList::const_iterator it;
};
template <typename T>
struct List : public TreeView {
using iterator = ListIterator<T>;
using const_iterator = ListIterator<T>;
List(const TreeRef& tree) : TreeView(tree) {
tree->match(TK_LIST);
// Iterate over list to temporarily instantiate Ts that will check the type
for (const T& elem : *this) {
(void)elem; // silence unused warning
}
}
iterator begin() const {
return iterator(tree_->trees().begin());
}
iterator end() const {
return iterator(tree_->trees().end());
}
bool empty() const {
return tree_->trees().begin() == tree_->trees().end();
}
T operator[](size_t i) const {
return T(subtree(i));
}
TreeRef map(const std::function<TreeRef(const T&)>& fn) {
return tree_->map([&](TreeRef v) { return fn(T(v)); });
}
static List create(const SourceRange& range, const std::vector<T>& subtrees) {
TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
}
static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
return List(Compound::create(TK_LIST, range, std::move(subtrees)));
}
size_t size() const {
return tree_->trees().size();
}
};
template <typename T>
struct Maybe : public TreeView {
explicit Maybe(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_OPTION);
if (tree_->trees().size() > 1)
throw ErrorReport(tree) << "Maybe trees can have at most one subtree";
}
/* implicit */ Maybe(const T& tree) : TreeView(tree) {}
bool present() const {
return tree_->trees().size() > 0;
}
T get() const {
return T(tree_->trees().at(0));
}
TreeRef map(const std::function<TreeRef(const T&)>& fn) {
return tree_->map([&](TreeRef v) { return fn(T(v)); });
}
static Maybe<T> create(const SourceRange& range) {
return Maybe<T>(Compound::create(TK_OPTION, range, {}));
}
static Maybe<T> create(const SourceRange& range, const T& value) {
return Maybe<T>(Compound::create(TK_OPTION, range, {value}));
}
};
struct Ident : public TreeView {
explicit Ident(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_IDENT);
}
const std::string& name() const {
return subtree(0)->stringValue();
}
static Ident create(const SourceRange& range, std::string name) {
return Ident(
Compound::create(TK_IDENT, range, {String::create(std::move(name))}));
}
};
////////////////////////////////////////////////////////////////////////////////
// Base types (production LHS)
////////////////////////////////////////////////////////////////////////////////
struct Stmt : public TreeView {
explicit Stmt(const TreeRef& tree) : TreeView(tree) {
switch (tree->kind()) {
case TK_IF:
case TK_FOR:
case TK_WHILE:
case TK_GLOBAL:
case TK_ASSIGN:
case TK_AUG_ASSIGN:
case TK_RETURN:
case TK_EXPR_STMT:
case TK_RAISE:
case TK_ASSERT:
case TK_PASS:
case TK_BREAK:
case TK_DELETE:
case TK_CONTINUE:
case TK_DEF:
case TK_WITH:
return;
default:
throw ErrorReport(tree)
<< kindToString(tree->kind()) << " is not a valid Stmt";
}
}
};
struct Expr : public TreeView {
explicit Expr(const TreeRef& tree) : TreeView(tree) {
switch (tree->kind()) {
case TK_IF_EXPR:
case TK_AND:
case TK_OR:
case '<':
case '>':
case TK_IS:
case TK_ISNOT:
case TK_EQ:
case TK_LE:
case TK_GE:
case TK_NE:
case '+':
case '-':
case TK_UNARY_MINUS:
case '~':
case '*':
case TK_STARRED:
case '/':
case '%':
case TK_NOT:
case TK_CONST:
case TK_STRINGLITERAL:
case TK_TRUE:
case TK_FALSE:
case TK_NONE:
case TK_CAST:
case TK_APPLY:
case '.':
case TK_SUBSCRIPT:
case TK_SLICE_EXPR:
case TK_VAR:
case TK_LIST_LITERAL:
case TK_TUPLE_LITERAL:
case TK_DICT_LITERAL:
case '@':
case TK_POW:
case TK_LSHIFT:
case TK_RSHIFT:
case TK_FLOOR_DIV:
case '&':
case '^':
case '|':
case TK_LIST_COMP:
case TK_DICT_COMP:
case TK_DOTS:
case TK_IN:
case TK_WITH_ITEM:
return;
default:
throw ErrorReport(tree)
<< kindToString(tree->kind()) << " is not a valid Expr";
}
}
};
////////////////////////////////////////////////////////////////////////////////
// Helper nodes (mostly for function arguments)
////////////////////////////////////////////////////////////////////////////////
struct Attribute : public TreeView {
explicit Attribute(const TreeRef& tree) : TreeView(tree) {
tree_->match(TK_ATTRIBUTE);
}
Ident name() const {
return Ident(subtree(0));
}
Expr value() const {
return Expr(subtree(1));
}
static Attribute create(
const SourceRange& range,
const Ident& name,
const TreeRef& value) {
return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
}
};
Loading ...