Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import logging
import typing as t

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import sarus_synthetic_data.configs.typing as st
from sarus_synthetic_data.configs.global_config import (
    CorrelationColumn,
    NonOptionalCorrelationColumn,
    OptionalCorrelationColumn,
)
from sarus_synthetic_data.constants import (
    IS_NOT_NULL,
    OPTIONAL_VALUE,
)
from sarus_synthetic_data.configs.typing import (
    DistributionKind,
    TypeKind,
)
from sarus_synthetic_data.data_processing.preprocessor import TOKENIZER
from sarus_synthetic_data.data_processing.typing import OptionalOutput

logger = logging.getLogger(__name__)


class Postprocessor:
    """Class Responsible to preprocess input data in order to be able to feed it
    either to the Independent of Correlation generator. In fact, pre-processing
    happens only for correlation data."""

    def __init__(
        self, column_config: t.Mapping[str, CorrelationColumn]
    ) -> None:
        self.column_config = column_config

    def post_process_table(self, table_data: t.List[t.Any]) -> pa.Table:
        """This method pre-process the data in a table if some columns are
        generated with correlation. It builds:
        - a PyTree of data where leaves are numpy arrays
        - a list of group indices that belong to the same protected entity
        These elements are stored in the disk.
        """
        struct = []
        fields = []
        col_index = 0
        for col_name, col_config in self.column_config.items():
            if isinstance(col_config, OptionalCorrelationColumn):
                if (
                    col_config.child_col.col_type == TypeKind.Time
                    and col_config.child_col.distribution_kind
                    == DistributionKind.quantiles
                ):
                    col_data = table_data[col_index : col_index + 3]
                    struct.append(
                        inverse_transcode_optional(
                            col_data, col_config=col_config
                        )
                    )
                    fields.append(
                        pa.field(
                            name=col_name,
                            type=col_config.child_col.example.type,
                            nullable=True,
                        )
                    )
                    col_index = col_index + 3

                elif (
                    col_config.child_col.col_type == TypeKind.Date
                    and col_config.child_col.distribution_kind
                    == DistributionKind.quantiles
                ):
                    col_data = table_data[col_index : col_index + 3]
                    struct.append(
                        inverse_transcode_optional(
                            col_data, col_config=col_config
                        )
                    )
                    fields.append(
                        pa.field(
                            name=col_name,
                            type=col_config.child_col.example.type,
                            nullable=True,
                        )
                    )
                    col_index = col_index + 3

                elif (
                    col_config.child_col.col_type == TypeKind.Datetime
                    and col_config.child_col.distribution_kind
                    == DistributionKind.quantiles
                ):
                    col_data = table_data[col_index : col_index + 6]
                    struct.append(
                        inverse_transcode_optional(
                            col_data, col_config=col_config
                        )
                    )
                    fields.append(
                        pa.field(
                            name=col_name,
                            type=col_config.child_col.example.type,
                            nullable=True,
                        )
                    )
                    col_index = col_index + 6

                else:
                    col_data = table_data[col_index]
                    struct.append(
                        inverse_transcode_optional(
                            col_data, col_config=col_config
                        )
                    )
                    fields.append(
                        pa.field(
                            name=col_name,
                            type=col_config.child_col.example.type,
                            nullable=True,
                        )
                    )
                    col_index = col_index + 1

            elif (
                col_config.col_type == TypeKind.Date
                and col_config.distribution_kind == DistributionKind.quantiles
            ):
                col_data = table_data[col_index : col_index + 3]
                struct.append(
                    inverse_transcode_type(col_data, col_config=col_config)
                )
                fields.append(
                    pa.field(
                        name=col_name,
                        type=col_config.example.type,
                        nullable=False,
                    )
                )
                col_index = col_index + 3

            elif (
                col_config.col_type == TypeKind.Time
                and col_config.distribution_kind == DistributionKind.quantiles
            ):
                col_data = table_data[col_index : col_index + 3]

                struct.append(
                    inverse_transcode_type(col_data, col_config=col_config)
                )
                fields.append(
                    pa.field(
                        name=col_name,
                        type=col_config.example.type,
                        nullable=False,
                    )
                )
                col_index = col_index + 3

            elif (
                col_config.col_type == TypeKind.Datetime
                and col_config.distribution_kind == DistributionKind.quantiles
            ):
                col_data = table_data[col_index : col_index + 6]

                struct.append(
                    inverse_transcode_type(col_data, col_config=col_config)
                )
                fields.append(
                    pa.field(
                        name=col_name,
                        type=col_config.example.type,
                        nullable=False,
                    )
                )
                col_index = col_index + 6

            else:
                col_data = table_data[col_index]
                struct.append(
                    inverse_transcode_type(
                        col_data,
                        col_config=col_config,
                    )
                )
                fields.append(
                    pa.field(
                        name=col_name,
                        type=col_config.example.type,
                        nullable=False,
                    )
                )
                col_index = col_index + 1

        return pa.Table.from_arrays(arrays=struct, schema=pa.schema(fields))


def inverse_transcode_type(
    data_input: pa.Array, col_config: NonOptionalCorrelationColumn
) -> pa.Array:
    """Method that changes an initial Pyarrow array to a
    container of numpy arrays of integers. Additional properties
    is a dict to store information that can be transmitted for
    specific types eg:"""

    kind = col_config.col_type
    distribution_kind = col_config.distribution_kind
    distrib_values = col_config.distribution.values

    PostProcessor_classes = {
        st.DistributionKind.histogram: {
            st.TypeKind.Text: StrHistogramPostProcessor,
            st.TypeKind.Float: HistogramPostProcessor[float],
            st.TypeKind.Time: TimeHistogramPostProcessor,
            st.TypeKind.Integer: HistogramPostProcessor[int],
            st.TypeKind.Date: DateHistogramPostProcessor,
            st.TypeKind.Datetime: DatetimeHistogramPostProcessor,
            st.TypeKind.Duration: DurationHistogramPostProcessor,
            st.TypeKind.Boolean: BooleanHistogramPostProcessor,
        },
        st.DistributionKind.quantiles: {
            st.TypeKind.Integer: QuantilePostProcessor[int],
            st.TypeKind.Float: QuantilePostProcessor[float],
            st.TypeKind.Date: DateQuantilePostProcessor,
            st.TypeKind.Datetime: DatetimeQuantilePostProcessor,
            st.TypeKind.Time: TimeQuantilePostProcessor,
            st.TypeKind.Duration: DurationQuantilePostProcessor,
            st.TypeKind.Text: TextQuantilePostProcessor,
        },
    }

    post_processor_class = PostProcessor_classes[distribution_kind][kind]
    postprocessor = post_processor_class(distrib_values=distrib_values)
    out = postprocessor.inverse_transcode(data_input)
    return out


def inverse_transcode_optional(
    data_input: t.Union[OptionalOutput, t.List[OptionalOutput]],
    col_config: OptionalCorrelationColumn,
) -> pa.Array:
    if isinstance(data_input, list):
        valid_masks = [d[IS_NOT_NULL].astype(bool) for d in data_input]  # type: ignore
        valid_mask = valid_masks[0]
        for mask in valid_masks[1:]:
            valid_mask = valid_mask & mask
        child_array = inverse_transcode_type(
            [d[OPTIONAL_VALUE] for d in data_input],  # type: ignore
            col_config=col_config.child_col,
        )
    else:
        valid_mask = data_input[IS_NOT_NULL].astype(bool)  # type: ignore
        child_array = inverse_transcode_type(
            data_input[OPTIONAL_VALUE],  # type: ignore
            col_config=col_config.child_col,
        )

    return pc.if_else(
        valid_mask,
        child_array,
        pa.nulls(len(child_array), type=child_array.type),
    )


QuantType = t.TypeVar("QuantType", int, float)


class QuantilePostProcessor(t.Generic[QuantType]):
    def __init__(self, distrib_values: t.List[QuantType]) -> None:
        self.values: t.List[QuantType] = distrib_values

    def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
        if self.values[0] == self.values[1]:
            values = self.values[1:]
        else:
            values = self.values
        out = np.full(
            fill_value=values[0],
            shape=len(data_input),
        )
        # now sample between each quantile
        other_values_mask = data_input > 0
        filtered_values = data_input[other_values_mask]
        upper_bound = np.take_along_axis(
            np.array(values), filtered_values, axis=0
        )
        lower_bound = np.take_along_axis(
            np.array(values), filtered_values - 1, axis=0
        )
        samples = np.random.uniform(low=lower_bound, high=upper_bound)
        out[other_values_mask] = samples
        return pa.array(out.astype(type(values[0])))


class DurationQuantilePostProcessor(QuantilePostProcessor[int]):
    def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
        out = super().inverse_transcode(data_input)
        return pa.array(
            pd.to_timedelta(out.to_numpy(), "us"), type=pa.duration("us")
        )


class DateQuantilePostProcessor:
    def __init__(self, distrib_values: t.List[int]) -> None:
        self.values = distrib_values

    def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
        min_year = pa.compute.year(
            pa.scalar(np.int32(self.values[0]), pa.date32())
        ).as_py()
        year = data_input[0] + min_year
        month = data_input[1] + 1
        day = data_input[2] + 1

        # correct for months with 30 days
        day = np.where(
            np.logical_and(np.isin(month, [4, 6, 9, 11]), day == 31),
            30,
            day,
        )
        # correct for february
        day = np.where(np.logical_and(month == 2, day > 28), 28, day)
        np_values = pd.to_datetime(
            {
                "year": year,
                "month": month,
                "day": day,
            }
        ).values.astype("datetime64[D]")

        # clip to avoid going further from min/max,
        # can happen due to dpsgd
        np_values = np.clip(
            np_values.astype(np.int32),
            a_min=self.values[0],
            a_max=self.values[-1],
        )
        return pa.array(
            np_values,
            pa.date32(),
        )


class DatetimeQuantilePostProcessor:
    def __init__(self, distrib_values: t.List[int]) -> None:
        self.values = distrib_values

    def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
        min_year = pa.compute.year(
            pa.scalar(self.values[0], pa.timestamp("ns"))
        ).as_py()
        year = data_input[0] + min_year
        month = data_input[1] + 1
        day = data_input[2] + 1
        hour = data_input[3]
        minutes = data_input[4]
        seconds = data_input[5]

        # correct for months with 30 days
        day = np.where(
            np.logical_and(np.isin(month, [4, 6, 9, 11]), day == 31),
            30,
            day,
        )
        # correct for february
        day = np.where(np.logical_and(month == 2, day > 28), 28, day)
        np_values = pd.to_datetime(
            {
                "year": year,
                "month": month,
                "day": day,
                "hour": hour,
                "minutes": minutes,
                "seconds": seconds,
            }
        ).values
        # clip to avoid going further from min/max,
        # can happen due to dpsgd
        np_values = np.clip(
            np_values.astype(np.int64),
            a_min=self.values[0],
            a_max=self.values[-1],
        )
        return pa.array(
            np_values,
            pa.timestamp(unit="ns"),
        )


class TimeQuantilePostProcessor:
    def __init__(self, distrib_values: t.List[int]) -> None:
        self.values = distrib_values

    def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
        hour = data_input[0]
        minutes = data_input[1]
        seconds = data_input[2]

        # pandas to_datetime needs year month and
        # day to compose a datetime
        np_values = pd.to_datetime(
            {
                "year": np.zeros_like(hour) + 1970,
                "month": np.ones_like(hour),
                "day": np.ones_like(hour),
                "hour": hour,
                "minutes": minutes,
                "seconds": seconds,
            }
        ).values
        # clip to avoid going further from min/max,
        # can happen due to dpsgd
        np_values = np.clip(
            np_values.astype(np.int64),
            a_min=self.values[0] * 1000,
            a_max=self.values[-1] * 1000,
        )
        return pa.compute.cast(
            pa.array(np_values, pa.time64("ns")), pa.time64("us")
        )


class TextQuantilePostProcessor:
    def __init__(self, distrib_values: t.List[int]):
        self.values = distrib_values

    def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
        text = TOKENIZER.batch_decode(data_input, skip_special_tokens=True)
        return pa.array(text).cast(pa.large_string())


HistType = t.TypeVar("HistType", int, float, str)


class HistogramPostProcessor(t.Generic[HistType]):
    def __init__(
        self,
        distrib_values: t.List[HistType],
    ) -> None:
        self.values: t.List[HistType] = distrib_values

    def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
        return pa.array(
            np.take_along_axis(np.array(self.values), data_input, axis=0)
        )


class BooleanHistogramPostProcessor(HistogramPostProcessor[int]):
    def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
        out = super().inverse_transcode(data_input)
        return out.cast(pa.bool_())


class StrHistogramPostProcessor(HistogramPostProcessor[str]):
    def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
        out = super().inverse_transcode(data_input)
        return out.cast(pa.large_string())


class DateHistogramPostProcessor(HistogramPostProcessor[int]):
    def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
        out = super().inverse_transcode(data_input)
        return out.cast(pa.int32()).cast(pa.date32())


class DatetimeHistogramPostProcessor(HistogramPostProcessor[int]):
    def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
        out = super().inverse_transcode(data_input)
        return out.cast(pa.timestamp("ns"))


class TimeHistogramPostProcessor(HistogramPostProcessor[int]):
    def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
        out = super().inverse_transcode(data_input)
        return out.cast(pa.time64("us"))


class DurationHistogramPostProcessor(HistogramPostProcessor[int]):
    def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
        out = super().inverse_transcode(data_input)
        return out.cast(pa.duration("us"))