Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_query_builder / sarus_query_builder / builders / marginals_builder.py
Size: Mime:
from __future__ import annotations

import numpy as np
import sarus_data_spec.status as sds
from sarus_data_spec.constants import BIG_DATA_TASK, BIG_DATA_THRESHOLD, DATA

try:
    from sarus_data_spec.manager.async_utils import sync
    from sarus_data_spec.manager.ops.processor.standard.sampling.differentiated_sampling_sizes import (
        differentiated_sampling_sizes_bisection,
    )
    from sarus_data_spec.manager.ops.processor.standard.sampling.size_utils import (
        differentiated_sampled_size,
    )
except ImportError:
    pass

from sarus_data_spec.typing import Dataset
from sarus_differential_privacy.query import PrivateQuery, SampledQuery
from sarus_differential_privacy.sample import SampleWithoutReplacement
from sarus_statistics.tasks.marginals.base import MarginalsParameters
from sarus_statistics.tasks.marginals.visitor import default_marginal

from sarus_query_builder.core.core import OptimizableQueryBuilder, QueryBuilder
from sarus_query_builder.core.typing import Task
from sarus_query_builder.protobuf.query_pb2 import Query


class MarginalsBuilder(QueryBuilder):
    """Generate Marginals hyperparameters"""

    def __init__(self, dataset: Dataset):
        self._dataset = dataset
        self._schema = dataset.schema()
        self._size = dataset.size()

    def build_query(self, input_parameter: Query.Marginals) -> Task:
        dataset = self.dataset
        marginals_tree = MarginalsParameters(
            default_marginal(self._schema.data_type())
        )
        marginals_tree._protobuf.sampling_ratio = 1

        if dataset.manager().is_big_data(dataset):
            last_status = sds.last_status(dataset, task=BIG_DATA_TASK)
            stage = last_status.task(task=BIG_DATA_TASK)
            max_size = int(stage.properties().get(BIG_DATA_THRESHOLD))

            # set global sampling
            assert self._size
            size_dict = sync(
                differentiated_sampling_sizes_bisection(self.dataset, max_size)
            )

            def default_sampling_ratio(size: float) -> float:
                """Returns sampling rate given size"""
                return min(1, max(size_dict.values()) / size)

            global_sampling = max_size / self._size.statistics().size()
            marginals_tree._protobuf.sampling_ratio = global_sampling
            size = differentiated_sampled_size(
                self._size.statistics(), size_dict, curr_path=[DATA]
            ).protobuf()
        else:

            def default_sampling_ratio(size: float) -> float:
                """Returns sampling rate given size"""
                return min(1, 10000 / size)

            size = self._size.protobuf().statistics

        marginals_tree.set_sampling_ratio(default_sampling_ratio, size)
        marginals_tree.set_noise(input_parameter.noise)
        return marginals_tree.protobuf()

    def private_query(self, out: Task) -> PrivateQuery:
        queries = MarginalsParameters(out).private_query()
        if out.sampling_ratio < 1:
            sample = SampleWithoutReplacement(out.sampling_ratio)
            return SampledQuery(queries, sample)
        return queries


class OptimizableMarginalsBuilder(OptimizableQueryBuilder):
    def __init__(self, dataset: Dataset, query: Query):
        self._dataset = dataset
        self.query = query
        self._builders = [MarginalsBuilder(dataset)]

    def build_query(self, input_parameter: float) -> Task:
        query = self.query
        if input_parameter:
            query.marginals.noise = 1 / input_parameter
        else:
            query.marginals.noise = np.inf
        return self.builders[0].build_query(query.marginals)


def marginals_builder(
    dataset: Dataset, query: Query
) -> OptimizableMarginalsBuilder:
    return OptimizableMarginalsBuilder(dataset, query)