Repository URL to install this package:
|
Version:
4.0.7 ▾
|
import logging
import typing as t
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import sarus_synthetic_data.configs.typing as st
from sarus_synthetic_data.configs.global_config import (
CorrelationColumn,
NonOptionalCorrelationColumn,
OptionalCorrelationColumn,
)
from sarus_synthetic_data.constants import (
IS_NOT_NULL,
OPTIONAL_VALUE,
)
from sarus_synthetic_data.configs.typing import (
DistributionKind,
TypeKind,
)
from sarus_synthetic_data.data_processing.preprocessor import TOKENIZER
from sarus_synthetic_data.data_processing.typing import OptionalOutput
logger = logging.getLogger(__name__)
class Postprocessor:
"""Class Responsible to preprocess input data in order to be able to feed it
either to the Independent of Correlation generator. In fact, pre-processing
happens only for correlation data."""
def __init__(
self, column_config: t.Mapping[str, CorrelationColumn]
) -> None:
self.column_config = column_config
def post_process_table(self, table_data: t.List[t.Any]) -> pa.Table:
"""This method pre-process the data in a table if some columns are
generated with correlation. It builds:
- a PyTree of data where leaves are numpy arrays
- a list of group indices that belong to the same protected entity
These elements are stored in the disk.
"""
struct = []
fields = []
col_index = 0
for col_name, col_config in self.column_config.items():
if isinstance(col_config, OptionalCorrelationColumn):
if (
col_config.child_col.col_type == TypeKind.Time
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
col_data = table_data[col_index : col_index + 3]
struct.append(
inverse_transcode_optional(
col_data, col_config=col_config
)
)
fields.append(
pa.field(
name=col_name,
type=col_config.child_col.example.type,
nullable=True,
)
)
col_index = col_index + 3
elif (
col_config.child_col.col_type == TypeKind.Date
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
col_data = table_data[col_index : col_index + 3]
struct.append(
inverse_transcode_optional(
col_data, col_config=col_config
)
)
fields.append(
pa.field(
name=col_name,
type=col_config.child_col.example.type,
nullable=True,
)
)
col_index = col_index + 3
elif (
col_config.child_col.col_type == TypeKind.Datetime
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
col_data = table_data[col_index : col_index + 6]
struct.append(
inverse_transcode_optional(
col_data, col_config=col_config
)
)
fields.append(
pa.field(
name=col_name,
type=col_config.child_col.example.type,
nullable=True,
)
)
col_index = col_index + 6
else:
col_data = table_data[col_index]
struct.append(
inverse_transcode_optional(
col_data, col_config=col_config
)
)
fields.append(
pa.field(
name=col_name,
type=col_config.child_col.example.type,
nullable=True,
)
)
col_index = col_index + 1
elif (
col_config.col_type == TypeKind.Date
and col_config.distribution_kind == DistributionKind.quantiles
):
col_data = table_data[col_index : col_index + 3]
struct.append(
inverse_transcode_type(col_data, col_config=col_config)
)
fields.append(
pa.field(
name=col_name,
type=col_config.example.type,
nullable=False,
)
)
col_index = col_index + 3
elif (
col_config.col_type == TypeKind.Time
and col_config.distribution_kind == DistributionKind.quantiles
):
col_data = table_data[col_index : col_index + 3]
struct.append(
inverse_transcode_type(col_data, col_config=col_config)
)
fields.append(
pa.field(
name=col_name,
type=col_config.example.type,
nullable=False,
)
)
col_index = col_index + 3
elif (
col_config.col_type == TypeKind.Datetime
and col_config.distribution_kind == DistributionKind.quantiles
):
col_data = table_data[col_index : col_index + 6]
struct.append(
inverse_transcode_type(col_data, col_config=col_config)
)
fields.append(
pa.field(
name=col_name,
type=col_config.example.type,
nullable=False,
)
)
col_index = col_index + 6
else:
col_data = table_data[col_index]
struct.append(
inverse_transcode_type(
col_data,
col_config=col_config,
)
)
fields.append(
pa.field(
name=col_name,
type=col_config.example.type,
nullable=False,
)
)
col_index = col_index + 1
return pa.Table.from_arrays(arrays=struct, schema=pa.schema(fields))
def inverse_transcode_type(
data_input: pa.Array, col_config: NonOptionalCorrelationColumn
) -> pa.Array:
"""Method that changes an initial Pyarrow array to a
container of numpy arrays of integers. Additional properties
is a dict to store information that can be transmitted for
specific types eg:"""
kind = col_config.col_type
distribution_kind = col_config.distribution_kind
distrib_values = col_config.distribution.values
PostProcessor_classes = {
st.DistributionKind.histogram: {
st.TypeKind.Text: StrHistogramPostProcessor,
st.TypeKind.Float: HistogramPostProcessor[float],
st.TypeKind.Time: TimeHistogramPostProcessor,
st.TypeKind.Integer: HistogramPostProcessor[int],
st.TypeKind.Date: DateHistogramPostProcessor,
st.TypeKind.Datetime: DatetimeHistogramPostProcessor,
st.TypeKind.Duration: DurationHistogramPostProcessor,
st.TypeKind.Boolean: BooleanHistogramPostProcessor,
},
st.DistributionKind.quantiles: {
st.TypeKind.Integer: QuantilePostProcessor[int],
st.TypeKind.Float: QuantilePostProcessor[float],
st.TypeKind.Date: DateQuantilePostProcessor,
st.TypeKind.Datetime: DatetimeQuantilePostProcessor,
st.TypeKind.Time: TimeQuantilePostProcessor,
st.TypeKind.Duration: DurationQuantilePostProcessor,
st.TypeKind.Text: TextQuantilePostProcessor,
},
}
post_processor_class = PostProcessor_classes[distribution_kind][kind]
postprocessor = post_processor_class(distrib_values=distrib_values)
out = postprocessor.inverse_transcode(data_input)
return out
def inverse_transcode_optional(
data_input: t.Union[OptionalOutput, t.List[OptionalOutput]],
col_config: OptionalCorrelationColumn,
) -> pa.Array:
if isinstance(data_input, list):
valid_masks = [d[IS_NOT_NULL].astype(bool) for d in data_input] # type: ignore
valid_mask = valid_masks[0]
for mask in valid_masks[1:]:
valid_mask = valid_mask & mask
child_array = inverse_transcode_type(
[d[OPTIONAL_VALUE] for d in data_input], # type: ignore
col_config=col_config.child_col,
)
else:
valid_mask = data_input[IS_NOT_NULL].astype(bool) # type: ignore
child_array = inverse_transcode_type(
data_input[OPTIONAL_VALUE], # type: ignore
col_config=col_config.child_col,
)
return pc.if_else(
valid_mask,
child_array,
pa.nulls(len(child_array), type=child_array.type),
)
QuantType = t.TypeVar("QuantType", int, float)
class QuantilePostProcessor(t.Generic[QuantType]):
def __init__(self, distrib_values: t.List[QuantType]) -> None:
self.values: t.List[QuantType] = distrib_values
def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
if self.values[0] == self.values[1]:
values = self.values[1:]
else:
values = self.values
out = np.full(
fill_value=values[0],
shape=len(data_input),
)
# now sample between each quantile
other_values_mask = data_input > 0
filtered_values = data_input[other_values_mask]
upper_bound = np.take_along_axis(
np.array(values), filtered_values, axis=0
)
lower_bound = np.take_along_axis(
np.array(values), filtered_values - 1, axis=0
)
samples = np.random.uniform(low=lower_bound, high=upper_bound)
out[other_values_mask] = samples
return pa.array(out.astype(type(values[0])))
class DurationQuantilePostProcessor(QuantilePostProcessor[int]):
def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
out = super().inverse_transcode(data_input)
return pa.array(
pd.to_timedelta(out.to_numpy(), "us"), type=pa.duration("us")
)
class DateQuantilePostProcessor:
def __init__(self, distrib_values: t.List[int]) -> None:
self.values = distrib_values
def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
min_year = pa.compute.year(
pa.scalar(np.int32(self.values[0]), pa.date32())
).as_py()
year = data_input[0] + min_year
month = data_input[1] + 1
day = data_input[2] + 1
# correct for months with 30 days
day = np.where(
np.logical_and(np.isin(month, [4, 6, 9, 11]), day == 31),
30,
day,
)
# correct for february
day = np.where(np.logical_and(month == 2, day > 28), 28, day)
np_values = pd.to_datetime(
{
"year": year,
"month": month,
"day": day,
}
).values.astype("datetime64[D]")
# clip to avoid going further from min/max,
# can happen due to dpsgd
np_values = np.clip(
np_values.astype(np.int32),
a_min=self.values[0],
a_max=self.values[-1],
)
return pa.array(
np_values,
pa.date32(),
)
class DatetimeQuantilePostProcessor:
def __init__(self, distrib_values: t.List[int]) -> None:
self.values = distrib_values
def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
min_year = pa.compute.year(
pa.scalar(self.values[0], pa.timestamp("ns"))
).as_py()
year = data_input[0] + min_year
month = data_input[1] + 1
day = data_input[2] + 1
hour = data_input[3]
minutes = data_input[4]
seconds = data_input[5]
# correct for months with 30 days
day = np.where(
np.logical_and(np.isin(month, [4, 6, 9, 11]), day == 31),
30,
day,
)
# correct for february
day = np.where(np.logical_and(month == 2, day > 28), 28, day)
np_values = pd.to_datetime(
{
"year": year,
"month": month,
"day": day,
"hour": hour,
"minutes": minutes,
"seconds": seconds,
}
).values
# clip to avoid going further from min/max,
# can happen due to dpsgd
np_values = np.clip(
np_values.astype(np.int64),
a_min=self.values[0],
a_max=self.values[-1],
)
return pa.array(
np_values,
pa.timestamp(unit="ns"),
)
class TimeQuantilePostProcessor:
def __init__(self, distrib_values: t.List[int]) -> None:
self.values = distrib_values
def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
hour = data_input[0]
minutes = data_input[1]
seconds = data_input[2]
# pandas to_datetime needs year month and
# day to compose a datetime
np_values = pd.to_datetime(
{
"year": np.zeros_like(hour) + 1970,
"month": np.ones_like(hour),
"day": np.ones_like(hour),
"hour": hour,
"minutes": minutes,
"seconds": seconds,
}
).values
# clip to avoid going further from min/max,
# can happen due to dpsgd
np_values = np.clip(
np_values.astype(np.int64),
a_min=self.values[0] * 1000,
a_max=self.values[-1] * 1000,
)
return pa.compute.cast(
pa.array(np_values, pa.time64("ns")), pa.time64("us")
)
class TextQuantilePostProcessor:
def __init__(self, distrib_values: t.List[int]):
self.values = distrib_values
def inverse_transcode(self, data_input: pa.Array) -> pa.Array:
text = TOKENIZER.batch_decode(data_input, skip_special_tokens=True)
return pa.array(text).cast(pa.large_string())
HistType = t.TypeVar("HistType", int, float, str)
class HistogramPostProcessor(t.Generic[HistType]):
def __init__(
self,
distrib_values: t.List[HistType],
) -> None:
self.values: t.List[HistType] = distrib_values
def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
return pa.array(
np.take_along_axis(np.array(self.values), data_input, axis=0)
)
class BooleanHistogramPostProcessor(HistogramPostProcessor[int]):
def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
out = super().inverse_transcode(data_input)
return out.cast(pa.bool_())
class StrHistogramPostProcessor(HistogramPostProcessor[str]):
def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
out = super().inverse_transcode(data_input)
return out.cast(pa.large_string())
class DateHistogramPostProcessor(HistogramPostProcessor[int]):
def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
out = super().inverse_transcode(data_input)
return out.cast(pa.int32()).cast(pa.date32())
class DatetimeHistogramPostProcessor(HistogramPostProcessor[int]):
def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
out = super().inverse_transcode(data_input)
return out.cast(pa.timestamp("ns"))
class TimeHistogramPostProcessor(HistogramPostProcessor[int]):
def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
out = super().inverse_transcode(data_input)
return out.cast(pa.time64("us"))
class DurationHistogramPostProcessor(HistogramPostProcessor[int]):
def inverse_transcode(self, data_input: np.ndarray) -> pa.Array:
out = super().inverse_transcode(data_input)
return out.cast(pa.duration("us"))