Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
import typing as t
from snsql._ast.ast import (
AggFunction,
AliasedRelation,
AliasedSubquery,
ArithmeticExpression,
BetweenCondition,
BooleanCompare,
BooleanJoinCriteria,
CaseExpression,
ChooseFunction,
Column,
ColumnBoolean,
GreatestFunction,
GroupingExpression,
IIFFunction,
InCondition,
LeastFunction,
LikeExpression,
LogicalNot,
MathFunction,
NamedExpression,
NamedSubquery,
NestedBoolean,
NestedExpression,
Order,
OverClause,
PartitionByClause,
PowerFunction,
PredicatedExpression,
Query,
QueryOperation,
RankingFunction,
Relation,
RoundFunction,
Seq,
SortItem,
Sql,
TruncFunction,
UnifiedQuery,
UserDefinedFunction,
WhenExpression,
Where,
)
from snsql._ast.expressions.string import (
CharLengthFunction,
CoalesceFunction,
ConcatFunction,
LowerFunction,
TrimFunction,
UpperFunction,
)
from snsql._ast.expressions.types import CastFunction
class ASTWalker:
"""Base class for walking and modifying the query AST"""
def __init__(
self, query: t.Union[Query, UnifiedQuery], **_kwargs: t.Any
) -> None:
self.query = query
def transform(self, expr: Sql) -> Sql:
raise NotImplementedError()
def run(self) -> None:
self.query = self.transform(self.query)
def walk_tree(self, ast_expr: Sql) -> Sql:
"""walking AST tree function. This is common for all translators.
Translator specific transformations are done inside self.transform
function.
Args:
ast_expr (Sql):
Returns:
Sql:
"""
if isinstance(ast_expr, UnifiedQuery) or isinstance(ast_expr, Query):
if isinstance(ast_expr, UnifiedQuery):
ast_expr.query_ops = [
QueryOperation(
qry_op.operator,
qry_op.all,
self.walk_tree(qry_op.query),
)
for qry_op in ast_expr.query_ops
]
if ast_expr.select.ctes is not None:
ast_expr.select.ctes = self.walk_tree(ast_expr.select.ctes)
ast_expr.select.namedExpressions = self.transform(
ast_expr.select.namedExpressions
)
# from
if ast_expr.source:
ast_expr.source.relations = self.transform(
ast_expr.source.relations
)
# where
if ast_expr.where:
ast_expr.where = self.transform(ast_expr.where)
# group by
if ast_expr.agg:
ast_expr.agg.groupingExpressions = self.transform(
ast_expr.agg.groupingExpressions
)
# having
if ast_expr.having:
ast_expr.having.condition = self.transform(
ast_expr.having.condition
)
# order by
if ast_expr.order:
ast_expr.order.sortItems = Seq(
[
self.transform(order_item)
for order_item in ast_expr.order.sortItems
]
)
elif isinstance(ast_expr, AliasedSubquery):
ast_expr.query = self.transform(ast_expr.query)
elif isinstance(ast_expr, NamedSubquery):
ast_expr.subquery = self.transform(ast_expr.subquery)
elif isinstance(ast_expr, Seq):
ast_expr.seq = [self.transform(seq) for seq in ast_expr.seq]
elif isinstance(ast_expr, Column):
ast_expr = self.transform(ast_expr)
elif isinstance(ast_expr, AliasedRelation):
ast_expr.relation = self.transform(ast_expr.relation)
elif isinstance(ast_expr, Relation):
ast_expr.primary = self.transform(ast_expr.primary)
for join in ast_expr.joins:
join.right = self.transform(join.right)
join.criteria = self.transform(join.criteria)
elif isinstance(ast_expr, (ArithmeticExpression, BooleanCompare)):
ast_expr.left = self.transform(ast_expr.left)
ast_expr.right = self.transform(ast_expr.right)
elif isinstance(ast_expr, CaseExpression):
ast_expr.expression = self.transform(ast_expr.expression)
ast_expr.when_exprs = [
WhenExpression(
self.transform(when_expr.expression),
self.transform(when_expr.then),
)
for when_expr in ast_expr.when_exprs
]
ast_expr.else_expr = self.transform(ast_expr.else_expr)
elif isinstance(ast_expr, RankingFunction):
partition: t.Optional[PartitionByClause] = None
if ast_expr.over.partition is not None:
partition = PartitionByClause(
Seq(
[
self.transform(expr)
for expr in ast_expr.over.partition.expressions
]
)
)
ast_expr.over = OverClause(
partition,
Order(
[
self.transform(order_item)
for order_item in ast_expr.over.order.sortItems
]
),
)
elif isinstance(
ast_expr,
(
NamedExpression,
BooleanJoinCriteria,
RoundFunction,
PowerFunction,
AggFunction,
MathFunction,
ChooseFunction,
NestedExpression,
ColumnBoolean,
NestedBoolean,
LogicalNot,
TruncFunction,
SortItem,
GroupingExpression,
CastFunction,
LowerFunction,
UpperFunction,
TrimFunction,
CharLengthFunction,
),
):
ast_expr.expression = self.transform(ast_expr.expression)
elif isinstance(ast_expr, IIFFunction):
ast_expr.test = self.transform(ast_expr.test)
ast_expr.yes = self.transform(ast_expr.yes)
ast_expr.no = self.transform(ast_expr.no)
elif isinstance(ast_expr, PredicatedExpression):
ast_expr.expression = self.transform(ast_expr.expression)
ast_expr.predicate = self.transform(ast_expr.predicate)
elif isinstance(
ast_expr,
(
InCondition,
ConcatFunction,
CoalesceFunction,
GreatestFunction,
LeastFunction,
UserDefinedFunction,
),
):
ast_expr.expressions = Seq(
[self.transform(expr) for expr in ast_expr.expressions]
)
elif isinstance(ast_expr, BetweenCondition):
ast_expr.upper = self.transform(ast_expr.upper)
ast_expr.lower = self.transform(ast_expr.lower)
elif isinstance(ast_expr, Where):
ast_expr.condition = self.transform(ast_expr.condition)
elif isinstance(ast_expr, LikeExpression):
ast_expr.pattern = self.transform(ast_expr.pattern)
ast_expr.value = self.transform(ast_expr.value)
return ast_expr