Repository URL to install this package:
|
Version:
4.0.7 ▾
|
import functools
import logging
import os
import pickle
import typing as t
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
from pandas.api.types import CategoricalDtype
import sarus_synthetic_data.configs.typing as st
from sarus_synthetic_data.configs.global_config import (
NonOptionalCorrelationColumn,
OptionalCorrelationColumn,
TextCorrelationCol,
)
from sarus_synthetic_data.constants import (
IS_NOT_NULL,
OPTIONAL_VALUE,
)
from sarus_synthetic_data.data_processing.typing import (
PreprocessingConfig,
TableConfig,
)
from sarus_synthetic_data.configs.typing import (
DistributionKind,
TypeKind,
)
logger = logging.getLogger(__name__)
class Preprocessor:
"""Class Responsible to preprocess input data in order to be able to feed it
either to the Independent of Correlation generator. In fact, pre-processing
happens only for correlation data."""
def __init__(self, config: PreprocessingConfig) -> None:
self.config = config
def preprocess_tables(self) -> None:
for table_name, table_config in self.config.tables.items():
current_saving_dir = os.path.join(
self.config.saving_directory, *table_name
)
logger.info(
f"Preprocessing table {table_name} in"
f"directory {current_saving_dir}"
)
os.makedirs(current_saving_dir, exist_ok=True)
self.preprocess_table(
saving_dir=current_saving_dir,
table_config=table_config,
privacy_unit_col=self.config.privacy_unit_col,
)
logger.info(f"Done Preprocessing table {table_name}")
def preprocess_table(
self, saving_dir: str, table_config: TableConfig, privacy_unit_col: str
) -> None:
"""This method pre-process the data in a table if some columns are
generated with correlation. It builds:
- a PyTree of data where leaves are numpy arrays
- a list of group indices that belong to the same protected entity
These elements are stored in the disk.
"""
correlation_config = table_config.correlation_generation
if correlation_config is not None:
data = pq.read_table(table_config.data_uri)
# Preprocess correlation only
col_correlation = list(correlation_config.columns.keys())
data_correlation = data.select(col_correlation)
# Step 1: transform correlation_data in numpy list
struct = []
for col_name, col_config in correlation_config.columns.items():
if isinstance(col_config, OptionalCorrelationColumn):
if (
col_config.child_col.col_type == TypeKind.Time
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
hour, minute, second = transcode_optional(
data_correlation.column(col_name).combine_chunks(),
col_config=col_config,
)
struct.append(hour)
struct.append(minute)
struct.append(second)
elif (
col_config.child_col.col_type == TypeKind.Datetime
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
year, month, day, hour, minute, second = (
transcode_optional(
data_correlation.column(
col_name
).combine_chunks(),
col_config=col_config,
)
)
struct.append(year)
struct.append(month)
struct.append(day)
struct.append(hour)
struct.append(minute)
struct.append(second)
elif (
col_config.child_col.col_type == TypeKind.Date
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
year, month, day = transcode_optional(
data_correlation.column(col_name).combine_chunks(),
col_config=col_config,
)
struct.append(year)
struct.append(month)
struct.append(day)
else:
struct.append(
transcode_optional(
data_correlation.column(
col_name
).combine_chunks(),
col_config=col_config,
)
)
else:
if (
col_config.col_type == TypeKind.Date
and col_config.distribution_kind
== DistributionKind.quantiles
):
year, month, day = transcode_type(
data_correlation.column(col_name).combine_chunks(),
col_config=col_config,
)
struct.append(year)
struct.append(month)
struct.append(day)
elif (
col_config.col_type == TypeKind.Datetime
and col_config.distribution_kind
== DistributionKind.quantiles
):
year, month, day, hour, minute, second = (
transcode_type(
data_correlation.column(
col_name
).combine_chunks(),
col_config=col_config,
)
)
struct.append(year)
struct.append(month)
struct.append(day)
struct.append(hour)
struct.append(minute)
struct.append(second)
elif (
col_config.col_type == TypeKind.Time
and col_config.distribution_kind
== DistributionKind.quantiles
):
hour, minute, second = transcode_type(
data_correlation.column(col_name).combine_chunks(),
col_config=col_config,
)
struct.append(hour)
struct.append(minute)
struct.append(second)
else:
struct.append(
transcode_type(
data_correlation.column(
col_name
).combine_chunks(),
col_config=col_config,
)
)
# Compute groups:
data = data.append_column(
"sarus_index", pa.array(np.arange(len(data)))
)
groups = (
data.group_by(privacy_unit_col)
.aggregate([("sarus_index", "list")])["sarus_index_list"]
.to_pylist()
)
# Save
file_dir = os.path.join(saving_dir, "correlation_data.pkl")
with open(file_dir, "wb") as file:
pickle.dump((groups, struct), file)
def transcode_type(
initial_array: pa.Array, col_config: NonOptionalCorrelationColumn
) -> t.Any:
"""Method that changes an initial Pyarrow array to a
container of numpy arrays of integers. Additional properties
is a dict to store information that can be transmitted for
specific types eg:"""
kind = col_config.col_type
distribution_kind = col_config.distribution_kind
distrib_values = col_config.distribution.values
preprocessor_classes = {
st.DistributionKind.histogram: {
st.TypeKind.Text: StrHistogramPreprocessor,
st.TypeKind.Float: Float64HistogramPreprocessor,
st.TypeKind.Time: Int64HistogramPreprocessor,
st.TypeKind.Integer: Int64HistogramPreprocessor,
st.TypeKind.Date: Int32HistogramPreprocessor,
st.TypeKind.Datetime: Int64HistogramPreprocessor,
st.TypeKind.Duration: Int64HistogramPreprocessor,
st.TypeKind.Boolean: BooleanHistogramPreprocessor,
},
st.DistributionKind.quantiles: {
st.TypeKind.Integer: QuantilePreprocessor[int],
st.TypeKind.Float: QuantilePreprocessor[float],
st.TypeKind.Date: DateQuantilePreprocessor,
st.TypeKind.Datetime: DatetimeQuantilePreprocessor,
st.TypeKind.Time: TimeQuantilePreprocessor,
st.TypeKind.Duration: DurationQuantilePreprocessor,
st.TypeKind.Text: TextQuantilePreprocessor,
},
}
preprocessor_class = preprocessor_classes[distribution_kind][kind]
if isinstance(col_config, TextCorrelationCol):
preprocessor = preprocessor_class(
distrib_values=distrib_values,
tokenizer_max_length=col_config.tokenizer_max_length,
)
else:
preprocessor = preprocessor_class(distrib_values=distrib_values)
return preprocessor.transcode(initial_array)
def transcode_optional(
initial_array: pa.Array, col_config: OptionalCorrelationColumn
) -> t.Union[t.Any, t.List[t.Any]]:
nulls = initial_array.is_null(nan_is_null=True)
valid_mask = pa.compute.invert(nulls).cast(pa.int64()).to_numpy()
padded_array = pc.replace_with_mask(
initial_array,
nulls,
pa.concat_arrays(
[col_config.child_col.example for _ in range(nulls.sum().as_py())]
),
)
if (
col_config.child_col.col_type == TypeKind.Datetime
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
year, month, day, hour, minute, second = transcode_type(
initial_array=padded_array, col_config=col_config.child_col
)
return [
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: year},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: month},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: day},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: hour},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: minute},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: second},
]
elif (
col_config.child_col.col_type == TypeKind.Date
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
year, month, day = transcode_type(
initial_array=padded_array, col_config=col_config.child_col
)
return [
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: year},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: month},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: day},
]
elif (
col_config.child_col.col_type == TypeKind.Time
and col_config.child_col.distribution_kind
== DistributionKind.quantiles
):
hour, minute, second = transcode_type(
initial_array=padded_array, col_config=col_config.child_col
)
return [
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: hour},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: minute},
{IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: second},
]
else:
new_arr = transcode_type(
initial_array=padded_array, col_config=col_config.child_col
)
return {IS_NOT_NULL: valid_mask, OPTIONAL_VALUE: new_arr}
class HasCast(t.Protocol):
values: t.List
def cast_array(self, initial_array: pa.Array) -> pa.Array: ...
QuantType = t.TypeVar("QuantType", int, float)
class QuantilePreprocessor(t.Generic[QuantType]):
def __init__(self, distrib_values: t.List[QuantType]) -> None:
self.values: t.List[QuantType] = distrib_values
def transcode(
self, initial_array: pa.Array
) -> np.ndarray[t.Any, np.dtype[np.int_]]:
if self.values[0] == self.values[1]:
distrib_values = self.values[1:]
else:
distrib_values = self.values
indices = np.searchsorted(distrib_values, initial_array, side="left")
unique_val, freq = np.unique(distrib_values, return_counts=True)
dirac_values = unique_val[freq > 1]
for element in dirac_values:
indices = np.where(initial_array == element, indices + 1, indices)
return t.cast(np.ndarray[t.Any, np.dtype[np.int_]], indices)
class DurationQuantilePreprocessor(QuantilePreprocessor[int]):
def transcode(
self, initial_array: pa.Array
) -> np.ndarray[t.Any, np.dtype[np.int_]]:
return super().transcode(initial_array.cast(pa.int64()))
class DateQuantilePreprocessor:
def __init__(self, distrib_values: t.List[int]) -> None:
self.values = distrib_values
def transcode(
self, initial_array: pa.Array
) -> t.Tuple[np.ndarray[t.Any, np.dtype[np.int_]], ...]:
min_year = pa.compute.year(
pa.scalar(np.int32(self.values[0]), pa.date32())
)
year = pa.compute.subtract(pa.compute.year(initial_array), min_year)
month = pa.compute.subtract(pa.compute.month(initial_array), 1)
day = pa.compute.subtract(pa.compute.day(initial_array), 1)
return (
year.to_numpy(),
month.to_numpy(),
day.to_numpy(),
)
class DatetimeQuantilePreprocessor:
def __init__(self, distrib_values: t.List[int]) -> None:
self.values = distrib_values
def transcode(
self, initial_array: pa.Array
) -> t.Tuple[np.ndarray[t.Any, np.dtype[np.int_]], ...]:
min_year = pa.compute.year(
pa.scalar(self.values[0], pa.timestamp("ns"))
)
year = pa.compute.subtract(pa.compute.year(initial_array), min_year)
month = pa.compute.subtract(pa.compute.month(initial_array), 1)
day = pa.compute.subtract(pa.compute.day(initial_array), 1)
hour = pa.compute.hour(initial_array)
minutes = pa.compute.minute(initial_array)
seconds = pa.compute.second(initial_array)
return (
year.to_numpy(),
month.to_numpy(),
day.to_numpy(),
hour.to_numpy(),
minutes.to_numpy(),
seconds.to_numpy(),
)
class TimeQuantilePreprocessor:
def __init__(self, distrib_values: t.List[int]) -> None:
self.values = distrib_values
def transcode(
self, initial_array: pa.Array
) -> t.Tuple[np.ndarray[t.Any, np.dtype[np.int_]], ...]:
hour = pa.compute.hour(initial_array)
minutes = pa.compute.minute(initial_array)
seconds = pa.compute.second(initial_array)
return (
hour.to_numpy(),
minutes.to_numpy(),
seconds.to_numpy(),
)
class TextQuantilePreprocessor:
def __init__(self, distrib_values: t.List[int], tokenizer_max_length: int):
self.values = distrib_values
self.tokenizer_max_length = tokenizer_max_length
def transcode(
self, initial_array: pa.Array
) -> t.Dict[str, np.ndarray[t.Any, np.dtype[np.int_]]]:
data = initial_array.to_numpy(zero_copy_only=False)
# try to get max_length from input, if not env var
# if not default
max_length = min(self.tokenizer_max_length, int(max(self.values)))
tokenized_text = TOKENIZER(
data.tolist(),
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="np",
)
input_ids = tokenized_text.input_ids
position_ids = np.broadcast_to(
np.arange(max_length)[None, :],
(len(input_ids), t.cast(int, max_length)),
)
mask = np.concatenate(
[
np.ones(
shape=(tokenized_text.attention_mask.shape[0], 1),
dtype=np.int64,
),
tokenized_text.attention_mask,
],
axis=1,
)[:, :-1]
return {
"input_ids": tokenized_text.input_ids,
"position_ids": position_ids,
"attention_mask": mask,
}
HistType = t.TypeVar("HistType", int, float, str)
class HistogramPreprocessor(t.Generic[HistType]):
def __init__(
self,
distrib_values: t.List[HistType],
) -> None:
self.values: t.List[HistType] = distrib_values
def transcode(
self: HasCast, initial_array: pa.Array
) -> np.ndarray[t.Any, np.dtype[np.int_]]:
cast_array = self.cast_array(initial_array)
cat_type = CategoricalDtype(
categories=self.values,
ordered=True,
)
return t.cast(
np.ndarray[t.Any, np.dtype[np.int_]],
pa.DictionaryArray.from_pandas(
cast_array.to_pandas(
self_destruct=True, split_blocks=False
).astype(cat_type),
type=pa.dictionary(
index_type=pa.int64(),
value_type=cast_array.type,
ordered=True,
),
).indices.to_numpy(zero_copy_only=False),
)
class Int64CasterMixin:
def cast_array(self, initial_array: pa.Array) -> pa.Array:
return initial_array.cast(pa.int64())
class Int32CasterMixin:
def cast_array(self, initial_array: pa.Array) -> pa.Array:
return initial_array.cast(pa.int32())
class Int64HistogramPreprocessor(Int64CasterMixin, HistogramPreprocessor[int]):
pass
class Int32HistogramPreprocessor(Int32CasterMixin, HistogramPreprocessor[int]):
pass
class BooleanHistogramPreprocessor(Int64CasterMixin):
def __init__(self, distrib_values: t.List[int]) -> None:
self.values = distrib_values
def transcode(
self, initial_array: pa.Array
) -> np.ndarray[t.Any, np.dtype[np.int_]]:
return t.cast(
np.ndarray[t.Any, np.dtype[np.int_]],
self.cast_array(initial_array).to_numpy(),
)
class Float64HistogramPreprocessor(HistogramPreprocessor[float]):
def cast_array(self, initial_array: pa.Array) -> pa.Array:
return initial_array
class StrHistogramPreprocessor(HistogramPreprocessor[str]):
def cast_array(self, initial_array: pa.Array) -> pa.Array:
return initial_array
class _LazyTokenizer(object):
"""A lazily loaded Tokenizer.
Does not load any data if never called.
"""
def __init__(self, path: str):
super(_LazyTokenizer, self).__init__()
self.path = path
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
return self._tokenizer(*args, **kwargs)
def __getattr__(self, name: str) -> t.Any:
return getattr(self._tokenizer, name)
@functools.cached_property
def _tokenizer(self) -> t.Any:
"""Loads and cache the tokenizer on its first call"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.path)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
TOKENIZER = _LazyTokenizer("EleutherAI/gpt-neo-125M")