Repository URL to install this package:
|
Version:
0.2.102 ▾
|
from lib2to3.pgen2 import token
from snsql._ast.expressions.date import *
from snsql._ast.expressions.string import *
from snsql._ast.expressions.numeric import *
from snsql._ast.expressions.types import CastFunction
from snsql.metadata import Metadata
from snsql.sql.parser.SqlSmallParser import SqlSmallParser
from .parser.SqlSmallLexer import SqlSmallLexer # type: ignore
from .parser.SqlSmallParser import SqlSmallParser # type: ignore
from .parser.SqlSmallVisitor import SqlSmallVisitor # type: ignore
from .parser.SqlSmallErrorListener import SyntaxErrorListener # type: ignore
from antlr4 import * # type: ignore
from snsql._ast.tokens import *
from snsql._ast.ast import *
class QueryParser:
def __init__(self, metadata=None):
if metadata:
self.metadata = Metadata.from_(metadata)
else:
self.metadata = None
def start_parser(self, stream):
lexer = SqlSmallLexer(stream)
stream = CommonTokenStream(lexer)
parser = SqlSmallParser(stream)
parser._interp.predictionMode = PredictionMode.LL_EXACT_AMBIG_DETECTION
lexer._listeners = [SyntaxErrorListener(), DiagnosticErrorListener()]
parser._listeners = [SyntaxErrorListener(), DiagnosticErrorListener()]
return parser
def queries(self, query_string, metadata=None):
if metadata is None and self.metadata is not None:
metadata = self.metadata
elif metadata:
metadata = Metadata.from_(metadata)
istream = InputStream(query_string)
parser = self.start_parser(istream)
bv = BatchVisitor()
queries = [q for q in bv.visit(parser.batch()).queries]
if metadata is not None:
for q in queries:
q.load_symbols(metadata)
return queries
def query(self, query_string, metadata=None):
queries = self.queries(query_string, metadata)
if len(queries) > 1:
raise ValueError("Attempt to parse query resulted in a batch with more than one")
q = queries[0]
if metadata is not None:
q.load_symbols(metadata)
return q
def parse_only(self, query_string):
if (
query_string.strip().upper().startswith("SELECT") or
query_string.strip().upper().startswith("WITH") or
query_string.strip().startswith("--")):
istream = InputStream(query_string)
else:
istream = FileStream(query_string)
parser = self.start_parser(istream)
SqlSmallVisitor().visit(parser.batch())
return None
def parse_named_expressions(self, expression_string):
istream = InputStream(expression_string)
parser = self.start_parser(istream)
nev = NamedExpressionVisitor()
return nev.visitNamedExpressionSeq(parser.namedExpressionSeq())
def parse_expression(self, expression_string):
istream = InputStream(expression_string)
parser = self.start_parser(istream)
ev = ExpressionVisitor()
return ev.visit(parser.expression())
def parse_table_name(self, expression_string):
istream = InputStream(expression_string)
parser = self.start_parser(istream)
return Identifier(parser.qualifiedTableName().getText())
class BatchVisitor(SqlSmallVisitor):
def visitBatch(self, ctx):
qv = QueryVisitor()
queries = [q for q in [qv.visit(c) for c in ctx.children] if q is not None]
return Batch(queries)
class QueryVisitor(SqlSmallVisitor):
def visitBasicQuery(self, ctx: SqlSmallParser.BasicQueryContext):
# SELECT is required
select = SelectVisitor().visit(ctx.selectClause())
fm = ctx.fromClause()
source = FromVisitor().visit(fm) if fm is not None else None
wc = ctx.whereClause()
where = WhereVisitor().visit(wc) if wc is not None else None
hc = ctx.havingClause()
having = HavingVisitor().visit(hc) if hc is not None else None
ac = ctx.aggregationClause()
agg = AggregateVisitor().visit(ac) if ac is not None else None
oc = ctx.orderClause()
order = OrderVisitor().visit(oc) if oc is not None else None
limit = None
if hasattr(ctx, "limitClause"):
lc = ctx.limitClause()
limit = LimitVisitor().visit(lc) if lc is not None else None
return Query(select, source, where, agg, having, order, limit)
def visitUnifiedQuery(self, ctx: SqlSmallParser.UnifiedQueryContext):
query = self.visit(ctx.basicQuery())
if not ctx.queryOperations():
return query
else:
query_ops = []
for ops_ctx in ctx.queryOperations():
if ops_ctx.UNION():
token = Token("UNION")
elif ops_ctx.INTERSECT():
token = Token("INTERSECT")
else:
token = Token("EXCEPT")
query_op = QueryOperation(
token,
ops_ctx.ALL() is not None,
self.visit(ops_ctx.basicQuery())
)
query_ops.append(query_op)
return UnifiedQuery(query, query_ops)
class SelectVisitor(SqlSmallVisitor):
def visitSelectClause(self, ctx: SqlSmallParser.SelectClauseContext):
if ctx.WITH():
namedsubqueries = NamedSubqueryVisitor().visit(ctx.namedSubquerySeq())
ctes = [nsq for nsq in namedsubqueries]
else:
ctes = None
namedExpressions = NamedExpressionVisitor().visit(ctx.namedExpressionSeq())
sq = ctx.setQuantifier()
tc = sq.topClause() if sq is not None else None
if tc is None:
quantifier = None if sq is None else Token(sq.getText())
else:
quantifier = LimitVisitor().visit(tc)
return Select(quantifier, [ne for ne in namedExpressions if ne is not None], ctes)
class NamedSubqueryVisitor(SqlSmallVisitor):
def visitNamedSubquerySeq(self, ctx):
return [self.visit(ne) for ne in ctx.namedSubquery()]
def visitNamedSubquery(self, ctx: SqlSmallParser.NamedSubqueryContext):
subquery = QueryVisitor().visit(ctx.unifiedQuery())
name = Identifier(ctx.name.getText())
cols = [Identifier(i.getText()) for i in ctx.identifier()[1:]]
return NamedSubquery(name, cols, subquery)
class FromVisitor(SqlSmallVisitor):
def visitFromClause(self, ctx):
rv = RelationVisitor()
relations = [rv.visit(rel) for rel in ctx.relation()]
return From(relations)
class AggregateVisitor(SqlSmallVisitor):
def visitAggregationClause(self, ctx):
groups = ctx.groupingExpressions
ev = ExpressionVisitor()
cols = [GroupingExpression(ev.visit(g)) for g in groups]
return Aggregate(cols)
class WhereVisitor(SqlSmallVisitor):
def visitWhereClause(self, ctx):
bev = BooleanExpressionVisitor()
return Where(bev.visit(ctx.booleanExpression()))
class HavingVisitor(SqlSmallVisitor):
def visitHavingClause(self, ctx):
bev = BooleanExpressionVisitor()
return Having(bev.visit(ctx.booleanExpression()))
class NamedExpressionVisitor(SqlSmallVisitor):
def visitNamedExpressionSeq(self, ctx):
return [self.visit(ne) for ne in ctx.namedExpression()]
def visitNamedExpression(self, ctx):
expression = ExpressionVisitor().visit(ctx.expression())
name = Identifier(ctx.name.getText()) if ctx.name is not None else None
return NamedExpression(name, expression)
class OrderVisitor(SqlSmallVisitor):
def visitOrderClause(self, ctx):
sortItems = [self.visit(si) for si in ctx.sortItem()]
return Order(sortItems)
def visitSortItem(self, ctx):
ev = ExpressionVisitor()
expr = ev.visit(ctx.expression())
if ctx.DESC() is not None:
o = Token("DESC")
elif ctx.ASC() is not None:
o = Token("ASC")
else:
o = None
return SortItem(expr, o)
class PartitionByVisitor(SqlSmallVisitor):
def visitPartitionByClause(self, ctx):
ev = ExpressionVisitor()
return PartitionByClause([ev.visit(c) for c in ctx.children])
class LimitVisitor(SqlSmallVisitor):
def visitLimitClause(self, ctx):
offset = (
int(ctx.offset.getText()) if ctx.offset
else None
)
return Limit(int(ctx.n.getText()), offset)
def visitTopClause(self, ctx):
return Top(int(ctx.n.getText()))
class RelationVisitor(SqlSmallVisitor):
def visitRelation(self, ctx: SqlSmallParser.RelationContext):
primary = self.visit(ctx.relationPrimary())
jr = ctx.joinRelation()
joins = [self.visit(j) for j in jr] if jr is not None else None
return Relation(primary, joins)
def visitUnnest(self, ctx: SqlSmallParser.UnnestContext):
context = ctx.unnestOperator()
alias = Identifier(context.alias.getText()) if context.alias is not None else None
name = Identifier(context.qualifiedTableName().getText())
return Unnest(name, alias)
def visitTableSample(self, ctx):
sampling_method = (
'SYSTEM' if ctx.SYSTEM()
else 'BERNOULLI'
)
return Tablesample(
sampling_method,
float(ctx.p.getText()) if ctx.p is not None else None,
int(ctx.n.getText()) if ctx.n is not None else None,
int(ctx.seed.getText()) if ctx.seed is not None else None,
)
def visitTable(self, ctx):
alias = Identifier(ctx.alias.getText()) if ctx.alias is not None else None
tablesample = (
self.visitTableSample(ctx.tablesample()) if ctx.tablesample()
else None
)
return Table(
Identifier(ctx.qualifiedTableName().getText()),
alias,
tablesample,
)
def visitAliasedQuery(self, ctx):
alias = Identifier(ctx.alias.getText()) if ctx.alias is not None else None
qv = QueryVisitor()
return AliasedSubquery(qv.visit(ctx.unifiedQuery()), alias)
def visitAliasedRelation(self, ctx):
alias = Identifier(ctx.alias.getText()) if ctx.alias is not None else None
relation = self.visit(ctx.relation())
return AliasedRelation(relation, alias)
def visitJoinRelation(self, ctx):
joinType = Token(allText(ctx.joinType()))
right = RelationVisitor().visit(ctx.right)
crit = ctx.joinCriteria()
if type(crit) is SqlSmallParser.BooleanJoinContext:
bev = BooleanExpressionVisitor()
criteria = BooleanJoinCriteria(bev.visit(crit.booleanExpression()))
elif type(crit) is SqlSmallParser.UsingJoinContext:
ids = crit.identifier()
criteria = UsingJoinCriteria([Column(i.getText()) for i in ids])
else:
criteria = None
return Join(joinType, right, criteria)
class ExpressionVisitor(SqlSmallVisitor):
"""
SQL Grammar
"""
def visitColumnName(self, ctx):
return Column(ctx.name.getText())
def visitAllExpr(self, ctx):
ident = ctx.allExpression().identifier().getText() if ctx.allExpression().identifier() else None
ident_ident = ctx.allExpression().QN2().getText() if ctx.allExpression().QN2() else None
return AllColumns(ident if ident else ident_ident)
def visitDecimalLiteral(self, ctx):
return Literal(float(allText(ctx)))
def visitIntegerLiteral(self, ctx):
return Literal(int(allText(ctx)))
def visitStringLiteral(self, ctx):
text = str(allText(ctx))
t_len = len(text)
value = text
if t_len > 1:
l_delim = text[0]
r_delim = text[t_len - 1]
if r_delim == "'" and l_delim == "'":
# this is the expected case for all stringLiteral
value = text[1:t_len - 1]
return Literal(value, text)
def visitRawStringLiteral(self, ctx: SqlSmallParser.RawStringLiteralContext):
text = str(allText(ctx))
if text.startswith("r'") or text.startswith("R'"):
value = r"{}".format(str(allText(ctx)[2:-1]))
return Literal(value=value, text=text)
return Literal(text)
def visitTrueLiteral(self, ctx):
return Literal(True)
def visitFalseLiteral(self, ctx):
return Literal(False)
def visitNullLiteral(self, ctx):
return Literal(None)
def visitNestedExpr(self, ctx):
return NestedExpression(self.visit(ctx.expression()))
def visitAggFunc(self, ctx):
fname = FuncName(ctx.function.getText().upper())
qt = ctx.setQuantifier()
quantifier = Token(qt.getText().upper()) if qt is not None else None
return AggFunction(fname, quantifier, self.visit(ctx.expression()))
def visitSubqueryExpr(self, ctx):
qv = QueryVisitor()
q = qv.visit(ctx.unifiedQuery())
return AliasedSubquery(q, None)
def visitRankingFunction(self, ctx):
fname = FuncName(ctx.function.getText().upper())
over = self.visit(ctx.overClause())
return RankingFunction(fname, over)
def visitOverClause(self, ctx):
pbc = ctx.partitionByClause()
partition = PartitionByVisitor().visit(pbc) if pbc is not None else None
oc = ctx.orderClause()
order = OrderVisitor().visit(oc) if oc is not None else None
return OverClause(partition, order)
def visitLeadFunc(self, ctx: SqlSmallParser.LeadFuncContext):
lfctx = ctx.leadFunction()
source = ExpressionVisitor().visit(lfctx.sourceExpr)
offset = ExpressionVisitor().visit(lfctx.offset) if lfctx.offset is not None else None
default = ExpressionVisitor().visit(lfctx.default) if lfctx.default is not None else None
over = self.visit(lfctx.overClause())
return LeadFunction(source, offset, default, over)
def visitGreatestFunc(self, ctx: SqlSmallParser.GreatestFuncContext):
gstctx = ctx.greatestFunction()
ev = ExpressionVisitor()
expressions = [e for e in [ev.visit(c) for c in gstctx.children] if e is not None]
return GreatestFunction(expressions)
def visitLeastFunc(self, ctx: SqlSmallParser.LeastFuncContext):
lstctx = ctx.leastFunction()
ev = ExpressionVisitor()
expressions = [e for e in [ev.visit(c) for c in lstctx.children] if e is not None]
return LeastFunction(expressions)
def visitNotFunc(self, ctx: SqlSmallParser.NotFuncContext):
boolexpr = BooleanExpressionVisitor().visit(ctx.booleanExpression())
return NotFuction(boolexpr)
def visitUserDefFunc(self, ctx: SqlSmallParser.UserDefFuncContext):
ev = UserDefinedArgumentVisitor()
args = [ev.visit(arg) for arg in ctx.userDefinedArgument()]
return UserDefinedFunction(ctx.function.getText(), args)
"""
Typing functions
"""
def visitCastExpression(self, ctx: SqlSmallParser.CastExpressionContext):
exp = ExpressionVisitor().visit(ctx.fromExpr)
dbtype = ctx.dbType().getText().lower()
return CastFunction(exp, dbtype)
"""
Logical functions
"""
def visitCaseExpr(self, ctx):
return CaseExpressionVisitor().visit(ctx)
def visitIifFunc(self, ctx):
test = BooleanExpressionVisitor().visit(ctx.test)
yes = ExpressionVisitor().visit(ctx.yes)
no = ExpressionVisitor().visit(ctx.no)
return IIFFunction(test, yes, no)
def visitChooseFunc(self, ctx):
expression = ExpressionVisitor().visit(ctx.index)
choices = Seq([ExpressionVisitor().visit(e) for e in ctx.literal()])
return ChooseFunction(expression, choices)
"""
Numeric functions
"""
def visitMultiply(self, ctx):
return ArithmeticExpression(self.visit(ctx.left), Op("*"), self.visit(ctx.right))
def visitDivide(self, ctx):
return ArithmeticExpression(self.visit(ctx.left), Op("/"), self.visit(ctx.right))
def visitModulo(self, ctx):
return ArithmeticExpression(self.visit(ctx.left), Op("%"), self.visit(ctx.right))
def visitAdd(self, ctx):
return ArithmeticExpression(self.visit(ctx.left), Op("+"), self.visit(ctx.right))
def visitSubtract(self, ctx):
return ArithmeticExpression(self.visit(ctx.left), Op("-"), self.visit(ctx.right))
def visitRoundFunction(self, ctx):
expression = ExpressionVisitor().visit(ctx.expression())
if ctx.digits:
digits = ExpressionVisitor().visit(ctx.digits)
else:
digits = None
return RoundFunction(expression, digits)
def visitTruncFunction(self, ctx):
expression = ExpressionVisitor().visit(ctx.expression())
if ctx.digits:
digits = ExpressionVisitor().visit(ctx.digits)
else:
digits = None
return TruncFunction(expression, digits)
def visitMathFunc(self, ctx):
fname = FuncName(ctx.function.getText().upper())
return MathFunction(fname, self.visit(ctx.expression()))
def visitPowerFunction(self, ctx):
return PowerFunction(
ExpressionVisitor().visit(ctx.expExpr), ExpressionVisitor().visit(ctx.powerExpr)
)
def visitBareFunction(self, ctx):
return BareFunction(FuncName(ctx.function.getText().upper()))
"""
String functions
"""
def visitStringLower(self, ctx: SqlSmallParser.StringLowerContext):
return LowerFunction(
ExpressionVisitor().visit(ctx.expression())
)
def visitStringUpper(self, ctx: SqlSmallParser.StringUpperContext):
return UpperFunction(
ExpressionVisitor().visit(ctx.expression())
)
def visitTrimFunction(self, ctx: SqlSmallParser.TrimFunctionContext):
return TrimFunction(
ExpressionVisitor().visit(ctx.expression())
)
def visitCharLengthFunction(self, ctx: SqlSmallParser.CharLengthFunctionContext):
if ctx.LENGTH():
token = Token("LENGTH")
return CharLengthFunction(
ExpressionVisitor().visit(ctx.expression()), token
)
return CharLengthFunction(
ExpressionVisitor().visit(ctx.expression())
)
def visitPositionFunction(self, ctx: SqlSmallParser.PositionFunctionContext):
return PositionFunction(
ExpressionVisitor().visit(ctx.searchString),
ExpressionVisitor().visit(ctx.sourceString)
)
def visitStringConcat(self, ctx: SqlSmallParser.StringConcatContext):
ev = ExpressionVisitor()
expressions = [e for e in [ev.visit(c) for c in ctx.children] if e is not None]
return ConcatFunction(expressions)
def visitCoalesceFunction(self, ctx: SqlSmallParser.CoalesceFunctionContext):
ev = ExpressionVisitor()
expressions = [e for e in [ev.visit(c) for c in ctx.children] if e is not None]
return CoalesceFunction(expressions)
def visitSubstringFunction(self, ctx: SqlSmallParser.SubstringFunctionContext):
tokens = {
"FROM": Token(str(ctx.FROM())) if ctx.FROM() else Token(','),
"FOR": Token(str(ctx.FOR())) if ctx.FOR() else Token(',')
}
return SubstringFunction(
ExpressionVisitor().visit(ctx.sourceExpr),
ExpressionVisitor().visit(ctx.startIdx),
ExpressionVisitor().visit(ctx.length) if ctx.length else None,
tokens
)
def visitSubstringBigqueryFunction(self, ctx: SqlSmallParser.SubstringBigqueryFunctionContext):
return SubstrBigQueryFunction(
ExpressionVisitor().visit(ctx.sourceExpr),
ExpressionVisitor().visit(ctx.position),
ExpressionVisitor().visit(ctx.length) if ctx.length else None
)
def visitRegexpContainsFunction(self, ctx:SqlSmallParser.RegexpContainsFunctionContext):
return RegexpContainsFunction(
ExpressionVisitor().visit(ctx.sourceExpr),
ExpressionVisitor().visit(ctx.regExpr),
)
def visitRegexpExctractFunction(self, ctx: SqlSmallParser.RegexpExctractFunctionContext):
return RegexpExtractFunction(
ExpressionVisitor().visit(ctx.sourceExpr),
ExpressionVisitor().visit(ctx.regExpr),
ExpressionVisitor().visit(ctx.position) if ctx.position else None,
ExpressionVisitor().visit(ctx.occurrence) if ctx.occurrence else None,
)
def visitRegexpReplaceFunction(self, ctx: SqlSmallParser.RegexpReplaceFunctionContext):
return RegexpReplaceFunction(
ExpressionVisitor().visit(ctx.sourceExpr),
ExpressionVisitor().visit(ctx.regExpr),
ExpressionVisitor().visit(ctx.replacementExpr)
)
def visitEncodeFunction(self, ctx: SqlSmallParser.EncodeFunctionContext):
return EncodeFunction(
ExpressionVisitor().visit(ctx.sourceString),
ExpressionVisitor().visit(ctx.formatString),
)
def visitDecodeFunction(self, ctx: SqlSmallParser.DecodeFunctionContext):
return DecodeFunction(
ExpressionVisitor().visit(ctx.sourceString),
ExpressionVisitor().visit(ctx.formatString),
)
def visitUnhexFunction(self, ctx: SqlSmallParser.UnhexFunctionContext):
if ctx.FROM_HEX():
token = Token("FROM_HEX")
return UnhexFunction(
ExpressionVisitor().visit(ctx.sourceString),
token = token,
)
return UnhexFunction(
ExpressionVisitor().visit(ctx.sourceString),
)
def visitHexFunction(self, ctx: SqlSmallParser.HexFunctionContext):
if ctx.TO_HEX():
token = Token("TO_HEX")
return HexFunction(
ExpressionVisitor().visit(ctx.sourceString),
token = token,
)
return HexFunction(
ExpressionVisitor().visit(ctx.sourceString),
)
def visitRegexpSplitToTableFunction(self, ctx: SqlSmallParser.RegexpSplitToTableFunctionContext):
return RegexpSplitToTableFunction(
ExpressionVisitor().visit(ctx.sourceString),
ExpressionVisitor().visit(ctx.pattern),
)
"""
Date/time functions
"""
def visitCurrentTimeFunc(self, ctx: SqlSmallParser.CurrentTimeFuncContext):
func = ctx.getText()
if func == 'CURRENT_TIMESTAMP':
return CurrentTimestampFunction()
elif func == 'CURRENT_TIME':
return CurrentTimeFunction()
elif func == 'CURRENT_DATE':
return CurrentDateFunction()
else:
raise ValueError(f"Unknown function to get current time: {func}")
def visitDayNameFunction(self, ctx: SqlSmallParser.DayNameFuncContext):
expr = ExpressionVisitor().visit(ctx.expr)
return DayNameFunction(expr)
def visitExtractFunction(self, ctx: SqlSmallParser.ExtractFunctionContext):
expr = ExpressionVisitor().visit(ctx.expression())
date_part = ctx.datePart().getText()
return ExtractFunction(date_part.lower(), expr)
def visitFromUnixTimeFunction(self, ctx: SqlSmallParser.FromUnixTimeFunctionContext):
unix_timestamp = ExpressionVisitor().visit(ctx.unix_timestamp)
if ctx.formatDate:
formatDate = ExpressionVisitor().visit(ctx.formatDate)
else:
formatDate = None
return FromUnixTimeFunction(unix_timestamp, formatDate)
def visitUnixTimeStampFunction(self, ctx: SqlSmallParser.UnixTimeStampFunctionContext):
if ctx.expr:
expr = ExpressionVisitor().visit(ctx.expr)
else:
expr = None
return UnixTimeStampFunction(expr)
def visitDateFormatFunction(self, ctx: SqlSmallParser.DateFormatFunctionContext):
expr = ExpressionVisitor().visit(ctx.expr)
format = ExpressionVisitor().visit(ctx.formatDate)
return DateFormatFunction(expr, format)
def visitFormatDateFunction(self, ctx: SqlSmallParser.FormatDateFunctionContext):
format = ExpressionVisitor().visit(ctx.formatDate)
expr = ExpressionVisitor().visit(ctx.expr)
return FormatDateFunction(format, expr)
def visitQuarterFunction(self, ctx: SqlSmallParser.QuarterFunctionContext):
expr = ExpressionVisitor().visit(ctx.expr)
return QuarterFunction(expr)
def visitDowFunction(self, ctx: SqlSmallParser.DowFunctionContext):
if ctx.DOW():
token = Token("DOW")
else:
token = Token("DAYOFWEEK")
expr = ExpressionVisitor().visit(ctx.expr)
return DowFunction(token, expr)
def visitDatetimeDiffFunction(self, ctx: SqlSmallParser.DatetimeDiffFunctionContext):
order_date = ExpressionVisitor().visit(ctx.order_date)
ship_date = ExpressionVisitor().visit(ctx.ship_date)
date_part = ctx.datePart().getText()
return DatetimeDiffFunction(order_date, ship_date, date_part)
def visitDatePartFunction(self, ctx: SqlSmallParser.DatePartFunctionContext):
date_part = ctx.datePart().getText()
interval = ExpressionVisitor().visit(ctx.interval)
return DatePartFunction(date_part, interval)
def visitDateFunction(self, ctx: SqlSmallParser.DateFunctionContext):
expr = ExpressionVisitor().visit(ctx.expr)
return DateFunction(expr)
class CaseExpressionVisitor(SqlSmallVisitor):
def visitCaseBaseExpr(self, ctx):
wxp = ctx.whenBaseExpression()
whenExpressions = [self.visit(we) for we in wxp]
expression = ExpressionVisitor().visit(ctx.baseCaseExpr)
else_expr = ExpressionVisitor().visit(ctx.elseExpr) if ctx.elseExpr is not None else None
return CaseExpression(expression, whenExpressions, else_expr)
def visitCaseWhenExpr(self, ctx):
wxp = ctx.whenExpression()
whenExpressions = [self.visit(we) for we in wxp]
expression = None
else_expr = ExpressionVisitor().visit(ctx.elseExpr) if ctx.elseExpr is not None else None
return CaseExpression(expression, whenExpressions, else_expr)
def visitWhenExpression(self, ctx):
expression = BooleanExpressionVisitor().visit(ctx.baseBoolExpr)
thenExpression = ExpressionVisitor().visit(ctx.thenExpr)
return WhenExpression(expression, thenExpression)
def visitWhenBaseExpression(self, ctx):
expression = ExpressionVisitor().visit(ctx.baseWhenExpr)
thenExpression = ExpressionVisitor().visit(ctx.thenExpr)
return WhenExpression(expression, thenExpression)
class BooleanExpressionVisitor(SqlSmallVisitor):
def visitLogicalNot(self, ctx):
return LogicalNot(self.visit(ctx.booleanExpression()))
def visitComparison(self, ctx):
ev = ExpressionVisitor()
return BooleanCompare(ev.visit(ctx.left), Op(ctx.op.getText()), ev.visit(ctx.right))
def visitConjunction(self, ctx):
return BooleanCompare(self.visit(ctx.left), Op("AND"), self.visit(ctx.right))
def visitDisjunction(self, ctx):
return BooleanCompare(self.visit(ctx.left), Op("OR"), self.visit(ctx.right))
def visitNestedBoolean(self, ctx):
return NestedBoolean(self.visit(ctx.booleanExpression()))
def visitPredicated(self, ctx):
expression = ExpressionVisitor().visit(ctx.expression())
predicate = self.visit(ctx.predicate())
return PredicatedExpression(expression, predicate)
def visitInCondition(self, ctx):
is_not = ctx.NOT() is not None
expressions = Seq([ExpressionVisitor().visit(e) for e in ctx.expression()])
return InCondition(expressions, is_not)
def visitIsCondition(self, ctx):
is_not = ctx.NOT() is not None
if ctx.TRUE() is not None:
value = Literal(True)
elif ctx.FALSE() is not None:
value = Literal(False)
elif ctx.NULL() is not None:
value = Literal(None)
else:
raise ValueError("Unknown condition in IS clause: " + allText(ctx))
return IsCondition(value, is_not)
def visitBetweenCondition(self, ctx):
is_not = ctx.NOT() is not None
lower = ExpressionVisitor().visit(ctx.lower)
upper = ExpressionVisitor().visit(ctx.upper)
return BetweenCondition(lower, upper, is_not)
def visitQualifiedColumnName(self, ctx):
return ColumnBoolean(Column(ctx.getText()))
def visitLikeCondition(self, ctx: SqlSmallParser.LikeConditionContext):
value = ExpressionVisitor().visit(ctx.value)
pattern = ExpressionVisitor().visit(ctx.pattern)
return LikeExpression(value, pattern)
def visitILikeCondition(self, ctx: SqlSmallParser.ILikeConditionContext):
value = ExpressionVisitor().visit(ctx.value)
pattern = ExpressionVisitor().visit(ctx.pattern)
return ILikeExpression(value, pattern)
def visitBigqueryRegexpContainsCondition(self, ctx: SqlSmallParser.BigqueryRegexpContainsConditionContext):
return RegexpContainsFunction(
ExpressionVisitor().visit(ctx.regexpContainsFunction().sourceExpr),
ExpressionVisitor().visit(ctx.regexpContainsFunction().regExpr),
)
class UserDefinedArgumentVisitor(SqlSmallVisitor):
def visitExprArg(self, ctx: SqlSmallParser.ExprArgContext):
return ExpressionVisitor().visit(ctx.expression())
def visitDatePartArg(self, ctx: SqlSmallParser.DatePartArgContext):
return Token(ctx.datePart().getText())
def visitTypeArg(self, ctx: SqlSmallParser.TypeArgContext):
return ctx.dbType().getText()
def allText(ctx):
"""
This method is used to grab text with whitespace
for a terminal node of the AST that hasn't been
strongly-typed yet. Should not be used for lexer
tokens, and should be replaced over time.
"""
a = ctx.start.start
b = ctx.stop.stop
inp = ctx.start.getInputStream()
return inp.getText(a, b)