Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import functools
import logging
import os
import pickle
import typing as t

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
from pandas.api.types import CategoricalDtype

import sarus_synthetic_data.configs.typing as st
from sarus_synthetic_data.configs.global_config import (
    NonOptionalCorrelationColumn,
    OptionalCorrelationColumn,
    TextCorrelationCol,
)
from sarus_synthetic_data.constants import (
    IS_NOT_NULL,
    OPTIONAL_VALUE,
)
from sarus_synthetic_data.data_processing.typing import (
    PreprocessingConfig,
    TableConfig,
)

from sarus_synthetic_data.configs.typing import (
    DistributionKind,
    TypeKind,
)

logger = logging.getLogger(__name__)


class Preprocessor:
    """Class Responsible to preprocess input data in order to be able to feed it
    either to the Independent of Correlation generator. In fact, pre-processing
    happens only for correlation data."""

    def __init__(self, config: PreprocessingConfig) -> None:
        self.config = config

    def preprocess_tables(self) -> None:
        for table_name, table_config in self.config.tables.items():
            current_saving_dir = os.path.join(
                self.config.saving_directory, *table_name
            )
            logger.info(
                f"Preprocessing table {table_name} in"
                f"directory {current_saving_dir}"
            )
            os.makedirs(current_saving_dir, exist_ok=True)
            self.preprocess_table(
                saving_dir=current_saving_dir,
                table_config=table_config,
                privacy_unit_col=self.config.privacy_unit_col,
            )
            logger.info(f"Done Preprocessing table {table_name}")

    def preprocess_table(
        self, saving_dir: str, table_config: TableConfig, privacy_unit_col: str
    ) -> None:
        """This method pre-process the data in a table if some columns are
        generated with correlation. It builds:
        - a PyTree of data where leaves are numpy arrays
        - a list of group indices that belong to the same protected entity
        These elements are stored in the disk.
        """
        correlation_config = table_config.correlation_generation
        if correlation_config is not None:
            data = pq.read_table(table_config.data_uri)
            # Preprocess correlation only
            col_correlation = list(correlation_config.columns.keys())
            data_correlation = data.select(col_correlation)

            # Step 1: transform correlation_data in numpy list
            struct = []
            for col_name, col_config in correlation_config.columns.items():
                if isinstance(col_config, OptionalCorrelationColumn):
                    if (
                        col_config.child_col.col_type == TypeKind.Time
                        and col_config.child_col.distribution_kind
                        == DistributionKind.quantiles
                    ):
                        hour, minute, second = transcode_optional(
                            data_correlation.column(col_name).combine_chunks(),
                            col_config=col_config,
                        )
                        struct.append(hour)
                        struct.append(minute)
                        struct.append(second)

                    elif (
                        col_config.child_col.col_type == TypeKind.Datetime
                        and col_config.child_col.distribution_kind
                        == DistributionKind.quantiles
                    ):
                        year, month, day, hour, minute, second = (
                            transcode_optional(
                                data_correlation.column(
                                    col_name
                                ).combine_chunks(),
                                col_config=col_config,
                            )
                        )
                        struct.append(year)
                        struct.append(month)
                        struct.append(day)
                        struct.append(hour)
                        struct.append(minute)
                        struct.append(second)

                    elif (
                        col_config.child_col.col_type == TypeKind.Date
                        and col_config.child_col.distribution_kind
                        == DistributionKind.quantiles
                    ):
                        year, month, day = transcode_optional(
                            data_correlation.column(col_name).combine_chunks(),
                            col_config=col_config,
                        )
                        struct.append(year)
                        struct.append(month)
                        struct.append(day)

                    else:
                        struct.append(
                            transcode_optional(
                                data_correlation.column(
                                    col_name
                                ).combine_chunks(),
                                col_config=col_config,
                            )
                        )
                else:
                    if (
                        col_config.col_type == TypeKind.Date
                        and col_config.distribution_kind
                        == DistributionKind.quantiles
                    ):
                        year, month, day = transcode_type(
                            data_correlation.column(col_name).combine_chunks(),
                            col_config=col_config,
                        )
                        struct.append(year)
                        struct.append(month)
                        struct.append(day)

                    elif (
                        col_config.col_type == TypeKind.Datetime
                        and col_config.distribution_kind
                        == DistributionKind.quantiles
                    ):
                        year, month, day, hour, minute, second = (
                            transcode_type(
                                data_correlation.column(
                                    col_name
                                ).combine_chunks(),
                                col_config=col_config,
                            )
                        )
                        struct.append(year)
                        struct.append(month)
                        struct.append(day)
                        struct.append(hour)
                        struct.append(minute)
                        struct.append(second)

                    elif (
                        col_config.col_type == TypeKind.Time
                        and col_config.distribution_kind
                        == DistributionKind.quantiles
                    ):
                        hour, minute, second = transcode_type(
                            data_correlation.column(col_name).combine_chunks(),
                            col_config=col_config,
                        )
                        struct.append(hour)
                        struct.append(minute)
                        struct.append(second)

                    else:
                        struct.append(
                            transcode_type(
                                data_correlation.column(
                                    col_name
                                ).combine_chunks(),
                                col_config=col_config,
                            )
                        )

            # Compute groups:
            data = data.append_column(
                "sarus_index", pa.array(np.arange(len(data)))
            )
            groups = (
                data.group_by(privacy_unit_col)
                .aggregate([("sarus_index", "list")])["sarus_index_list"]
                .to_pylist()
            )
            # Save
            file_dir = os.path.join(saving_dir, "correlation_data.pkl")
            with open(file_dir, "wb") as file:
                pickle.dump((groups, struct), file)


def transcode_type(
    initial_array: pa.Array, col_config: NonOptionalCorrelationColumn
) -> t.Any:
    """Method that changes an initial Pyarrow array to a
    container of numpy arrays of integers. Additional properties
    is a dict to store information that can be transmitted for
    specific types eg:"""

    kind = col_config.col_type
    distribution_kind = col_config.distribution_kind
    distrib_values = col_config.distribution.values

    preprocessor_classes = {
        st.DistributionKind.histogram: {
            st.TypeKind.Text: StrHistogramPreprocessor,
            st.TypeKind.Float: Float64HistogramPreprocessor,
            st.TypeKind.Time: Int64HistogramPreprocessor,
            st.TypeKind.Integer: Int64HistogramPreprocessor,
            st.TypeKind.Date: Int32HistogramPreprocessor,
            st.TypeKind.Datetime: Int64HistogramPreprocessor,
            st.TypeKind.Duration: Int64HistogramPreprocessor,
            st.TypeKind.Boolean: BooleanHistogramPreprocessor,
        },
        st.DistributionKind.quantiles: {
            st.TypeKind.Integer: QuantilePreprocessor[int],
            st.TypeKind.Float: QuantilePreprocessor[float],
            st.TypeKind.Date: DateQuantilePreprocessor,
            st.TypeKind.Datetime: DatetimeQuantilePreprocessor,
            st.TypeKind.Time: TimeQuantilePreprocessor,
            st.TypeKind.Duration: DurationQuantilePreprocessor,
            st.TypeKind.Text: TextQuantilePreprocessor,
        },
    }

    preprocessor_class = preprocessor_classes[distribution_kind][kind]

    if isinstance(col_config, TextCorrelationCol):
        preprocessor = preprocessor_class(
            distrib_values=distrib_values,
            tokenizer_max_length=col_config.tokenizer_max_length,
        )
    else:
        preprocessor = preprocessor_class(distrib_values=distrib_values)
    return preprocessor.transcode(initial_array)


def transcode_optional(
    initial_array: pa.Array, col_config: OptionalCorrelationColumn
) -> t.Union[t.Any, t.List[t.Any]]:
    nulls = initial_array.is_null(nan_is_null=True)
    valid_mask = pa.compute.invert(nulls).cast(pa.int64()).to_numpy()

    padded_array = pc.replace_with_mask(
        initial_array,
        nulls,
        pa.concat_arrays(
            [col_config.child_col.example for _ in range(nulls.sum().as_py())]
        ),
    )
    if (
        col_config.child_col.col_type == TypeKind.Datetime
        and col_config.child_col.distribution_kind
        == DistributionKind.quantiles
    ):
        year, month, day, hour, minute, second = transcode_type(
            initial_array=padded_array, col_config=col_config.child_col
        )
        return [
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: year},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: month},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: day},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: hour},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: minute},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: second},
        ]

    elif (
        col_config.child_col.col_type == TypeKind.Date
        and col_config.child_col.distribution_kind
        == DistributionKind.quantiles
    ):
        year, month, day = transcode_type(
            initial_array=padded_array, col_config=col_config.child_col
        )
        return [
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: year},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: month},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: day},
        ]

    elif (
        col_config.child_col.col_type == TypeKind.Time
        and col_config.child_col.distribution_kind
        == DistributionKind.quantiles
    ):
        hour, minute, second = transcode_type(
            initial_array=padded_array, col_config=col_config.child_col
        )
        return [
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: hour},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: minute},
            {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: second},
        ]

    else:
        new_arr = transcode_type(
            initial_array=padded_array, col_config=col_config.child_col
        )
        return {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: new_arr}


class HasCast(t.Protocol):
    values: t.List

    def cast_array(self, initial_array: pa.Array) -> pa.Array: ...


QuantType = t.TypeVar("QuantType", int, float)


class QuantilePreprocessor(t.Generic[QuantType]):
    def __init__(self, distrib_values: t.List[QuantType]) -> None:
        self.values: t.List[QuantType] = distrib_values

    def transcode(
        self, initial_array: pa.Array
    ) -> np.ndarray[t.Any, np.dtype[np.int_]]:
        if self.values[0] == self.values[1]:
            distrib_values = self.values[1:]
        else:
            distrib_values = self.values
        indices = np.searchsorted(distrib_values, initial_array, side="left")
        unique_val, freq = np.unique(distrib_values, return_counts=True)
        dirac_values = unique_val[freq > 1]
        for element in dirac_values:
            indices = np.where(initial_array == element, indices + 1, indices)
        return t.cast(np.ndarray[t.Any, np.dtype[np.int_]], indices)


class DurationQuantilePreprocessor(QuantilePreprocessor[int]):
    def transcode(
        self, initial_array: pa.Array
    ) -> np.ndarray[t.Any, np.dtype[np.int_]]:
        return super().transcode(initial_array.cast(pa.int64()))


class DateQuantilePreprocessor:
    def __init__(self, distrib_values: t.List[int]) -> None:
        self.values = distrib_values

    def transcode(
        self, initial_array: pa.Array
    ) -> t.Tuple[np.ndarray[t.Any, np.dtype[np.int_]], ...]:
        min_year = pa.compute.year(
            pa.scalar(np.int32(self.values[0]), pa.date32())
        )
        year = pa.compute.subtract(pa.compute.year(initial_array), min_year)
        month = pa.compute.subtract(pa.compute.month(initial_array), 1)
        day = pa.compute.subtract(pa.compute.day(initial_array), 1)
        return (
            year.to_numpy(),
            month.to_numpy(),
            day.to_numpy(),
        )


class DatetimeQuantilePreprocessor:
    def __init__(self, distrib_values: t.List[int]) -> None:
        self.values = distrib_values

    def transcode(
        self, initial_array: pa.Array
    ) -> t.Tuple[np.ndarray[t.Any, np.dtype[np.int_]], ...]:
        min_year = pa.compute.year(
            pa.scalar(self.values[0], pa.timestamp("ns"))
        )
        year = pa.compute.subtract(pa.compute.year(initial_array), min_year)
        month = pa.compute.subtract(pa.compute.month(initial_array), 1)
        day = pa.compute.subtract(pa.compute.day(initial_array), 1)
        hour = pa.compute.hour(initial_array)
        minutes = pa.compute.minute(initial_array)
        seconds = pa.compute.second(initial_array)
        return (
            year.to_numpy(),
            month.to_numpy(),
            day.to_numpy(),
            hour.to_numpy(),
            minutes.to_numpy(),
            seconds.to_numpy(),
        )


class TimeQuantilePreprocessor:
    def __init__(self, distrib_values: t.List[int]) -> None:
        self.values = distrib_values

    def transcode(
        self, initial_array: pa.Array
    ) -> t.Tuple[np.ndarray[t.Any, np.dtype[np.int_]], ...]:
        hour = pa.compute.hour(initial_array)
        minutes = pa.compute.minute(initial_array)
        seconds = pa.compute.second(initial_array)
        return (
            hour.to_numpy(),
            minutes.to_numpy(),
            seconds.to_numpy(),
        )


class TextQuantilePreprocessor:
    def __init__(self, distrib_values: t.List[int], tokenizer_max_length: int):
        self.values = distrib_values
        self.tokenizer_max_length = tokenizer_max_length

    def transcode(
        self, initial_array: pa.Array
    ) -> t.Dict[str, np.ndarray[t.Any, np.dtype[np.int_]]]:
        data = initial_array.to_numpy(zero_copy_only=False)
        # try to get max_length from input, if not env var
        # if not default
        max_length = min(self.tokenizer_max_length, int(max(self.values)))
        tokenized_text = TOKENIZER(
            data.tolist(),
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="np",
        )
        input_ids = tokenized_text.input_ids
        position_ids = np.broadcast_to(
            np.arange(max_length)[None, :],
            (len(input_ids), t.cast(int, max_length)),
        )
        mask = np.concatenate(
            [
                np.ones(
                    shape=(tokenized_text.attention_mask.shape[0], 1),
                    dtype=np.int64,
                ),
                tokenized_text.attention_mask,
            ],
            axis=1,
        )[:, :-1]
        return {
            "input_ids": tokenized_text.input_ids,
            "position_ids": position_ids,
            "attention_mask": mask,
        }


HistType = t.TypeVar("HistType", int, float, str)


class HistogramPreprocessor(t.Generic[HistType]):
    def __init__(
        self,
        distrib_values: t.List[HistType],
    ) -> None:
        self.values: t.List[HistType] = distrib_values

    def transcode(
        self: HasCast, initial_array: pa.Array
    ) -> np.ndarray[t.Any, np.dtype[np.int_]]:
        cast_array = self.cast_array(initial_array)
        cat_type = CategoricalDtype(
            categories=self.values,
            ordered=True,
        )
        return t.cast(
            np.ndarray[t.Any, np.dtype[np.int_]],
            pa.DictionaryArray.from_pandas(
                cast_array.to_pandas(
                    self_destruct=True, split_blocks=False
                ).astype(cat_type),
                type=pa.dictionary(
                    index_type=pa.int64(),
                    value_type=cast_array.type,
                    ordered=True,
                ),
            ).indices.to_numpy(zero_copy_only=False),
        )


class Int64CasterMixin:
    def cast_array(self, initial_array: pa.Array) -> pa.Array:
        return initial_array.cast(pa.int64())


class Int32CasterMixin:
    def cast_array(self, initial_array: pa.Array) -> pa.Array:
        return initial_array.cast(pa.int32())


class Int64HistogramPreprocessor(Int64CasterMixin, HistogramPreprocessor[int]):
    pass


class Int32HistogramPreprocessor(Int32CasterMixin, HistogramPreprocessor[int]):
    pass


class BooleanHistogramPreprocessor(Int64CasterMixin):
    def __init__(self, distrib_values: t.List[int]) -> None:
        self.values = distrib_values

    def transcode(
        self, initial_array: pa.Array
    ) -> np.ndarray[t.Any, np.dtype[np.int_]]:
        return t.cast(
            np.ndarray[t.Any, np.dtype[np.int_]],
            self.cast_array(initial_array).to_numpy(),
        )


class Float64HistogramPreprocessor(HistogramPreprocessor[float]):
    def cast_array(self, initial_array: pa.Array) -> pa.Array:
        return initial_array


class StrHistogramPreprocessor(HistogramPreprocessor[str]):
    def cast_array(self, initial_array: pa.Array) -> pa.Array:
        return initial_array


class _LazyTokenizer(object):
    """A lazily loaded Tokenizer.
    Does not load any data if never called.
    """

    def __init__(self, path: str):
        super(_LazyTokenizer, self).__init__()
        self.path = path

    def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
        return self._tokenizer(*args, **kwargs)

    def __getattr__(self, name: str) -> t.Any:
        return getattr(self._tokenizer, name)

    @functools.cached_property
    def _tokenizer(self) -> t.Any:
        """Loads and cache the tokenizer on its first call"""
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(self.path)
        tokenizer.pad_token = tokenizer.eos_token
        return tokenizer


TOKENIZER = _LazyTokenizer("EleutherAI/gpt-neo-125M")