Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_query_builder / sarus_query_builder / builders / synthetic_builder.py
Size: Mime:
from __future__ import annotations

import copy
import os
import typing

import numpy as np
from sarus_data_spec.typing import Dataset
from sarus_differential_privacy.query import PrivateQuery, SampledQuery
from sarus_differential_privacy.sample import SampleWithoutReplacement
from sarus_synthetic_data.columns_generator.default_params import (
    synthesizer_params as synthesizer_column_params,
)
from sarus_synthetic_data.columns_generator.queries import (
    private_queries_no_correlation,
    set_noise_no_correlation,
)
from sarus_synthetic_data.correlations_generator.default_params_correlations import (  # noqa: E501
    TrainingHyperparameters,
    set_batch_size_params,
    synthesizer_params,
)
from sarus_synthetic_data.correlations_generator.queries import (
    private_queries_jax,
    set_noise_jax,
)

from sarus_query_builder.core.core import OptimizableQueryBuilder, QueryBuilder
from sarus_query_builder.core.typing import Task
from sarus_query_builder.protobuf.query_pb2 import Query, SyntheticData

DEFAULT_SAMPLING = 1.0


class SyntheticDPSGDBuilder(QueryBuilder):
    """Synthetic data builder"""

    def __init__(self, dataset: Dataset):
        self._dataset = dataset

    def build_query(self, input_parameter: Query.SyntheticDPSGD) -> Task:
        return SyntheticData(
            sampling_ratio=input_parameter.sampling_ratio or DEFAULT_SAMPLING,
            generator=set_noise_no_correlation(
                generator_params=synthesizer_column_params(
                    dataset=self.dataset
                ),
                noise=input_parameter.noise_multiplier,
            ),
        )

    def private_query(self, out: Task) -> PrivateQuery:
        # if not isinstance(out, SyntheticData):
        #    raise TypeError("Expected SyntheticData task")

        query = private_queries_no_correlation(generator_params=out.generator)
        if out.sampling_ratio < 1:
            sample = SampleWithoutReplacement(out.sampling_ratio)
            return SampledQuery(query, sample)
        return query


class OptimizableSyntheticDPSGDBuilder(OptimizableQueryBuilder):
    def __init__(self, dataset: Dataset, query: Query):
        self._dataset = dataset
        self.query = query
        self._builders = [SyntheticDPSGDBuilder(dataset)]

    def build_query(self, input_parameter: float) -> Task:
        query = self.query
        if input_parameter:
            query.synthetic_dpsgd.noise_multiplier = 1 / input_parameter
        else:
            query.synthetic_dpsgd.noise_multiplier = np.inf
        return self.builders[0].build_query(query.synthetic_dpsgd)


class SyntheticJaxBuilder(QueryBuilder):
    """Synthetic data builder. The default parameters are
    created in the init to reduce iterations on the dataset.
    """

    def __init__(self, dataset: Dataset, query: Query):
        self._dataset = dataset
        self._query = query

        N_MARG_STEPS = int(os.environ.get('N_MARG_STEPS', default=20))
        N_GRAD_STEPS = int(os.environ.get('N_GRAD_STEPS', default=100))
        use_jax_text = (
            self._query.synthetic_dpsgd.generator.table_generator.use_jax_text
        )

        training_hyperparameters = TrainingHyperparameters(
            default_marg_steps=N_MARG_STEPS,
            default_dpsgd_steps=N_GRAD_STEPS,
            use_jax_text=use_jax_text,
        )

        self.default_params = synthesizer_params(
            self.dataset, training_hyperparameters
        )

    def build_query(self, input_parameter: Query.SyntheticDPSGD) -> Task:
        default_params = (
            self.default_params
            if not input_parameter.batch_size
            else set_batch_size_params(
                params=copy.deepcopy(self.default_params),
                batch_size=input_parameter.batch_size,
            )
        )
        return SyntheticData(
            sampling_ratio=input_parameter.sampling_ratio or DEFAULT_SAMPLING,
            generator=set_noise_jax(
                synth_params=default_params,
                noise=input_parameter.noise_multiplier,
                dataset=self.dataset,
            ),
        )

    def private_query(self, out: Task) -> PrivateQuery:
        # if not isinstance(out, SyntheticData):
        #    raise TypeError("Expected SyntheticData task")

        query = private_queries_jax(
            synthetic_params=out.generator, dataset=self.dataset
        )
        if out.sampling_ratio < 1:
            sample = SampleWithoutReplacement(out.sampling_ratio)
            return SampledQuery(query, sample)
        return query


class OptimizableSyntheticJaxBuilder(OptimizableQueryBuilder):
    def __init__(self, dataset: Dataset, query: Query):
        self._dataset = dataset
        self.query = query
        self._builders = [SyntheticJaxBuilder(dataset, query)]

    def build_query(self, input_parameter: float) -> Task:
        query = self.query
        if input_parameter:
            query.synthetic_dpsgd.noise_multiplier = 1 / input_parameter
        else:
            query.synthetic_dpsgd.noise_multiplier = np.inf
        return self.builders[0].build_query(query.synthetic_dpsgd)


def synthetic_dpsgd_builder(
    dataset: Dataset, query: Query
) -> typing.Union[
    OptimizableSyntheticJaxBuilder, OptimizableSyntheticDPSGDBuilder
]:
    if (
        query.synthetic_dpsgd.generator.WhichOneof('generator')
        == 'column_generator'
    ):
        return OptimizableSyntheticDPSGDBuilder(dataset, query)
    return OptimizableSyntheticJaxBuilder(dataset, query)