Repository URL to install this package:
|
Version:
0.2.102 ▾
|
from .tokens import *
from .expression import *
import copy
import warnings
"""
AST for parsed Python Query Batch. Allows validation, normalization,
rewriting, and serialization. Grammar is a strict subset of SQL-92.
Lexer and parser token names borrowed from SparkSQL Grammar.
"""
class Batch(Sql):
"""A batch of queries"""
def __init__(self, queries: List["Query"]) -> None:
self.queries = queries
def children(self):
return self.queries
class Query(SqlRel):
"""A single query"""
def __init__(self, select, source, where, agg, having, order, limit, metadata=None, privacy=None) -> None:
self.select = select
self.source = source
self.where = where
self.agg = agg
self.having = having
self.order = order
self.limit = limit
self.max_ids = None
self.sample_max_ids = None
self.row_privacy = None
self._named_symbols = None
self._select_symbols = None
if metadata:
self.load_symbols(metadata, privacy=privacy)
def load_symbols(self, metadata, privacy=None, ctes: Dict[str, "Query"] = {}):
self.privacy = privacy
new_ctes = copy.deepcopy(ctes)
# Handle common table expressions (CTES)
if self.select.ctes is not None:
for cte in self.select.ctes:
if str(cte.name) in new_ctes:
raise ValueError (
f'WITH query name "{str(cte.name)}" specified more than once'
)
cte_query = cte.subquery
cte_query.load_symbols(metadata, privacy=privacy, ctes=new_ctes)
new_ctes[str(cte.name)] = cte_query
# recursively load symbols for all relations
if self.source is None:
return
relations = self.source.relations
for r in relations:
r.load_symbols(metadata, privacy=privacy, ctes=new_ctes)
if not all([r.has_symbols() for r in relations]):
return # unable to load symbols
# load symbols in subqueries in the SELECT clause
for ne in self.select.namedExpressions:
if isinstance(ne.expression, (AliasedSubquery, Query)):
ne.expression.load_symbols(metadata, privacy=privacy, ctes=new_ctes)
tables = []
for t in self.find_nodes(Table):
# grab the first column in the table, to extract table metadata
if t.has_symbols():
tables.append(t._select_symbols[0].expression)
if len(tables) > 0:
self.max_ids = max(
(tc.max_ids if isinstance(tc, TableColumn) else 1)
for tc in tables
)
self.sample_max_ids = any(
(tc.sample_max_ids if isinstance(tc, TableColumn) else False)
for tc in tables
)
self.row_privacy = any(
(tc.row_privacy if isinstance(tc, TableColumn) else False)
for tc in tables
)
self.censor_dims = any(
(tc.censor_dims if isinstance(tc, TableColumn) else False)
for tc in tables
)
nexprs = self.select.namedExpressions
# get grouping expression symbols
self._grouping_symbols = []
if self.agg:
self._grouping_symbols = []
for ge in self.agg.groupingExpressions:
try:
symb = ge.expression.symbol(relations)
except ValueError as err: # Check if the expression has been aliased in the SELECT clause
if isinstance(ge.expression, Column):
expr = [
ne.expression for ne in nexprs
if ne.name and metadata.compare.identifier_match(ge.expression.name, ne.name)
]
if len(expr) == 1:
symb = expr[0].symbol(relations)
else:
raise err
else:
raise err
self._grouping_symbols.append(Symbol(symb))
# get namedExpression symbols
_symbols = []
for ne in nexprs:
if not isinstance(ne.expression, AllColumns):
name = ne.column_name()
_symbol_expr = ne.expression.symbol(relations)
_symbol = Symbol(_symbol_expr, name)
# annotate selects that reference GROUP BY column
_symbol.is_grouping_column = False
for ge in self._grouping_symbols:
if _symbol_expr == ge.expression:
_symbol.is_grouping_column = True
# annotate key_counts
_symbol.is_key_count = False
if _symbol.is_count:
col = _symbol.expression.xpath_first("//AggFunction[@name='COUNT']")
if col:
if self.row_privacy:
_symbol.is_key_count = isinstance(col.expression, AllColumns)
else:
_symbol.is_key_count = col.is_key_count
if self.privacy:
# add mechanism
_symbol.mechanism = None
mechanisms = self.privacy.mechanisms
epsilon = self.privacy.epsilon
delta = self.privacy.delta
if not _symbol.is_grouping_column:
sensitivity = _symbol.expression.sensitivity()
t = _symbol.expression.type()
if t in ['int', 'float'] and sensitivity is not None:
stat = 'count' if _symbol.is_count else 'sum'
if _symbol.is_key_count and self.censor_dims:
stat = 'threshold'
mech_class = mechanisms.get_mechanism(sensitivity, stat, t)
mech = mech_class(epsilon, delta=delta, sensitivity=sensitivity, max_contrib=self.max_ids)
if _symbol.is_key_count:
mech.delta = delta
_symbol.mechanism = mech
_symbols.append(_symbol)
# attach to named expression for xpath in accuracy.py
ne.m_symbol = _symbol
else:
# It's SELECT *, expand out to all columns
syms = ne.expression.all_symbols(relations)
for sym in syms:
_symbols.append(Symbol(sym.expression, sym.name))
self._select_symbols = _symbols
self._named_symbols = {
sym.name: sym
for sym in _symbols
if not sym.name == "???"
}
def is_public(self):
return self.source.is_public() and self.select.is_public()
def symbol(self, expression):
"""
returns the expression for an output column in the SELECT statement.
Query objects do not have aliases, so caller must strip alias first.
"""
if not self.has_symbols():
raise ValueError("Attempted to get symbol from query with no symbols loaded.")
if type(expression) is not Column:
raise ValueError(
"Can only request output columns from a query: " + str(type(expression))
)
return self[expression.name]
@property
def m_symbols(self):
warnings.warn("m_symbols has been renamed to _select_symbols")
return self._select_symbols
def numeric_symbols(self):
return [s for s in self._select_symbols if s.expression.type() in ["int", "float"]]
@property
def key_column(self):
# return TableColumn used as the primary key
rsyms = self.source.relations[0].all_symbols(AllColumns())
tcsyms = [r.expression for r in rsyms if type(r.expression) is TableColumn]
keys = [tc for tc in tcsyms if tc.is_key]
if len(keys) > 1:
raise ValueError("We only know how to handle tables with one key: " + str(keys))
if self.row_privacy:
if len(keys) > 0:
raise ValueError("Row privacy is set, but metadata specifies a private_id")
else:
return None
elif self.row_privacy == False:
if len(keys) < 1:
raise ValueError("No private_id column specified, and row_privacy is False")
else:
return keys[0]
else:
# symbols haven't been loaded yet
if len(keys) < 1:
return None
else:
# kp = keys[0].split(".")
# return kp[len(kp) - 1]
return keys[0]
def children(self) -> List[Any]:
return [self.select, self.source, self.where, self.agg, self.having, self.order, self.limit]
def evaluate(self, bindings):
return [(ne.name, ne.expression.evaluate(bindings)) for ne in self.select.namedExpressions]
class UnifiedQuery(Query):
"""UNION or INTERSECT Query"""
def __init__(self, query, query_ops, metadata=None, privacy=None):
self.query_ops = Seq(query_ops)
super().__init__(
query.select,
query.source,
query.where,
query.agg,
query.having,
query.order,
query.limit,
metadata=metadata,
privacy=privacy
)
def query(self, metadata=None, privacy=None):
return Query(
self.select,
self.source,
self.where,
self.agg,
self.having,
self.order,
self.limit,
metadata=metadata,
privacy=privacy
)
def __str__(self):
return str(self.query()) + ' ' + ' '.join([str(qry_op) for qry_op in self.query_ops])
def children(self):
return [self.select, self.source, self.where, self.agg, self.having, self.order, self.limit, self.query_ops]
def load_symbols(self, metadata, privacy=None, ctes: Dict[str, "Query"] = {}):
super().load_symbols(metadata=metadata, privacy=privacy, ctes=ctes)
for qry_op in self.query_ops:
qry_op.load_symbols(metadata=metadata, privacy=privacy, ctes=ctes)
if len(qry_op.query.all_symbols()) != len(self._select_symbols):
raise ValueError(
f"each {qry_op.operator} query must have the same number of columns"
)
for symb1, symb2 in zip(self._select_symbols, qry_op.query.all_symbols()):
if symb1.expression != symb2.expression:
raise ValueError(
f"Cannot load symbols if the columns of all the queries are not the same"
)
def is_public(self):
return all([self.query().is_public()] + [qops.is_public() for qops in self.query_ops])
class QueryOperation(Sql):
def __init__(self, operator, all: bool, query):
self.operator = operator
self.all = all
self.query = query
def is_public(self):
return self.query.is_public()
def children(self):
tokens = [self.operator]
if self.all:
tokens.append(Token("ALL"))
return tokens + [self.query]
def load_symbols(self, metadata, privacy=None, ctes: Dict[str, Query] = {}):
self.query.load_symbols(metadata, privacy=privacy, ctes=ctes)
def is_public(self):
return self.query.is_public()
class Select(Sql):
"""Result Columns"""
def __init__(self, quantifier, namedExpressions, ctes=None):
self.quantifier = quantifier
self.namedExpressions = Seq(namedExpressions)
self.ctes = (
Seq(ctes) if ctes is not None
else None
)
def functions(self):
return [c for c in self.namedExpressions if type(c.expression) is AggFunction]
def aggregates(self):
return [f for f in self.functions() if f.is_aggregate()]
def children(self):
ctes = (
[Token("WITH"), self.ctes]
if self.ctes is not None else
[]
)
return ctes + [Token("SELECT"), self.quantifier, self.namedExpressions]
def is_public(self):
return all([
ne.expression.is_public()
for ne in self.namedExpressions
if isinstance(ne.expression, (Query, AliasedSubquery))
])
class From(Sql):
"""From"""
def __init__(self, relations):
self.relations = Seq(relations)
def is_public(self):
return all([r.is_public() for r in self.relations])
def children(self):
return [Token("FROM"), self.relations]
class Where(Sql):
"""Predicates."""
def __init__(self, condition):
self.condition = condition
def children(self):
return [Token("WHERE"), self.condition]
class Aggregate(Sql):
"""Group By"""
def __init__(self, groupingExpressions):
self.groupingExpressions = Seq(groupingExpressions)
def groupedColumns(self):
return [ge.expression for ge in self.groupingExpressions if type(ge.expression) == Column]
def children(self):
return [Token("GROUP"), Token("BY"), self.groupingExpressions]
class Having(Sql):
"""Having clause"""
def __init__(self, condition):
self.condition = condition
def children(self):
return [Token("HAVING"), self.condition]
class Order(Sql):
"""Order By"""
def __init__(self, sortItems):
self.sortItems = Seq(sortItems)
def children(self):
return [Token("ORDER"), Token("BY"), self.sortItems]
def symbol(self, relations):
return Order(self.sortItems.symbol(relations))
class Limit(Sql):
"""Limit"""
def __init__(self, n, offset=None):
self.n = n
self.offset = offset
def children(self):
start = [Token("LIMIT"), Literal(self.n, str(self.n))]
if self.offset is not None:
start += [Token("OFFSET"), Literal(self.offset, str(self.offset))]
return start
def symbol(self, relations):
return self
class Top(Sql):
"""Top"""
def __init__(self, n):
self.n = n
def children(self):
return [Token("TOP"), Literal(self.n, str(self.n))]
def symbol(self, relations):
return self
"""
RELATIONS
"""
class Relation(SqlRel):
"""A relation such as table, join, or subquery"""
def __init__(self, primary, joins):
self.primary = primary
self.joins = joins if joins is not None else []
def is_public(self):
return self.primary.is_public() and all([j.is_public() for j in self.joins])
def load_symbols(self, metadata, privacy=None, ctes: Dict[str, Query] = {}):
strip_escape = metadata.compare.strip_escapes
self.privacy = privacy
relations = [self.primary] + [j for j in self.joins]
for r in relations:
r.load_symbols(metadata, privacy=privacy, ctes=ctes)
# check the join keys
if len(self.joins) > 0:
primary_symbols = [s.name.lower() for s in self.primary.all_symbols(AllColumns())]
for j in self.joins:
join_symbols = [s.name.lower() for s in j.right.all_symbols(AllColumns())]
if type(j.criteria) is UsingJoinCriteria:
for i in j.criteria.identifiers:
if not strip_escape(i.name.lower()) in [strip_escape(i) for i in primary_symbols]:
raise ValueError(
"Join clause uses a join column that doesn't exist in the primary relation: "
+ str(i)
)
if not strip_escape(i.name.lower()) in [strip_escape(i) for i in join_symbols]:
raise ValueError(
"Join clause uses a join column that doesn't exist in the joined relation: "
+ str(i)
)
primary_symbols.extend(join_symbols)
def symbol(self, expression):
if type(expression) is not Column:
raise ValueError("Tables can only have column symbols: " + str(type(expression)))
alias, colname = self.split_alias(expression.name)
alias = alias if alias != "" else None
syms_a = self.all_symbols(AllColumns(alias))
syms_b = [s for s in syms_a if s is not None]
syms_c = [
s.expression
for s in syms_b
if (type(s.expression) is TableColumn and s.expression.compare.identifier_match(colname, s.name))
or s.name == colname
]
if len(syms_c) == 1:
return syms_c[0]
elif len(syms_c) > 1:
raise NameError("Too many relations matched column, ambiguous: " + str(expression))
else:
raise ValueError("Symbol could not be found in any relations: " + str(expression))
def all_symbols(self, expression=None):
strip_escapes = NameCompare().strip_escapes
if expression is None:
expression = AllColumns()
if type(expression) is not AllColumns:
raise ValueError("Can only request all columns with * : " + str(type(expression)))
syms = (
self.primary.all_symbols(expression)
if self.primary.alias_match(str(expression))
else []
)
for j in self.joins:
if not j.alias_match(str(expression)):
continue
drop_cols = []
alias, name = self.split_alias(str(expression))
# if alias.* specified, don't drop join column
if type(j.criteria) is UsingJoinCriteria and alias == "":
drop_cols = [
strip_escapes(str(i).lower())
for i in j.criteria.identifiers
]
syms = syms + [
Symbol(sym.expression, sym.name)
for sym in j.all_symbols(expression)
if strip_escapes(sym.name.lower()) not in drop_cols
]
if len(syms) == 0:
raise ValueError("Symbol could not be found in any relations: " + str(expression))
return syms
def children(self):
return [self.primary] + self.joins
class Tablesample(SqlRel):
def __init__(self, sampling_method, percentage, n_rows, seed):
self.sampling_method = sampling_method
self.percentage = percentage
self.n_rows = n_rows
self.seed = seed
def children(self):
start = [
Token("TABLESAMPLE"),
Identifier(self.sampling_method),
Token('('),
]
middle = (
[Literal(self.percentage), Token(')')] if self.percentage
else [Literal(self.n_rows), Token('ROWS )')]
)
end = (
[
Token("REPEATABLE ("),
Literal(self.seed),
Token(')'),
] if self.seed
else []
)
return start + middle + end
def symbol(self, relations):
return self
class Unnest(SqlRel):
"""Unnest operator in bigquery generates a table from an Array.
We need it since for SQL implementation of range inference there
make use of queries with UNNEST.
"""
def __init__(self, name, alias):
self.name = name
self.alias = alias
def children(self):
alias = [Token("AS"), self.alias] if self.alias else []
return [Token("UNNEST"), Token("("), self.name, Token(")"), *alias]
def symbol(self, expression):
if self.alias:
return self.alias
else:
return self.name
class Table(SqlRel):
"""A fully qualified table name with optional alias"""
def __init__(self, name, alias, tablesample):
self.name = name
self.alias = alias
self.tablesample = tablesample
self._select_symbols = None
self._named_symbols = None
self._is_public = None
def symbol(self, expression):
if type(expression) is not Column:
raise ValueError("Tables can only have column symbols: " + str(type(expression)))
if not self.alias_match(expression.name):
raise ValueError(
"Attempt to look up symbol with different alias. Use alias_match() first."
+ expression.name
+ " -- "
+ str(self.name)
)
alias, name = self.split_alias(expression.name)
if self._select_symbols is None:
raise ValueError("Please load symbols with metadata first: " + str(self))
else:
if name in self:
return self[name]
else:
return None
def is_public(self):
if self._is_public is None:
raise ValueError("Cannot get `is_public` if metadata have not been loaded")
return self._is_public
def load_symbols(self, metadata, privacy=None, ctes: Dict[str, Query] = {}):
self.privacy = privacy
self._named_symbols = None
if metadata is None:
return
# First we check in the common table exprs then in the tables
if str(self.name) in ctes:
cte = ctes[str(self.name)]
self._select_symbols = cte.all_symbols()
self.is_public = cte.is_public
else:
table = metadata[str(self.name)]
if table is None:
raise ValueError(
"No metadata available for " + str(self.name)
)
self._is_public = table.is_public
tc = table.m_columns
def get_table_expr(name):
return TableColumn(
tablename=self.name,
colname=name,
valtype=tc[name].typename(),
is_key=tc[name].is_key,
lower=tc[name].lower if hasattr(tc[name], "lower") else None,
upper=tc[name].upper if hasattr(tc[name], "upper") else None,
nullable=tc[name].nullable if hasattr(tc[name], "nullable") else True,
missing_value=tc[name].missing_value if hasattr(tc[name], "missing_value") else None,
sensitivity=tc[name].sensitivity if hasattr(tc[name], "sensitivity") else None,
possible_values=tc[name].possible_values if hasattr(tc[name], "possible_values") else [],
max_ids=table.max_ids,
sample_max_ids=table.sample_max_ids,
row_privacy=table.row_privacy,
censor_dims=table.censor_dims,
compare=metadata.compare,
)
self._select_symbols = [Symbol(get_table_expr(name), name) for name in tc.keys()]
def escaped(self):
# is any part of this identifier escaped?
parts = str(self).split(".")
return any([p.startswith('"') or p.startswith("[") for p in parts])
def children(self):
table = [self.name]
alias = [Token("AS"), self.alias] if self.alias is not None else []
tablesample = [self.tablesample] if self.tablesample is not None else []
return table + alias + tablesample
def all_symbols(self, expression=None):
"""Returns all table symbols. If an expression is provided it can only an AllColumns
If AllColumns will have a not None table, symbols will be returned only after the alias/name match.
"""
if expression is None:
expression = AllColumns()
if type(expression) is not AllColumns:
raise ValueError("Can only request all columns with * : " + str(type(expression)))
syms = (
self._select_symbols
if self.alias_match(str(expression))
else []
)
return syms
class AliasedSubquery(SqlRel, SqlExpr):
"""A subquery with optional alias. It can be relation and also an expression"""
def __init__(self, query, alias):
self.query = query
self.alias = alias
def is_public(self):
return self.query.is_public()
def symbol(self, expression):
if hasattr(expression, "name"):
alias, name = self.split_alias(expression.name)
return self.query.symbol(Column(name))
else:
return self
def all_symbols(self, expression):
if type(expression) is not AllColumns:
raise ValueError("Need to pass in a * or alias.* to get all columns")
if not self.alias_match(str(expression)):
raise ValueError("Requesting all coluns with mismatched alias")
return self.query.all_symbols(AllColumns())
def children(self):
return [Token("("), self.query, Token(")")] + (
[Token("AS"), self.alias] if self.alias is not None else []
)
class NamedSubquery(SqlRel):
def __init__(self, name, cols, subquery):
self.subquery = subquery
self.cols = Seq(cols)
self.name = name
def children(self):
cte_columns = [Token("("), self.cols, Token(")")] if self.cols else []
return [self.name, *cte_columns, Token("AS"), Token("("), self.subquery, Token(")")]
def symbol(self, expression):
if hasattr(expression, "name"):
alias, name = self.split_alias(expression.name)
return self.subquery.symbol(Column(name))
else:
return self
def all_symbols(self, expression):
if type(expression) is not AllColumns:
raise ValueError("Need to pass in a * or alias.* to get all columns")
if not self.alias_match(str(expression)):
raise ValueError("Requesting all coluns with mismatched alias")
return self.query.all_symbols(AllColumns())
def is_public(self):
return self.query.is_public()
class AliasedRelation(SqlRel):
"""A subrelation (table, join, or subquery) with optional alias"""
def __init__(self, relation, alias):
self.relation = relation
self.alias = alias
def symbol(self, expression):
alias, name = self.split_alias(expression.name)
return self.relation.symbol(Column(name))
def all_symbols(self, expression):
if type(expression) is not AllColumns:
raise ValueError("Need to pass in a * or alias.* to get all columns")
if not self.alias_match(str(expression)):
raise ValueError("Requesting all coluns with mismatched alias")
return self.relation.all_symbols(AllColumns())
def children(self):
return [Token("("), self.relation, Token(")")] + (
[Token("AS"), self.alias] if self.alias is not None else []
)
class Join(SqlRel):
"""A join expression attached to a primary relation"""
def __init__(self, joinType, right, criteria):
self.joinType = joinType
self.right = right
self.criteria = criteria
def is_public(self):
return self.right.is_public()
def symbol(self, expression):
return self.right.symbol(expression)
def all_symbols(self, expression):
return self.right.all_symbols(expression)
def children(self):
return [self.joinType, Token("JOIN"), self.right, self.criteria]
#
# METADATA
#
class TableColumn(SqlExpr):
""" A column attached to a fully qualified table """
def __init__(
self,
tablename,
colname,
valtype="unknown",
*ignore,
is_key=False,
lower=None,
upper=None,
max_ids=1,
sample_max_ids=True,
row_privacy=False,
censor_dims=False,
compare=None,
nullable = True,
missing_value = None,
sensitivity = None,
possible_values = []
):
self.tablename = tablename
self.colname = colname
self.valtype = valtype
self.is_key = is_key
self.lower = lower
self.upper = upper
self.max_ids = max_ids
self.sample_max_ids = sample_max_ids
self.row_privacy = row_privacy
self.censor_dims = censor_dims
self.unbounded = lower is None or upper is None
self.nullable = nullable
self.missing_value = missing_value
self._sensitivity = sensitivity
self.compare = compare
self.possible_values = possible_values
def __str__(self):
return self.tablename + "." + self.colname
def __eq__(self, other):
return isinstance(self, type(other)) and self.tablename == other.tablename and self.colname == other.colname
def __hash__(self):
return hash((self.tablename, self.colname))
def type(self):
return self.valtype
def sensitivity(self):
if self.valtype in ["int", "float"]:
if self.lower is not None and self.upper is not None:
bounds_sensitivity = max(abs(self.upper), abs(self.lower))
if self._sensitivity is not None:
return self._sensitivity
else:
return bounds_sensitivity
else:
if self._sensitivity is not None:
return self._sensitivity
else:
return np.inf # unbounded
elif self.valtype == "boolean":
return 1
else:
return None
def evaluate(self, bindings):
if str(self).lower() in bindings:
return bindings[str(self).lower()]
else:
return None
@property
def is_key_count(self):
return self.is_key
@property
def is_count(self):
return False