Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
smartnoise-sql / snsql / sql / parse.py
Size: Mime:
from lib2to3.pgen2 import token
from snsql._ast.expressions.date import *
from snsql._ast.expressions.string import *
from snsql._ast.expressions.numeric import *
from snsql._ast.expressions.types import CastFunction
from snsql.metadata import Metadata
from snsql.sql.parser.SqlSmallParser import SqlSmallParser

from .parser.SqlSmallLexer import SqlSmallLexer  # type: ignore
from .parser.SqlSmallParser import SqlSmallParser  # type: ignore
from .parser.SqlSmallVisitor import SqlSmallVisitor  # type: ignore
from .parser.SqlSmallErrorListener import SyntaxErrorListener  # type: ignore

from antlr4 import *  # type: ignore
from snsql._ast.tokens import *
from snsql._ast.ast import *


class QueryParser:
    def __init__(self, metadata=None):
        if metadata:
            self.metadata = Metadata.from_(metadata)
        else:
            self.metadata = None

    def start_parser(self, stream):
        lexer = SqlSmallLexer(stream)
        stream = CommonTokenStream(lexer)
        parser = SqlSmallParser(stream)
        parser._interp.predictionMode = PredictionMode.LL_EXACT_AMBIG_DETECTION
        lexer._listeners = [SyntaxErrorListener(), DiagnosticErrorListener()]
        parser._listeners = [SyntaxErrorListener(), DiagnosticErrorListener()]
        return parser

    def queries(self, query_string, metadata=None):
        if metadata is None and self.metadata is not None:
            metadata = self.metadata
        elif metadata:
            metadata = Metadata.from_(metadata)

        istream = InputStream(query_string)
        parser = self.start_parser(istream)
        bv = BatchVisitor()
        queries = [q for q in bv.visit(parser.batch()).queries]
        if metadata is not None:
            for q in queries:
                q.load_symbols(metadata)
        return queries

    def query(self, query_string, metadata=None):
        queries = self.queries(query_string, metadata)
        if len(queries) > 1:
            raise ValueError("Attempt to parse query resulted in a batch with more than one")
        q = queries[0]
        if metadata is not None:
            q.load_symbols(metadata)
        return q

    def parse_only(self, query_string):
        if (
            query_string.strip().upper().startswith("SELECT") or
            query_string.strip().upper().startswith("WITH") or
            query_string.strip().startswith("--")):
            istream = InputStream(query_string)
        else:
            istream = FileStream(query_string)
        parser = self.start_parser(istream)
        SqlSmallVisitor().visit(parser.batch())
        return None

    def parse_named_expressions(self, expression_string):
        istream = InputStream(expression_string)
        parser = self.start_parser(istream)
        nev = NamedExpressionVisitor()
        return nev.visitNamedExpressionSeq(parser.namedExpressionSeq())

    def parse_expression(self, expression_string):
        istream = InputStream(expression_string)
        parser = self.start_parser(istream)
        ev = ExpressionVisitor()
        return ev.visit(parser.expression())

    def parse_table_name(self, expression_string):
        istream = InputStream(expression_string)
        parser = self.start_parser(istream)
        return Identifier(parser.qualifiedTableName().getText())


class BatchVisitor(SqlSmallVisitor):
    def visitBatch(self, ctx):
        qv = QueryVisitor()
        queries = [q for q in [qv.visit(c) for c in ctx.children] if q is not None]
        return Batch(queries)


class QueryVisitor(SqlSmallVisitor):
    def visitBasicQuery(self, ctx: SqlSmallParser.BasicQueryContext):

        # SELECT is required
        select = SelectVisitor().visit(ctx.selectClause())
        fm = ctx.fromClause()
        source = FromVisitor().visit(fm) if fm is not None else None

        wc = ctx.whereClause()
        where = WhereVisitor().visit(wc) if wc is not None else None

        hc = ctx.havingClause()
        having = HavingVisitor().visit(hc) if hc is not None else None

        ac = ctx.aggregationClause()
        agg = AggregateVisitor().visit(ac) if ac is not None else None

        oc = ctx.orderClause()
        order = OrderVisitor().visit(oc) if oc is not None else None

        limit = None
        if hasattr(ctx, "limitClause"):
            lc = ctx.limitClause()
            limit = LimitVisitor().visit(lc) if lc is not None else None
        return Query(select, source, where, agg, having, order, limit)

    def visitUnifiedQuery(self, ctx: SqlSmallParser.UnifiedQueryContext):
        query = self.visit(ctx.basicQuery())
        if not ctx.queryOperations():
            return query
        else:
            query_ops = []
            for ops_ctx in ctx.queryOperations():
                if ops_ctx.UNION():
                    token = Token("UNION")
                elif ops_ctx.INTERSECT():
                    token = Token("INTERSECT")
                else:
                    token = Token("EXCEPT")
                query_op = QueryOperation(
                    token,
                    ops_ctx.ALL() is not None,
                    self.visit(ops_ctx.basicQuery())
                )
                query_ops.append(query_op)
            return UnifiedQuery(query, query_ops)

class SelectVisitor(SqlSmallVisitor):
    def visitSelectClause(self, ctx: SqlSmallParser.SelectClauseContext):
        if ctx.WITH():
            namedsubqueries = NamedSubqueryVisitor().visit(ctx.namedSubquerySeq())
            ctes = [nsq for nsq in namedsubqueries]
        else:
            ctes = None
        namedExpressions = NamedExpressionVisitor().visit(ctx.namedExpressionSeq())
        sq = ctx.setQuantifier()
        tc = sq.topClause() if sq is not None else None
        if tc is None:
            quantifier = None if sq is None else Token(sq.getText())
        else:
            quantifier = LimitVisitor().visit(tc)
        return Select(quantifier, [ne for ne in namedExpressions if ne is not None], ctes)


class NamedSubqueryVisitor(SqlSmallVisitor):
    def visitNamedSubquerySeq(self, ctx):
        return [self.visit(ne) for ne in ctx.namedSubquery()]

    def visitNamedSubquery(self, ctx: SqlSmallParser.NamedSubqueryContext):
        subquery = QueryVisitor().visit(ctx.unifiedQuery())
        name = Identifier(ctx.name.getText())
        cols = [Identifier(i.getText()) for i in ctx.identifier()[1:]]
        return NamedSubquery(name, cols, subquery)


class FromVisitor(SqlSmallVisitor):
    def visitFromClause(self, ctx):
        rv = RelationVisitor()
        relations = [rv.visit(rel) for rel in ctx.relation()]
        return From(relations)


class AggregateVisitor(SqlSmallVisitor):
    def visitAggregationClause(self, ctx):
        groups = ctx.groupingExpressions
        ev = ExpressionVisitor()
        cols = [GroupingExpression(ev.visit(g)) for g in groups]
        return Aggregate(cols)


class WhereVisitor(SqlSmallVisitor):
    def visitWhereClause(self, ctx):
        bev = BooleanExpressionVisitor()
        return Where(bev.visit(ctx.booleanExpression()))


class HavingVisitor(SqlSmallVisitor):
    def visitHavingClause(self, ctx):
        bev = BooleanExpressionVisitor()
        return Having(bev.visit(ctx.booleanExpression()))


class NamedExpressionVisitor(SqlSmallVisitor):
    def visitNamedExpressionSeq(self, ctx):
        return [self.visit(ne) for ne in ctx.namedExpression()]

    def visitNamedExpression(self, ctx):
        expression = ExpressionVisitor().visit(ctx.expression())
        name = Identifier(ctx.name.getText()) if ctx.name is not None else None
        return NamedExpression(name, expression)


class OrderVisitor(SqlSmallVisitor):
    def visitOrderClause(self, ctx):
        sortItems = [self.visit(si) for si in ctx.sortItem()]
        return Order(sortItems)

    def visitSortItem(self, ctx):
        ev = ExpressionVisitor()
        expr = ev.visit(ctx.expression())
        if ctx.DESC() is not None:
            o = Token("DESC")
        elif ctx.ASC() is not None:
            o = Token("ASC")
        else:
            o = None
        return SortItem(expr, o)


class PartitionByVisitor(SqlSmallVisitor):
    def visitPartitionByClause(self, ctx):
        ev = ExpressionVisitor()
        return PartitionByClause([ev.visit(c) for c in ctx.children])

class LimitVisitor(SqlSmallVisitor):
    def visitLimitClause(self, ctx):
        offset = (
            int(ctx.offset.getText()) if ctx.offset
            else None
        )
        return Limit(int(ctx.n.getText()), offset)

    def visitTopClause(self, ctx):
        return Top(int(ctx.n.getText()))


class RelationVisitor(SqlSmallVisitor):
    def visitRelation(self, ctx: SqlSmallParser.RelationContext):
        primary = self.visit(ctx.relationPrimary())
        jr = ctx.joinRelation()
        joins = [self.visit(j) for j in jr] if jr is not None else None
        return Relation(primary, joins)

    def visitUnnest(self, ctx: SqlSmallParser.UnnestContext):
        context = ctx.unnestOperator()
        alias = Identifier(context.alias.getText()) if context.alias is not None else None
        name = Identifier(context.qualifiedTableName().getText())
        return Unnest(name, alias)


    def visitTableSample(self, ctx):
        sampling_method = (
                'SYSTEM' if ctx.SYSTEM()
                else 'BERNOULLI'
            )
        return Tablesample(
            sampling_method,
            float(ctx.p.getText()) if ctx.p is not None else None,
            int(ctx.n.getText()) if ctx.n is not None else None,
            int(ctx.seed.getText()) if ctx.seed is not None else None,
        )

    def visitTable(self, ctx):
        alias = Identifier(ctx.alias.getText()) if ctx.alias is not None else None
        tablesample = (
            self.visitTableSample(ctx.tablesample()) if ctx.tablesample()
            else None
        )
        return Table(
            Identifier(ctx.qualifiedTableName().getText()),
            alias,
            tablesample,
        )

    def visitAliasedQuery(self, ctx):
        alias = Identifier(ctx.alias.getText()) if ctx.alias is not None else None
        qv = QueryVisitor()
        return AliasedSubquery(qv.visit(ctx.unifiedQuery()), alias)

    def visitAliasedRelation(self, ctx):
        alias = Identifier(ctx.alias.getText()) if ctx.alias is not None else None
        relation = self.visit(ctx.relation())
        return AliasedRelation(relation, alias)

    def visitJoinRelation(self, ctx):
        joinType = Token(allText(ctx.joinType()))
        right = RelationVisitor().visit(ctx.right)
        crit = ctx.joinCriteria()
        if type(crit) is SqlSmallParser.BooleanJoinContext:
            bev = BooleanExpressionVisitor()
            criteria = BooleanJoinCriteria(bev.visit(crit.booleanExpression()))
        elif type(crit) is SqlSmallParser.UsingJoinContext:
            ids = crit.identifier()
            criteria = UsingJoinCriteria([Column(i.getText()) for i in ids])
        else:
            criteria = None
        return Join(joinType, right, criteria)


class ExpressionVisitor(SqlSmallVisitor):
    """
    SQL Grammar
    """
    def visitColumnName(self, ctx):
        return Column(ctx.name.getText())

    def visitAllExpr(self, ctx):
        ident = ctx.allExpression().identifier().getText() if ctx.allExpression().identifier() else None
        ident_ident = ctx.allExpression().QN2().getText() if ctx.allExpression().QN2() else None
        return AllColumns(ident if ident else ident_ident)

    def visitDecimalLiteral(self, ctx):
        return Literal(float(allText(ctx)))

    def visitIntegerLiteral(self, ctx):
        return Literal(int(allText(ctx)))

    def visitStringLiteral(self, ctx):
        text = str(allText(ctx))
        t_len = len(text)
        value = text
        if t_len > 1:
            l_delim = text[0]
            r_delim = text[t_len - 1]
            if r_delim == "'" and l_delim == "'":
                # this is the expected case for all stringLiteral
                value = text[1:t_len - 1]
        return Literal(value, text)

    def visitRawStringLiteral(self, ctx: SqlSmallParser.RawStringLiteralContext):
        text = str(allText(ctx))
        if text.startswith("r'") or text.startswith("R'"):
            value = r"{}".format(str(allText(ctx)[2:-1]))
            return Literal(value=value, text=text)
        return Literal(text)

    def visitTrueLiteral(self, ctx):
        return Literal(True)

    def visitFalseLiteral(self, ctx):
        return Literal(False)

    def visitNullLiteral(self, ctx):
        return Literal(None)

    def visitNestedExpr(self, ctx):
        return NestedExpression(self.visit(ctx.expression()))

    def visitAggFunc(self, ctx):
        fname = FuncName(ctx.function.getText().upper())
        qt = ctx.setQuantifier()
        quantifier = Token(qt.getText().upper()) if qt is not None else None
        return AggFunction(fname, quantifier, self.visit(ctx.expression()))

    def visitSubqueryExpr(self, ctx):
        qv = QueryVisitor()
        q = qv.visit(ctx.unifiedQuery())
        return AliasedSubquery(q, None)

    def visitRankingFunction(self, ctx):
        fname = FuncName(ctx.function.getText().upper())
        over = self.visit(ctx.overClause())
        return RankingFunction(fname, over)

    def visitOverClause(self, ctx):
        pbc = ctx.partitionByClause()
        partition = PartitionByVisitor().visit(pbc) if pbc is not None else None
        oc = ctx.orderClause()
        order = OrderVisitor().visit(oc) if oc is not None else None
        return OverClause(partition, order)

    def visitLeadFunc(self, ctx: SqlSmallParser.LeadFuncContext):
        lfctx = ctx.leadFunction()
        source = ExpressionVisitor().visit(lfctx.sourceExpr)
        offset = ExpressionVisitor().visit(lfctx.offset) if lfctx.offset is not None else None
        default = ExpressionVisitor().visit(lfctx.default) if lfctx.default is not None else None
        over = self.visit(lfctx.overClause())
        return LeadFunction(source, offset, default, over)

    def visitGreatestFunc(self, ctx: SqlSmallParser.GreatestFuncContext):
        gstctx = ctx.greatestFunction()
        ev = ExpressionVisitor()
        expressions = [e for e in [ev.visit(c) for c in gstctx.children] if e is not None]
        return GreatestFunction(expressions)

    def visitLeastFunc(self, ctx: SqlSmallParser.LeastFuncContext):
        lstctx = ctx.leastFunction()
        ev = ExpressionVisitor()
        expressions = [e for e in [ev.visit(c) for c in lstctx.children] if e is not None]
        return LeastFunction(expressions)

    def visitNotFunc(self, ctx: SqlSmallParser.NotFuncContext):
        boolexpr = BooleanExpressionVisitor().visit(ctx.booleanExpression())
        return NotFuction(boolexpr)

    def visitUserDefFunc(self, ctx: SqlSmallParser.UserDefFuncContext):
        ev = UserDefinedArgumentVisitor()
        args = [ev.visit(arg) for arg in ctx.userDefinedArgument()] 
        return UserDefinedFunction(ctx.function.getText(), args)
    """
    Typing functions
    """

    def visitCastExpression(self, ctx: SqlSmallParser.CastExpressionContext):
        exp = ExpressionVisitor().visit(ctx.fromExpr)
        dbtype = ctx.dbType().getText().lower()
        return CastFunction(exp, dbtype)

    """
    Logical functions
    """
    def visitCaseExpr(self, ctx):
        return CaseExpressionVisitor().visit(ctx)

    def visitIifFunc(self, ctx):
        test = BooleanExpressionVisitor().visit(ctx.test)
        yes = ExpressionVisitor().visit(ctx.yes)
        no = ExpressionVisitor().visit(ctx.no)
        return IIFFunction(test, yes, no)

    def visitChooseFunc(self, ctx):
        expression = ExpressionVisitor().visit(ctx.index)
        choices = Seq([ExpressionVisitor().visit(e) for e in ctx.literal()])
        return ChooseFunction(expression, choices)


    """
    Numeric functions
    """
    def visitMultiply(self, ctx):
        return ArithmeticExpression(self.visit(ctx.left), Op("*"), self.visit(ctx.right))

    def visitDivide(self, ctx):
        return ArithmeticExpression(self.visit(ctx.left), Op("/"), self.visit(ctx.right))

    def visitModulo(self, ctx):
        return ArithmeticExpression(self.visit(ctx.left), Op("%"), self.visit(ctx.right))

    def visitAdd(self, ctx):
        return ArithmeticExpression(self.visit(ctx.left), Op("+"), self.visit(ctx.right))

    def visitSubtract(self, ctx):
        return ArithmeticExpression(self.visit(ctx.left), Op("-"), self.visit(ctx.right))

    def visitRoundFunction(self, ctx):
        expression = ExpressionVisitor().visit(ctx.expression())
        if ctx.digits:
            digits = ExpressionVisitor().visit(ctx.digits)
        else:
            digits = None
        return RoundFunction(expression, digits)

    def visitTruncFunction(self, ctx):
        expression = ExpressionVisitor().visit(ctx.expression())
        if ctx.digits:
            digits = ExpressionVisitor().visit(ctx.digits)
        else:
            digits = None
        return TruncFunction(expression, digits)

    def visitMathFunc(self, ctx):
        fname = FuncName(ctx.function.getText().upper())
        return MathFunction(fname, self.visit(ctx.expression()))

    def visitPowerFunction(self, ctx):
        return PowerFunction(
            ExpressionVisitor().visit(ctx.expExpr), ExpressionVisitor().visit(ctx.powerExpr)
        )

    def visitBareFunction(self, ctx):
        return BareFunction(FuncName(ctx.function.getText().upper()))

    """
    String functions
    """

    def visitStringLower(self, ctx: SqlSmallParser.StringLowerContext):
        return LowerFunction(
            ExpressionVisitor().visit(ctx.expression())
        )

    def visitStringUpper(self, ctx: SqlSmallParser.StringUpperContext):
        return UpperFunction(
            ExpressionVisitor().visit(ctx.expression())
        )

    def visitTrimFunction(self, ctx: SqlSmallParser.TrimFunctionContext):
        return TrimFunction(
            ExpressionVisitor().visit(ctx.expression())
        )

    def visitCharLengthFunction(self, ctx: SqlSmallParser.CharLengthFunctionContext):
        if ctx.LENGTH():
            token = Token("LENGTH")
            return CharLengthFunction(
                ExpressionVisitor().visit(ctx.expression()), token
            )
        return CharLengthFunction(
            ExpressionVisitor().visit(ctx.expression())
        )

    def visitPositionFunction(self, ctx: SqlSmallParser.PositionFunctionContext):
        return PositionFunction(
            ExpressionVisitor().visit(ctx.searchString),
            ExpressionVisitor().visit(ctx.sourceString)
        )

    def visitStringConcat(self, ctx: SqlSmallParser.StringConcatContext):
        ev = ExpressionVisitor()
        expressions = [e for e in [ev.visit(c) for c in ctx.children] if e is not None]
        return ConcatFunction(expressions)

    def visitCoalesceFunction(self, ctx: SqlSmallParser.CoalesceFunctionContext):
        ev = ExpressionVisitor()
        expressions = [e for e in [ev.visit(c) for c in ctx.children] if e is not None]
        return CoalesceFunction(expressions)

    def visitSubstringFunction(self, ctx: SqlSmallParser.SubstringFunctionContext):
        tokens = {
            "FROM": Token(str(ctx.FROM())) if ctx.FROM() else Token(','),
            "FOR": Token(str(ctx.FOR())) if ctx.FOR() else Token(',')
        }
        return SubstringFunction(
            ExpressionVisitor().visit(ctx.sourceExpr),
            ExpressionVisitor().visit(ctx.startIdx),
            ExpressionVisitor().visit(ctx.length) if ctx.length else None,
            tokens
        )

    def visitSubstringBigqueryFunction(self, ctx: SqlSmallParser.SubstringBigqueryFunctionContext):
        return SubstrBigQueryFunction(
            ExpressionVisitor().visit(ctx.sourceExpr),
            ExpressionVisitor().visit(ctx.position),
            ExpressionVisitor().visit(ctx.length) if ctx.length else None
        )

    def visitRegexpContainsFunction(self, ctx:SqlSmallParser.RegexpContainsFunctionContext):
        return RegexpContainsFunction(
            ExpressionVisitor().visit(ctx.sourceExpr),
            ExpressionVisitor().visit(ctx.regExpr),
        )

    def visitRegexpExctractFunction(self, ctx: SqlSmallParser.RegexpExctractFunctionContext):
        return RegexpExtractFunction(
            ExpressionVisitor().visit(ctx.sourceExpr),
            ExpressionVisitor().visit(ctx.regExpr),
            ExpressionVisitor().visit(ctx.position) if ctx.position else None,
            ExpressionVisitor().visit(ctx.occurrence) if ctx.occurrence else None,
        )

    def visitRegexpReplaceFunction(self, ctx: SqlSmallParser.RegexpReplaceFunctionContext):
        return RegexpReplaceFunction(
            ExpressionVisitor().visit(ctx.sourceExpr),
            ExpressionVisitor().visit(ctx.regExpr),
            ExpressionVisitor().visit(ctx.replacementExpr)
        )

    def visitEncodeFunction(self, ctx: SqlSmallParser.EncodeFunctionContext):
        return EncodeFunction(
            ExpressionVisitor().visit(ctx.sourceString),
            ExpressionVisitor().visit(ctx.formatString),
        )

    def visitDecodeFunction(self, ctx: SqlSmallParser.DecodeFunctionContext):
        return DecodeFunction(
            ExpressionVisitor().visit(ctx.sourceString),
            ExpressionVisitor().visit(ctx.formatString),
        )

    def visitUnhexFunction(self, ctx: SqlSmallParser.UnhexFunctionContext):
        if ctx.FROM_HEX():
            token = Token("FROM_HEX")
            return UnhexFunction(
                ExpressionVisitor().visit(ctx.sourceString),
                token = token,
            )
        return UnhexFunction(
                ExpressionVisitor().visit(ctx.sourceString),
            )

    def visitHexFunction(self, ctx: SqlSmallParser.HexFunctionContext):
        if ctx.TO_HEX():
            token = Token("TO_HEX")
            return HexFunction(
                ExpressionVisitor().visit(ctx.sourceString),
                token = token,
            )
        return HexFunction(
                ExpressionVisitor().visit(ctx.sourceString),
            )

    def visitRegexpSplitToTableFunction(self, ctx: SqlSmallParser.RegexpSplitToTableFunctionContext):
        return RegexpSplitToTableFunction(
            ExpressionVisitor().visit(ctx.sourceString),
            ExpressionVisitor().visit(ctx.pattern),
        )

    """
    Date/time functions
    """

    def visitCurrentTimeFunc(self, ctx: SqlSmallParser.CurrentTimeFuncContext):
        func = ctx.getText()
        if func == 'CURRENT_TIMESTAMP':
            return CurrentTimestampFunction()
        elif func == 'CURRENT_TIME':
            return CurrentTimeFunction()
        elif func == 'CURRENT_DATE':
            return CurrentDateFunction()
        else:
            raise ValueError(f"Unknown function to get current time: {func}")

    def visitDayNameFunction(self, ctx: SqlSmallParser.DayNameFuncContext):
        expr = ExpressionVisitor().visit(ctx.expr)
        return DayNameFunction(expr)

    def visitExtractFunction(self, ctx: SqlSmallParser.ExtractFunctionContext):
        expr = ExpressionVisitor().visit(ctx.expression())
        date_part = ctx.datePart().getText()
        return ExtractFunction(date_part.lower(), expr)

    def visitFromUnixTimeFunction(self, ctx: SqlSmallParser.FromUnixTimeFunctionContext):
        unix_timestamp = ExpressionVisitor().visit(ctx.unix_timestamp)
        if ctx.formatDate:
            formatDate = ExpressionVisitor().visit(ctx.formatDate)
        else:
            formatDate = None
        return FromUnixTimeFunction(unix_timestamp, formatDate)

    def visitUnixTimeStampFunction(self, ctx: SqlSmallParser.UnixTimeStampFunctionContext):
        if ctx.expr:
            expr = ExpressionVisitor().visit(ctx.expr)
        else:
            expr = None
        return UnixTimeStampFunction(expr)

    def visitDateFormatFunction(self, ctx: SqlSmallParser.DateFormatFunctionContext):
        expr = ExpressionVisitor().visit(ctx.expr)
        format = ExpressionVisitor().visit(ctx.formatDate)
        return DateFormatFunction(expr, format)

    def visitFormatDateFunction(self, ctx: SqlSmallParser.FormatDateFunctionContext):
        format = ExpressionVisitor().visit(ctx.formatDate)
        expr = ExpressionVisitor().visit(ctx.expr)
        return FormatDateFunction(format, expr)

    def visitQuarterFunction(self, ctx: SqlSmallParser.QuarterFunctionContext):
        expr = ExpressionVisitor().visit(ctx.expr)
        return QuarterFunction(expr)

    def visitDowFunction(self, ctx: SqlSmallParser.DowFunctionContext):
        if ctx.DOW():
            token = Token("DOW")
        else:
            token = Token("DAYOFWEEK")
        expr = ExpressionVisitor().visit(ctx.expr)
        return DowFunction(token, expr)

    def visitDatetimeDiffFunction(self, ctx: SqlSmallParser.DatetimeDiffFunctionContext):
        order_date = ExpressionVisitor().visit(ctx.order_date)
        ship_date = ExpressionVisitor().visit(ctx.ship_date)
        date_part = ctx.datePart().getText()
        return DatetimeDiffFunction(order_date, ship_date, date_part)

    def visitDatePartFunction(self, ctx: SqlSmallParser.DatePartFunctionContext):
        date_part = ctx.datePart().getText()
        interval = ExpressionVisitor().visit(ctx.interval)
        return DatePartFunction(date_part, interval)

    def visitDateFunction(self, ctx: SqlSmallParser.DateFunctionContext):
        expr = ExpressionVisitor().visit(ctx.expr)
        return DateFunction(expr)

class CaseExpressionVisitor(SqlSmallVisitor):
    def visitCaseBaseExpr(self, ctx):
        wxp = ctx.whenBaseExpression()
        whenExpressions = [self.visit(we) for we in wxp]
        expression = ExpressionVisitor().visit(ctx.baseCaseExpr)
        else_expr = ExpressionVisitor().visit(ctx.elseExpr) if ctx.elseExpr is not None else None
        return CaseExpression(expression, whenExpressions, else_expr)

    def visitCaseWhenExpr(self, ctx):
        wxp = ctx.whenExpression()
        whenExpressions = [self.visit(we) for we in wxp]
        expression = None
        else_expr = ExpressionVisitor().visit(ctx.elseExpr) if ctx.elseExpr is not None else None
        return CaseExpression(expression, whenExpressions, else_expr)

    def visitWhenExpression(self, ctx):
        expression = BooleanExpressionVisitor().visit(ctx.baseBoolExpr)
        thenExpression = ExpressionVisitor().visit(ctx.thenExpr)
        return WhenExpression(expression, thenExpression)

    def visitWhenBaseExpression(self, ctx):
        expression = ExpressionVisitor().visit(ctx.baseWhenExpr)
        thenExpression = ExpressionVisitor().visit(ctx.thenExpr)
        return WhenExpression(expression, thenExpression)


class BooleanExpressionVisitor(SqlSmallVisitor):
    def visitLogicalNot(self, ctx):
        return LogicalNot(self.visit(ctx.booleanExpression()))

    def visitComparison(self, ctx):
        ev = ExpressionVisitor()
        return BooleanCompare(ev.visit(ctx.left), Op(ctx.op.getText()), ev.visit(ctx.right))

    def visitConjunction(self, ctx):
        return BooleanCompare(self.visit(ctx.left), Op("AND"), self.visit(ctx.right))

    def visitDisjunction(self, ctx):
        return BooleanCompare(self.visit(ctx.left), Op("OR"), self.visit(ctx.right))

    def visitNestedBoolean(self, ctx):
        return NestedBoolean(self.visit(ctx.booleanExpression()))

    def visitPredicated(self, ctx):
        expression = ExpressionVisitor().visit(ctx.expression())
        predicate = self.visit(ctx.predicate())
        return PredicatedExpression(expression, predicate)

    def visitInCondition(self, ctx):
        is_not = ctx.NOT() is not None
        expressions = Seq([ExpressionVisitor().visit(e) for e in ctx.expression()])
        return InCondition(expressions, is_not)

    def visitIsCondition(self, ctx):
        is_not = ctx.NOT() is not None
        if ctx.TRUE() is not None:
            value = Literal(True)
        elif ctx.FALSE() is not None:
            value = Literal(False)
        elif ctx.NULL() is not None:
            value = Literal(None)
        else:
            raise ValueError("Unknown condition in IS clause: " + allText(ctx))
        return IsCondition(value, is_not)

    def visitBetweenCondition(self, ctx):
        is_not = ctx.NOT() is not None
        lower = ExpressionVisitor().visit(ctx.lower)
        upper = ExpressionVisitor().visit(ctx.upper)
        return BetweenCondition(lower, upper, is_not)

    def visitQualifiedColumnName(self, ctx):
        return ColumnBoolean(Column(ctx.getText()))

    def visitLikeCondition(self, ctx: SqlSmallParser.LikeConditionContext):
        value = ExpressionVisitor().visit(ctx.value)
        pattern = ExpressionVisitor().visit(ctx.pattern)
        return LikeExpression(value, pattern)

    def visitILikeCondition(self, ctx: SqlSmallParser.ILikeConditionContext):
        value = ExpressionVisitor().visit(ctx.value)
        pattern = ExpressionVisitor().visit(ctx.pattern)
        return ILikeExpression(value, pattern)

    def visitBigqueryRegexpContainsCondition(self, ctx: SqlSmallParser.BigqueryRegexpContainsConditionContext):
        return RegexpContainsFunction(
            ExpressionVisitor().visit(ctx.regexpContainsFunction().sourceExpr),
            ExpressionVisitor().visit(ctx.regexpContainsFunction().regExpr),
        )
    
class UserDefinedArgumentVisitor(SqlSmallVisitor):
    def visitExprArg(self, ctx: SqlSmallParser.ExprArgContext):
        return ExpressionVisitor().visit(ctx.expression()) 
    
    def visitDatePartArg(self, ctx: SqlSmallParser.DatePartArgContext):
        return Token(ctx.datePart().getText())
    
    def visitTypeArg(self, ctx: SqlSmallParser.TypeArgContext):
        return ctx.dbType().getText()


def allText(ctx):
    """
        This method is used to grab text with whitespace
        for a terminal node of the AST that hasn't been
        strongly-typed yet.  Should not be used for lexer
        tokens, and should be replaced over time.
    """
    a = ctx.start.start
    b = ctx.stop.stop
    inp = ctx.start.getInputStream()
    return inp.getText(a, b)