Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_sql / sarus_sql / process_results.py
Size: Mime:
# pylint: disable=consider-using-generator
"""Set of function used to filter the results"""

import itertools
import typing as t

from snsql._ast.ast import AllColumns, Column, Query, Top

from sarus_sql.ast_transform import escape_quotes
from sarus_sql.constants import KEYCOUNT_COLNAME


def apply_thresholding(
    results: t.List[t.Dict[str, t.Any]],
    threshold: float,
) -> t.List[t.Dict[str, t.Any]]:
    """tau-thresholding"""
    if len(results) == 0:
        return results
    if KEYCOUNT_COLNAME not in results[0]:
        raise ValueError(
            "Cannot apply thresholding if there is not "
            f"{KEYCOUNT_COLNAME} column."
        )
    return [row for row in results if row[KEYCOUNT_COLNAME] > threshold]


def post_process_results(
    results: t.List[t.Dict[str, t.Any]],
    query: Query,
    cols_to_release: t.List[str],
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.List[t.Dict[str, t.Any]]]:
    """Apply Where, orde by and limite clauses and split results to be
    released from hidden results (group by cols that have not been
    requested by the user)"""
    assert query.has_symbols()
    if len(results) == 0:
        return [], []

    released_results = apply_where_clause(results, query)
    released_results = apply_order_by_clause(released_results, query)
    released_results = apply_limit_clause(released_results, query)
    released_results, hidden_results = censor_columns(
        released_results, cols_to_release
    )
    return released_results, hidden_results


def censor_columns(
    results: t.List[t.Dict[str, t.Any]],
    cols: t.List[str],
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.List[t.Dict[str, t.Any]]]:
    """Release on the columns the user has asked for"""
    cols_to_release = [escape_quotes(col) for col in cols]

    released_results = []
    hidden_results = []

    for row in results:
        row_hidden_results = {}
        row_released_results = {}
        for name, value in row.items():
            if name in cols_to_release:
                row_released_results[name] = value
            else:
                row_hidden_results[name] = value
        released_results.append(row_released_results)
        hidden_results.append(row_hidden_results)

    return released_results, hidden_results


def apply_where_clause(
    results: t.List[t.Dict[str, t.Any]],
    query: Query,
) -> t.List[t.Dict[str, t.Any]]:
    """Applies WHERE condition to results
    (in the original query, this was an HAVING clause)"""
    if not query.where:
        return results
    return [row for row in results if query.where.condition.evaluate(row)]


def apply_order_by_clause(
    results: t.List[t.Dict[str, t.Any]],
    query: Query,
) -> t.List[t.Dict[str, t.Any]]:
    """Applies ORDER BY clause to results"""
    if not query.order or len(results) == 0:
        return results

    out_types = {
        escape_quotes(symb.name): symb.expression.type()
        for symb in query.all_symbols(AllColumns())
    }

    sort_fields = []
    for sort_item in query.order.sortItems:
        if not isinstance(sort_item.expression, Column):
            raise NotImplementedError(
                "We only know how to sort by column names."
            )
        colname = escape_quotes(sort_item.expression.name)
        if colname not in out_types:
            raise ValueError(
                f"Can't sort by '{colname}', because it's not in "
                f"output columns: {out_types}"
            )
        column_type = out_types[colname]
        desc = False
        if sort_item.order is not None and sort_item.order.lower() == "desc":
            desc = True
        if desc and column_type not in ["int", "float", "boolean"]:
            raise ValueError(
                f"We don't know how to sort descending by {column_type}"
            )
        sort_fields.append((desc, colname))

    def sort_func(row: t.Dict[str, t.Any]) -> t.Tuple[t.Any, ...]:
        new_row: t.List[t.Any] = []
        for desc, idx in sort_fields:
            row_idx = row[idx]
            type_idx = out_types[idx]
            new_row.append((not desc) & (row_idx is None))
            if not desc:
                new_row.append(row_idx is None)
                new_row.append(row_idx)
            else:
                new_row.append(row_idx is not None)
                if row_idx is None:
                    new_row.append(row_idx)
                elif type_idx == "boolean":
                    new_row.append(not row_idx)
                else:
                    new_row.append(-row_idx)
        return tuple(new_row)

    return sorted(results, key=sort_func)


def apply_limit_clause(
    results: t.List[t.Dict[str, t.Any]], query: Query
) -> t.List[t.Dict[str, t.Any]]:
    """Applies LIMIT clause to results"""
    if len(results) == 0:
        return results
    if query.limit:
        limit_rows = query.limit.n
    elif query.select.quantifier is not None and isinstance(
        query.select.quantifier, Top
    ):
        limit_rows = query.select.quantifier.n
    else:
        return results

    return list(itertools.islice(results, limit_rows))