Repository URL to install this package:
|
Version:
2.7.2 ▾
|
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Union
import numpy as np
from sarus_statistics.ops.tau_thresholding.op import TauThresholdingOp
from sarus_query_builder.core.core import OptimizableQueryBuilder, QueryBuilder
from sarus_query_builder.core.typing import Task
from sarus_query_builder.protobuf.query_pb2 import GenericTask, Query
if TYPE_CHECKING:
from sarus_data_spec.typing import Dataset
from sarus_differential_privacy.query import ( # type: ignore
ComposedQuery,
EpsilonDeltaQuery,
EpsilonQuery,
LaplaceQuery,
PrivateQuery,
)
class TauThresholdingBuilder(QueryBuilder):
"""Generate Tau thresholding hyperparameters"""
def __init__(self, dataset: Dataset):
self._dataset = dataset
def build_query(self, input_parameter: Query.TauThresholding) -> Task:
return GenericTask(
parameters={
'epsilon_tau_thresholding': input_parameter.epsilon_tau_thresholding,
'delta_tau_thresholding': input_parameter.delta_tau_thresholding,
}
)
def private_query(self, out: Task) -> PrivateQuery:
return TauThresholdingOp(
self.dataset,
out.parameters['epsilon_tau_thresholding'],
out.parameters['delta_tau_thresholding'],
).private_query()
class OptimizableTauThresholdingBuilder(OptimizableQueryBuilder):
def __init__(self, dataset: Dataset, query: Query):
self._dataset = dataset
self.query = query
self._builders = [TauThresholdingBuilder(dataset)]
def build_query(self, input_parameter: float) -> Task:
query = self.query
query.tau_thresholding.epsilon_tau_thresholding = input_parameter
return self.builders[0].build_query(query.tau_thresholding)
def tau_threshold_builder(
dataset: Dataset, query: Query
) -> OptimizableTauThresholdingBuilder:
return OptimizableTauThresholdingBuilder(dataset, query)
def tau_threshold_builder_delta(
dataset: Dataset, delta: float
) -> OptimizableTauThresholdingBuilder:
query = Query(
tau_thresholding=Query.TauThresholding(
delta_tau_thresholding=delta / 1e3
)
)
return OptimizableTauThresholdingBuilder(dataset, query)