Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
"""Local implementation of range detection"""

from __future__ import annotations

from typing import Optional, Tuple

from sarus_statistics.ops.utils import laplace_mechanism, rescale_weights_sql
import numpy as np
import sqlalchemy as sa
import typing as t
from sarus_data_spec.typing import Type


# pylint: disable=too-many-arguments, too-many-locals, too-many-branches
def automatic_column_range(
    session: sa.orm.Session,
    table: sa.sql.FromClause,
    data_col: str,
    private_col: str,
    user_col: str,
    weight_col: str,
    data_type: Type,
    noise: float,
    prob_no_false_positive: float = 1 - 1e-9,
    max_multiplicity: float = 1,
    estimate: Tuple[Optional[float], Optional[float]] = (None, None),
    random_generator: Optional[np.random.Generator] = None,
    result_table_prefixes: Optional[Tuple[sa.sql.ColumnElement]] = None,
    return_insert: bool = False,
) -> Tuple[sa.sql.Selectable, sa.sql.Selectable]:
    """Automatic bounding function from
    https://arxiv.org/pdf/1909.01917.pdf paragraph 5.1.1

    Parameters
    -----------
    data: pd.Dataframe
        column to evaluate
    data_col: str
        name of the column with values to be evaluated
    user_col: str
        name of the user column
    private_col: str
        name of the column indicating the privacy status
    weight_col: str
        name of the weight_col
    data_type:
        type of the data (from sarus_dataset.type)
    noise: float
        scale of the laplace noise
    prob_no_false_positive: float
        probability of not having
        a false positive, should be very close to 1.
    max_multiplicity: float
        maximum weight per user considered.
        values are sampled if it overflows.

    Returns
    -------
    Tuple[float, float, PrivateQuery]
        minimum and maximum bounding power of 2 and
        list of queries

    Raises
    -------
    ValueError
        if prob_no_false_positive not in [0,1]
    TypeError
        if dtype not in float64, float32 or int64 to int8"

    As this is a histogram query, the total privacy spent is only one
    LaplaceQuery(noise) and not len(categories) of it.
    This is because adding or removing one element can only change the value
    of one bin.
    See https://www.tau.ac.il/~saharon/BigData2018/privacybook.pdf#page=37
    for more information.

    Note: this holds for the definition of DP where two datasets are adjacent
    if they differ by the addition or removal of one sample. If substition is
    allowed, the corresponding privacy consumption would double.
    """
    assert result_table_prefixes

    if prob_no_false_positive < 0 or prob_no_false_positive > 1:
        raise ValueError("prob_no_false_positive should be in [0,1]")

    if estimate[0] == estimate[1] and estimate[0] is not None:
        min_statement = (
            sa.select(
                result_table_prefixes[0],
                sa.text("'min'"),
                sa.text(f"{estimate[0]}"),
            )
            .limit(1)
            .select_from(table)
        )

        max_statement = (
            sa.select(
                result_table_prefixes[0],
                sa.text("'max'"),
                sa.text(f"{estimate[1]}"),
            )
            .limit(1)
            .select_from(table)
        )
        return min_statement, max_statement
    private_data = getattr(table.c, data_col)
    dtype = type(private_data.type)
    if dtype == sa.sql.sqltypes.Boolean:
        min_statement = (
            sa.select(result_table_prefixes[0], sa.text("'min'"), sa.text("0"))
            .limit(1)
            .select_from(table)
        )
        max_statement = (
            sa.select(result_table_prefixes[0], sa.text("'max'"), sa.text("1"))
            .limit(1)
            .select_from(table)
        )
        return min_statement, max_statement

    if dtype == sa.sql.sqltypes.Float:
        # TODO: doubles
        bitsize = 1023 * 2
        bins = (
            [-(2 ** (bitsize // 2 - b)) for b in range(bitsize)]
            + [0]
            + [2 ** (b - (bitsize // 2)) for b in range(bitsize)]
        )
    elif dtype in (
        sa.sql.sqltypes.Integer,
        sa.sql.sqltypes.BigInteger,
    ):
        private_data = sa.cast(private_data, sa.Float)
        # TODO: other int types
        bitsize = 63
        bins = (
            [-(1 << b) for b in reversed(range(bitsize))]
            + [0]
            + [1 << b for b in range(bitsize)]
        )
    else:
        raise TypeError(
            f"Dtype {dtype} not implemented, please convert to float64,"
            " float32 or int64 to int8"
        )

    if estimate[0] is not None:
        bins = [estimate[0]] + [b for b in bins if b > estimate[0]]
    if estimate[1] is not None:
        bins = [b for b in bins if b < estimate[1]] + [estimate[1]]

    # rescale weights for private rows
    table = rescale_weights_sql(
        session,
        table=table,
        data_col=data_col,
        user_col=user_col,
        private_col=private_col,
        weight_col=weight_col,
        max_multiplicity=max_multiplicity,
    )

    binned = sa.select(
        sa.select(case_from_bins(private_data, bins).label("binned")).cte()
    ).cte()

    counts = (
        sa.select(
            binned.c.binned,
            laplace_mechanism(
                sa.func.count(binned.c.binned),
                noise * max_multiplicity,
            ).label("count_"),
        )
        .group_by(binned.c.binned)
        .cte()
    )

    # TODO add weights

    threshold = (
        -noise
        * max_multiplicity
        * np.log(1 - (prob_no_false_positive ** (1 / (len(bins) - 1))))
    )
    minimum = (
        sa.select((sa.func.min(counts.c.binned)).label("min_value"))
        .filter(counts.c.count_ >= threshold)
        .cte()
    )
    maximum = (
        sa.select((2 * sa.func.max(counts.c.binned)).label("max_value"))
        .filter(counts.c.count_ >= threshold)
        .cte()
    )

    if estimate[0] is not None:
        minimum = sa.select(
            (sa.func.greatest(minimum.c.min_value, estimate[0])).label(
                "min_value"
            )
        ).cte()
    if estimate[1] is not None:
        maximum = sa.select(
            sa.func.least(maximum.c.max_value, estimate[1]).label("max_value")
        ).cte()

    min_statement = sa.select(
        result_table_prefixes[0], sa.text("'min'"), minimum.c.min_value
    )
    max_statement = sa.select(
        result_table_prefixes[0], sa.text("'max'"), maximum.c.max_value
    )
    # TODO if no count is above threshold take min / max bin

    return min_statement, max_statement


def public_bounds(
    session: sa.orm.Session,
    table: sa.sql.FromClause,
    data_col: str,
    result_table_prefixes: Optional[Tuple[sa.sql.ColumnElement]] = None,
    return_insert: bool = False,
) -> Tuple[sa.sql.Selectable, sa.sql.Selectable]:
    """Automatic bounding function from
    https://arxiv.org/pdf/1909.01917.pdf paragraph 5.1.1

    Parameters
    -----------
    data: pd.Dataframe
        column to evaluate
    data_col: str
        name of the column with values to be evaluated
    user_col: str
        name of the user column
    private_col: str
        name of the column indicating the privacy status
    weight_col: str
        name of the weight_col
    data_type:
        type of the data (from sarus_dataset.type)
    noise: float
        scale of the laplace noise
    prob_no_false_positive: float
        probability of not having
        a false positive, should be very close to 1.
    max_multiplicity: float
        maximum weight per user considered.
        values are sampled if it overflows.

    Returns
    -------
    Tuple[float, float, PrivateQuery]
        minimum and maximum bounding power of 2 and
        list of queries

    Raises
    -------
    ValueError
        if prob_no_false_positive not in [0,1]
    TypeError
        if dtype not in float64, float32 or int64 to int8"

    As this is a histogram query, the total privacy spent is only one
    LaplaceQuery(noise) and not len(categories) of it.
    This is because adding or removing one element can only change the value
    of one bin.
    See https://www.tau.ac.il/~saharon/BigData2018/privacybook.pdf#page=37
    for more information.

    Note: this holds for the definition of DP where two datasets are adjacent
    if they differ by the addition or removal of one sample. If substition is
    allowed, the corresponding privacy consumption would double.
    """
    assert result_table_prefixes
    minimum = sa.select(sa.func.min(getattr(table.c, data_col)))
    maximum = sa.select(sa.func.max(getattr(table.c, data_col)))

    min_statement = sa.select(
        result_table_prefixes[0], sa.text("'min'"), minimum.cte()
    )
    max_statement = sa.select(
        result_table_prefixes[0], sa.text("'max'"), maximum.cte()
    )
    return min_statement, max_statement


def construct_inner_case(column: sa.Column, bins: t.List[float]) -> sa.Case:
    case_conditions = []
    for i in range(len(bins) - 1):
        lower_bound = bins[i]
        upper_bound = bins[i + 1]
        case_conditions.append(
            (column.between(lower_bound, upper_bound), lower_bound)
        )
    return sa.case(*case_conditions, else_=bins[-2])


def construct_outer_case(
    column: sa.Column, list_of_bins: t.List[t.List[float]]
) -> sa.Case:
    outer_case_conditions = []
    for bins in list_of_bins:
        inner_case = construct_inner_case(column, bins)
        outer_case_conditions.append(
            (column.between(bins[0], bins[-1]), inner_case)
        )
    return sa.case(*outer_case_conditions, else_=None)


def case_from_bins(column: sa.Column, bins: t.List[float]) -> sa.Case:
    """Given a column and bins, if the number of bins is < 4 it constructs
    the following case:

    CASE
        WHEN column BETWEEN bins[0] AND bins[1] THEN bins[0]
        WHEN column BETWEEN bins[1] AND bins[2] THEN bins[1]
        ...
    END

    Else It constructs the following:
    CASE
        WHEN column BETWEEN bins[0] AND bins[c1] THEN
            CASE
                WHEN column BETWEEN bins[0] AND bins[1] THEN bins[0]
                WHEN column BETWEEN bins[1] AND bins[2] THEN bins[1]
                WHEN column BETWEEN bins[c1-2] AND bins[c1-1] THEN bins[c1-2]
                ELSE bins[c1-1]
            END
        WHEN column BETWEEN bins[c1] AND bins[c2] THEN
            CASE
                WHEN column BETWEEN bins[c1] AND bins[c1+1] THEN bins[c1]
                ...
                ELSE bins[c2-1]
            END
        ...
    END
    This is equivalent to the former but prevents a stackoverflow error from
    qrlew when having a very long list of bins.

    The bin list is divided into chunks such that the number of WHEN conditions
    of the outer CASE is similar to the one in the inner CASE.
    """
    chunk_size = int(np.sqrt(len(bins)))
    if chunk_size > 1:
        # construct a nested case
        chunks = [
            bins[i : i + chunk_size] for i in range(0, len(bins), chunk_size)
        ]
        if len(chunks) > 1 and len(chunks[-1]) < 2:  # the minimum chunk size
            chunks[-2].extend(chunks[-1])
            # Remove the last chunk after merging
            chunks.pop()
        return construct_outer_case(column, chunks)
    else:
        # construct a standard case when the number of bins is very small.
        conditions = [
            (column.between(bins[index], bins[index + 1]), bins[index])
            for index in range(len(bins[:-1]))
        ]
        return sa.case(*conditions)