Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import typing as t
from warnings import warn

import numpy as np
import sqlalchemy as sa
from sqlalchemy.sql.expression import cast

from sarus_statistics.ops.histograms.local import CatType, NoiseKind
from sarus_statistics.ops.utils import (
    VIEW_ENABLED,
    laplace_mechanism,
    rescale_weights_sql,
)

SQLCatType = t.Union[CatType, sa.sql.Executable]


# pylint: disable=too-many-arguments, too-many-locals
def private_histogram(
    session: sa.orm.Session,
    table: sa.sql.FromClause,
    data_col: str,
    user_col: str,
    private_col: str,
    weight_col: str,
    noise: float,
    max_multiplicity: float = 1,
    categories: SQLCatType = None,
    public_categories: t.Optional[t.List[str]] = None,
    clip_below_zero: bool = True,
    noise_kind: NoiseKind = NoiseKind.LAPLACE,
    random_generator: t.Optional[np.random.Generator] = None,
    result_table_prefixes: t.Optional[
        t.Tuple[sa.sql.ColumnElement, sa.sql.ColumnElement]
    ] = None,
    return_insert: bool = True,
) -> sa.sql.ClauseElement:

    """Computes private histogram for a given categorical data
    Parameters
    -----------
    data: pd.DataFrame
        data on which to compute histogram
    noise: float
        scale of the gaussian noise to add to each category
    data_col: str
        name of the value's column
    user_col: str
        name of the users' column. If None, the row is public.
    private_col: str
        name of the column indicating the privacy status
    weight_col: str
        name of the weight's column
    max_multiplicity: float
        max number of identical users
    categories: Optional[Sequence[str]]
        t.List of categories for the histogram
    public_categories: Optional[t.List[str]]
        t.List of categories that are known
    Returns
    -------
    Tuple[Dict[float, float], PrivateQuery]
        dictionary of counts+ list of queries
    As this is a histogram query, the total privacy spent is only one
    LaplaceQuery(noise) and not len(categories) of it.
    This is because adding or removing one element can only change the value
    of one bin.
    See https://www.tau.ac.il/~saharon/BigData2018/privacybook.pdf#page=37
    for more information.
    Note: this holds for the definition of DP where two datasets are adjacent
    if they differ by the addition or removal of one sample. If substition is
    allowed, the corresponding privacy consumption would double.
    """
    # check_is_private(data, user_col, private_col)  # TODO: sql version
    assert result_table_prefixes
    if categories is None:
        warn("Privacy leak: no categories given, inferring from data")
        categories = session.query(getattr(table.c, data_col)).distinct()
    if public_categories is None:
        public_categories = []
    private_data = session.query(table)
    if len(public_categories) > 0:
        private_data = private_data.where(
            getattr(table.c, data_col).not_in(public_categories)
        )
    # rescale weights for private rows
    scaled_data = rescale_weights_sql(
        session,
        private_data.subquery(),
        data_col,
        user_col=user_col,
        private_col=private_col,
        weight_col=weight_col,
        max_multiplicity=max_multiplicity,
    )
    # TODO: rescale l2 ?
    scaled_col = getattr(scaled_data.c, data_col)
    private_hist = session.query(
        scaled_col.label('data'),
        laplace_mechanism(
            sa.func.count(),
            noise,
        ).label('count_'),
    ).group_by(scaled_col)
    # here we do not need to clip, just consider sum of weights
    if len(public_categories) > 0:
        public_data = (
            session.query(getattr(table.c, data_col))
            .where(getattr(table.c, data_col).in_(public_categories))
            .subquery()
        )
        public_hist = session.query(
            public_data, sa.func.count().label('count_')
        ).group_by(public_data)
        union_hists = private_hist.union_all(public_hist).subquery()
    else:
        union_hists = private_hist.subquery()
    tot_hist = session.query(
        sa.cast(union_hists.c.data, sa.VARCHAR).label('category'),
        sa.func.sum(union_hists.c.count_).label('count_'),
    ).group_by(union_hists)
    # if clip_below_zero:
    #     return (tot_hist.clip(lower=0).to_dict(), query)
    if not return_insert:
        return tot_hist.selectable
    return sa.select(
        result_table_prefixes[0],
        result_table_prefixes[1],
        tot_hist.subquery(),
    )


# pylint: disable=too-many-arguments, too-many-locals
def public_histogram(
    session: sa.orm.Session,
    table: sa.sql.FromClause,
    data_col: str,
    clip_below_zero: bool = True,
    random_generator: t.Optional[np.random.Generator] = None,
    result_table_prefixes: t.Optional[
        t.Tuple[sa.sql.ColumnElement, sa.sql.ColumnElement]
    ] = None,
    return_insert: bool = True,
) -> sa.sql.ClauseElement:
    """Computes private histogram for a given categorical data
    Parameters
    -----------
    data: pd.DataFrame
        data on which to compute histogram
    noise: float
        scale of the gaussian noise to add to each category
    data_col: str
        name of the value's column
    user_col: str
        name of the users' column. If None, the row is public.
    private_col: str
        name of the column indicating the privacy status
    weight_col: str
        name of the weight's column
    max_multiplicity: float
        max number of identical users
    categories: Optional[Sequence[str]]
        list of categories for the histogram
    public_categories: Optional[List[str]]
        list of categories that are known
    Returns
    -------
    Tuple[Dict[float, float], PrivateQuery]
        dictionary of counts+ list of queries
    As this is a histogram query, the total privacy spent is only one
    LaplaceQuery(noise) and not len(categories) of it.
    This is because adding or removing one element can only change the value
    of one bin.
    See https://www.tau.ac.il/~saharon/BigData2018/privacybook.pdf#page=37
    for more information.
    Note: this holds for the definition of DP where two datasets are adjacent
    if they differ by the addition or removal of one sample. If substition is
    allowed, the corresponding privacy consumption would double.
    """
    assert result_table_prefixes
    # TODO: rescale l2 ?
    data = getattr(table.c, data_col)
    hist = (
        session.query(data, sa.func.count().label('count_'))
        .group_by(data)
        .subquery()
    )
    tot_hist = session.query(
        sa.cast(getattr(hist.c, data_col), sa.VARCHAR).label('category'),
        sa.func.sum(hist.c.count_).label('count_'),
    ).group_by(hist)
    if not return_insert:
        return tot_hist.selectable
    return sa.select(
        result_table_prefixes[0],
        result_table_prefixes[1],
        tot_hist.subquery(),
    )


def dataset_length(
    session: sa.orm.Session,
    table: sa.schema.Table,
    max_multiplicity: float,
    noise: float,
    user_col: str,
    private_col: str,
    weight_col: str,
    result_table_prefixes: t.Optional[
        t.Tuple[sa.sql.ColumnElement, sa.sql.ColumnElement]
    ] = None,
) -> sa.sql.ClauseElement:
    """Compute DP length of dataset according to max_multiplicity of user"""
    assert result_table_prefixes
    # check_is_private(data, user_col, private_col)
    if VIEW_ENABLED:
        count_per_user = session.query(
            sa.func.least(
                max_multiplicity, sa.func.sum(getattr(table.c, weight_col))
            ).label('count_')
        ).group_by(user_col)
        total_count: t.Union[sa.orm.Query, sa.sql.Select] = session.query(
            cast(
                sa.func.count(table.table_valued()) * 0, sa.TEXT
            ),  # Hacky way to preserve the number of output
            laplace_mechanism(
                sa.func.sum(count_per_user.subquery().c.count_),
                noise * max_multiplicity,
            ),
        )
    else:
        total_count = sa.select(
            cast(
                sa.func.count() * 0, sa.TEXT
            ),  # Hacky way to preserve the number of output
            laplace_mechanism(
                sa.func.count(),
                noise * max_multiplicity,
            ),
        ).select_from(table)
    return sa.select(
        result_table_prefixes[0],
        result_table_prefixes[1],
        total_count.subquery(),
    )


def non_private_dataset_length(
    session: sa.orm.Session,
    table: sa.schema.Table,
    user_col: str,
    private_col: str,
    weight_col: str,
    result_table_prefixes: t.Optional[
        t.Tuple[sa.sql.ColumnElement, sa.sql.ColumnElement]
    ] = None,
) -> sa.sql.ClauseElement:
    """Compute length of dataset without DP
    DP should be added in post-processing"""
    assert result_table_prefixes
    # check_is_private(data, user_col, private_col)
    if VIEW_ENABLED:
        count_per_user = session.query(
            sa.func.sum(getattr(table.c, weight_col)).label('count_')
        ).group_by(user_col)
        total_count: t.Union[sa.orm.Query, sa.sql.Select] = session.query(
            cast(
                sa.func.count(table.table_valued()) * 0, sa.TEXT
            ),  # Hacky way to preserve the number of output
            sa.func.sum(count_per_user.subquery().c.count_),
        )
    else:
        total_count = sa.select(
            cast(
                sa.func.count() * 0, sa.TEXT
            ),  # Hacky way to preserve the number of output
            sa.func.count(),
        ).select_from(table)
    return sa.select(
        result_table_prefixes[0],
        result_table_prefixes[1],
        total_count.subquery(),
    )