Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
from __future__ import annotations
import logging
import os
import typing as t
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
from sarus_synthetic_data.configs.global_config import (
    SyntheticConfig,
    TableConfig,
)
from sarus_synthetic_data.correlations_generator.generator import (
    CorrelationGenerator,
)
from sarus_synthetic_data.data_processing.postprocessor import Postprocessor
from sarus_synthetic_data.data_processing.preprocessor import Preprocessor
from sarus_synthetic_data.independent_generator.generator import (
    IndependentGenerator,
)
from sarus_synthetic_data.shared.generation_utils import gen_from_cumulative

logger = logging.getLogger(__name__)


class SyntheticDatasetGenerator:
    def __init__(self, config: SyntheticConfig):
        self.config = config

    def train(self) -> None:
        # first preprocess
        Preprocessor(config=self.config).preprocess_tables()

        # then train
        for table_name, table_config in self.config.tables.items():
            current_saving_dir = os.path.join(
                self.config.saving_directory, *table_name
            )
            logger.info(f"Starting Training for {table_name}")
            self.train_table(table_config, current_saving_dir)
            logger.info(f"Finished Training for {table_name}")

    def train_table(
        self, table_config: TableConfig, saving_directory: str
    ) -> None:
        ind_config = table_config.independent_generation
        if ind_config is not None:
            logger.info("Starting Independent Training")
            independent_generator = IndependentGenerator(
                generation_config=ind_config,
                privacy_unit_col=self.config.privacy_unit_col,
                weights_col=self.config.weights_col,
                is_private_col=self.config.is_private_col,
                data_uri=table_config.data_uri,
                saving_dir=saving_directory,
            )
            independent_generator.train()
            logger.info("Finished Independent Training")

        corr_config = table_config.correlation_generation
        if corr_config is not None:
            logger.info("Starting Correlation Training")
            corr_gen = CorrelationGenerator(
                generation_config=corr_config,
                data_uri=os.path.join(
                    saving_directory, "correlation_data.pkl"
                ),
            )
            corr_gen.train()
            logger.info("Finished Correlation Training")

    def sample(self) -> t.Dict[t.Tuple[str, ...], pa.Table]:
        samples = {}
        for table_name, table_config in self.config.tables.items():
            logger.info(f"Starting Sampling for {table_name}")
            current_saving_dir = os.path.join(
                self.config.saving_directory, *table_name
            )
            curr_sample = self.sample_table(
                table_config, saving_directory=current_saving_dir
            )
            samples[table_name] = curr_sample
            logger.info(f"Finished Sampling for {table_name}")

        return self.add_links(samples=samples)

    def sample_table(
        self, table_config: TableConfig, saving_directory: str
    ) -> pa.Table:
        if table_config.is_public:
            logger.info("Returning Public Table")
            table = pq.read_table(table_config.data_uri)
            return table.drop(
                columns=[
                    self.config.privacy_unit_col,
                    self.config.weights_col,
                    self.config.is_private_col,
                ]
            )

        samples = []
        fields = []
        ind_config = table_config.independent_generation
        if ind_config is not None:
            logger.info("Starting Independent Sampling")
            ind_gen = IndependentGenerator(
                generation_config=ind_config,
                privacy_unit_col=self.config.privacy_unit_col,
                weights_col=self.config.weights_col,
                is_private_col=self.config.is_private_col,
                data_uri=table_config.data_uri,
                saving_dir=saving_directory,
            )
            sample = ind_gen.sample()
            samples.extend(sample.flatten())
            fields.extend([sample.field(name) for name in sample.column_names])
            logger.info("Finished Independent Sampling")

        corr_config = table_config.correlation_generation
        if corr_config is not None:
            logger.info("Starting Correlation Sampling")
            corr_gen = CorrelationGenerator(
                generation_config=corr_config,
                data_uri=os.path.join(
                    saving_directory, "correlation_data.pkl"
                ),
            )
            np_samples = corr_gen.sample()
            arrow_sample = Postprocessor(
                corr_config.columns
            ).post_process_table(np_samples)
            samples.extend(arrow_sample.flatten())
            fields.extend(
                [
                    arrow_sample.field(name)
                    for name in arrow_sample.column_names
                ]
            )
            logger.info("Finished correlation Sampling")

        return pa.Table.from_arrays(samples, schema=pa.schema(fields))

    def add_links(
        self, samples: t.Dict[t.Tuple[str, ...], pa.Table]
    ) -> t.Dict[t.Tuple[str, ...], pa.Table]:
        if self.config.links is None:
            return samples
        else:
            random_gen = np.random.default_rng(self.config.links.seed)
            for link_info in self.config.links.links_info_list:
                primary_key_table = samples[link_info.primary_key[:-1]]
                primary_key_col = link_info.primary_key[-1]
                foreign_key_table = samples[link_info.foreign_key[:-1]]
                foreign_key_col = link_info.foreign_key[-1]
                count_distribution = link_info.count_distribution

                primary_key = primary_key_table.column(
                    primary_key_col
                ).combine_chunks()

                length = (
                    foreign_key_table.column(foreign_key_col)
                    .combine_chunks()
                    .is_valid()
                    .sum()
                    .as_py()
                )
                number_fk = count_distribution.values
                probabilities = count_distribution.probabilities

                # create array of counts of size primary_key,
                # 0 repetitions are considered in the distribution
                # computed by statistics
                counts = gen_from_cumulative(
                    probabilities=probabilities,
                    quantile_values=number_fk,
                    size=len(primary_key),
                    random_gen=random_gen,
                ).squeeze()

                # normalize
                counts = np.around((counts / counts.sum() * length)).astype(
                    int
                )

                # now excess should be very small, so
                # what comes next should be very quick
                excess = np.sum(counts) - length
                if excess > 0:
                    print("Adding missing counts in FK")
                    while excess > 0:
                        counts, excess = remove_counts(
                            counts, excess, number_fk[0], random_gen
                        )

                if excess < 0:
                    print("Removing excess counts in FK")
                    while excess < 0:
                        counts, excess = add_counts(
                            counts,
                            -excess,
                            number_fk[-1],
                            random_gen=random_gen,
                        )

                new_fks = pa.array(
                    np.repeat(
                        primary_key,
                        repeats=counts,
                    )
                ).cast(primary_key.type)
                samples[link_info.foreign_key[:-1]] = (
                    foreign_key_table.set_column(
                        foreign_key_table.schema.get_field_index(
                            foreign_key_col
                        ),
                        foreign_key_col,
                        new_fks,
                    )
                )
            return samples


def remove_counts(
    counts: np.ndarray,
    excess: int,
    min_val: int,
    random_gen: np.random.Generator,
) -> t.Tuple[np.ndarray, int]:
    """Removes uniformly 1 on each count bigger than min_val until excess is
    reached. The method is called recursively on the candidates if excess
    cannot be reached in one pass.
    """

    idx = np.argwhere(counts > min_val).squeeze()
    to_remove = np.concatenate(
        [np.ones(min(len(idx), excess)), np.zeros(max(len(idx) - excess, 0))]
    )
    random_gen.shuffle(to_remove)
    counts[idx] = counts[idx] - to_remove
    return counts, int(excess - to_remove.sum())


def add_counts(
    counts: np.ndarray,
    missing: int,
    max_val: int,
    random_gen: np.random.Generator,
) -> t.Tuple[np.ndarray, int]:
    """Adds uniformly 1 on each count smaller than max_val until missing is
    reached. The method is called recursively on the candidates if missing
    cannot be reached in one pass."""

    # missing is positive
    idx = np.argwhere(counts < max_val).squeeze()
    to_add = np.concatenate(
        [np.ones(min(len(idx), missing)), np.zeros(max(len(idx) - missing, 0))]
    )
    random_gen.shuffle(to_add)

    counts[idx] = counts[idx] + to_add
    return counts, int(to_add.sum() - missing)