Repository URL to install this package:
|
Version:
2.7.2 ▾
|
from __future__ import annotations
import copy
import os
import typing
import numpy as np
from sarus_data_spec.typing import Dataset
from sarus_differential_privacy.query import PrivateQuery, SampledQuery
from sarus_differential_privacy.sample import SampleWithoutReplacement
from sarus_synthetic_data.columns_generator.default_params import (
synthesizer_params as synthesizer_column_params,
)
from sarus_synthetic_data.columns_generator.queries import (
private_queries_no_correlation,
set_noise_no_correlation,
)
from sarus_synthetic_data.correlations_generator.default_params_correlations import ( # noqa: E501
TrainingHyperparameters,
set_batch_size_params,
synthesizer_params,
)
from sarus_synthetic_data.correlations_generator.queries import (
private_queries_jax,
set_noise_jax,
)
from sarus_query_builder.core.core import OptimizableQueryBuilder, QueryBuilder
from sarus_query_builder.core.typing import Task
from sarus_query_builder.protobuf.query_pb2 import Query, SyntheticData
DEFAULT_SAMPLING = 1.0
class SyntheticDPSGDBuilder(QueryBuilder):
"""Synthetic data builder"""
def __init__(self, dataset: Dataset):
self._dataset = dataset
def build_query(self, input_parameter: Query.SyntheticDPSGD) -> Task:
return SyntheticData(
sampling_ratio=input_parameter.sampling_ratio or DEFAULT_SAMPLING,
generator=set_noise_no_correlation(
generator_params=synthesizer_column_params(
dataset=self.dataset
),
noise=input_parameter.noise_multiplier,
),
)
def private_query(self, out: Task) -> PrivateQuery:
# if not isinstance(out, SyntheticData):
# raise TypeError("Expected SyntheticData task")
query = private_queries_no_correlation(generator_params=out.generator)
if out.sampling_ratio < 1:
sample = SampleWithoutReplacement(out.sampling_ratio)
return SampledQuery(query, sample)
return query
class OptimizableSyntheticDPSGDBuilder(OptimizableQueryBuilder):
def __init__(self, dataset: Dataset, query: Query):
self._dataset = dataset
self.query = query
self._builders = [SyntheticDPSGDBuilder(dataset)]
def build_query(self, input_parameter: float) -> Task:
query = self.query
if input_parameter:
query.synthetic_dpsgd.noise_multiplier = 1 / input_parameter
else:
query.synthetic_dpsgd.noise_multiplier = np.inf
return self.builders[0].build_query(query.synthetic_dpsgd)
class SyntheticJaxBuilder(QueryBuilder):
"""Synthetic data builder. The default parameters are
created in the init to reduce iterations on the dataset.
"""
def __init__(self, dataset: Dataset, query: Query):
self._dataset = dataset
self._query = query
N_MARG_STEPS = int(os.environ.get('N_MARG_STEPS', default=20))
N_GRAD_STEPS = int(os.environ.get('N_GRAD_STEPS', default=100))
use_jax_text = (
self._query.synthetic_dpsgd.generator.table_generator.use_jax_text
)
training_hyperparameters = TrainingHyperparameters(
default_marg_steps=N_MARG_STEPS,
default_dpsgd_steps=N_GRAD_STEPS,
use_jax_text=use_jax_text,
)
self.default_params = synthesizer_params(
self.dataset, training_hyperparameters
)
def build_query(self, input_parameter: Query.SyntheticDPSGD) -> Task:
default_params = (
self.default_params
if not input_parameter.batch_size
else set_batch_size_params(
params=copy.deepcopy(self.default_params),
batch_size=input_parameter.batch_size,
)
)
return SyntheticData(
sampling_ratio=input_parameter.sampling_ratio or DEFAULT_SAMPLING,
generator=set_noise_jax(
synth_params=default_params,
noise=input_parameter.noise_multiplier,
dataset=self.dataset,
),
)
def private_query(self, out: Task) -> PrivateQuery:
# if not isinstance(out, SyntheticData):
# raise TypeError("Expected SyntheticData task")
query = private_queries_jax(
synthetic_params=out.generator, dataset=self.dataset
)
if out.sampling_ratio < 1:
sample = SampleWithoutReplacement(out.sampling_ratio)
return SampledQuery(query, sample)
return query
class OptimizableSyntheticJaxBuilder(OptimizableQueryBuilder):
def __init__(self, dataset: Dataset, query: Query):
self._dataset = dataset
self.query = query
self._builders = [SyntheticJaxBuilder(dataset, query)]
def build_query(self, input_parameter: float) -> Task:
query = self.query
if input_parameter:
query.synthetic_dpsgd.noise_multiplier = 1 / input_parameter
else:
query.synthetic_dpsgd.noise_multiplier = np.inf
return self.builders[0].build_query(query.synthetic_dpsgd)
def synthetic_dpsgd_builder(
dataset: Dataset, query: Query
) -> typing.Union[
OptimizableSyntheticJaxBuilder, OptimizableSyntheticDPSGDBuilder
]:
if (
query.synthetic_dpsgd.generator.WhichOneof('generator')
== 'column_generator'
):
return OptimizableSyntheticDPSGDBuilder(dataset, query)
return OptimizableSyntheticJaxBuilder(dataset, query)