Repository URL to install this package:
|
Version:
4.0.7 ▾
|
from __future__ import annotations
import logging
import os
import typing as t
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
from sarus_synthetic_data.configs.global_config import (
SyntheticConfig,
TableConfig,
)
from sarus_synthetic_data.correlations_generator.generator import (
CorrelationGenerator,
)
from sarus_synthetic_data.data_processing.postprocessor import Postprocessor
from sarus_synthetic_data.data_processing.preprocessor import Preprocessor
from sarus_synthetic_data.independent_generator.generator import (
IndependentGenerator,
)
from sarus_synthetic_data.shared.generation_utils import gen_from_cumulative
logger = logging.getLogger(__name__)
class SyntheticDatasetGenerator:
def __init__(self, config: SyntheticConfig):
self.config = config
def train(self) -> None:
# first preprocess
Preprocessor(config=self.config).preprocess_tables()
# then train
for table_name, table_config in self.config.tables.items():
current_saving_dir = os.path.join(
self.config.saving_directory, *table_name
)
logger.info(f"Starting Training for {table_name}")
self.train_table(table_config, current_saving_dir)
logger.info(f"Finished Training for {table_name}")
def train_table(
self, table_config: TableConfig, saving_directory: str
) -> None:
ind_config = table_config.independent_generation
if ind_config is not None:
logger.info("Starting Independent Training")
independent_generator = IndependentGenerator(
generation_config=ind_config,
privacy_unit_col=self.config.privacy_unit_col,
weights_col=self.config.weights_col,
is_private_col=self.config.is_private_col,
data_uri=table_config.data_uri,
saving_dir=saving_directory,
)
independent_generator.train()
logger.info("Finished Independent Training")
corr_config = table_config.correlation_generation
if corr_config is not None:
logger.info("Starting Correlation Training")
corr_gen = CorrelationGenerator(
generation_config=corr_config,
data_uri=os.path.join(
saving_directory, "correlation_data.pkl"
),
)
corr_gen.train()
logger.info("Finished Correlation Training")
def sample(self) -> t.Dict[t.Tuple[str, ...], pa.Table]:
samples = {}
for table_name, table_config in self.config.tables.items():
logger.info(f"Starting Sampling for {table_name}")
current_saving_dir = os.path.join(
self.config.saving_directory, *table_name
)
curr_sample = self.sample_table(
table_config, saving_directory=current_saving_dir
)
samples[table_name] = curr_sample
logger.info(f"Finished Sampling for {table_name}")
return self.add_links(samples=samples)
def sample_table(
self, table_config: TableConfig, saving_directory: str
) -> pa.Table:
if table_config.is_public:
logger.info("Returning Public Table")
table = pq.read_table(table_config.data_uri)
return table.drop(
columns=[
self.config.privacy_unit_col,
self.config.weights_col,
self.config.is_private_col,
]
)
samples = []
fields = []
ind_config = table_config.independent_generation
if ind_config is not None:
logger.info("Starting Independent Sampling")
ind_gen = IndependentGenerator(
generation_config=ind_config,
privacy_unit_col=self.config.privacy_unit_col,
weights_col=self.config.weights_col,
is_private_col=self.config.is_private_col,
data_uri=table_config.data_uri,
saving_dir=saving_directory,
)
sample = ind_gen.sample()
samples.extend(sample.flatten())
fields.extend([sample.field(name) for name in sample.column_names])
logger.info("Finished Independent Sampling")
corr_config = table_config.correlation_generation
if corr_config is not None:
logger.info("Starting Correlation Sampling")
corr_gen = CorrelationGenerator(
generation_config=corr_config,
data_uri=os.path.join(
saving_directory, "correlation_data.pkl"
),
)
np_samples = corr_gen.sample()
arrow_sample = Postprocessor(
corr_config.columns
).post_process_table(np_samples)
samples.extend(arrow_sample.flatten())
fields.extend(
[
arrow_sample.field(name)
for name in arrow_sample.column_names
]
)
logger.info("Finished correlation Sampling")
return pa.Table.from_arrays(samples, schema=pa.schema(fields))
def add_links(
self, samples: t.Dict[t.Tuple[str, ...], pa.Table]
) -> t.Dict[t.Tuple[str, ...], pa.Table]:
if self.config.links is None:
return samples
else:
random_gen = np.random.default_rng(self.config.links.seed)
for link_info in self.config.links.links_info_list:
primary_key_table = samples[link_info.primary_key[:-1]]
primary_key_col = link_info.primary_key[-1]
foreign_key_table = samples[link_info.foreign_key[:-1]]
foreign_key_col = link_info.foreign_key[-1]
count_distribution = link_info.count_distribution
primary_key = primary_key_table.column(
primary_key_col
).combine_chunks()
length = (
foreign_key_table.column(foreign_key_col)
.combine_chunks()
.is_valid()
.sum()
.as_py()
)
number_fk = count_distribution.values
probabilities = count_distribution.probabilities
# create array of counts of size primary_key,
# 0 repetitions are considered in the distribution
# computed by statistics
counts = gen_from_cumulative(
probabilities=probabilities,
quantile_values=number_fk,
size=len(primary_key),
random_gen=random_gen,
).squeeze()
# normalize
counts = np.around((counts / counts.sum() * length)).astype(
int
)
# now excess should be very small, so
# what comes next should be very quick
excess = np.sum(counts) - length
if excess > 0:
print("Adding missing counts in FK")
while excess > 0:
counts, excess = remove_counts(
counts, excess, number_fk[0], random_gen
)
if excess < 0:
print("Removing excess counts in FK")
while excess < 0:
counts, excess = add_counts(
counts,
-excess,
number_fk[-1],
random_gen=random_gen,
)
new_fks = pa.array(
np.repeat(
primary_key,
repeats=counts,
)
).cast(primary_key.type)
samples[link_info.foreign_key[:-1]] = (
foreign_key_table.set_column(
foreign_key_table.schema.get_field_index(
foreign_key_col
),
foreign_key_col,
new_fks,
)
)
return samples
def remove_counts(
counts: np.ndarray,
excess: int,
min_val: int,
random_gen: np.random.Generator,
) -> t.Tuple[np.ndarray, int]:
"""Removes uniformly 1 on each count bigger than min_val until excess is
reached. The method is called recursively on the candidates if excess
cannot be reached in one pass.
"""
idx = np.argwhere(counts > min_val).squeeze()
to_remove = np.concatenate(
[np.ones(min(len(idx), excess)), np.zeros(max(len(idx) - excess, 0))]
)
random_gen.shuffle(to_remove)
counts[idx] = counts[idx] - to_remove
return counts, int(excess - to_remove.sum())
def add_counts(
counts: np.ndarray,
missing: int,
max_val: int,
random_gen: np.random.Generator,
) -> t.Tuple[np.ndarray, int]:
"""Adds uniformly 1 on each count smaller than max_val until missing is
reached. The method is called recursively on the candidates if missing
cannot be reached in one pass."""
# missing is positive
idx = np.argwhere(counts < max_val).squeeze()
to_add = np.concatenate(
[np.ones(min(len(idx), missing)), np.zeros(max(len(idx) - missing, 0))]
)
random_gen.shuffle(to_add)
counts[idx] = counts[idx] + to_add
return counts, int(to_add.sum() - missing)