Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
"""Module containing the merging function used to merge DP with synthetic
results."""
from enum import Enum
from typing import Any, Dict, List, Tuple
import warnings
from snsql._ast.ast import Query
from sarus_sql.merge_strategies.base import Merger
from . import ast_utils
from .merge_strategies import (
bayesian,
logistic,
neural_network,
no_synthetic,
simple,
synthetic_as_fallback,
)
from .process_results import apply_limit_clause, apply_order_by_clause
class MergeStrategy(Enum):
"""Enum class with possible merging strategies implemented so far."""
SIMPLE = simple
BAYESIAN = bayesian
SYNTHETIC_AS_FALLBACK = synthetic_as_fallback
NO_SYNTHETIC = no_synthetic
LOGISTIC = logistic
NN = neural_network
def __call__(self, **kwargs: Any) -> Merger:
return self.__call__(**kwargs)
def groupby_cols(query: Query) -> List[str]:
"""Extracts group by columns from query
Args:
query (Query): it should be a valid query
Returns:
List[str]: a list with groupby column names as string.
"""
grouping_cols = [
query.name_compare(col.name) for col in query.agg.groupedColumns()
]
grouping_cols.extend(
[
query.name_compare(nexpr.name)
for nexpr in query.select.namedExpressions
if ast_utils.only_grouping_cols(nexpr.expression, query.agg)
]
)
return grouping_cols
def extract_cols(row: List[Any], indices: List[int]) -> List[Any]:
"""Extract 1 or more column values from line given the index.
It avoids a double for loop when extracting values of group by columns.
Args:
row (List[Any]): line with dp or synthetic results
indices (List[int]): list with index of columns we want to extract
Returns:
List[Any]: list with values looked for.
"""
return [row for ind, row in enumerate(row) if ind in indices]
def extract_values(
result_list: List[Dict[str, Any]], key_idx: List[int]
) -> Tuple[List[Tuple[Any, ...]], List[List[Any]]]:
"""Extract values from a list with dict given the index of group by keys:
e.g. given result_list = [{"group_1":"A","count_all":222}] and key_idx=[0]
if will provide ([("A", )], [["A", 222]])
Args:
result_list (List[Dict[str, Any]])
key_idx (List[int])
Returns:
Tuple[List[Tuple[Any]], List[List[Any]]]
"""
values = [list(row.values()) for row in result_list]
keys = [tuple(extract_cols(row, key_idx)) for row in values]
return keys, values
def full_dp_results(
dp_results: List[Dict[str, Any]],
grouping_keys: List[Dict[str, Any]],
query: Query,
) -> List[Dict[str, Any]]:
"""It joins the information from dp_results and grouping_keys.
If group by columns are not included in the select statements they will
appear only on grouping_keys.
Args:
dp_results (List[Dict[str, Any]]): as coming from merge_result
grouping_keys (List[Dict[str, Any]]): as coming from merge_result
query (Query): valid query
Returns:
List[Dict[str, Any]]: dp_results with group by column values.
"""
all_columns = [
query.name_compare(nexpr.name)
for nexpr in query.select.namedExpressions
]
return [
{
name: dict_results[name]
if name in dict_results
else dict_goup[name]
for name in all_columns
}
for dict_results, dict_goup in zip(dp_results, grouping_keys)
]
def is_mergeable(
dp_keys: List[Tuple[Any, ...]],
dp_values: List[List[Any]],
sy_keys: List[Tuple[Any, ...]],
sy_values: List[List[Any]],
) -> bool:
"""Checks wheather dp and sythetic values are among allowed types.
If yes returns ture meaning that any merge strategy can be applied.
Args:
dp_keys (List[Tuple[Any, ...]]):
list containing tuples with group by values for each line.
They are used as keys. e.g. given a query with 'GROUP BY sex,
education_num', dp_keys would be :
[("Male", 1), ("Female", 1)] where "Female" and "Male"
dp_values (List[List[Any]]):
list with row values from dp results. e.g. given a query like
'SELECT sex, education_num, AVG(age), COUNT(*) FROM census
GROUP BY sex, education_num' dp_values would be:
[["Male", 1, 32.3, 444], ["Female", 1, 44.4, 654]]
sy_keys (List[Tuple[Any, ...]]): similarly as dp_keys, list
containing tuples with group by values for each line.
sy_values (List[List[Any]]): similarly as dp_values,
list with row values from synthetic results.
Returns:
bool
"""
allowed_types = (int, float, bool, type(None))
for key, value in zip(dp_keys, dp_values):
if not all(
[isinstance(i, allowed_types) for i in value if i not in key]
):
return False
for key, value in zip(sy_keys, sy_values):
if not all(
[isinstance(i, allowed_types) for i in value if i not in key]
):
return False
return True
def merge_results( # pylint: disable=too-many-locals
dp_results: List[Dict[str, Any]],
grouping_keys: List[Dict[str, Any]],
synth_results: List[Dict[str, Any]],
query: Query,
merge_strategy: MergeStrategy = MergeStrategy.SIMPLE,
**kwargs: Any,
) -> List[Dict[str, Any]]:
"""main function used in Private Learning lab
Args:
dp_results (List[Dict[str, Any]]): _description_
grouping_keys (List[Dict[str, Any]]): _description_
synth_results (List[Dict[str, Any]]):
It is assumed that it has all in the informations about
query (Query): _description_
merge_strategy (MergeStrategy, optional): _description_.
Defaults to MergeStrategy.SIMPLE.
**kwargs: (Any): any additional parameters to be passed
to merging strategies.
Returns:
List[Dict[str, Any]]: _description_
"""
strategy = merge_strategy(**kwargs)
if dp_results:
cols_to_release = list(dp_results[0].keys())
else:
cols_to_release = list(synth_results[0].keys())
if query.agg:
group_by_cols = groupby_cols(query)
else:
group_by_cols = []
dp_full_results = full_dp_results(dp_results, grouping_keys, query)
if dp_full_results:
all_columns = list(dp_full_results[0].keys())
else:
all_columns = cols_to_release.copy() + group_by_cols
group_by_idx = [all_columns.index(i) for i in group_by_cols]
dp_keys, dp_values = extract_values(dp_full_results, group_by_idx)
sy_keys, sy_values = extract_values(synth_results, group_by_idx)
if is_mergeable(dp_keys, dp_values, sy_keys, sy_values):
results = strategy.merge(dp_keys, dp_values, sy_keys, sy_values)
else:
# if resuts are not mergeable return the dp results.
warnings.warn(
"Merge can't be performed because of datatype incompatibility. "
"Outputing NO_SYNTHETIC merge strategy."
)
results = no_synthetic().merge(dp_keys, dp_values, sy_keys, sy_values)
results_dict = [
{name: res for name, res in zip(all_columns, row)} for row in results
]
results_dict = apply_order_by_clause(results_dict, query)
results_dict = apply_limit_clause(results_dict, query)
# release only the columns the user has asked for
return [
{name: res for name, res in row.items() if name in cols_to_release}
for row in results_dict
]