Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / frontend / tree_views.h

#pragma once
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/strtod.h>
#include <torch/csrc/jit/frontend/tree.h>

#include <functional>
#include <iostream>
#include <string>

namespace torch {
namespace jit {

// clang-format off
// TreeView provides a statically-typed way to traverse the tree, which should
// be formed according to the grammar below.
//
// A few notes on types and their aliases:
// - List<T> is really a Tree with kind TK_LIST and elements as subtrees
// - Maybe<T> is really a Tree with kind TK_OPTION that has 0 or 1 subtree of type T
// - Builtin types are: Ident (TK_IDENT), String (TK_STRING)
//
// Param = Param(Maybe<Expr> type, Ident name)                          TK_PARAM
//
// Decl  = Decl(List<Param> params, Maybe<Expr> return_type)            TK_DECL
// Def   = Def(Ident name, Decl decl, List<Stmt> body)                  TK_DEF
// ClassDef = ClassDef(Ident name,                                      TK_CLASS_DEF
//                     Maybe<Expr> superclass,
//                     List<Stmt> body)
//
// Stmt  = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body)   TK_IF
//       | For(List<Expr> targets, List<Expr> iters, List<Stmt> body)   TK_FOR
//       | While(Expr cond, List<Stmt> body)                            TK_WHILE
//       | Global(List<Ident> idents)                                   TK_GLOBAL
//       -- NB: the only type of Expr's allowed on lhs are Var
//          Or a tuple containing Var with an optional terminating Starred
//       | Assign(Expr lhs, Maybe<Expr> rhs, Maybe<Expr> type)          TK_ASSIGN
//       | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs)          TK_AUG_ASSIGN
//       | Return(List<Expr> values)                                    TK_RETURN
//       | ExprStmt(List<Expr> expr)                                    TK_EXPR_STMT
//       | Raise(Expr expr)                                             TK_RAISE
//       | Def                                                          TK_DEF
//       | With(List<WithItem> targets, List<Stmt> body)                TK_WITH
//
// Expr  = TernaryIf(Expr cond, Expr true_expr, Expr false_expr)        TK_IF_EXPR
//       | BinOp(Expr lhs, Expr rhs)
//       |     And                                                      TK_AND
//       |     Or                                                       TK_OR
//       |     Lt                                                       '<'
//       |     Gt                                                       '>'
//       |     Eq                                                       TK_EQ
//       |     Le                                                       TK_LE
//       |     Ge                                                       TK_GE
//       |     Ne                                                       TK_NE
//       |     Is                                                       TK_IS
//       |     IsNot                                                    TK_ISNOT
//       |     Add                                                      '+'
//       |     Sub                                                      '-'
//       |     Mul                                                      '*'
//       |     Div                                                      '/'
//       |     Mod                                                      '%'
//       |     MatMult                                                  '@'
//       |     Pow                                                      TK_POW
//       | UnaryOp(Expr expr)
//       |     Not                                                      TK_NOT
//       |     USub                                                     '-'
//       | Const(String value)                                          TK_CONST
//       -- NB: x.name(y) is desugared into name(x, y)
//       | Apply(Ident name, List<Expr> args, List<Attribute> kwargs)   TK_APPLY
//       | Select(Expr value, Ident selector)                           '.'
//       | Subscript(Expr value, List<Expr> subscript_exprs)            TK_SUBSCRIPT
//       | SliceExpr(Maybe<Expr> start, Maybe<Expr> end)                TK_SLICE_EXPR
//       | Var(Ident name)                                              TK_VAR
//       | ListLiteral(List<Expr> inputs)                               TK_LIST_LITERAL
//       | TupleLiteral(List<Expr> inputs)                              TK_TUPLE_LITERAL
//       | Starred(Expr expr)                                           TK_STARRED
//       | WithItem(Expr target, Maybe<Var> var)                        TK_WITH_ITEM
// -- NB: only allowed expressions are Const or List(Const)
//        (List as a value, not type constructor)
// Attribute = Attribute(Ident name, Expr value)                        TK_ATTRIBUTE
//
// AugAssignKind =
//            | Add()                                                   TK_PLUS_EQ
//            | Sub()                                                   TK_MINUS_EQ
//            | Mul()                                                   TK_TIMES_EQ
//            | Div()                                                   TK_DIV_EQ
//            | Mod()                                                   TK_MOD_EQ
//

// Each subclass of TreeView should provide:
// 1. Constructor that takes a TreeRef, and checks that it's of the right type.
// 2. Accessors that get underlying information out of the object. If they
//    return subtrees, they should wrap them in appropriate views too.
// 3. Static method 'create' that creates the underlying TreeRef object
//    for every TreeRef kind that has a TreeView, the parser always uses
//    (e.g.) Ident::create rather than Compound::Create, this means that
//    changes to the structure of Ident are always made right here rather
//    than both in the parser and in this code.
// XXX: these structs should have no fields to prevent slicing when passing by value
// clang-format on
struct TreeView {
  explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
  TreeRef tree() const {
    return tree_;
  }
  const SourceRange& range() const {
    return tree_->range();
  }
  operator TreeRef() const {
    return tree_;
  }
  const TreeRef& get() const {
    return tree_;
  }
  int kind() const {
    return tree_->kind();
  }
  void dump() const {
    std::cout << tree_;
  }

 protected:
  const TreeRef& subtree(size_t i) const {
    return tree_->trees().at(i);
  }
  TreeRef tree_;
};

template <typename T>
struct ListIterator {
  ListIterator(TreeList::const_iterator it) : it(it) {}
  bool operator!=(const ListIterator& rhs) const {
    return it != rhs.it;
  }
  bool operator==(const ListIterator& rhs) const {
    return it == rhs.it;
  }
  T operator*() const {
    return T(*it);
  }
  ListIterator& operator+=(std::ptrdiff_t n) {
    it += n;
    return *this;
  }
  ListIterator& operator++() {
    ++it;
    return *this;
  }
  ListIterator& operator--() {
    --it;
    return *this;
  }

 private:
  TreeList::const_iterator it;
};

template <typename T>
struct List : public TreeView {
  using iterator = ListIterator<T>;
  using const_iterator = ListIterator<T>;

  List(const TreeRef& tree) : TreeView(tree) {
    tree->match(TK_LIST);
    // Iterate over list to temporarily instantiate Ts that will check the type
    for (const T& elem : *this) {
      (void)elem; // silence unused warning
    }
  }
  iterator begin() const {
    return iterator(tree_->trees().begin());
  }
  iterator end() const {
    return iterator(tree_->trees().end());
  }
  bool empty() const {
    return tree_->trees().begin() == tree_->trees().end();
  }
  T operator[](size_t i) const {
    return T(subtree(i));
  }
  TreeRef map(const std::function<TreeRef(const T&)>& fn) {
    return tree_->map([&](TreeRef v) { return fn(T(v)); });
  }
  static List create(const SourceRange& range, const std::vector<T>& subtrees) {
    TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
    return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
  }
  static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
    return List(Compound::create(TK_LIST, range, std::move(subtrees)));
  }
  size_t size() const {
    return tree_->trees().size();
  }
};

template <typename T>
struct Maybe : public TreeView {
  explicit Maybe(const TreeRef& tree) : TreeView(tree) {
    tree_->match(TK_OPTION);
    if (tree_->trees().size() > 1)
      throw ErrorReport(tree) << "Maybe trees can have at most one subtree";
  }
  /* implicit */ Maybe(const T& tree) : TreeView(tree) {}
  bool present() const {
    return tree_->trees().size() > 0;
  }
  T get() const {
    return T(tree_->trees().at(0));
  }
  TreeRef map(const std::function<TreeRef(const T&)>& fn) {
    return tree_->map([&](TreeRef v) { return fn(T(v)); });
  }
  static Maybe<T> create(const SourceRange& range) {
    return Maybe<T>(Compound::create(TK_OPTION, range, {}));
  }
  static Maybe<T> create(const SourceRange& range, const T& value) {
    return Maybe<T>(Compound::create(TK_OPTION, range, {value}));
  }
};

struct Ident : public TreeView {
  explicit Ident(const TreeRef& tree) : TreeView(tree) {
    tree_->match(TK_IDENT);
  }
  const std::string& name() const {
    return subtree(0)->stringValue();
  }
  static Ident create(const SourceRange& range, std::string name) {
    return Ident(
        Compound::create(TK_IDENT, range, {String::create(std::move(name))}));
  }
};

////////////////////////////////////////////////////////////////////////////////
// Base types (production LHS)
////////////////////////////////////////////////////////////////////////////////

struct Stmt : public TreeView {
  explicit Stmt(const TreeRef& tree) : TreeView(tree) {
    switch (tree->kind()) {
      case TK_IF:
      case TK_FOR:
      case TK_WHILE:
      case TK_GLOBAL:
      case TK_ASSIGN:
      case TK_AUG_ASSIGN:
      case TK_RETURN:
      case TK_EXPR_STMT:
      case TK_RAISE:
      case TK_ASSERT:
      case TK_PASS:
      case TK_BREAK:
      case TK_DELETE:
      case TK_CONTINUE:
      case TK_DEF:
      case TK_WITH:
        return;
      default:
        throw ErrorReport(tree)
            << kindToString(tree->kind()) << " is not a valid Stmt";
    }
  }
};

struct Expr : public TreeView {
  explicit Expr(const TreeRef& tree) : TreeView(tree) {
    switch (tree->kind()) {
      case TK_IF_EXPR:
      case TK_AND:
      case TK_OR:
      case '<':
      case '>':
      case TK_IS:
      case TK_ISNOT:
      case TK_EQ:
      case TK_LE:
      case TK_GE:
      case TK_NE:
      case '+':
      case '-':
      case TK_UNARY_MINUS:
      case '~':
      case '*':
      case TK_STARRED:
      case '/':
      case '%':
      case TK_NOT:
      case TK_CONST:
      case TK_STRINGLITERAL:
      case TK_TRUE:
      case TK_FALSE:
      case TK_NONE:
      case TK_CAST:
      case TK_APPLY:
      case '.':
      case TK_SUBSCRIPT:
      case TK_SLICE_EXPR:
      case TK_VAR:
      case TK_LIST_LITERAL:
      case TK_TUPLE_LITERAL:
      case TK_DICT_LITERAL:
      case '@':
      case TK_POW:
      case TK_LSHIFT:
      case TK_RSHIFT:
      case TK_FLOOR_DIV:
      case '&':
      case '^':
      case '|':
      case TK_LIST_COMP:
      case TK_DICT_COMP:
      case TK_DOTS:
      case TK_IN:
      case TK_WITH_ITEM:
        return;
      default:
        throw ErrorReport(tree)
            << kindToString(tree->kind()) << " is not a valid Expr";
    }
  }
};

////////////////////////////////////////////////////////////////////////////////
// Helper nodes (mostly for function arguments)
////////////////////////////////////////////////////////////////////////////////

struct Attribute : public TreeView {
  explicit Attribute(const TreeRef& tree) : TreeView(tree) {
    tree_->match(TK_ATTRIBUTE);
  }
  Ident name() const {
    return Ident(subtree(0));
  }
  Expr value() const {
    return Expr(subtree(1));
  }
  static Attribute create(
      const SourceRange& range,
      const Ident& name,
      const TreeRef& value) {
    return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
  }
};
Loading ...