Repository URL to install this package:
|
Version:
4.0.7 ▾
|
import os
import typing as t
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from sarus_statistics.ops.histograms import NoiseKind, private_histogram
import sarus_synthetic_data.independent_generator.jax_dataclasses as jax_dataclasses # noqa : E501
from sarus_synthetic_data.shared.memory_utils import (
jax_cleanup,
jaxjit_cleanup,
)
ALPHABET_LENGTH = 387 # number needed to code all unicode characters
MAX_TRIGRAM_TRAINING_LENGTH = "max_training_trigram_length"
max_training_trigram_length = int(
os.environ.get(MAX_TRIGRAM_TRAINING_LENGTH, 1000)
)
class TextGenerator:
"""Text Generator based on trigrams"""
def __init__(
self,
sample_batch_size: int,
lengths: t.List[int],
probabilities: t.List[float],
chars_sorted: jnp.ndarray,
unigram_transitions: jnp.ndarray,
bigram_transitions: jnp.ndarray,
trigram_transitions: jnp.ndarray,
alphabet_length: int,
split: int,
random_gen: np.random.Generator,
) -> None:
self.chars_sorted = chars_sorted
self.unigram_transitions = unigram_transitions
self.bigram_transitions = bigram_transitions
self.trigram_transitions = trigram_transitions
self.lengths = np.array(lengths)
self.probabilities = np.array(probabilities)
self.random_gen = random_gen
self.alphabet_length = alphabet_length
self.split = split
self.sample_batch_size = sample_batch_size
def sample(self, size: int) -> np.ndarray:
lengths = self.sample_lengths(
size=int(
self.sample_batch_size * np.ceil(size / self.sample_batch_size)
)
)
max_length = int(2 * lengths.max())
return generate_jax(
n_examples=size,
max_length=max_length,
samples_per_batch=self.sample_batch_size,
seed=0,
alphabet_length=self.alphabet_length,
chars_sorted=self.chars_sorted,
trigram_transitions=self.trigram_transitions,
bigram_transitions=self.bigram_transitions,
unigram_transitions=self.unigram_transitions,
optimal_split=self.split,
lengths=jnp.array(lengths),
)
def sample_lengths(self, size: int) -> np.ndarray:
random_cumulative = self.random_gen.uniform(size=(size, 1))
probability = self.probabilities[None, :]
to_eval = np.where(
probability - random_cumulative > 0,
-np.inf,
probability - random_cumulative,
)
lower_indices = np.argmax(
to_eval,
axis=1,
)
lower_bound_length = np.squeeze(
np.take_along_axis(
np.tile(self.lengths[None, :], reps=[size, 1]),
indices=lower_indices[:, None],
axis=1,
)
)
upper_bound_lengths = (
np.squeeze(
np.take_along_axis(
np.tile(self.lengths[None, :], reps=[size, 1]),
indices=lower_indices[:, None] + 1,
axis=1,
)
)
+ 1
)
return self.random_gen.integers(
low=lower_bound_length, high=upper_bound_lengths
)
def train_trigram_text(
noise: float,
data: pd.DataFrame,
col: str,
privacy_unit_col: str,
weights_col: str,
private_col: str,
random_gen: np.random.Generator,
max_multiplicity: float = 1.0,
max_length: float = 1.0,
char_list: t.Optional[t.List[int]] = None,
) -> t.Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]:
"""Given a dataframe with the sarus framework
(protected_entity, weights..) updates the trigram model
using the specified values in col
"""
all_chars, frequency = np.unique(
[
ord(char)
for element in data[col].values.astype(str)
for char in element
],
return_counts=True,
)
sorted_indexes = np.argsort(frequency)[::-1]
chars_sorted = all_chars[sorted_indexes]
# add elements that might be in chars_list
# but not in text
if char_list is not None:
for char in char_list:
if char not in chars_sorted:
chars_sorted = np.append(chars_sorted, char)
# add 0 character corresponding to the empty start of the
# trigram, in unicode 0 does not code any character so
# it is fine
chars_sorted = np.insert(chars_sorted, 0, 0)
# find best alphabet_length
alphabet_length = best_alphabet_length(len(chars_sorted))
# Create the vocabulary
vocab = {}
split = optimal_alphabet_split(
n_total_chars=len(chars_sorted), alphabet_length=alphabet_length
)
for i, value in enumerate(chars_sorted):
vocab[value] = encode_unicode_char(
character_index=i,
optimal_split=split,
alphabet_length=alphabet_length,
)
def _get_index_trigram(
int1: int,
int2: int,
int3: int,
) -> int:
"""returns index in the array for a given trigram"""
return (alphabet_length**2) * int1 + alphabet_length * int2 + int3
def _apply_trigram(sentence: t.List[str]) -> t.List[int]:
"""Transforms a sentence in the list of the trigrams it contains,
this is done in two steps:
- first each sentence is encoded into the alphabet learned in vocab
- then trigrams are computed with this alphabet"""
chars: t.List[int] = []
for el in sentence[:max_training_trigram_length]:
val = vocab[ord(el)]
if isinstance(val, int):
chars.append(val)
else:
chars.extend(val)
trigrams = []
# insert beginning value
chars.insert(0, 0)
chars.insert(0, 0)
for el1, el2, el3 in zip(chars[:-2], chars[1:-1], chars[2:]):
try:
index = _get_index_trigram(el1, el2, el3)
except KeyError:
pass
else:
trigrams.append(index)
return trigrams
data[col] = data[col].astype(str).apply(_apply_trigram)
data = data.explode(col)
histograms = private_histogram(
data=data,
data_col=col,
private_col=private_col,
weight_col=weights_col,
user_col=privacy_unit_col,
noise=noise,
max_multiplicity=np.sqrt(max_multiplicity * max_length),
categories=np.linspace(
0, alphabet_length**3 - 1, alphabet_length**3, dtype=int
).tolist(),
noise_kind=NoiseKind.GAUSSIAN,
random_generator=random_gen,
)
threshold = 5 * noise * np.sqrt(max_multiplicity * max_length)
counts = np.array(list(histograms.values()))
transitions = np.stack(np.split(counts, alphabet_length**2))
transitions = np.where(transitions < threshold, 0, transitions)
transitions[:, 0] = 0 # forbid to restart
# now computes bigrams from counts
bigram_counts = transitions.sum(axis=1)
bigram_transitions = np.stack(np.split(bigram_counts, alphabet_length))
bigram_transitions[:, 0] = 0
# compute unigram from bigram
unigram_count = bigram_transitions.sum(axis=1)
unigram_transitions = unigram_count[1:] / unigram_count[1:].sum()
if np.any(np.isnan(unigram_transitions)):
unigram_transitions = np.full(
fill_value=1 / len(unigram_transitions),
shape=len(unigram_transitions),
)
bigram_transitions = bigram_transitions / bigram_transitions.sum(
axis=1, keepdims=True
)
trigram_transitions = transitions / transitions.sum(axis=1, keepdims=True)
trigram_transitions = jnp.array(trigram_transitions)
bigram_transitions = jnp.array(bigram_transitions)
unigram_transitions = jnp.array(unigram_transitions)
return (
jnp.array(chars_sorted),
trigram_transitions,
bigram_transitions,
unigram_transitions,
split,
alphabet_length,
)
def best_alphabet_length(n_chars: int) -> int:
"""Given N different characters, and an alphabet of length M,
there will be k values of the alphabet coding for one character
and M-k coding for a half character. The worst case scenario corresponds
to all values coding for half except the trigram start token, so
N=(M-1)**2+1.
The best case scenario to M=N. We want as many values coding for one
character within a reasonable limit."""
if n_chars <= 256:
return n_chars
# when all unicode chars, we want to take the worst case where
# all tokens code for half, that correspond to M=386
# in the middle we take an affine relation
# what follows is the solution of f(256)=256 and f(149186)=387
return int(
np.ceil(
(387 - 256) / (149186 - 256) * n_chars
+ 256 * (1 - (387 - 256) / (149186 - 256))
)
)
def optimal_alphabet_split(
n_total_chars: int, alphabet_length: int = 100
) -> int:
"""Given N different characters, and an alphabet of length M, we want
to find k such that k values of the alphabet code for one character
and M-k- code for a half character. So we want to solve the equation:
k+(M-k)*(M-1)=N. Note that we take M-1 because we do not want to use
the token of the vocabulary that is used to specify the beginning of the
trigram. As k is an integer, we take the floor of the result.
"""
return int(
np.floor(
(alphabet_length * (alphabet_length - 1) - n_total_chars)
/ (alphabet_length - 2)
)
)
def encode_unicode_char(
character_index: int, optimal_split: int, alphabet_length: int
) -> t.Union[int, t.Tuple[int, int]]:
"""This method encodes the character index depending on whether it
is bigger that optimal split.
Every character that is smaller than optimal split is encoded by
its index position, every character that has a larger index,
is encoded by two values:
- the value between optimal split and alphabet length
- the reminder
We do not use the token 0, as it is corresponds to the start
token of the trigrams.
"""
if character_index - optimal_split < 0:
return character_index
first_val = (character_index - optimal_split) // (
alphabet_length - 1
) + optimal_split
second_val = (character_index - optimal_split) % (alphabet_length - 1) + 1
return (first_val, second_val)
# -------JAX METHODS FOR TRIGRAM GENERATION---------------
@jax_cleanup
def generate_jax(
n_examples: int,
max_length: int,
seed: int,
alphabet_length: int,
trigram_transitions: jnp.ndarray,
bigram_transitions: jnp.ndarray,
unigram_transitions: jnp.ndarray,
optimal_split: int,
chars_sorted: jnp.ndarray,
lengths: jnp.ndarray,
samples_per_batch: int,
) -> np.ndarray:
sampled_size = 0
idx = 0
sampled = []
key = jax.random.key(seed)
while sampled_size < n_examples:
generated_array = generate_array(
n_examples=samples_per_batch,
max_length=max_length,
key=key,
alphabet_length=alphabet_length,
trigram_transitions=trigram_transitions,
bigram_transitions=bigram_transitions,
unigram_transitions=unigram_transitions,
)
decoded_array = decode_array(
array=generated_array,
lengths=lengths[idx : idx + samples_per_batch],
optimal_split=optimal_split,
chars_sorted=chars_sorted,
alphabet_length=alphabet_length,
)
decoded_array = (
np.array(decoded_array) # type:ignore[assignment]
.astype(np.uint32)
.view("U" + str(max_length))
.astype(str)
.astype(str)
.squeeze()
)
sampled.append(decoded_array)
sampled_size += samples_per_batch
idx += samples_per_batch
key, _ = jax.random.split(key)
return np.concatenate(sampled)[:n_examples]
@jax_dataclasses.dataclass
class GenerationState:
gen_arr: jnp.ndarray
rng_key: jax.Array
step: int
def generate_array(
n_examples: int,
max_length: int,
key: jax.Array,
alphabet_length: int,
trigram_transitions: jnp.ndarray,
bigram_transitions: jnp.ndarray,
unigram_transitions: jnp.ndarray,
) -> jnp.ndarray:
vocab = jnp.array(
np.linspace(0, alphabet_length - 1, alphabet_length, dtype=np.ubyte)
)
def _get_index_bigram(int1: jax.Array, int2: jax.Array) -> jax.Array:
"Returns index for a given bigram"
return len(vocab) * vocab[int1] + vocab[int2]
def _get_index_unigram(int1: jax.Array) -> jax.Array:
return vocab[int1]
def generate_next_value(state: GenerationState) -> GenerationState:
"""Generates next value of the alphabet, the arguments organization is done
to fit the jax.lax.scan method:
- carry: contains info to generate next: the two previous
indices in the vocab and the rng key
- x: kept for signature but unused
"""
index = _get_index_bigram(
state.gen_arr[state.step - 2], state.gen_arr[state.step - 1]
)
def value_from_bigram(state: GenerationState) -> jnp.ndarray:
index = _get_index_unigram(state.gen_arr[state.step - 1])
return jax.lax.cond( # type:ignore[no-any-return]
jnp.isnan(bigram_transitions[index]).any(),
lambda x: jax.random.choice(
x, a=vocab[1:], p=unigram_transitions
),
lambda x: jax.random.choice(
x,
a=vocab,
p=bigram_transitions[index],
),
state.rng_key,
)
def value_from_trigram(state: GenerationState) -> jnp.ndarray:
return jax.random.choice(
key=state.rng_key, a=vocab, p=trigram_transitions[index]
)
is_nan = jnp.isnan(trigram_transitions[index]).any()
new_value = jax.lax.cond(
is_nan, value_from_bigram, value_from_trigram, state
)
key, _ = jax.random.split(state.rng_key)
return GenerationState(
gen_arr=state.gen_arr.at[state.step].set(new_value),
rng_key=key,
step=state.step + 1,
)
def generate_example(max_length: int, key: jax.Array) -> jnp.ndarray:
"""Calls successively generate_next_value via the lax.while
method to generate one line of alphabet values"""
init_state = GenerationState(
gen_arr=jnp.zeros(shape=(max_length + 2,), dtype=np.ubyte),
step=2,
rng_key=key,
)
def cond_fun(state: GenerationState) -> bool:
return state.step < max_length + 2
state = jax.lax.while_loop(
body_fun=generate_next_value,
cond_fun=cond_fun,
init_val=init_state,
)
return state.gen_arr[2:]
def generate_vmap(
n_examples: int, max_length: int, seed: int
) -> jnp.ndarray:
"""Vmap a single example generation"""
key = jax.random.key(seed)
keys = jax.random.split(key, num=n_examples)
return jax.vmap(generate_example, in_axes=(None, 0), out_axes=0)(
max_length, keys
)
def generate_scan(
n_examples: int, max_length: int, key: jax.Array
) -> jnp.ndarray:
"""Rather than using vmap, goes through the scan procedure
to generate many lines. This bizarrely can be jitted quickly,
while vmap takes a long time to be jitted"""
def single_pass(
carry: GenerationState, x: jax.Array
) -> t.Tuple[GenerationState, jnp.ndarray]:
out = generate_example(max_length=max_length, key=x)
return carry, out
keys = jax.random.split(key, num=n_examples)
carry, out = jax.lax.scan( # type:ignore
single_pass, init=None, xs=keys, length=n_examples
)
return t.cast(jax.Array, out)
generate_scan_jit = jax.jit(generate_scan, static_argnums=(0, 1))
value: jax.Array = generate_scan_jit(
n_examples=n_examples, max_length=max_length, key=key
)
jaxjit_cleanup(generate_scan_jit)
return value
@jax_dataclasses.dataclass
class DecodeState:
max_length: jax.Array # int
index_generated: jax.Array # int
index_decoded: jax.Array # int
decoded: jnp.ndarray
generated: jnp.ndarray
def decode_array(
array: jnp.ndarray,
lengths: jnp.ndarray,
chars_sorted: jnp.ndarray,
optimal_split: int,
alphabet_length: int,
) -> jnp.ndarray:
"""Decode an array generated with the alphabet
in the actual characters."""
def replace_unicode_values(
array: jnp.ndarray,
max_length: jnp.ndarray, # single value
) -> jnp.ndarray:
"""Method that scans one example horizontally
and replaces each value/pair of values by
a unicode character. Max length corresponds
to the maximal length in the generated
(the converted that is returned will have
max_length/2)"""
def stop_condition(state: DecodeState) -> jax.Array:
"""Stops when either max_length
or max_length/2 is reached depending
on the array generated or decoded
"""
return jnp.logical_and(
jnp.greater(
jnp.array(state.max_length), state.index_generated
),
jnp.greater(
jnp.array(state.max_length // 2), state.index_decoded
),
)
def body_condition(state: DecodeState) -> DecodeState:
current_val = state.generated[state.index_generated]
def smaller_optimal_split(
state: DecodeState,
) -> t.Tuple[jax.Array, jax.Array, jax.Array]:
"""When the first value is smaller than optimal
split, the value corresponds to the character"""
return (
state.index_generated + 1,
state.index_decoded + 1,
state.decoded.at[state.index_decoded].set(
chars_sorted[state.generated[state.index_generated]]
),
)
def larger_optimal_split(
state: DecodeState,
) -> t.Tuple[jax.Array, jax.Array, jax.Array]:
"""When the first value is equal or larger than
optimal split, we need two values to map a character
character"""
first = state.generated[state.index_generated] * (
alphabet_length - 2
) + optimal_split * (3 - alphabet_length)
second = state.generated[state.index_generated + 1] - 1
value = chars_sorted[first + second]
return (
state.index_generated + 2,
state.index_decoded + 1,
state.decoded.at[state.index_decoded].set(value),
)
(
index_generated,
index_decoded,
decoded,
) = jax.lax.cond(
current_val < optimal_split,
smaller_optimal_split,
larger_optimal_split,
state,
)
return DecodeState(
max_length=state.max_length,
index_generated=index_generated,
index_decoded=index_decoded,
decoded=decoded,
generated=state.generated,
)
init_state = DecodeState(
max_length=max_length,
index_decoded=jnp.array(0),
index_generated=jnp.array(0),
decoded=jnp.zeros_like(array),
generated=array,
)
final_state = jax.lax.while_loop(
stop_condition,
body_fun=body_condition,
init_val=init_state,
)
return final_state.decoded
return jax.vmap(replace_unicode_values, in_axes=(0, 0), out_axes=0)(
array, 2 * lengths
)