Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_query_builder / sarus_query_builder / builders / tau_thresholding_builder.py
Size: Mime:
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Union

import numpy as np
from sarus_statistics.ops.tau_thresholding.op import TauThresholdingOp

from sarus_query_builder.core.core import OptimizableQueryBuilder, QueryBuilder
from sarus_query_builder.core.typing import Task
from sarus_query_builder.protobuf.query_pb2 import GenericTask, Query

if TYPE_CHECKING:
    from sarus_data_spec.typing import Dataset

from sarus_differential_privacy.query import (  # type: ignore
    ComposedQuery,
    EpsilonDeltaQuery,
    EpsilonQuery,
    LaplaceQuery,
    PrivateQuery,
)


class TauThresholdingBuilder(QueryBuilder):
    """Generate Tau thresholding hyperparameters"""

    def __init__(self, dataset: Dataset):
        self._dataset = dataset

    def build_query(self, input_parameter: Query.TauThresholding) -> Task:
        return GenericTask(
            parameters={
                'epsilon_tau_thresholding': input_parameter.epsilon_tau_thresholding,
                'delta_tau_thresholding': input_parameter.delta_tau_thresholding,
            }
        )

    def private_query(self, out: Task) -> PrivateQuery:
        return TauThresholdingOp(
            self.dataset,
            out.parameters['epsilon_tau_thresholding'],
            out.parameters['delta_tau_thresholding'],
        ).private_query()


class OptimizableTauThresholdingBuilder(OptimizableQueryBuilder):
    def __init__(self, dataset: Dataset, query: Query):
        self._dataset = dataset
        self.query = query
        self._builders = [TauThresholdingBuilder(dataset)]

    def build_query(self, input_parameter: float) -> Task:
        query = self.query
        query.tau_thresholding.epsilon_tau_thresholding = input_parameter
        return self.builders[0].build_query(query.tau_thresholding)


def tau_threshold_builder(
    dataset: Dataset, query: Query
) -> OptimizableTauThresholdingBuilder:
    return OptimizableTauThresholdingBuilder(dataset, query)


def tau_threshold_builder_delta(
    dataset: Dataset, delta: float
) -> OptimizableTauThresholdingBuilder:
    query = Query(
        tau_thresholding=Query.TauThresholding(
            delta_tau_thresholding=delta / 1e3
        )
    )
    return OptimizableTauThresholdingBuilder(dataset, query)