Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
"""Set of functions for extracting info from the parsed queries (AST)"""
import typing as t
from snsql._ast.ast import (
Aggregate,
AliasedRelation,
AliasedSubquery,
AllColumns,
Column,
Identifier,
Join,
Query,
Relation,
Sql,
Table,
TableColumn,
)
from sarus_sql.ast_transform import split_quote
# deprecated use get_toplevel_name instead
def get_table_name(query: Query) -> str:
"""Return the name of the table in the query. Works only for one table"""
error_msg = (
"Subqueries and JOIN are now supported but for one table only. "
"From clauses must be of the form 'FROM <table_name>' or 'FROM "
"<table_name> AS <alias>' or 'FROM (FROM <table_name>)'"
)
tables = list(
set([str(table.name.strip('"')) for table in query.find_nodes(Table)])
)
if len(tables) != 1:
raise ValueError(error_msg)
aliases = [table.alias for table in query.find_nodes(Table)]
cols = query.find_nodes(Column)
for col in cols:
names = col.name.split(".")
if len(names) == 2 and names[0] not in aliases:
raise ValueError(
"Please alias the table and use the alias instead"
" of the table name."
f"\n{col.name} --> table_alias.{names[-1]}"
)
return tables[0]
def get_toplevel_name(query: Query) -> str:
"""Return the to level name (slugname) in the query.
SELECT ... FROM D.S.T; it would give D
SELECT ... FROM S.T; it would give S
when JOINs, only the top level of the left side is provided.
If JOIN is between multiple different dataset this would not work
It is used in PLL to retreive the name of dataset queried by the user.
Today the top level name (the slug name) can't contain "." caracters
If otherwise this wouldn't work.
"""
ds_list = list(
set(
[
split_quote(str(table.name), "")[0]
for table in query.find_nodes(Table)
]
)
)
return str(ds_list[0])
def only_grouping_cols(expression: Sql, agg: t.Optional[Aggregate]) -> bool:
"""Returns True if expression contains only columns from the GROUP BY
clause. Otherwise returns False."""
if not agg:
return False
col_in_groupby = [
c in agg.groupedColumns() for c in expression.find_nodes(Column)
]
return col_in_groupby != [] and all(col_in_groupby)
def has_joins(query: Query) -> bool:
"""Returns True if the query has joins otherwise returns false"""
if query.find_nodes(Join):
return True
return False
def find_key_col_from_relation(relation: Relation) -> t.List[TableColumn]:
"""Find the private key of a relation"""
if relation.joins:
tables = [relation.primary] + [j.right for j in relation.joins]
protected_tables = [table for table in tables if not table.is_public()]
if len(protected_tables) > 1:
raise ValueError(
"Cannot join more than one protected tables if "
"row_privacy = False"
)
return find_key_col_from_relation(Relation(protected_tables[0], None))
if isinstance(relation.primary, AliasedSubquery):
return [find_key_col(relation.primary.query)]
if isinstance(relation.primary, AliasedRelation):
return find_key_col_from_relation(relation.relation)
if isinstance(relation.primary, Table):
rsyms = relation.primary.all_symbols(AllColumns())
tcsyms = [
r.expression
for r in rsyms
if isinstance(r.expression, TableColumn)
]
return [tc for tc in tcsyms if tc.is_key]
raise ValueError(f"Cannot find private column in : {relation}")
def find_key_col(query: Query) -> TableColumn:
"""Returns the private key column"""
if not query.has_symbols():
raise ValueError(
"Cannot find the private id of a column "
"if the metadata are not loaded"
)
relations = query.source.relations
if len(relations) > 1:
raise ValueError(
"row_privacy = False with multitable not implemented "
"for the moment. Cannot load private column."
)
keys = find_key_col_from_relation(relations[0])
if len(keys) > 1:
raise ValueError(
f"We only know how to handle tables with one key: {keys}"
)
if len(keys) < 1:
raise ValueError(
"row_privacy = False with multitable of joins not implemented "
"for the moment. Cannot load private column."
)
return keys[0]
def find_table(relation: Relation) -> Table:
"""Finds the table recursively from a relation"""
while not isinstance(relation.primary, Table):
relation = relation.primary.query.source.relations[0]
return relation.primary
def find_subquery(query: Query, alias: Identifier) -> AliasedSubquery:
"""Finds (if it exists) the outest subquery whose alias = alias"""
relation = query.source.relations[0].primary
if relation.alias == alias:
return relation
while isinstance(
relation.query.source.relations[0], Relation
) and isinstance(
relation.query.source.relations[0].primary, AliasedSubquery
):
relation = relation.query.source.relations[0].primary
if relation.alias == alias:
return relation