Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_sql / sarus_sql / merge_results.py
Size: Mime:
"""Module containing the merging function used to merge DP with synthetic
results."""

from enum import Enum
from typing import Any, Dict, List, Tuple
import warnings

from snsql._ast.ast import Query

from sarus_sql.merge_strategies.base import Merger

from . import ast_utils
from .merge_strategies import (
    bayesian,
    logistic,
    neural_network,
    no_synthetic,
    simple,
    synthetic_as_fallback,
)
from .process_results import apply_limit_clause, apply_order_by_clause


class MergeStrategy(Enum):
    """Enum class with possible merging strategies implemented so far."""

    SIMPLE = simple
    BAYESIAN = bayesian
    SYNTHETIC_AS_FALLBACK = synthetic_as_fallback
    NO_SYNTHETIC = no_synthetic
    LOGISTIC = logistic
    NN = neural_network

    def __call__(self, **kwargs: Any) -> Merger:
        return self.__call__(**kwargs)


def groupby_cols(query: Query) -> List[str]:
    """Extracts group by columns from query

    Args:
        query (Query): it should be a valid query

    Returns:
        List[str]: a list with groupby column names as string.
    """
    grouping_cols = [
        query.name_compare(col.name) for col in query.agg.groupedColumns()
    ]
    grouping_cols.extend(
        [
            query.name_compare(nexpr.name)
            for nexpr in query.select.namedExpressions
            if ast_utils.only_grouping_cols(nexpr.expression, query.agg)
        ]
    )
    return grouping_cols


def extract_cols(row: List[Any], indices: List[int]) -> List[Any]:
    """Extract 1 or more column values from line given the index.
    It avoids a double for loop when extracting values of group by columns.

    Args:
        row (List[Any]): line with dp or synthetic results
        indices (List[int]): list with index of columns we want to extract

    Returns:
        List[Any]: list with values looked for.
    """
    return [row for ind, row in enumerate(row) if ind in indices]


def extract_values(
    result_list: List[Dict[str, Any]], key_idx: List[int]
) -> Tuple[List[Tuple[Any, ...]], List[List[Any]]]:
    """Extract values from a list with dict given the index of group by keys:

    e.g. given result_list = [{"group_1":"A","count_all":222}] and key_idx=[0]
    if will provide ([("A", )], [["A", 222]])

    Args:
        result_list (List[Dict[str, Any]])
        key_idx (List[int])

    Returns:
        Tuple[List[Tuple[Any]], List[List[Any]]]
    """

    values = [list(row.values()) for row in result_list]
    keys = [tuple(extract_cols(row, key_idx)) for row in values]
    return keys, values


def full_dp_results(
    dp_results: List[Dict[str, Any]],
    grouping_keys: List[Dict[str, Any]],
    query: Query,
) -> List[Dict[str, Any]]:
    """It joins the information from dp_results and grouping_keys.
    If group by columns are not included in the select statements they will
    appear only on grouping_keys.

    Args:
        dp_results (List[Dict[str, Any]]): as coming from merge_result
        grouping_keys (List[Dict[str, Any]]): as coming from merge_result
        query (Query): valid query

    Returns:
        List[Dict[str, Any]]: dp_results with group by column values.
    """
    all_columns = [
        query.name_compare(nexpr.name)
        for nexpr in query.select.namedExpressions
    ]

    return [
        {
            name: dict_results[name]
            if name in dict_results
            else dict_goup[name]
            for name in all_columns
        }
        for dict_results, dict_goup in zip(dp_results, grouping_keys)
    ]


def is_mergeable(
    dp_keys: List[Tuple[Any, ...]],
    dp_values: List[List[Any]],
    sy_keys: List[Tuple[Any, ...]],
    sy_values: List[List[Any]],
) -> bool:
    """Checks wheather dp and sythetic values are among allowed types.
    If yes returns ture meaning that any merge strategy can be applied.

    Args:
        dp_keys (List[Tuple[Any, ...]]):
            list containing tuples with group by values for each line.
            They are used as keys. e.g. given a query with 'GROUP BY sex,
            education_num', dp_keys would be :
            [("Male", 1), ("Female", 1)] where "Female" and "Male"

        dp_values (List[List[Any]]):
            list with row values from dp results. e.g. given a query like
            'SELECT sex, education_num, AVG(age), COUNT(*) FROM census
            GROUP BY sex, education_num' dp_values would be:
            [["Male", 1, 32.3, 444], ["Female", 1, 44.4, 654]]
        sy_keys (List[Tuple[Any, ...]]): similarly as dp_keys, list
            containing tuples with group by values for each line.
        sy_values (List[List[Any]]): similarly as dp_values,
            list with row values from synthetic results.

    Returns:
        bool
    """
    allowed_types = (int, float, bool, type(None))

    for key, value in zip(dp_keys, dp_values):
        if not all(
            [isinstance(i, allowed_types) for i in value if i not in key]
        ):
            return False

    for key, value in zip(sy_keys, sy_values):
        if not all(
            [isinstance(i, allowed_types) for i in value if i not in key]
        ):
            return False

    return True


def merge_results(  # pylint: disable=too-many-locals
    dp_results: List[Dict[str, Any]],
    grouping_keys: List[Dict[str, Any]],
    synth_results: List[Dict[str, Any]],
    query: Query,
    merge_strategy: MergeStrategy = MergeStrategy.SIMPLE,
    **kwargs: Any,
) -> List[Dict[str, Any]]:
    """main function used in Private Learning lab

    Args:
        dp_results (List[Dict[str, Any]]): _description_
        grouping_keys (List[Dict[str, Any]]): _description_
        synth_results (List[Dict[str, Any]]):
            It is assumed that it has all in the informations about
        query (Query): _description_
        merge_strategy (MergeStrategy, optional): _description_.
            Defaults to MergeStrategy.SIMPLE.
        **kwargs: (Any): any additional parameters to be passed
            to merging strategies.

    Returns:
        List[Dict[str, Any]]: _description_
    """

    strategy = merge_strategy(**kwargs)

    if dp_results:
        cols_to_release = list(dp_results[0].keys())
    else:
        cols_to_release = list(synth_results[0].keys())

    if query.agg:
        group_by_cols = groupby_cols(query)
    else:
        group_by_cols = []

    dp_full_results = full_dp_results(dp_results, grouping_keys, query)

    if dp_full_results:
        all_columns = list(dp_full_results[0].keys())
    else:
        all_columns = cols_to_release.copy() + group_by_cols

    group_by_idx = [all_columns.index(i) for i in group_by_cols]
    dp_keys, dp_values = extract_values(dp_full_results, group_by_idx)
    sy_keys, sy_values = extract_values(synth_results, group_by_idx)

    if is_mergeable(dp_keys, dp_values, sy_keys, sy_values):
        results = strategy.merge(dp_keys, dp_values, sy_keys, sy_values)
    else:
        # if resuts are not mergeable return the dp results.
        warnings.warn(
            "Merge can't be performed because of datatype incompatibility. "
            "Outputing NO_SYNTHETIC merge strategy."
        )
        results = no_synthetic().merge(dp_keys, dp_values, sy_keys, sy_values)

    results_dict = [
        {name: res for name, res in zip(all_columns, row)} for row in results
    ]

    results_dict = apply_order_by_clause(results_dict, query)
    results_dict = apply_limit_clause(results_dict, query)

    # release only the columns the user has asked for
    return [
        {name: res for name, res in row.items() if name in cols_to_release}
        for row in results_dict
    ]