Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
# pylint: disable=consider-using-generator
"""Set of function used to filter the results"""
import itertools
import typing as t
from snsql._ast.ast import AllColumns, Column, Query, Top
from sarus_sql.ast_transform import escape_quotes
from sarus_sql.constants import KEYCOUNT_COLNAME
def apply_thresholding(
results: t.List[t.Dict[str, t.Any]],
threshold: float,
) -> t.List[t.Dict[str, t.Any]]:
"""tau-thresholding"""
if len(results) == 0:
return results
if KEYCOUNT_COLNAME not in results[0]:
raise ValueError(
"Cannot apply thresholding if there is not "
f"{KEYCOUNT_COLNAME} column."
)
return [row for row in results if row[KEYCOUNT_COLNAME] > threshold]
def post_process_results(
results: t.List[t.Dict[str, t.Any]],
query: Query,
cols_to_release: t.List[str],
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.List[t.Dict[str, t.Any]]]:
"""Apply Where, orde by and limite clauses and split results to be
released from hidden results (group by cols that have not been
requested by the user)"""
assert query.has_symbols()
if len(results) == 0:
return [], []
released_results = apply_where_clause(results, query)
released_results = apply_order_by_clause(released_results, query)
released_results = apply_limit_clause(released_results, query)
released_results, hidden_results = censor_columns(
released_results, cols_to_release
)
return released_results, hidden_results
def censor_columns(
results: t.List[t.Dict[str, t.Any]],
cols: t.List[str],
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.List[t.Dict[str, t.Any]]]:
"""Release on the columns the user has asked for"""
cols_to_release = [escape_quotes(col) for col in cols]
released_results = []
hidden_results = []
for row in results:
row_hidden_results = {}
row_released_results = {}
for name, value in row.items():
if name in cols_to_release:
row_released_results[name] = value
else:
row_hidden_results[name] = value
released_results.append(row_released_results)
hidden_results.append(row_hidden_results)
return released_results, hidden_results
def apply_where_clause(
results: t.List[t.Dict[str, t.Any]],
query: Query,
) -> t.List[t.Dict[str, t.Any]]:
"""Applies WHERE condition to results
(in the original query, this was an HAVING clause)"""
if not query.where:
return results
return [row for row in results if query.where.condition.evaluate(row)]
def apply_order_by_clause(
results: t.List[t.Dict[str, t.Any]],
query: Query,
) -> t.List[t.Dict[str, t.Any]]:
"""Applies ORDER BY clause to results"""
if not query.order or len(results) == 0:
return results
out_types = {
escape_quotes(symb.name): symb.expression.type()
for symb in query.all_symbols(AllColumns())
}
sort_fields = []
for sort_item in query.order.sortItems:
if not isinstance(sort_item.expression, Column):
raise NotImplementedError(
"We only know how to sort by column names."
)
colname = escape_quotes(sort_item.expression.name)
if colname not in out_types:
raise ValueError(
f"Can't sort by '{colname}', because it's not in "
f"output columns: {out_types}"
)
column_type = out_types[colname]
desc = False
if sort_item.order is not None and sort_item.order.lower() == "desc":
desc = True
if desc and column_type not in ["int", "float", "boolean"]:
raise ValueError(
f"We don't know how to sort descending by {column_type}"
)
sort_fields.append((desc, colname))
def sort_func(row: t.Dict[str, t.Any]) -> t.Tuple[t.Any, ...]:
new_row: t.List[t.Any] = []
for desc, idx in sort_fields:
row_idx = row[idx]
type_idx = out_types[idx]
new_row.append((not desc) & (row_idx is None))
if not desc:
new_row.append(row_idx is None)
new_row.append(row_idx)
else:
new_row.append(row_idx is not None)
if row_idx is None:
new_row.append(row_idx)
elif type_idx == "boolean":
new_row.append(not row_idx)
else:
new_row.append(-row_idx)
return tuple(new_row)
return sorted(results, key=sort_func)
def apply_limit_clause(
results: t.List[t.Dict[str, t.Any]], query: Query
) -> t.List[t.Dict[str, t.Any]]:
"""Applies LIMIT clause to results"""
if len(results) == 0:
return results
if query.limit:
limit_rows = query.limit.n
elif query.select.quantifier is not None and isinstance(
query.select.quantifier, Top
):
limit_rows = query.select.quantifier.n
else:
return results
return list(itertools.islice(results, limit_rows))