Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_sql / sarus_sql / ast_transform / ast_walker.py
Size: Mime:
import typing as t

from snsql._ast.ast import (
    AggFunction,
    AliasedRelation,
    AliasedSubquery,
    ArithmeticExpression,
    BetweenCondition,
    BooleanCompare,
    BooleanJoinCriteria,
    CaseExpression,
    ChooseFunction,
    Column,
    ColumnBoolean,
    GreatestFunction,
    GroupingExpression,
    IIFFunction,
    InCondition,
    LeastFunction,
    LikeExpression,
    LogicalNot,
    MathFunction,
    NamedExpression,
    NamedSubquery,
    NestedBoolean,
    NestedExpression,
    Order,
    OverClause,
    PartitionByClause,
    PowerFunction,
    PredicatedExpression,
    Query,
    QueryOperation,
    RankingFunction,
    Relation,
    RoundFunction,
    Seq,
    SortItem,
    Sql,
    TruncFunction,
    UnifiedQuery,
    UserDefinedFunction,
    WhenExpression,
    Where,
)
from snsql._ast.expressions.string import (
    CharLengthFunction,
    CoalesceFunction,
    ConcatFunction,
    LowerFunction,
    TrimFunction,
    UpperFunction,
)
from snsql._ast.expressions.types import CastFunction


class ASTWalker:
    """Base class for walking and modifying the query AST"""

    def __init__(
        self, query: t.Union[Query, UnifiedQuery], **_kwargs: t.Any
    ) -> None:
        self.query = query

    def transform(self, expr: Sql) -> Sql:
        raise NotImplementedError()

    def run(self) -> None:
        self.query = self.transform(self.query)

    def walk_tree(self, ast_expr: Sql) -> Sql:
        """walking AST tree function. This is common for all translators.
        Translator specific transformations are done inside self.transform
        function.

        Args:
            ast_expr (Sql):

        Returns:
            Sql:
        """
        if isinstance(ast_expr, UnifiedQuery) or isinstance(ast_expr, Query):
            if isinstance(ast_expr, UnifiedQuery):
                ast_expr.query_ops = [
                    QueryOperation(
                        qry_op.operator,
                        qry_op.all,
                        self.walk_tree(qry_op.query),
                    )
                    for qry_op in ast_expr.query_ops
                ]

            if ast_expr.select.ctes is not None:
                ast_expr.select.ctes = self.walk_tree(ast_expr.select.ctes)
            ast_expr.select.namedExpressions = self.transform(
                ast_expr.select.namedExpressions
            )

            # from
            if ast_expr.source:
                ast_expr.source.relations = self.transform(
                    ast_expr.source.relations
                )

            # where
            if ast_expr.where:
                ast_expr.where = self.transform(ast_expr.where)

            # group by

            if ast_expr.agg:
                ast_expr.agg.groupingExpressions = self.transform(
                    ast_expr.agg.groupingExpressions
                )

            # having
            if ast_expr.having:
                ast_expr.having.condition = self.transform(
                    ast_expr.having.condition
                )

            # order by
            if ast_expr.order:
                ast_expr.order.sortItems = Seq(
                    [
                        self.transform(order_item)
                        for order_item in ast_expr.order.sortItems
                    ]
                )

        elif isinstance(ast_expr, AliasedSubquery):
            ast_expr.query = self.transform(ast_expr.query)

        elif isinstance(ast_expr, NamedSubquery):
            ast_expr.subquery = self.transform(ast_expr.subquery)

        elif isinstance(ast_expr, Seq):
            ast_expr.seq = [self.transform(seq) for seq in ast_expr.seq]

        elif isinstance(ast_expr, Column):
            ast_expr = self.transform(ast_expr)

        elif isinstance(ast_expr, AliasedRelation):
            ast_expr.relation = self.transform(ast_expr.relation)

        elif isinstance(ast_expr, Relation):
            ast_expr.primary = self.transform(ast_expr.primary)
            for join in ast_expr.joins:
                join.right = self.transform(join.right)
                join.criteria = self.transform(join.criteria)

        elif isinstance(ast_expr, (ArithmeticExpression, BooleanCompare)):
            ast_expr.left = self.transform(ast_expr.left)
            ast_expr.right = self.transform(ast_expr.right)

        elif isinstance(ast_expr, CaseExpression):
            ast_expr.expression = self.transform(ast_expr.expression)
            ast_expr.when_exprs = [
                WhenExpression(
                    self.transform(when_expr.expression),
                    self.transform(when_expr.then),
                )
                for when_expr in ast_expr.when_exprs
            ]
            ast_expr.else_expr = self.transform(ast_expr.else_expr)

        elif isinstance(ast_expr, RankingFunction):
            partition: t.Optional[PartitionByClause] = None
            if ast_expr.over.partition is not None:
                partition = PartitionByClause(
                    Seq(
                        [
                            self.transform(expr)
                            for expr in ast_expr.over.partition.expressions
                        ]
                    )
                )
            ast_expr.over = OverClause(
                partition,
                Order(
                    [
                        self.transform(order_item)
                        for order_item in ast_expr.over.order.sortItems
                    ]
                ),
            )

        elif isinstance(
            ast_expr,
            (
                NamedExpression,
                BooleanJoinCriteria,
                RoundFunction,
                PowerFunction,
                AggFunction,
                MathFunction,
                ChooseFunction,
                NestedExpression,
                ColumnBoolean,
                NestedBoolean,
                LogicalNot,
                TruncFunction,
                SortItem,
                GroupingExpression,
                CastFunction,
                LowerFunction,
                UpperFunction,
                TrimFunction,
                CharLengthFunction,
            ),
        ):
            ast_expr.expression = self.transform(ast_expr.expression)

        elif isinstance(ast_expr, IIFFunction):
            ast_expr.test = self.transform(ast_expr.test)
            ast_expr.yes = self.transform(ast_expr.yes)
            ast_expr.no = self.transform(ast_expr.no)

        elif isinstance(ast_expr, PredicatedExpression):
            ast_expr.expression = self.transform(ast_expr.expression)
            ast_expr.predicate = self.transform(ast_expr.predicate)

        elif isinstance(
            ast_expr,
            (
                InCondition,
                ConcatFunction,
                CoalesceFunction,
                GreatestFunction,
                LeastFunction,
                UserDefinedFunction,
            ),
        ):
            ast_expr.expressions = Seq(
                [self.transform(expr) for expr in ast_expr.expressions]
            )

        elif isinstance(ast_expr, BetweenCondition):
            ast_expr.upper = self.transform(ast_expr.upper)
            ast_expr.lower = self.transform(ast_expr.lower)

        elif isinstance(ast_expr, Where):
            ast_expr.condition = self.transform(ast_expr.condition)

        elif isinstance(ast_expr, LikeExpression):
            ast_expr.pattern = self.transform(ast_expr.pattern)
            ast_expr.value = self.transform(ast_expr.value)

        return ast_expr