#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/Exception.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/parser_constants.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/frontend/strtod.h>
#include <algorithm>
#include <clocale>
#include <cstdlib>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
#endif
namespace torch {
namespace jit {
// single character tokens are just the character itself '+'
// multi-character tokens need an entry here
// if the third entry is not the empty string, it is used
// in the lexer to match this token.
// These kinds are also used in Tree.h as the kind of the AST node.
// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
// lexer.
#define TC_FORALL_TOKEN_KINDS(_) \
_(TK_EOF, "eof", "") \
_(TK_WHITESPACE, "whitespace", "") \
_(TK_WHITESPACE_EOF, "whitespace_eof", "") \
_(TK_NUMBER, "number", "") \
_(TK_NEWLINE, "newline", "") \
_(TK_INDENT, "indent", "") \
_(TK_DEDENT, "dedent", "") \
_(TK_DEF, "def", "def") \
_(TK_EQUIVALENT, "equivalent", "<=>") \
_(TK_IDENT, "ident", "") \
_(TK_STRING, "string", "") \
_(TK_STRINGLITERAL, "string_literal", "") \
_(TK_CONST, "const", "") \
_(TK_LIST, "list", "") \
_(TK_DICT, "dict", "") \
_(TK_OPTION, "option", "") \
_(TK_APPLY, "apply", "") \
_(TK_COMPREHENSION, "comprehension", "") \
_(TK_RANGE_CONSTRAINT, "range_constraint", "") \
_(TK_PARAM, "param", "") \
_(TK_INFERRED, "inferred", "") \
_(TK_ACCESS, "access", "") \
_(TK_ASSIGN, "assign", "") \
_(TK_AUG_ASSIGN, "aug_assign", "") \
_(TK_ATTRIBUTE, "attribute", "") \
_(TK_IF, "if", "if") \
_(TK_ELSE, "else", "else") \
_(TK_ELIF, "elif", "elif") \
_(TK_WHILE, "while", "while") \
_(TK_EXPR_STMT, "expression statement", "") \
_(TK_RETURN, "return", "return") \
_(TK_IS, "is", "is") \
_(TK_ISNOT, "is not", "is not") \
_(TK_NE, "ne", "!=") \
_(TK_EQ, "eq", "==") \
_(TK_LE, "le", "<=") \
_(TK_GE, "ge", ">=") \
_(TK_FLOOR_DIV, "floordiv", "//") \
_(TK_IF_EXPR, "if", "") \
_(TK_TRUE, "True", "True") \
_(TK_FALSE, "False", "False") \
_(TK_NONE, "None", "None") \
_(TK_AND, "and", "and") \
_(TK_OR, "or", "or") \
_(TK_NOT, "not", "not") \
_(TK_LSHIFT, "<<", "<<") \
_(TK_RSHIFT, ">>", ">>") \
_(TK_CAST, "cast", "") \
_(TK_PLUS_EQ, "+=", "+=") \
_(TK_MINUS_EQ, "-=", "-=") \
_(TK_TIMES_EQ, "*=", "*=") \
_(TK_DIV_EQ, "/=", "/=") \
_(TK_MOD_EQ, "%=", "%=") \
_(TK_BIT_OR_EQ, "|=", "|=") \
_(TK_BIT_AND_EQ, "&=", "&=") \
_(TK_BIT_XOR_EQ, "^=", "^=") \
_(TK_LSHIFT_EQ, "<<=", "<<=") \
_(TK_RSHIFT_EQ, ">>=", ">>=") \
_(TK_POW_EQ, "**=", "**=") \
_(TK_GLOBAL, "global", "global") \
_(TK_BUILT_IN, "built-in", "") \
_(TK_SUBSCRIPT, "subscript", "") \
_(TK_VAR, "variable", "") \
_(TK_NOTHING, "nothing", "") \
_(TK_DICT_LITERAL, "dict-literal", "") \
_(TK_LIST_LITERAL, "list-literal", "") \
_(TK_TUPLE_LITERAL, "tuple-literal", "") \
_(TK_FOR, "for", "for") \
_(TK_IN, "in", "in") \
_(TK_NOTIN, "not in", "not in") \
_(TK_STARRED, "starred", "") \
_(TK_UNARY_MINUS, "unary minus", "") \
_(TK_POW, "pow operator", "**") \
_(TK_ARROW, "arrow", "->") \
_(TK_DECL, "decl", "") \
_(TK_SLICE_EXPR, "slice expr", "") \
_(TK_TYPE_COMMENT, "type comment", "# type:") \
_(TK_RAISE, "raise", "raise") \
_(TK_ASSERT, "assert", "assert") \
_(TK_DOTS, "dots", "...") \
_(TK_LIST_COMP, "list comprehension", "") \
_(TK_DICT_COMP, "dict comprehension", "") \
_(TK_BREAK, "break", "break") \
_(TK_CONTINUE, "continue", "continue") \
_(TK_DELETE, "del", "del") \
_(TK_PASS, "pass", "pass") \
_(TK_CLASS_DEF, "class", "class") \
_(TK_IMPORT, "import", "import") \
_(TK_WITH, "with", "with") \
_(TK_WITH_ITEM, "withitem", "") \
_(TK_AS, "as", "as") \
_(TK_PROP, "property", "") \
_(TK_ELLIPSIS, "Ellipsis", "Ellipsis") \
_(TK_NONE_TYPE, "NoneType", "NoneType")
enum TokenKind {
// we use characters to represent themselves so skip all valid characters
// before
// assigning enum values to multi-char tokens.
TK_DUMMY_START = 256,
#define DEFINE_TOKEN(tok, _, _2) tok,
TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
#undef DEFINE_TOKEN
};
TORCH_API std::string kindToString(int kind);
TORCH_API int stringToKind(const std::string& str);
// nested hash tables that indicate char-by-char what is a valid token.
struct TokenTrie;
using TokenTrieRef = std::unique_ptr<TokenTrie>;
struct TokenTrie {
TokenTrie() : kind(0) {}
void insert(const char* str, int tok) {
if (*str == '\0') {
AT_ASSERT(kind == 0);
kind = tok;
return;
}
for (size_t i = 0, e = child_chars.size(); i < e; ++i) {
if (child_chars[i] == *str) {
child_tries[i]->insert(str + 1, tok);
return;
}
}
child_chars.emplace_back(*str);
child_tries.emplace_back(std::make_unique<TokenTrie>());
child_tries.back()->insert(str + 1, tok);
}
int kind; // 0 == invalid token
std::vector<char> child_chars;
std::vector<TokenTrieRef> child_tries;
};
// stuff that is shared against all TC lexers/parsers and is initialized only
// once.
struct TORCH_API SharedParserData {
SharedParserData() : head(new TokenTrie()) {
std::stringstream ss;
for (const char* c = valid_single_char_tokens; *c; c++) {
std::string str(1, *c);
head->insert(str.c_str(), *c);
}
#define ADD_CASE(tok, _, tokstring) \
if (*(tokstring) != '\0') { \
head->insert((tokstring), (tok)); \
}
TC_FORALL_TOKEN_KINDS(ADD_CASE)
#undef ADD_CASE
}
bool match(
StringCordView::Iterator pos,
bool continuation, // are we inside a scope where newlines don't count
// (e.g. inside parens)
bool whitespace_token, // should we treat whitespace as a token
int* kind,
StringCordView::Iterator* start,
StringCordView::Iterator* end) {
*start = pos;
// skip whitespace
while (pos.has_next() && isblank(*pos)) {
++pos;
}
// special handling
if (pos.has_next()) {
if (*pos == '#' && !isTypeComment(pos)) {
// skip comments
while (pos.has_next() && *pos != '\n')
++pos;
// tail call, handle whitespace and more comments
return match(pos, continuation, whitespace_token, kind, start, end);
}
if (*pos == '\\') {
auto newiter = pos;
++newiter;
if (newiter.has_next() && *newiter == '\n' && !whitespace_token) {
++newiter;
return match(newiter, continuation, false, kind, start, end);
}
}
if (*pos == '\n') {
return match(++pos, continuation, !continuation, kind, start, end);
}
}
// we handle white space before EOF because in the case we have something
// like the following where we need to generate the dedent token if foo:
// ...
// else:
// pass
if (whitespace_token) {
*kind = !pos.has_next() ? TK_WHITESPACE_EOF : TK_WHITESPACE;
*end = pos;
return true;
}
if (!pos.has_next()) {
*kind = TK_EOF;
*start = pos;
*end = *start;
return true;
}
// invariant: the next token is not whitespace or newline
*start = pos;
// check for a valid number
size_t len;
if (isNumber(pos.rest_line(), 0, &len)) {
*end = *start;
*end += len;
*kind = TK_NUMBER;
return true;
}
// check for string
if (isString(pos.rest_line(), 0, &len)) {
*kind = TK_STRINGLITERAL;
*end = *start;
*end += len;
return true;
}
// check for either an ident or a token
// ident tracks whether what we have scanned so far could be an identifier
// matched indicates if we have found any match.
bool matched = false;
bool ident = true;
TokenTrie* cur = head.get();
// for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr);
// i++)
for (size_t i = 0; pos.has_next() && (ident || cur != nullptr);
++pos, ++i) {
ident = ident && validIdent(i, *pos);
if (ident) {
matched = true;
*end = pos.next_iter();
*kind = TK_IDENT;
}
// check for token second, so that e.g. 'max' matches the token TK_MAX
// rather the
// identifier 'max'
if (cur) {
const auto begin_it = cur->child_chars.begin();
const auto end_it = cur->child_chars.end();
const auto ch_it = std::find(begin_it, end_it, *pos);
cur = (ch_it == end_it) ? nullptr
: cur->child_tries[ch_it - begin_it].get();
if (cur && cur->kind != 0) {
matched = true;
*end = pos.next_iter();
*kind = cur->kind;
}
}
}
return matched;
}
bool isUnary(int kind, int* prec);
bool isBinary(int kind, int* prec);
bool isRightAssociative(int kind) {
switch (kind) {
case '?':
case TK_POW:
case TK_IF:
return true;
default:
return false;
}
}
private:
bool validIdent(size_t i, char n) {
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}
// 1. skip whitespace
// 2. handle comment or newline
//
bool isNumber(c10::string_view str, size_t start, size_t* len) {
char first = str[start];
// strtod allows numbers to start with + or - or nan or inf
// http://en.cppreference.com/w/cpp/string/byte/strtof
// but we want only the number part, otherwise 1+3 will turn into two
// adjacent numbers in the lexer
if (first == '-' || first == '+' || isalpha(first))
return false;
const char* startptr = str.data() + start;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
char* endptr;
torch::jit::strtod_c(startptr, &endptr);
*len = endptr - startptr;
// check if the number is complex valued
// access is safe because string is assumed to be null terminated
if (endptr != nullptr && *endptr == 'j') {
*len += 1;
}
return *len > 0;
}
bool isCharCount(char c, c10::string_view str, size_t start, int len) {
// count checks from [start, start + len)
return start + len <= str.size() &&
std::count(str.begin() + start, str.begin() + start + len, c) == len;
}
// python concatenates all adjacent strings "a" "b" == "ab"
// strings can be enclosed with 1 or 3 single or double quotes
// if enclosed with 3 quotes newlines are valid
// as elsewhere, backslash and new line should be ignored
bool isString(c10::string_view str, size_t start, size_t* len) {
char quote = str[start];
if (quote != '\"' && quote != '\'')
return false;
int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;
// end is now set past the opening quotation marks
size_t end = start + quote_len;
while (end < str.size() && !isCharCount(quote, str, end, quote_len)) {
if (str[end] == '\n' && quote_len != 3) {
return false;
}
// handle escaped characters. advances past escaped quotation marks,
// escaped newlines and escaped backslashes
// multi-char escapes like \x1A are handled fine here because the
// remainder of the escape are valid string characters anyway
if (str[end] == '\\') {
end++;
}
end++;
}
// set length equal to the complete string including quotations
*len = end - start + quote_len;
// if end finished without going past the last character of the string than
// there is a match
return end < str.size();
}
bool isblank(int n) {
return isspace(n) && n != '\n';
}
bool isTypeComment(StringCordView::Iterator str_iter) {
c10::string_view rest_line = str_iter.rest_line();
const std::string type_string = "# type:";
if (rest_line.size() < type_string.length()) {
return false;
}
auto match_string = rest_line.substr(0, type_string.size());
return match_string == type_string;
}
// Make an exception ignoring comments for type annotation comments
bool isTypeComment(StringCordView str, size_t pos) {
const std::string type_string = "# type:";
if (str.size() < pos + type_string.length()) {
return false;
}
auto match_string = str.substr(pos, type_string.size());
return match_string == type_string;
}
TokenTrieRef head;
};
TORCH_API SharedParserData& sharedParserData();
struct Token {
int kind;
SourceRange range;
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
std::string text() {
return std::string(range.token_text());
}
std::string kindString() const {
return kindToString(kind);
}
};
struct Lexer {
explicit Lexer(std::shared_ptr<Source> source)
: source(std::move(source)),
pos(0),
nesting(0),
indent_stack(),
next_tokens(),
shared(sharedParserData()) {
auto first_indent = lexRaw(true);
indent_stack.push_back(first_indent.range.size());
lex();
}
// Return the current token, and then move to the next one
Token next() {
if (next_tokens.empty())
reportError("Lexer invariant violated: empty token queue");
Token r = std::move(next_tokens.front());
next_tokens.erase(next_tokens.begin());
if (next_tokens.empty()) {
lex();
}
return r;
}
// Skip the current token if it matches the given kind
bool nextIf(int kind) {
if (cur().kind != kind)
return false;
next();
return true;
}
[[noreturn]] void reportError(const std::string& what) {
reportError(what, cur());
}
[[noreturn]] void reportError(const std::string& what, const Token& t) {
std::stringstream ss;
ss << what << ":\n";
t.range.highlight(ss);
throw std::runtime_error(ss.str());
}
[[noreturn]] void expected(const std::string& what, const Token& t) {
std::stringstream ss;
ss << "expected " << what << " but found '" << t.kindString()
<< "' here:\n";
t.range.highlight(ss);
throw std::runtime_error(ss.str());
}
[[noreturn]] void expected(const std::string& what) {
expected(what, cur());
}
// Check that the current token has a given kind, return the current token,
// and advance to the next one.
Token expect(int kind) {
if (cur().kind != kind) {
expected(kindToString(kind));
}
return next();
}
Token& lookahead() {
if (next_tokens.size() < 2) {
lex();
}
return next_tokens[1];
}
Token& cur() {
return next_tokens.front();
}
private:
void lex() {
auto r = lexRaw();
switch (r.kind) {
case '(':
case '[':
case '{':
nesting++;
break;
case ')':
case ']':
case '}':
nesting--;
break;
case TK_WHITESPACE:
case TK_WHITESPACE_EOF: {
const auto depth = static_cast<int64_t>(
r.kind == TK_WHITESPACE_EOF ? indent_stack.front()
: r.range.size());
// note: TK_WHITESPACE_EOF is whitespace right before the EOF token
// just like we allow the code to be indented to a particular initial
// indent level, we allow the final indent to be anything and set
// it back to the initial indent level. This allows the code to be
// put into string literals inside code without worrying about final
// whitespace
if (depth > indent_stack.back()) {
indent_stack.push_back(depth);
r.kind = TK_INDENT;
} else if (depth == indent_stack.back()) {
r.kind = TK_NEWLINE;
} else {
next_tokens.emplace_back(TK_NEWLINE, r.range);
while (indent_stack.back() != depth) {
indent_stack.pop_back();
next_tokens.emplace_back(TK_DEDENT, r.range);
if (indent_stack.empty()) {
reportError(
"invalid indent level " + c10::guts::to_string(depth), r);
}
}
return; // We've already queued the tokens
}
} break;
default:
break;
}
next_tokens.push_back(std::move(r));
}
Token lexRaw(bool whitespace_token = false) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int kind;
AT_ASSERT(source);
if (current == nullptr) {
AT_ASSERT(pos == 0);
current = std::make_unique<StringCordView::Iterator>(
source->text_str().begin());
}
StringCordView::Iterator start_iter = *current;
StringCordView::Iterator end_iter = *current;
if (!shared.match(
*current,
nesting > 0,
whitespace_token,
&kind,
&start_iter,
&end_iter)) {
expected(
"a valid token",
Token(
**current,
SourceRange(source, start_iter, start_iter.pos() + 1)));
}
auto t = Token(kind, SourceRange(source, start_iter, end_iter.pos()));
pos = end_iter.pos();
*current = end_iter;
return t;
}
std::shared_ptr<Source> source;
std::unique_ptr<StringCordView::Iterator> current;
size_t pos;
size_t nesting; // depth of ( [ { nesting...
std::vector<int> indent_stack; // stack of indentation level of blocks
// Invariant: this should always contain at least a single element
std::vector<Token> next_tokens;
SharedParserData& shared;
};
} // namespace jit
} // namespace torch
C10_CLANG_DIAGNOSTIC_POP()