Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_sql / sarus_sql / ast_utils / query_analysis.py
Size: Mime:
"""Set of functions for extracting info from the parsed queries (AST)"""

import typing as t

from snsql._ast.ast import (
    Aggregate,
    AliasedRelation,
    AliasedSubquery,
    AllColumns,
    Column,
    Identifier,
    Join,
    Query,
    Relation,
    Sql,
    Table,
    TableColumn,
)

from sarus_sql.ast_transform import split_quote


# deprecated use get_toplevel_name instead
def get_table_name(query: Query) -> str:
    """Return the name of the table in the query. Works only for one table"""
    error_msg = (
        "Subqueries and JOIN are now supported but for one table only. "
        "From clauses must be of the form 'FROM <table_name>' or 'FROM "
        "<table_name> AS <alias>' or 'FROM (FROM <table_name>)'"
    )
    tables = list(
        set([str(table.name.strip('"')) for table in query.find_nodes(Table)])
    )
    if len(tables) != 1:
        raise ValueError(error_msg)

    aliases = [table.alias for table in query.find_nodes(Table)]
    cols = query.find_nodes(Column)
    for col in cols:
        names = col.name.split(".")
        if len(names) == 2 and names[0] not in aliases:
            raise ValueError(
                "Please alias the table and use the alias instead"
                " of the table name."
                f"\n{col.name} --> table_alias.{names[-1]}"
            )
    return tables[0]


def get_toplevel_name(query: Query) -> str:
    """Return the to level name (slugname) in the query.
    SELECT ... FROM D.S.T; it would give D
    SELECT ... FROM S.T; it would give S
    when JOINs, only the top level of the left side is provided.
    If JOIN is between multiple different dataset this would not work

    It is used in PLL to retreive the name of dataset queried by the user.
    Today the top level name (the slug name) can't contain "." caracters
    If otherwise this wouldn't work.
    """
    ds_list = list(
        set(
            [
                split_quote(str(table.name), "")[0]
                for table in query.find_nodes(Table)
            ]
        )
    )

    return str(ds_list[0])


def only_grouping_cols(expression: Sql, agg: t.Optional[Aggregate]) -> bool:
    """Returns True if expression contains only columns from the GROUP BY
    clause. Otherwise returns False."""
    if not agg:
        return False
    col_in_groupby = [
        c in agg.groupedColumns() for c in expression.find_nodes(Column)
    ]
    return col_in_groupby != [] and all(col_in_groupby)


def has_joins(query: Query) -> bool:
    """Returns True if the query has joins otherwise returns false"""
    if query.find_nodes(Join):
        return True
    return False


def find_key_col_from_relation(relation: Relation) -> t.List[TableColumn]:
    """Find the private key of a relation"""
    if relation.joins:
        tables = [relation.primary] + [j.right for j in relation.joins]
        protected_tables = [table for table in tables if not table.is_public()]
        if len(protected_tables) > 1:
            raise ValueError(
                "Cannot join more than one protected tables if "
                "row_privacy = False"
            )
        return find_key_col_from_relation(Relation(protected_tables[0], None))

    if isinstance(relation.primary, AliasedSubquery):
        return [find_key_col(relation.primary.query)]

    if isinstance(relation.primary, AliasedRelation):
        return find_key_col_from_relation(relation.relation)

    if isinstance(relation.primary, Table):
        rsyms = relation.primary.all_symbols(AllColumns())
        tcsyms = [
            r.expression
            for r in rsyms
            if isinstance(r.expression, TableColumn)
        ]
        return [tc for tc in tcsyms if tc.is_key]

    raise ValueError(f"Cannot find private column in : {relation}")


def find_key_col(query: Query) -> TableColumn:
    """Returns the private key column"""
    if not query.has_symbols():
        raise ValueError(
            "Cannot find the private id of a column "
            "if the metadata are not loaded"
        )
    relations = query.source.relations
    if len(relations) > 1:
        raise ValueError(
            "row_privacy = False with multitable not implemented "
            "for the moment. Cannot load private column."
        )
    keys = find_key_col_from_relation(relations[0])
    if len(keys) > 1:
        raise ValueError(
            f"We only know how to handle tables with one key: {keys}"
        )
    if len(keys) < 1:
        raise ValueError(
            "row_privacy = False with multitable of joins not implemented "
            "for the moment. Cannot load private column."
        )
    return keys[0]


def find_table(relation: Relation) -> Table:
    """Finds the table recursively from a relation"""
    while not isinstance(relation.primary, Table):
        relation = relation.primary.query.source.relations[0]
    return relation.primary


def find_subquery(query: Query, alias: Identifier) -> AliasedSubquery:
    """Finds (if it exists) the outest subquery whose alias = alias"""
    relation = query.source.relations[0].primary
    if relation.alias == alias:
        return relation
    while isinstance(
        relation.query.source.relations[0], Relation
    ) and isinstance(
        relation.query.source.relations[0].primary, AliasedSubquery
    ):
        relation = relation.query.source.relations[0].primary
        if relation.alias == alias:
            return relation