Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_sql / sarus_sql / ast_utils / rewrite.py
Size: Mime:
import copy
import typing as t

from snsql._ast.ast import (
    AliasedRelation,
    AliasedSubquery,
    BooleanCompare,
    BooleanJoinCriteria,
    CaseExpression,
    Column,
    Identifier,
    IsCondition,
    Join,
    Limit,
    Literal,
    NamedExpression,
    NestedBoolean,
    Op,
    PredicatedExpression,
    Query,
    Relation,
    Seq,
    Sql,
    Table,
    TableColumn,
    Token,
    UsingJoinCriteria,
    WhenExpression,
)
from snsql._ast.expressions.numeric import ArithmeticExpression
from snsql._ast.expressions.string import CoalesceFunction
from snsql._ast.expressions.types import CastFunction
from snsql.metadata import Metadata

import sarus_sql.ast_utils as ast_utils

PU_COLUMN = "sarus_privacy_unit"
PUBLIC = "sarus_is_public"
WEIGHTS = "sarus_weights"


def _replace_in_ast_by_expr(
    expression: t.Optional[Sql],
    old_expr: Sql,
    new_expr: Sql,
) -> None:
    """TODO"""
    if expression is None:
        return None
    exprs_att = expression.__dict__
    for attr in exprs_att:
        if isinstance(exprs_att[attr], t.List):
            new_child_expr = []
            for sql_expr in exprs_att[attr]:
                if isinstance(sql_expr, Sql):
                    _replace_in_ast_by_expr(sql_expr, old_expr, new_expr)
                    new_child_expr.append(sql_expr)
            exprs_att[attr] = new_child_expr
        if (
            isinstance(exprs_att[attr], type(old_expr))
            and exprs_att[attr] == old_expr
        ):
            exprs_att[attr] = new_expr
            return
        if isinstance(exprs_att[attr], Sql):
            _replace_in_ast_by_expr(exprs_att[attr], old_expr, new_expr)


def replace_in_ast(expression: Sql, map_exprs: t.Dict[Sql, Sql]) -> Sql:
    """If expression in map_exprs replace it
    by the corresponding expression"""
    expr = copy.deepcopy(expression)
    for old, new in map_exprs.items():
        _replace_in_ast_by_expr(expr, old, new)
    return expr


def replace_in_expression(
    expression: Sql,
    named_expr: NamedExpression,
) -> None:
    """replace named.expression by Column(named_expr.name.text) in
    expression"""
    exprs_att = expression.__dict__
    for attr in exprs_att:
        if isinstance(exprs_att[attr], list):
            new_child_expr = []
            for sql_expr in exprs_att[attr]:
                if isinstance(sql_expr, Sql):
                    replace_in_expression(sql_expr, named_expr)
                    new_child_expr.append(sql_expr)
            exprs_att[attr] = new_child_expr
        if (
            isinstance(exprs_att[attr], type(named_expr.expression))
            and exprs_att[attr] == named_expr.expression
        ):
            exprs_att[attr] = Column(named_expr.name.text)
            return
        if isinstance(exprs_att[attr], Sql):
            replace_in_expression(exprs_att[attr], named_expr)


def substitute_tables(
    query: Query, meta: Metadata, tables: t.Optional[t.Dict[str, str]]
) -> Query:
    """Replace Table in query by subquery from tables dict"""

    if not tables:
        return query

    def _replace_in_primary_relation(
        primary_relation: t.Union[Table, AliasedSubquery, AliasedRelation],
    ) -> t.Union[Table, AliasedSubquery, AliasedRelation]:
        if isinstance(primary_relation, Table):
            if primary_relation.name in [
                meta.compare.strip_escapes(tab) for tab in meta.m_tables
            ]:
                return primary_relation

            if not tables or primary_relation.name not in tables:
                raise ValueError(
                    f"Table {primary_relation} unknown. Got: {query}"
                )

            subquery = substitute_tables(
                ast_utils.parse_query(tables[primary_relation.name]),
                meta,
                tables,
            )
            if primary_relation.alias:
                alias = copy.deepcopy(primary_relation.alias)
            else:
                alias = primary_relation.name
            return AliasedSubquery(subquery, alias)

        if isinstance(primary_relation, AliasedSubquery):
            primary_relation.query = substitute_tables(
                primary_relation.query, meta, tables
            )

        elif isinstance(primary_relation, AliasedRelation):
            primary_relation.relation = _replace_in_relation(
                primary_relation.relation
            )

        return primary_relation

    def _replace_in_relation(relation: Relation) -> Relation:
        relation.primary = _replace_in_primary_relation(relation.primary)
        for join in relation.joins:
            join.right = _replace_in_primary_relation(join.right)
        return relation

    new_query = copy.deepcopy(query)
    new_query.source.relations.list = [
        _replace_in_relation(rel) for rel in new_query.source.relations
    ]
    return ast_utils.load_symbols_in_query(new_query, meta, True)


def add_cols_to_select(query: Query, list_of_admin_cols: t.List[str]) -> Query:
    """It adds list of administrative columns to the select statement.

    Warning!! In the case of a JOIN between 2 protected tables
    we modify the query to force the JOIN on pid.

    If list_of_admin_cols is not empty the query is assumed to be pup.
    TODO: We need to make sure the Query is pup if list_of_admin_cols is not
    empty.
    """

    if not list_of_admin_cols:
        return query

    if query.find_nodes(Join):
        return add_cols_and_pid_on_private_joins_key(query, list_of_admin_cols)
    else:
        new_column_expressions = [
            NamedExpression(
                name=Identifier(colname), expression=Column(colname)
            )
            for colname in list_of_admin_cols
        ]

        new_query = copy.deepcopy(query)
        existing_select = list(new_query.select.namedExpressions)
        existing_select.extend(new_column_expressions)
        new_query.select.namedExpressions = Seq(existing_select)
        return new_query


def add_limit(query: Query, limit: int) -> Query:
    """If limit exists it will override it"""
    new_query = copy.deepcopy(query)
    new_query.limit = Limit(limit)
    return new_query


def add_cols_and_pid_on_private_joins_key(
    query: Query, list_of_admin_cols: t.List[str]
) -> Query:
    """Given a query with JOINs and a list of admin columns it adds admin
    columns to the select statement and forces JOINs criteria to have
    (t1.pid=t2.pid) if t1 and t2 are private.

    Args:
        query (Query): Query with JOINs. If list_of_admin_cols is not empty
            we assume that the Query is pup. TODO: We should check that the
            query is pup. The query must have metadata already loaded (the
            method is_public() must be available, otherwise an error is
            raised.)
        list_of_admin_cols (t.List[str]) list with admin columns. If empty
            the Query without any modification is returned. If not empty,
            admin cols will be added into the SELECT statement respecting
            the order provided in the list.

    Since the statement to add the admin columns is different for all of them
    we raise a NotImplemented error we try to add admin columns different than
    PU_COLUMN, PUBLIC and WEIGHT.

    """

    assert len(query.find_nodes(Join)) != 0
    if not list_of_admin_cols:
        return query

    # make sure query has symbols properly loaded and that `is_public` is
    # available
    try:
        query.is_public()
    except ValueError as err:
        raise ValueError(err)

    new_query = copy.deepcopy(query)
    rels_with_joins = [
        rel for rel in new_query.find_nodes(Relation) if rel.joins
    ]

    # in the select we need to add a pid from a specific source table
    # depending on the join type, on the condition criteria and whether
    # the table is public or not. pid_source it allows us to maintain a
    # a temporary state for successive joins.
    # if pid_source[0] is an empty string:
    # add unqualified pid name to the select (JOIN USING(pid))
    # if pid_source[0] is list of string:
    # add a case statement, (t1 FULL JOIN t2 ON(t1.pid=t2.pid))
    # if pid_source[0] is a string:
    # add join from a specific table depending on the joinType,
    # t1 LEFT JOIN t2 ON(t1.pid=t2.pid).

    # pid_source[1] is bool and indicates if the source pid is coming
    # from public table this is useful in multiple nested joins.
    pid_source: t.Union[t.Tuple[t.List[str], bool], t.Tuple[str, bool]]

    tables_names = []
    for rel in rels_with_joins:
        if not all(
            [
                isinstance(tab, Table)
                for tab in [rel.primary] + [join.right for join in rel.joins]
            ]
        ):
            raise NotImplementedError()
        primary_name = (
            rel.primary.alias if rel.primary.alias else rel.primary.name
        )
        tables_names.append(primary_name)
        pid_source = (primary_name, rel.primary.is_public())

        for join in rel.joins:
            right_alias_or_name = (
                join.right.alias if join.right.alias else join.right.name
            )
            tables_names.append(right_alias_or_name)
            if not pid_source[-1] and not join.right.is_public():
                # enter here only if the primary and
                #  the joined table are private
                if join.criteria is None and join.joinType == Token("CROSS"):
                    join.criteria = BooleanJoinCriteria(
                        NestedBoolean(
                            BooleanCompare(
                                Literal(True), Op("="), Literal(True)
                            )
                        )
                    )
                    join.joinType = Token("")

                if isinstance(join.criteria, BooleanJoinCriteria):
                    existing_bool = copy.deepcopy(join.criteria.expression)
                    left_alias_or_name = pid_source[0]
                    left_col = Column(f"{left_alias_or_name}.{PU_COLUMN}")
                    right_col = Column(f"{right_alias_or_name}.{PU_COLUMN}")
                    pup_bool = NestedBoolean(
                        BooleanCompare(left_col, Op("="), right_col)
                    )
                    existing_bool = copy.deepcopy(join.criteria.expression)
                    outer_bool = NestedBoolean(
                        BooleanCompare(
                            left=pup_bool, op=Op("AND"), right=existing_bool
                        )
                    )
                    join.criteria = BooleanJoinCriteria(outer_bool)

                elif isinstance(join.criteria, UsingJoinCriteria):
                    idents_with_pid = [
                        ident for ident in join.criteria.identifiers
                    ] + [Column(PU_COLUMN)]
                    join.criteria = UsingJoinCriteria(idents_with_pid)
                    pid_source = ("", False)

                if join.joinType == Token("RIGHT") and not isinstance(
                    join.criteria, UsingJoinCriteria
                ):
                    pid_source = (right_alias_or_name, False)
                elif join.joinType in [Token("FULL"), Token("FULL OUTER")]:
                    source_list = (
                        [pid_source[0]]
                        if isinstance(pid_source[0], str)
                        else pid_source[0]
                    )
                    source_list.append(right_alias_or_name)
                    pid_source = (source_list, False)
            else:
                # one of joining tables is public
                if not all([pid_source[-1], join.right.is_public()]):
                    pid_source = (
                        right_alias_or_name
                        if not join.right.is_public()
                        else pid_source[0],
                        False,
                    )
    # The order on which admin columns are added matters.
    named_expressions = []
    for col in list_of_admin_cols:
        if col == PU_COLUMN:
            named_expressions.append(pid_named_named_expr(pid_source[0]))
        elif col == PUBLIC:
            named_expressions.append(public_column_named_expr(tables_names))
        elif col == WEIGHTS:
            named_expressions.append(weight_column_named_expr(tables_names))
        else:
            raise NotImplementedError()

    existing_select = list(new_query.select.namedExpressions)
    existing_select.extend(named_expressions)
    new_query.select.namedExpressions = Seq(existing_select)
    return new_query


def public_column_named_expr(joined_tables: t.List[str]) -> NamedExpression:
    """It creates a NamedExpression for the PUBLIC column when there are
    joins involved.

    Args:
        joined_tables: t.List[str] list of joined table names or aliases
            from which we select the PUBLIC column.
            It must have at least 2 items
    Returns:
        NamedExpression:

    We would like to add
    COALESCE(t1.PUBLIC, TRUE) AND COALESCE(t2.PUBLIC, TRUE) AND ...
    but the grammar doesn't allow to include a BooleanCompare in the select.
    An alternative but equivalent way to do the same thing is to construct the
    following CASE expression:

    CASE WHEN
        COALESCE(t1.PUBLIC, TRUE) IS TRUE AND
        COALESCE(t2.PUBLIC, TRUE) IS TRUE AND ...
        THEN TRUE
        ELSE FALSE
    END
    """
    assert len(joined_tables) > 1
    else_expr = Literal(False)
    coalesces = [
        CoalesceFunction(
            [
                CastFunction(
                    expression=TableColumn(
                        tablename=table_name, colname=PUBLIC
                    ),
                    dbtype="boolean",
                ),
                Literal(True),
            ]
        )
        for table_name in joined_tables
    ]

    predicates = [
        BooleanCompare(left=coalesce, op=Op("="), right=Literal(True))
        for coalesce in coalesces
    ]

    # predicates = [
    #     PredicatedExpression(coalesce, IsCondition(Literal(True)))
    #     for coalesce in coalesces
    # ]

    # chaining bool_compares
    bool_compares = predicates[0]
    for predicate in predicates[1:]:
        bool_compares = BooleanCompare(bool_compares, Op("AND"), predicate)

    when_then_expr = WhenExpression(bool_compares, Literal(True))
    case = CaseExpression(None, [when_then_expr], else_expr)
    return NamedExpression(
        name=Identifier(PUBLIC),
        expression=case,
    )


def weight_column_named_expr(joined_tables: t.List[str]) -> NamedExpression:
    """It creates the following NamedExpression to add WEIGHTS column
    when there are joins involved:

    COALESCE(t1.WEIGHTS, TRUE) * COALESCE(t2.WEIGHTS, TRUE) * ...

    Args:
        joined_tables (t.List[str]) list with joined table names or aliases
            from which we select the WEIGHTS column.
            It must have at least 2 items
    Returns:
        NamedExpression:
    """
    assert len(joined_tables) > 1
    coalesces = [
        CoalesceFunction(
            [
                TableColumn(tablename=table_name, colname=WEIGHTS),
                Literal("1.0"),
            ]
        )
        for table_name in joined_tables
    ]
    arithmetic_expr = coalesces[0]
    for coalesce in coalesces[1:]:
        arithmetic_expr = ArithmeticExpression(
            left=arithmetic_expr, op=Op("*"), right=coalesce
        )

    return NamedExpression(
        name=Identifier(WEIGHTS),
        expression=arithmetic_expr,
    )


def pid_named_named_expr(
    source_table: t.Union[t.List[str], str],
) -> NamedExpression:
    """It creates a NamedExpression for the pid column when there are
    joins involved .

    Args:
        source_table (t.Union[t.List[str], str]): Table/s name (or alias)
            from where the pid column is taken from. If it is a list with
            many tables, a case expression is constructed.

    Returns:
        NamedExpression:
    """
    if isinstance(source_table, t.List):
        assert len(source_table) > 1
        # build a case statement
        when_columns = [
            Column(f"{source}.{PU_COLUMN}") for source in source_table[:-1]
        ]
        else_column = Column(f"{source_table[-1]}.{PU_COLUMN}")
        return NamedExpression(
            name=Identifier(PU_COLUMN),
            expression=case_for_not_null(when_columns, else_column),
        )

    elif not source_table:
        return NamedExpression(
            name=Identifier(PU_COLUMN),
            expression=Column(PU_COLUMN),
        )
    else:
        return NamedExpression(
            name=Identifier(PU_COLUMN),
            expression=TableColumn(tablename=source_table, colname=PU_COLUMN),
        )


def case_for_not_null(
    when_columns: t.List[Column], else_column: Column
) -> CaseExpression:
    """CASE clause constructor with not null predicate on when_columns.

    Args:
        when_columns (t.List[Column]): list of Columns
        else_column (Column): column

    Returns:
        CaseExpression:

    e.g:
    when_columns = ["t1"."pid"]
    else_column = "t2"."pid"

    it will create the following CASE statement:
    CASE WHEN "t1"."pid" IS NOT NULL THEN "t1"."pid" ELSE "t2"."pid" END
    """
    whens = [
        WhenExpression(
            PredicatedExpression(
                col,
                IsCondition(Literal("NULL"), is_not=True),
            ),
            col,
        )
        for col in when_columns
    ]
    return CaseExpression(None, whens, else_column)