Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
"""Local implementation of range detection"""
from __future__ import annotations
from typing import Optional, Tuple
from sarus_statistics.ops.utils import laplace_mechanism, rescale_weights_sql
import numpy as np
import sqlalchemy as sa
import typing as t
from sarus_data_spec.typing import Type
# pylint: disable=too-many-arguments, too-many-locals, too-many-branches
def automatic_column_range(
session: sa.orm.Session,
table: sa.sql.FromClause,
data_col: str,
private_col: str,
user_col: str,
weight_col: str,
data_type: Type,
noise: float,
prob_no_false_positive: float = 1 - 1e-9,
max_multiplicity: float = 1,
estimate: Tuple[Optional[float], Optional[float]] = (None, None),
random_generator: Optional[np.random.Generator] = None,
result_table_prefixes: Optional[Tuple[sa.sql.ColumnElement]] = None,
return_insert: bool = False,
) -> Tuple[sa.sql.Selectable, sa.sql.Selectable]:
"""Automatic bounding function from
https://arxiv.org/pdf/1909.01917.pdf paragraph 5.1.1
Parameters
-----------
data: pd.Dataframe
column to evaluate
data_col: str
name of the column with values to be evaluated
user_col: str
name of the user column
private_col: str
name of the column indicating the privacy status
weight_col: str
name of the weight_col
data_type:
type of the data (from sarus_dataset.type)
noise: float
scale of the laplace noise
prob_no_false_positive: float
probability of not having
a false positive, should be very close to 1.
max_multiplicity: float
maximum weight per user considered.
values are sampled if it overflows.
Returns
-------
Tuple[float, float, PrivateQuery]
minimum and maximum bounding power of 2 and
list of queries
Raises
-------
ValueError
if prob_no_false_positive not in [0,1]
TypeError
if dtype not in float64, float32 or int64 to int8"
As this is a histogram query, the total privacy spent is only one
LaplaceQuery(noise) and not len(categories) of it.
This is because adding or removing one element can only change the value
of one bin.
See https://www.tau.ac.il/~saharon/BigData2018/privacybook.pdf#page=37
for more information.
Note: this holds for the definition of DP where two datasets are adjacent
if they differ by the addition or removal of one sample. If substition is
allowed, the corresponding privacy consumption would double.
"""
assert result_table_prefixes
if prob_no_false_positive < 0 or prob_no_false_positive > 1:
raise ValueError("prob_no_false_positive should be in [0,1]")
if estimate[0] == estimate[1] and estimate[0] is not None:
min_statement = (
sa.select(
result_table_prefixes[0],
sa.text("'min'"),
sa.text(f"{estimate[0]}"),
)
.limit(1)
.select_from(table)
)
max_statement = (
sa.select(
result_table_prefixes[0],
sa.text("'max'"),
sa.text(f"{estimate[1]}"),
)
.limit(1)
.select_from(table)
)
return min_statement, max_statement
private_data = getattr(table.c, data_col)
dtype = type(private_data.type)
if dtype == sa.sql.sqltypes.Boolean:
min_statement = (
sa.select(result_table_prefixes[0], sa.text("'min'"), sa.text("0"))
.limit(1)
.select_from(table)
)
max_statement = (
sa.select(result_table_prefixes[0], sa.text("'max'"), sa.text("1"))
.limit(1)
.select_from(table)
)
return min_statement, max_statement
if dtype == sa.sql.sqltypes.Float:
# TODO: doubles
bitsize = 1023 * 2
bins = (
[-(2 ** (bitsize // 2 - b)) for b in range(bitsize)]
+ [0]
+ [2 ** (b - (bitsize // 2)) for b in range(bitsize)]
)
elif dtype in (
sa.sql.sqltypes.Integer,
sa.sql.sqltypes.BigInteger,
):
private_data = sa.cast(private_data, sa.Float)
# TODO: other int types
bitsize = 63
bins = (
[-(1 << b) for b in reversed(range(bitsize))]
+ [0]
+ [1 << b for b in range(bitsize)]
)
else:
raise TypeError(
f"Dtype {dtype} not implemented, please convert to float64,"
" float32 or int64 to int8"
)
if estimate[0] is not None:
bins = [estimate[0]] + [b for b in bins if b > estimate[0]]
if estimate[1] is not None:
bins = [b for b in bins if b < estimate[1]] + [estimate[1]]
# rescale weights for private rows
table = rescale_weights_sql(
session,
table=table,
data_col=data_col,
user_col=user_col,
private_col=private_col,
weight_col=weight_col,
max_multiplicity=max_multiplicity,
)
binned = sa.select(
sa.select(case_from_bins(private_data, bins).label("binned")).cte()
).cte()
counts = (
sa.select(
binned.c.binned,
laplace_mechanism(
sa.func.count(binned.c.binned),
noise * max_multiplicity,
).label("count_"),
)
.group_by(binned.c.binned)
.cte()
)
# TODO add weights
threshold = (
-noise
* max_multiplicity
* np.log(1 - (prob_no_false_positive ** (1 / (len(bins) - 1))))
)
minimum = (
sa.select((sa.func.min(counts.c.binned)).label("min_value"))
.filter(counts.c.count_ >= threshold)
.cte()
)
maximum = (
sa.select((2 * sa.func.max(counts.c.binned)).label("max_value"))
.filter(counts.c.count_ >= threshold)
.cte()
)
if estimate[0] is not None:
minimum = sa.select(
(sa.func.greatest(minimum.c.min_value, estimate[0])).label(
"min_value"
)
).cte()
if estimate[1] is not None:
maximum = sa.select(
sa.func.least(maximum.c.max_value, estimate[1]).label("max_value")
).cte()
min_statement = sa.select(
result_table_prefixes[0], sa.text("'min'"), minimum.c.min_value
)
max_statement = sa.select(
result_table_prefixes[0], sa.text("'max'"), maximum.c.max_value
)
# TODO if no count is above threshold take min / max bin
return min_statement, max_statement
def public_bounds(
session: sa.orm.Session,
table: sa.sql.FromClause,
data_col: str,
result_table_prefixes: Optional[Tuple[sa.sql.ColumnElement]] = None,
return_insert: bool = False,
) -> Tuple[sa.sql.Selectable, sa.sql.Selectable]:
"""Automatic bounding function from
https://arxiv.org/pdf/1909.01917.pdf paragraph 5.1.1
Parameters
-----------
data: pd.Dataframe
column to evaluate
data_col: str
name of the column with values to be evaluated
user_col: str
name of the user column
private_col: str
name of the column indicating the privacy status
weight_col: str
name of the weight_col
data_type:
type of the data (from sarus_dataset.type)
noise: float
scale of the laplace noise
prob_no_false_positive: float
probability of not having
a false positive, should be very close to 1.
max_multiplicity: float
maximum weight per user considered.
values are sampled if it overflows.
Returns
-------
Tuple[float, float, PrivateQuery]
minimum and maximum bounding power of 2 and
list of queries
Raises
-------
ValueError
if prob_no_false_positive not in [0,1]
TypeError
if dtype not in float64, float32 or int64 to int8"
As this is a histogram query, the total privacy spent is only one
LaplaceQuery(noise) and not len(categories) of it.
This is because adding or removing one element can only change the value
of one bin.
See https://www.tau.ac.il/~saharon/BigData2018/privacybook.pdf#page=37
for more information.
Note: this holds for the definition of DP where two datasets are adjacent
if they differ by the addition or removal of one sample. If substition is
allowed, the corresponding privacy consumption would double.
"""
assert result_table_prefixes
minimum = sa.select(sa.func.min(getattr(table.c, data_col)))
maximum = sa.select(sa.func.max(getattr(table.c, data_col)))
min_statement = sa.select(
result_table_prefixes[0], sa.text("'min'"), minimum.cte()
)
max_statement = sa.select(
result_table_prefixes[0], sa.text("'max'"), maximum.cte()
)
return min_statement, max_statement
def construct_inner_case(column: sa.Column, bins: t.List[float]) -> sa.Case:
case_conditions = []
for i in range(len(bins) - 1):
lower_bound = bins[i]
upper_bound = bins[i + 1]
case_conditions.append(
(column.between(lower_bound, upper_bound), lower_bound)
)
return sa.case(*case_conditions, else_=bins[-2])
def construct_outer_case(
column: sa.Column, list_of_bins: t.List[t.List[float]]
) -> sa.Case:
outer_case_conditions = []
for bins in list_of_bins:
inner_case = construct_inner_case(column, bins)
outer_case_conditions.append(
(column.between(bins[0], bins[-1]), inner_case)
)
return sa.case(*outer_case_conditions, else_=None)
def case_from_bins(column: sa.Column, bins: t.List[float]) -> sa.Case:
"""Given a column and bins, if the number of bins is < 4 it constructs
the following case:
CASE
WHEN column BETWEEN bins[0] AND bins[1] THEN bins[0]
WHEN column BETWEEN bins[1] AND bins[2] THEN bins[1]
...
END
Else It constructs the following:
CASE
WHEN column BETWEEN bins[0] AND bins[c1] THEN
CASE
WHEN column BETWEEN bins[0] AND bins[1] THEN bins[0]
WHEN column BETWEEN bins[1] AND bins[2] THEN bins[1]
WHEN column BETWEEN bins[c1-2] AND bins[c1-1] THEN bins[c1-2]
ELSE bins[c1-1]
END
WHEN column BETWEEN bins[c1] AND bins[c2] THEN
CASE
WHEN column BETWEEN bins[c1] AND bins[c1+1] THEN bins[c1]
...
ELSE bins[c2-1]
END
...
END
This is equivalent to the former but prevents a stackoverflow error from
qrlew when having a very long list of bins.
The bin list is divided into chunks such that the number of WHEN conditions
of the outer CASE is similar to the one in the inner CASE.
"""
chunk_size = int(np.sqrt(len(bins)))
if chunk_size > 1:
# construct a nested case
chunks = [
bins[i : i + chunk_size] for i in range(0, len(bins), chunk_size)
]
if len(chunks) > 1 and len(chunks[-1]) < 2: # the minimum chunk size
chunks[-2].extend(chunks[-1])
# Remove the last chunk after merging
chunks.pop()
return construct_outer_case(column, chunks)
else:
# construct a standard case when the number of bins is very small.
conditions = [
(column.between(bins[index], bins[index + 1]), bins[index])
for index in range(len(bins[:-1]))
]
return sa.case(*conditions)