Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / frontend / lexer.h

#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/Exception.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/parser_constants.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/frontend/strtod.h>
#include <algorithm>
#include <clocale>
#include <cstdlib>
#include <memory>
#include <sstream>
#include <string>
#include <vector>

C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
#endif

namespace torch {
namespace jit {

// single character tokens are just the character itself '+'
// multi-character tokens need an entry here
// if the third entry is not the empty string, it is used
// in the lexer to match this token.

// These kinds are also used in Tree.h as the kind of the AST node.
// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
// lexer.

#define TC_FORALL_TOKEN_KINDS(_)                 \
  _(TK_EOF, "eof", "")                           \
  _(TK_WHITESPACE, "whitespace", "")             \
  _(TK_WHITESPACE_EOF, "whitespace_eof", "")     \
  _(TK_NUMBER, "number", "")                     \
  _(TK_NEWLINE, "newline", "")                   \
  _(TK_INDENT, "indent", "")                     \
  _(TK_DEDENT, "dedent", "")                     \
  _(TK_DEF, "def", "def")                        \
  _(TK_EQUIVALENT, "equivalent", "<=>")          \
  _(TK_IDENT, "ident", "")                       \
  _(TK_STRING, "string", "")                     \
  _(TK_STRINGLITERAL, "string_literal", "")      \
  _(TK_CONST, "const", "")                       \
  _(TK_LIST, "list", "")                         \
  _(TK_DICT, "dict", "")                         \
  _(TK_OPTION, "option", "")                     \
  _(TK_APPLY, "apply", "")                       \
  _(TK_COMPREHENSION, "comprehension", "")       \
  _(TK_RANGE_CONSTRAINT, "range_constraint", "") \
  _(TK_PARAM, "param", "")                       \
  _(TK_INFERRED, "inferred", "")                 \
  _(TK_ACCESS, "access", "")                     \
  _(TK_ASSIGN, "assign", "")                     \
  _(TK_AUG_ASSIGN, "aug_assign", "")             \
  _(TK_ATTRIBUTE, "attribute", "")               \
  _(TK_IF, "if", "if")                           \
  _(TK_ELSE, "else", "else")                     \
  _(TK_ELIF, "elif", "elif")                     \
  _(TK_WHILE, "while", "while")                  \
  _(TK_EXPR_STMT, "expression statement", "")    \
  _(TK_RETURN, "return", "return")               \
  _(TK_IS, "is", "is")                           \
  _(TK_ISNOT, "is not", "is not")                \
  _(TK_NE, "ne", "!=")                           \
  _(TK_EQ, "eq", "==")                           \
  _(TK_LE, "le", "<=")                           \
  _(TK_GE, "ge", ">=")                           \
  _(TK_FLOOR_DIV, "floordiv", "//")              \
  _(TK_IF_EXPR, "if", "")                        \
  _(TK_TRUE, "True", "True")                     \
  _(TK_FALSE, "False", "False")                  \
  _(TK_NONE, "None", "None")                     \
  _(TK_AND, "and", "and")                        \
  _(TK_OR, "or", "or")                           \
  _(TK_NOT, "not", "not")                        \
  _(TK_LSHIFT, "<<", "<<")                       \
  _(TK_RSHIFT, ">>", ">>")                       \
  _(TK_CAST, "cast", "")                         \
  _(TK_PLUS_EQ, "+=", "+=")                      \
  _(TK_MINUS_EQ, "-=", "-=")                     \
  _(TK_TIMES_EQ, "*=", "*=")                     \
  _(TK_DIV_EQ, "/=", "/=")                       \
  _(TK_MOD_EQ, "%=", "%=")                       \
  _(TK_BIT_OR_EQ, "|=", "|=")                    \
  _(TK_BIT_AND_EQ, "&=", "&=")                   \
  _(TK_BIT_XOR_EQ, "^=", "^=")                   \
  _(TK_LSHIFT_EQ, "<<=", "<<=")                  \
  _(TK_RSHIFT_EQ, ">>=", ">>=")                  \
  _(TK_POW_EQ, "**=", "**=")                     \
  _(TK_GLOBAL, "global", "global")               \
  _(TK_BUILT_IN, "built-in", "")                 \
  _(TK_SUBSCRIPT, "subscript", "")               \
  _(TK_VAR, "variable", "")                      \
  _(TK_NOTHING, "nothing", "")                   \
  _(TK_DICT_LITERAL, "dict-literal", "")         \
  _(TK_LIST_LITERAL, "list-literal", "")         \
  _(TK_TUPLE_LITERAL, "tuple-literal", "")       \
  _(TK_FOR, "for", "for")                        \
  _(TK_IN, "in", "in")                           \
  _(TK_NOTIN, "not in", "not in")                \
  _(TK_STARRED, "starred", "")                   \
  _(TK_UNARY_MINUS, "unary minus", "")           \
  _(TK_POW, "pow operator", "**")                \
  _(TK_ARROW, "arrow", "->")                     \
  _(TK_DECL, "decl", "")                         \
  _(TK_SLICE_EXPR, "slice expr", "")             \
  _(TK_TYPE_COMMENT, "type comment", "# type:")  \
  _(TK_RAISE, "raise", "raise")                  \
  _(TK_ASSERT, "assert", "assert")               \
  _(TK_DOTS, "dots", "...")                      \
  _(TK_LIST_COMP, "list comprehension", "")      \
  _(TK_DICT_COMP, "dict comprehension", "")      \
  _(TK_BREAK, "break", "break")                  \
  _(TK_CONTINUE, "continue", "continue")         \
  _(TK_DELETE, "del", "del")                     \
  _(TK_PASS, "pass", "pass")                     \
  _(TK_CLASS_DEF, "class", "class")              \
  _(TK_IMPORT, "import", "import")               \
  _(TK_WITH, "with", "with")                     \
  _(TK_WITH_ITEM, "withitem", "")                \
  _(TK_AS, "as", "as")                           \
  _(TK_PROP, "property", "")                     \
  _(TK_ELLIPSIS, "Ellipsis", "Ellipsis")         \
  _(TK_NONE_TYPE, "NoneType", "NoneType")

enum TokenKind {
  // we use characters to represent themselves so skip all valid characters
  // before
  // assigning enum values to multi-char tokens.
  TK_DUMMY_START = 256,
#define DEFINE_TOKEN(tok, _, _2) tok,
  TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
#undef DEFINE_TOKEN
};

TORCH_API std::string kindToString(int kind);
TORCH_API int stringToKind(const std::string& str);

// nested hash tables that indicate char-by-char what is a valid token.
struct TokenTrie;
using TokenTrieRef = std::unique_ptr<TokenTrie>;
struct TokenTrie {
  TokenTrie() : kind(0) {}
  void insert(const char* str, int tok) {
    if (*str == '\0') {
      AT_ASSERT(kind == 0);
      kind = tok;
      return;
    }

    for (size_t i = 0, e = child_chars.size(); i < e; ++i) {
      if (child_chars[i] == *str) {
        child_tries[i]->insert(str + 1, tok);
        return;
      }
    }

    child_chars.emplace_back(*str);
    child_tries.emplace_back(std::make_unique<TokenTrie>());
    child_tries.back()->insert(str + 1, tok);
  }
  int kind; // 0 == invalid token

  std::vector<char> child_chars;
  std::vector<TokenTrieRef> child_tries;
};

// stuff that is shared against all TC lexers/parsers and is initialized only
// once.
struct TORCH_API SharedParserData {
  SharedParserData() : head(new TokenTrie()) {
    std::stringstream ss;
    for (const char* c = valid_single_char_tokens; *c; c++) {
      std::string str(1, *c);
      head->insert(str.c_str(), *c);
    }

#define ADD_CASE(tok, _, tokstring)   \
  if (*(tokstring) != '\0') {         \
    head->insert((tokstring), (tok)); \
  }
    TC_FORALL_TOKEN_KINDS(ADD_CASE)
#undef ADD_CASE
  }

  bool match(
      StringCordView::Iterator pos,
      bool continuation, // are we inside a scope where newlines don't count
                         // (e.g. inside parens)
      bool whitespace_token, // should we treat whitespace as a token
      int* kind,
      StringCordView::Iterator* start,
      StringCordView::Iterator* end) {
    *start = pos;
    // skip whitespace
    while (pos.has_next() && isblank(*pos)) {
      ++pos;
    }

    // special handling
    if (pos.has_next()) {
      if (*pos == '#' && !isTypeComment(pos)) {
        // skip comments
        while (pos.has_next() && *pos != '\n')
          ++pos;
        // tail call, handle whitespace and more comments
        return match(pos, continuation, whitespace_token, kind, start, end);
      }
      if (*pos == '\\') {
        auto newiter = pos;
        ++newiter;
        if (newiter.has_next() && *newiter == '\n' && !whitespace_token) {
          ++newiter;
          return match(newiter, continuation, false, kind, start, end);
        }
      }
      if (*pos == '\n') {
        return match(++pos, continuation, !continuation, kind, start, end);
      }
    }
    // we handle white space before EOF because in the case we have something
    // like the following where we need to generate the dedent token if foo:
    //   ...
    // else:
    //   pass
    if (whitespace_token) {
      *kind = !pos.has_next() ? TK_WHITESPACE_EOF : TK_WHITESPACE;
      *end = pos;
      return true;
    }
    if (!pos.has_next()) {
      *kind = TK_EOF;
      *start = pos;
      *end = *start;
      return true;
    }
    // invariant: the next token is not whitespace or newline
    *start = pos;
    // check for a valid number
    size_t len;
    if (isNumber(pos.rest_line(), 0, &len)) {
      *end = *start;
      *end += len;
      *kind = TK_NUMBER;
      return true;
    }
    // check for string
    if (isString(pos.rest_line(), 0, &len)) {
      *kind = TK_STRINGLITERAL;
      *end = *start;
      *end += len;
      return true;
    }

    // check for either an ident or a token
    // ident tracks whether what we have scanned so far could be an identifier
    // matched indicates if we have found any match.
    bool matched = false;
    bool ident = true;
    TokenTrie* cur = head.get();
    // for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr);
    // i++)
    for (size_t i = 0; pos.has_next() && (ident || cur != nullptr);
         ++pos, ++i) {
      ident = ident && validIdent(i, *pos);
      if (ident) {
        matched = true;
        *end = pos.next_iter();
        *kind = TK_IDENT;
      }
      // check for token second, so that e.g. 'max' matches the token TK_MAX
      // rather the
      // identifier 'max'
      if (cur) {
        const auto begin_it = cur->child_chars.begin();
        const auto end_it = cur->child_chars.end();
        const auto ch_it = std::find(begin_it, end_it, *pos);

        cur = (ch_it == end_it) ? nullptr
                                : cur->child_tries[ch_it - begin_it].get();

        if (cur && cur->kind != 0) {
          matched = true;
          *end = pos.next_iter();
          *kind = cur->kind;
        }
      }
    }
    return matched;
  }

  bool isUnary(int kind, int* prec);
  bool isBinary(int kind, int* prec);
  bool isRightAssociative(int kind) {
    switch (kind) {
      case '?':
      case TK_POW:
      case TK_IF:
        return true;
      default:
        return false;
    }
  }

 private:
  bool validIdent(size_t i, char n) {
    return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
  }

  // 1. skip whitespace
  // 2. handle comment or newline
  //
  bool isNumber(c10::string_view str, size_t start, size_t* len) {
    char first = str[start];
    // strtod allows numbers to start with + or - or nan or inf
    // http://en.cppreference.com/w/cpp/string/byte/strtof
    // but we want only the number part, otherwise 1+3 will turn into two
    // adjacent numbers in the lexer
    if (first == '-' || first == '+' || isalpha(first))
      return false;
    const char* startptr = str.data() + start;
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    char* endptr;
    torch::jit::strtod_c(startptr, &endptr);
    *len = endptr - startptr;
    // check if the number is complex valued
    // access is safe because string is assumed to be null terminated
    if (endptr != nullptr && *endptr == 'j') {
      *len += 1;
    }
    return *len > 0;
  }

  bool isCharCount(char c, c10::string_view str, size_t start, int len) {
    // count checks from [start, start + len)
    return start + len <= str.size() &&
        std::count(str.begin() + start, str.begin() + start + len, c) == len;
  }

  // python concatenates all adjacent strings "a" "b" == "ab"
  // strings can be enclosed with 1 or 3 single or double quotes
  // if enclosed with 3 quotes newlines are valid
  // as elsewhere, backslash and new line should be ignored
  bool isString(c10::string_view str, size_t start, size_t* len) {
    char quote = str[start];
    if (quote != '\"' && quote != '\'')
      return false;
    int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;

    // end is now set past the opening quotation marks
    size_t end = start + quote_len;
    while (end < str.size() && !isCharCount(quote, str, end, quote_len)) {
      if (str[end] == '\n' && quote_len != 3) {
        return false;
      }
      // handle escaped characters. advances past escaped quotation marks,
      // escaped newlines and escaped backslashes
      // multi-char escapes like \x1A are handled fine here because the
      // remainder of the escape are valid string characters anyway
      if (str[end] == '\\') {
        end++;
      }
      end++;
    }
    // set length equal to the complete string including quotations
    *len = end - start + quote_len;
    // if end finished without going past the last character of the string than
    // there is a match
    return end < str.size();
  }

  bool isblank(int n) {
    return isspace(n) && n != '\n';
  }

  bool isTypeComment(StringCordView::Iterator str_iter) {
    c10::string_view rest_line = str_iter.rest_line();
    const std::string type_string = "# type:";
    if (rest_line.size() < type_string.length()) {
      return false;
    }
    auto match_string = rest_line.substr(0, type_string.size());
    return match_string == type_string;
  }

  // Make an exception ignoring comments for type annotation comments
  bool isTypeComment(StringCordView str, size_t pos) {
    const std::string type_string = "# type:";
    if (str.size() < pos + type_string.length()) {
      return false;
    }
    auto match_string = str.substr(pos, type_string.size());
    return match_string == type_string;
  }

  TokenTrieRef head;
};

TORCH_API SharedParserData& sharedParserData();

struct Token {
  int kind;
  SourceRange range;
  Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
  std::string text() {
    return std::string(range.token_text());
  }
  std::string kindString() const {
    return kindToString(kind);
  }
};

struct Lexer {
  explicit Lexer(std::shared_ptr<Source> source)
      : source(std::move(source)),
        pos(0),
        nesting(0),
        indent_stack(),
        next_tokens(),
        shared(sharedParserData()) {
    auto first_indent = lexRaw(true);
    indent_stack.push_back(first_indent.range.size());
    lex();
  }
  // Return the current token, and then move to the next one
  Token next() {
    if (next_tokens.empty())
      reportError("Lexer invariant violated: empty token queue");
    Token r = std::move(next_tokens.front());
    next_tokens.erase(next_tokens.begin());
    if (next_tokens.empty()) {
      lex();
    }
    return r;
  }
  // Skip the current token if it matches the given kind
  bool nextIf(int kind) {
    if (cur().kind != kind)
      return false;
    next();
    return true;
  }

  [[noreturn]] void reportError(const std::string& what) {
    reportError(what, cur());
  }
  [[noreturn]] void reportError(const std::string& what, const Token& t) {
    std::stringstream ss;
    ss << what << ":\n";
    t.range.highlight(ss);
    throw std::runtime_error(ss.str());
  }
  [[noreturn]] void expected(const std::string& what, const Token& t) {
    std::stringstream ss;
    ss << "expected " << what << " but found '" << t.kindString()
       << "' here:\n";
    t.range.highlight(ss);
    throw std::runtime_error(ss.str());
  }
  [[noreturn]] void expected(const std::string& what) {
    expected(what, cur());
  }
  // Check that the current token has a given kind, return the current token,
  // and advance to the next one.
  Token expect(int kind) {
    if (cur().kind != kind) {
      expected(kindToString(kind));
    }
    return next();
  }
  Token& lookahead() {
    if (next_tokens.size() < 2) {
      lex();
    }
    return next_tokens[1];
  }
  Token& cur() {
    return next_tokens.front();
  }

 private:
  void lex() {
    auto r = lexRaw();
    switch (r.kind) {
      case '(':
      case '[':
      case '{':
        nesting++;
        break;
      case ')':
      case ']':
      case '}':
        nesting--;
        break;
      case TK_WHITESPACE:
      case TK_WHITESPACE_EOF: {
        const auto depth = static_cast<int64_t>(
            r.kind == TK_WHITESPACE_EOF ? indent_stack.front()
                                        : r.range.size());
        // note: TK_WHITESPACE_EOF is whitespace right before the EOF token
        // just like we allow the code to be indented to a particular initial
        // indent level, we allow the final indent to be anything and set
        // it back to the initial indent level. This allows the code to be
        // put into string literals inside code without worrying about final
        // whitespace
        if (depth > indent_stack.back()) {
          indent_stack.push_back(depth);
          r.kind = TK_INDENT;
        } else if (depth == indent_stack.back()) {
          r.kind = TK_NEWLINE;
        } else {
          next_tokens.emplace_back(TK_NEWLINE, r.range);
          while (indent_stack.back() != depth) {
            indent_stack.pop_back();
            next_tokens.emplace_back(TK_DEDENT, r.range);
            if (indent_stack.empty()) {
              reportError(
                  "invalid indent level " + c10::guts::to_string(depth), r);
            }
          }
          return; // We've already queued the tokens
        }
      } break;
      default:
        break;
    }
    next_tokens.push_back(std::move(r));
  }
  Token lexRaw(bool whitespace_token = false) {
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    int kind;
    AT_ASSERT(source);
    if (current == nullptr) {
      AT_ASSERT(pos == 0);
      current = std::make_unique<StringCordView::Iterator>(
          source->text_str().begin());
    }

    StringCordView::Iterator start_iter = *current;
    StringCordView::Iterator end_iter = *current;
    if (!shared.match(
            *current,
            nesting > 0,
            whitespace_token,
            &kind,
            &start_iter,
            &end_iter)) {
      expected(
          "a valid token",
          Token(
              **current,
              SourceRange(source, start_iter, start_iter.pos() + 1)));
    }

    auto t = Token(kind, SourceRange(source, start_iter, end_iter.pos()));
    pos = end_iter.pos();
    *current = end_iter;
    return t;
  }

  std::shared_ptr<Source> source;
  std::unique_ptr<StringCordView::Iterator> current;
  size_t pos;
  size_t nesting; // depth of ( [ { nesting...
  std::vector<int> indent_stack; // stack of indentation level of blocks
  // Invariant: this should always contain at least a single element
  std::vector<Token> next_tokens;
  SharedParserData& shared;
};
} // namespace jit
} // namespace torch

C10_CLANG_DIAGNOSTIC_POP()