Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
import copy
import typing as t
from snsql._ast.ast import (
AliasedRelation,
AliasedSubquery,
BooleanCompare,
BooleanJoinCriteria,
CaseExpression,
Column,
Identifier,
IsCondition,
Join,
Limit,
Literal,
NamedExpression,
NestedBoolean,
Op,
PredicatedExpression,
Query,
Relation,
Seq,
Sql,
Table,
TableColumn,
Token,
UsingJoinCriteria,
WhenExpression,
)
from snsql._ast.expressions.numeric import ArithmeticExpression
from snsql._ast.expressions.string import CoalesceFunction
from snsql._ast.expressions.types import CastFunction
from snsql.metadata import Metadata
import sarus_sql.ast_utils as ast_utils
PU_COLUMN = "sarus_privacy_unit"
PUBLIC = "sarus_is_public"
WEIGHTS = "sarus_weights"
def _replace_in_ast_by_expr(
expression: t.Optional[Sql],
old_expr: Sql,
new_expr: Sql,
) -> None:
"""TODO"""
if expression is None:
return None
exprs_att = expression.__dict__
for attr in exprs_att:
if isinstance(exprs_att[attr], t.List):
new_child_expr = []
for sql_expr in exprs_att[attr]:
if isinstance(sql_expr, Sql):
_replace_in_ast_by_expr(sql_expr, old_expr, new_expr)
new_child_expr.append(sql_expr)
exprs_att[attr] = new_child_expr
if (
isinstance(exprs_att[attr], type(old_expr))
and exprs_att[attr] == old_expr
):
exprs_att[attr] = new_expr
return
if isinstance(exprs_att[attr], Sql):
_replace_in_ast_by_expr(exprs_att[attr], old_expr, new_expr)
def replace_in_ast(expression: Sql, map_exprs: t.Dict[Sql, Sql]) -> Sql:
"""If expression in map_exprs replace it
by the corresponding expression"""
expr = copy.deepcopy(expression)
for old, new in map_exprs.items():
_replace_in_ast_by_expr(expr, old, new)
return expr
def replace_in_expression(
expression: Sql,
named_expr: NamedExpression,
) -> None:
"""replace named.expression by Column(named_expr.name.text) in
expression"""
exprs_att = expression.__dict__
for attr in exprs_att:
if isinstance(exprs_att[attr], list):
new_child_expr = []
for sql_expr in exprs_att[attr]:
if isinstance(sql_expr, Sql):
replace_in_expression(sql_expr, named_expr)
new_child_expr.append(sql_expr)
exprs_att[attr] = new_child_expr
if (
isinstance(exprs_att[attr], type(named_expr.expression))
and exprs_att[attr] == named_expr.expression
):
exprs_att[attr] = Column(named_expr.name.text)
return
if isinstance(exprs_att[attr], Sql):
replace_in_expression(exprs_att[attr], named_expr)
def substitute_tables(
query: Query, meta: Metadata, tables: t.Optional[t.Dict[str, str]]
) -> Query:
"""Replace Table in query by subquery from tables dict"""
if not tables:
return query
def _replace_in_primary_relation(
primary_relation: t.Union[Table, AliasedSubquery, AliasedRelation],
) -> t.Union[Table, AliasedSubquery, AliasedRelation]:
if isinstance(primary_relation, Table):
if primary_relation.name in [
meta.compare.strip_escapes(tab) for tab in meta.m_tables
]:
return primary_relation
if not tables or primary_relation.name not in tables:
raise ValueError(
f"Table {primary_relation} unknown. Got: {query}"
)
subquery = substitute_tables(
ast_utils.parse_query(tables[primary_relation.name]),
meta,
tables,
)
if primary_relation.alias:
alias = copy.deepcopy(primary_relation.alias)
else:
alias = primary_relation.name
return AliasedSubquery(subquery, alias)
if isinstance(primary_relation, AliasedSubquery):
primary_relation.query = substitute_tables(
primary_relation.query, meta, tables
)
elif isinstance(primary_relation, AliasedRelation):
primary_relation.relation = _replace_in_relation(
primary_relation.relation
)
return primary_relation
def _replace_in_relation(relation: Relation) -> Relation:
relation.primary = _replace_in_primary_relation(relation.primary)
for join in relation.joins:
join.right = _replace_in_primary_relation(join.right)
return relation
new_query = copy.deepcopy(query)
new_query.source.relations.list = [
_replace_in_relation(rel) for rel in new_query.source.relations
]
return ast_utils.load_symbols_in_query(new_query, meta, True)
def add_cols_to_select(query: Query, list_of_admin_cols: t.List[str]) -> Query:
"""It adds list of administrative columns to the select statement.
Warning!! In the case of a JOIN between 2 protected tables
we modify the query to force the JOIN on pid.
If list_of_admin_cols is not empty the query is assumed to be pup.
TODO: We need to make sure the Query is pup if list_of_admin_cols is not
empty.
"""
if not list_of_admin_cols:
return query
if query.find_nodes(Join):
return add_cols_and_pid_on_private_joins_key(query, list_of_admin_cols)
else:
new_column_expressions = [
NamedExpression(
name=Identifier(colname), expression=Column(colname)
)
for colname in list_of_admin_cols
]
new_query = copy.deepcopy(query)
existing_select = list(new_query.select.namedExpressions)
existing_select.extend(new_column_expressions)
new_query.select.namedExpressions = Seq(existing_select)
return new_query
def add_limit(query: Query, limit: int) -> Query:
"""If limit exists it will override it"""
new_query = copy.deepcopy(query)
new_query.limit = Limit(limit)
return new_query
def add_cols_and_pid_on_private_joins_key(
query: Query, list_of_admin_cols: t.List[str]
) -> Query:
"""Given a query with JOINs and a list of admin columns it adds admin
columns to the select statement and forces JOINs criteria to have
(t1.pid=t2.pid) if t1 and t2 are private.
Args:
query (Query): Query with JOINs. If list_of_admin_cols is not empty
we assume that the Query is pup. TODO: We should check that the
query is pup. The query must have metadata already loaded (the
method is_public() must be available, otherwise an error is
raised.)
list_of_admin_cols (t.List[str]) list with admin columns. If empty
the Query without any modification is returned. If not empty,
admin cols will be added into the SELECT statement respecting
the order provided in the list.
Since the statement to add the admin columns is different for all of them
we raise a NotImplemented error we try to add admin columns different than
PU_COLUMN, PUBLIC and WEIGHT.
"""
assert len(query.find_nodes(Join)) != 0
if not list_of_admin_cols:
return query
# make sure query has symbols properly loaded and that `is_public` is
# available
try:
query.is_public()
except ValueError as err:
raise ValueError(err)
new_query = copy.deepcopy(query)
rels_with_joins = [
rel for rel in new_query.find_nodes(Relation) if rel.joins
]
# in the select we need to add a pid from a specific source table
# depending on the join type, on the condition criteria and whether
# the table is public or not. pid_source it allows us to maintain a
# a temporary state for successive joins.
# if pid_source[0] is an empty string:
# add unqualified pid name to the select (JOIN USING(pid))
# if pid_source[0] is list of string:
# add a case statement, (t1 FULL JOIN t2 ON(t1.pid=t2.pid))
# if pid_source[0] is a string:
# add join from a specific table depending on the joinType,
# t1 LEFT JOIN t2 ON(t1.pid=t2.pid).
# pid_source[1] is bool and indicates if the source pid is coming
# from public table this is useful in multiple nested joins.
pid_source: t.Union[t.Tuple[t.List[str], bool], t.Tuple[str, bool]]
tables_names = []
for rel in rels_with_joins:
if not all(
[
isinstance(tab, Table)
for tab in [rel.primary] + [join.right for join in rel.joins]
]
):
raise NotImplementedError()
primary_name = (
rel.primary.alias if rel.primary.alias else rel.primary.name
)
tables_names.append(primary_name)
pid_source = (primary_name, rel.primary.is_public())
for join in rel.joins:
right_alias_or_name = (
join.right.alias if join.right.alias else join.right.name
)
tables_names.append(right_alias_or_name)
if not pid_source[-1] and not join.right.is_public():
# enter here only if the primary and
# the joined table are private
if join.criteria is None and join.joinType == Token("CROSS"):
join.criteria = BooleanJoinCriteria(
NestedBoolean(
BooleanCompare(
Literal(True), Op("="), Literal(True)
)
)
)
join.joinType = Token("")
if isinstance(join.criteria, BooleanJoinCriteria):
existing_bool = copy.deepcopy(join.criteria.expression)
left_alias_or_name = pid_source[0]
left_col = Column(f"{left_alias_or_name}.{PU_COLUMN}")
right_col = Column(f"{right_alias_or_name}.{PU_COLUMN}")
pup_bool = NestedBoolean(
BooleanCompare(left_col, Op("="), right_col)
)
existing_bool = copy.deepcopy(join.criteria.expression)
outer_bool = NestedBoolean(
BooleanCompare(
left=pup_bool, op=Op("AND"), right=existing_bool
)
)
join.criteria = BooleanJoinCriteria(outer_bool)
elif isinstance(join.criteria, UsingJoinCriteria):
idents_with_pid = [
ident for ident in join.criteria.identifiers
] + [Column(PU_COLUMN)]
join.criteria = UsingJoinCriteria(idents_with_pid)
pid_source = ("", False)
if join.joinType == Token("RIGHT") and not isinstance(
join.criteria, UsingJoinCriteria
):
pid_source = (right_alias_or_name, False)
elif join.joinType in [Token("FULL"), Token("FULL OUTER")]:
source_list = (
[pid_source[0]]
if isinstance(pid_source[0], str)
else pid_source[0]
)
source_list.append(right_alias_or_name)
pid_source = (source_list, False)
else:
# one of joining tables is public
if not all([pid_source[-1], join.right.is_public()]):
pid_source = (
right_alias_or_name
if not join.right.is_public()
else pid_source[0],
False,
)
# The order on which admin columns are added matters.
named_expressions = []
for col in list_of_admin_cols:
if col == PU_COLUMN:
named_expressions.append(pid_named_named_expr(pid_source[0]))
elif col == PUBLIC:
named_expressions.append(public_column_named_expr(tables_names))
elif col == WEIGHTS:
named_expressions.append(weight_column_named_expr(tables_names))
else:
raise NotImplementedError()
existing_select = list(new_query.select.namedExpressions)
existing_select.extend(named_expressions)
new_query.select.namedExpressions = Seq(existing_select)
return new_query
def public_column_named_expr(joined_tables: t.List[str]) -> NamedExpression:
"""It creates a NamedExpression for the PUBLIC column when there are
joins involved.
Args:
joined_tables: t.List[str] list of joined table names or aliases
from which we select the PUBLIC column.
It must have at least 2 items
Returns:
NamedExpression:
We would like to add
COALESCE(t1.PUBLIC, TRUE) AND COALESCE(t2.PUBLIC, TRUE) AND ...
but the grammar doesn't allow to include a BooleanCompare in the select.
An alternative but equivalent way to do the same thing is to construct the
following CASE expression:
CASE WHEN
COALESCE(t1.PUBLIC, TRUE) IS TRUE AND
COALESCE(t2.PUBLIC, TRUE) IS TRUE AND ...
THEN TRUE
ELSE FALSE
END
"""
assert len(joined_tables) > 1
else_expr = Literal(False)
coalesces = [
CoalesceFunction(
[
CastFunction(
expression=TableColumn(
tablename=table_name, colname=PUBLIC
),
dbtype="boolean",
),
Literal(True),
]
)
for table_name in joined_tables
]
predicates = [
BooleanCompare(left=coalesce, op=Op("="), right=Literal(True))
for coalesce in coalesces
]
# predicates = [
# PredicatedExpression(coalesce, IsCondition(Literal(True)))
# for coalesce in coalesces
# ]
# chaining bool_compares
bool_compares = predicates[0]
for predicate in predicates[1:]:
bool_compares = BooleanCompare(bool_compares, Op("AND"), predicate)
when_then_expr = WhenExpression(bool_compares, Literal(True))
case = CaseExpression(None, [when_then_expr], else_expr)
return NamedExpression(
name=Identifier(PUBLIC),
expression=case,
)
def weight_column_named_expr(joined_tables: t.List[str]) -> NamedExpression:
"""It creates the following NamedExpression to add WEIGHTS column
when there are joins involved:
COALESCE(t1.WEIGHTS, TRUE) * COALESCE(t2.WEIGHTS, TRUE) * ...
Args:
joined_tables (t.List[str]) list with joined table names or aliases
from which we select the WEIGHTS column.
It must have at least 2 items
Returns:
NamedExpression:
"""
assert len(joined_tables) > 1
coalesces = [
CoalesceFunction(
[
TableColumn(tablename=table_name, colname=WEIGHTS),
Literal("1.0"),
]
)
for table_name in joined_tables
]
arithmetic_expr = coalesces[0]
for coalesce in coalesces[1:]:
arithmetic_expr = ArithmeticExpression(
left=arithmetic_expr, op=Op("*"), right=coalesce
)
return NamedExpression(
name=Identifier(WEIGHTS),
expression=arithmetic_expr,
)
def pid_named_named_expr(
source_table: t.Union[t.List[str], str],
) -> NamedExpression:
"""It creates a NamedExpression for the pid column when there are
joins involved .
Args:
source_table (t.Union[t.List[str], str]): Table/s name (or alias)
from where the pid column is taken from. If it is a list with
many tables, a case expression is constructed.
Returns:
NamedExpression:
"""
if isinstance(source_table, t.List):
assert len(source_table) > 1
# build a case statement
when_columns = [
Column(f"{source}.{PU_COLUMN}") for source in source_table[:-1]
]
else_column = Column(f"{source_table[-1]}.{PU_COLUMN}")
return NamedExpression(
name=Identifier(PU_COLUMN),
expression=case_for_not_null(when_columns, else_column),
)
elif not source_table:
return NamedExpression(
name=Identifier(PU_COLUMN),
expression=Column(PU_COLUMN),
)
else:
return NamedExpression(
name=Identifier(PU_COLUMN),
expression=TableColumn(tablename=source_table, colname=PU_COLUMN),
)
def case_for_not_null(
when_columns: t.List[Column], else_column: Column
) -> CaseExpression:
"""CASE clause constructor with not null predicate on when_columns.
Args:
when_columns (t.List[Column]): list of Columns
else_column (Column): column
Returns:
CaseExpression:
e.g:
when_columns = ["t1"."pid"]
else_column = "t2"."pid"
it will create the following CASE statement:
CASE WHEN "t1"."pid" IS NOT NULL THEN "t1"."pid" ELSE "t2"."pid" END
"""
whens = [
WhenExpression(
PredicatedExpression(
col,
IsCondition(Literal("NULL"), is_not=True),
),
col,
)
for col in when_columns
]
return CaseExpression(None, whens, else_column)