Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import os
import typing as t

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from sarus_statistics.ops.histograms import NoiseKind, private_histogram

import sarus_synthetic_data.independent_generator.jax_dataclasses as jax_dataclasses  # noqa : E501
from sarus_synthetic_data.shared.memory_utils import (
    jax_cleanup,
    jaxjit_cleanup,
)

ALPHABET_LENGTH = 387  # number needed to code all unicode characters
MAX_TRIGRAM_TRAINING_LENGTH = "max_training_trigram_length"
max_training_trigram_length = int(
    os.environ.get(MAX_TRIGRAM_TRAINING_LENGTH, 1000)
)


class TextGenerator:
    """Text Generator based on trigrams"""

    def __init__(
        self,
        sample_batch_size: int,
        lengths: t.List[int],
        probabilities: t.List[float],
        chars_sorted: jnp.ndarray,
        unigram_transitions: jnp.ndarray,
        bigram_transitions: jnp.ndarray,
        trigram_transitions: jnp.ndarray,
        alphabet_length: int,
        split: int,
        random_gen: np.random.Generator,
    ) -> None:
        self.chars_sorted = chars_sorted
        self.unigram_transitions = unigram_transitions
        self.bigram_transitions = bigram_transitions
        self.trigram_transitions = trigram_transitions
        self.lengths = np.array(lengths)
        self.probabilities = np.array(probabilities)
        self.random_gen = random_gen
        self.alphabet_length = alphabet_length
        self.split = split
        self.sample_batch_size = sample_batch_size

    def sample(self, size: int) -> np.ndarray:
        lengths = self.sample_lengths(
            size=int(
                self.sample_batch_size * np.ceil(size / self.sample_batch_size)
            )
        )
        max_length = int(2 * lengths.max())
        return generate_jax(
            n_examples=size,
            max_length=max_length,
            samples_per_batch=self.sample_batch_size,
            seed=0,
            alphabet_length=self.alphabet_length,
            chars_sorted=self.chars_sorted,
            trigram_transitions=self.trigram_transitions,
            bigram_transitions=self.bigram_transitions,
            unigram_transitions=self.unigram_transitions,
            optimal_split=self.split,
            lengths=jnp.array(lengths),
        )

    def sample_lengths(self, size: int) -> np.ndarray:
        random_cumulative = self.random_gen.uniform(size=(size, 1))
        probability = self.probabilities[None, :]
        to_eval = np.where(
            probability - random_cumulative > 0,
            -np.inf,
            probability - random_cumulative,
        )
        lower_indices = np.argmax(
            to_eval,
            axis=1,
        )
        lower_bound_length = np.squeeze(
            np.take_along_axis(
                np.tile(self.lengths[None, :], reps=[size, 1]),
                indices=lower_indices[:, None],
                axis=1,
            )
        )
        upper_bound_lengths = (
            np.squeeze(
                np.take_along_axis(
                    np.tile(self.lengths[None, :], reps=[size, 1]),
                    indices=lower_indices[:, None] + 1,
                    axis=1,
                )
            )
            + 1
        )
        return self.random_gen.integers(
            low=lower_bound_length, high=upper_bound_lengths
        )


def train_trigram_text(
    noise: float,
    data: pd.DataFrame,
    col: str,
    privacy_unit_col: str,
    weights_col: str,
    private_col: str,
    random_gen: np.random.Generator,
    max_multiplicity: float = 1.0,
    max_length: float = 1.0,
    char_list: t.Optional[t.List[int]] = None,
) -> t.Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]:
    """Given a dataframe with the sarus framework
    (protected_entity, weights..) updates the trigram model
    using the specified values in col
    """
    all_chars, frequency = np.unique(
        [
            ord(char)
            for element in data[col].values.astype(str)
            for char in element
        ],
        return_counts=True,
    )

    sorted_indexes = np.argsort(frequency)[::-1]
    chars_sorted = all_chars[sorted_indexes]

    # add elements that might be in chars_list
    # but not in text
    if char_list is not None:
        for char in char_list:
            if char not in chars_sorted:
                chars_sorted = np.append(chars_sorted, char)

    # add 0 character corresponding to the empty start of the
    # trigram, in unicode 0 does not code any character so
    # it is fine
    chars_sorted = np.insert(chars_sorted, 0, 0)

    # find best alphabet_length
    alphabet_length = best_alphabet_length(len(chars_sorted))

    # Create the vocabulary
    vocab = {}
    split = optimal_alphabet_split(
        n_total_chars=len(chars_sorted), alphabet_length=alphabet_length
    )
    for i, value in enumerate(chars_sorted):
        vocab[value] = encode_unicode_char(
            character_index=i,
            optimal_split=split,
            alphabet_length=alphabet_length,
        )

    def _get_index_trigram(
        int1: int,
        int2: int,
        int3: int,
    ) -> int:
        """returns index in the array for a given trigram"""
        return (alphabet_length**2) * int1 + alphabet_length * int2 + int3

    def _apply_trigram(sentence: t.List[str]) -> t.List[int]:
        """Transforms a sentence in the list of the trigrams it contains,
        this is done in two steps:
        - first each sentence is encoded into the alphabet learned in vocab
        - then trigrams are computed with this alphabet"""
        chars: t.List[int] = []
        for el in sentence[:max_training_trigram_length]:
            val = vocab[ord(el)]
            if isinstance(val, int):
                chars.append(val)
            else:
                chars.extend(val)

        trigrams = []
        # insert beginning value
        chars.insert(0, 0)
        chars.insert(0, 0)
        for el1, el2, el3 in zip(chars[:-2], chars[1:-1], chars[2:]):
            try:
                index = _get_index_trigram(el1, el2, el3)
            except KeyError:
                pass
            else:
                trigrams.append(index)
        return trigrams

    data[col] = data[col].astype(str).apply(_apply_trigram)
    data = data.explode(col)
    histograms = private_histogram(
        data=data,
        data_col=col,
        private_col=private_col,
        weight_col=weights_col,
        user_col=privacy_unit_col,
        noise=noise,
        max_multiplicity=np.sqrt(max_multiplicity * max_length),
        categories=np.linspace(
            0, alphabet_length**3 - 1, alphabet_length**3, dtype=int
        ).tolist(),
        noise_kind=NoiseKind.GAUSSIAN,
        random_generator=random_gen,
    )
    threshold = 5 * noise * np.sqrt(max_multiplicity * max_length)
    counts = np.array(list(histograms.values()))
    transitions = np.stack(np.split(counts, alphabet_length**2))
    transitions = np.where(transitions < threshold, 0, transitions)
    transitions[:, 0] = 0  # forbid to restart
    # now computes bigrams from counts
    bigram_counts = transitions.sum(axis=1)
    bigram_transitions = np.stack(np.split(bigram_counts, alphabet_length))
    bigram_transitions[:, 0] = 0

    # compute unigram from bigram
    unigram_count = bigram_transitions.sum(axis=1)
    unigram_transitions = unigram_count[1:] / unigram_count[1:].sum()
    if np.any(np.isnan(unigram_transitions)):
        unigram_transitions = np.full(
            fill_value=1 / len(unigram_transitions),
            shape=len(unigram_transitions),
        )
    bigram_transitions = bigram_transitions / bigram_transitions.sum(
        axis=1, keepdims=True
    )
    trigram_transitions = transitions / transitions.sum(axis=1, keepdims=True)
    trigram_transitions = jnp.array(trigram_transitions)
    bigram_transitions = jnp.array(bigram_transitions)
    unigram_transitions = jnp.array(unigram_transitions)

    return (
        jnp.array(chars_sorted),
        trigram_transitions,
        bigram_transitions,
        unigram_transitions,
        split,
        alphabet_length,
    )


def best_alphabet_length(n_chars: int) -> int:
    """Given N different characters, and an alphabet of length M,
    there will be k values of the alphabet coding for one character
    and M-k coding for a half character. The worst case scenario corresponds
    to all values coding for half except the trigram start token, so
    N=(M-1)**2+1.
    The best case scenario to M=N. We want as many values coding for one
    character within a reasonable limit."""

    if n_chars <= 256:
        return n_chars

    # when all unicode chars, we want to take the worst case where
    # all tokens code for half, that correspond to M=386
    # in the middle we take an affine relation
    # what follows is the solution of f(256)=256 and f(149186)=387

    return int(
        np.ceil(
            (387 - 256) / (149186 - 256) * n_chars
            + 256 * (1 - (387 - 256) / (149186 - 256))
        )
    )


def optimal_alphabet_split(
    n_total_chars: int, alphabet_length: int = 100
) -> int:
    """Given N different characters, and an alphabet of length M, we want
    to find k such that k values of the alphabet code for one character
    and M-k- code for a half character. So we want to solve the equation:
    k+(M-k)*(M-1)=N. Note that we take M-1 because we do not want to use
    the token of the vocabulary that is used to specify the beginning of the
    trigram. As k is an integer, we take the floor of the result.
    """
    return int(
        np.floor(
            (alphabet_length * (alphabet_length - 1) - n_total_chars)
            / (alphabet_length - 2)
        )
    )


def encode_unicode_char(
    character_index: int, optimal_split: int, alphabet_length: int
) -> t.Union[int, t.Tuple[int, int]]:
    """This method encodes the character index depending on whether it
    is bigger that optimal split.
    Every character that is smaller than optimal split is encoded by
    its index position, every character that has a larger index,
    is encoded by two values:
    - the value between optimal split and alphabet length
    - the reminder
    We do not use the token 0, as it is corresponds to the start
    token of the trigrams.
    """

    if character_index - optimal_split < 0:
        return character_index

    first_val = (character_index - optimal_split) // (
        alphabet_length - 1
    ) + optimal_split
    second_val = (character_index - optimal_split) % (alphabet_length - 1) + 1
    return (first_val, second_val)


# -------JAX METHODS FOR TRIGRAM GENERATION---------------


@jax_cleanup
def generate_jax(
    n_examples: int,
    max_length: int,
    seed: int,
    alphabet_length: int,
    trigram_transitions: jnp.ndarray,
    bigram_transitions: jnp.ndarray,
    unigram_transitions: jnp.ndarray,
    optimal_split: int,
    chars_sorted: jnp.ndarray,
    lengths: jnp.ndarray,
    samples_per_batch: int,
) -> np.ndarray:
    sampled_size = 0
    idx = 0
    sampled = []
    key = jax.random.key(seed)
    while sampled_size < n_examples:
        generated_array = generate_array(
            n_examples=samples_per_batch,
            max_length=max_length,
            key=key,
            alphabet_length=alphabet_length,
            trigram_transitions=trigram_transitions,
            bigram_transitions=bigram_transitions,
            unigram_transitions=unigram_transitions,
        )

        decoded_array = decode_array(
            array=generated_array,
            lengths=lengths[idx : idx + samples_per_batch],
            optimal_split=optimal_split,
            chars_sorted=chars_sorted,
            alphabet_length=alphabet_length,
        )
        decoded_array = (
            np.array(decoded_array)  # type:ignore[assignment]
            .astype(np.uint32)
            .view("U" + str(max_length))
            .astype(str)
            .astype(str)
            .squeeze()
        )
        sampled.append(decoded_array)
        sampled_size += samples_per_batch
        idx += samples_per_batch
        key, _ = jax.random.split(key)
    return np.concatenate(sampled)[:n_examples]


@jax_dataclasses.dataclass
class GenerationState:
    gen_arr: jnp.ndarray
    rng_key: jax.Array
    step: int


def generate_array(
    n_examples: int,
    max_length: int,
    key: jax.Array,
    alphabet_length: int,
    trigram_transitions: jnp.ndarray,
    bigram_transitions: jnp.ndarray,
    unigram_transitions: jnp.ndarray,
) -> jnp.ndarray:
    vocab = jnp.array(
        np.linspace(0, alphabet_length - 1, alphabet_length, dtype=np.ubyte)
    )

    def _get_index_bigram(int1: jax.Array, int2: jax.Array) -> jax.Array:
        "Returns index for a given bigram"
        return len(vocab) * vocab[int1] + vocab[int2]

    def _get_index_unigram(int1: jax.Array) -> jax.Array:
        return vocab[int1]

    def generate_next_value(state: GenerationState) -> GenerationState:
        """Generates next value of the alphabet, the arguments organization is done
        to fit the jax.lax.scan method:
         - carry: contains info to generate next: the two previous
         indices in the vocab and the rng key
         - x: kept for signature but unused
        """

        index = _get_index_bigram(
            state.gen_arr[state.step - 2], state.gen_arr[state.step - 1]
        )

        def value_from_bigram(state: GenerationState) -> jnp.ndarray:
            index = _get_index_unigram(state.gen_arr[state.step - 1])
            return jax.lax.cond(  # type:ignore[no-any-return]
                jnp.isnan(bigram_transitions[index]).any(),
                lambda x: jax.random.choice(
                    x, a=vocab[1:], p=unigram_transitions
                ),
                lambda x: jax.random.choice(
                    x,
                    a=vocab,
                    p=bigram_transitions[index],
                ),
                state.rng_key,
            )

        def value_from_trigram(state: GenerationState) -> jnp.ndarray:
            return jax.random.choice(
                key=state.rng_key, a=vocab, p=trigram_transitions[index]
            )

        is_nan = jnp.isnan(trigram_transitions[index]).any()
        new_value = jax.lax.cond(
            is_nan, value_from_bigram, value_from_trigram, state
        )
        key, _ = jax.random.split(state.rng_key)
        return GenerationState(
            gen_arr=state.gen_arr.at[state.step].set(new_value),
            rng_key=key,
            step=state.step + 1,
        )

    def generate_example(max_length: int, key: jax.Array) -> jnp.ndarray:
        """Calls successively generate_next_value via the lax.while
        method to generate one line of alphabet values"""
        init_state = GenerationState(
            gen_arr=jnp.zeros(shape=(max_length + 2,), dtype=np.ubyte),
            step=2,
            rng_key=key,
        )

        def cond_fun(state: GenerationState) -> bool:
            return state.step < max_length + 2

        state = jax.lax.while_loop(
            body_fun=generate_next_value,
            cond_fun=cond_fun,
            init_val=init_state,
        )
        return state.gen_arr[2:]

    def generate_vmap(
        n_examples: int, max_length: int, seed: int
    ) -> jnp.ndarray:
        """Vmap a single example generation"""
        key = jax.random.key(seed)
        keys = jax.random.split(key, num=n_examples)
        return jax.vmap(generate_example, in_axes=(None, 0), out_axes=0)(
            max_length, keys
        )

    def generate_scan(
        n_examples: int, max_length: int, key: jax.Array
    ) -> jnp.ndarray:
        """Rather than using vmap, goes through the scan procedure
        to generate many lines. This bizarrely can be jitted quickly,
        while vmap takes a long time to be jitted"""

        def single_pass(
            carry: GenerationState, x: jax.Array
        ) -> t.Tuple[GenerationState, jnp.ndarray]:
            out = generate_example(max_length=max_length, key=x)
            return carry, out

        keys = jax.random.split(key, num=n_examples)
        carry, out = jax.lax.scan(  # type:ignore
            single_pass, init=None, xs=keys, length=n_examples
        )
        return t.cast(jax.Array, out)

    generate_scan_jit = jax.jit(generate_scan, static_argnums=(0, 1))
    value: jax.Array = generate_scan_jit(
        n_examples=n_examples, max_length=max_length, key=key
    )
    jaxjit_cleanup(generate_scan_jit)
    return value


@jax_dataclasses.dataclass
class DecodeState:
    max_length: jax.Array  # int
    index_generated: jax.Array  # int
    index_decoded: jax.Array  # int
    decoded: jnp.ndarray
    generated: jnp.ndarray


def decode_array(
    array: jnp.ndarray,
    lengths: jnp.ndarray,
    chars_sorted: jnp.ndarray,
    optimal_split: int,
    alphabet_length: int,
) -> jnp.ndarray:
    """Decode an array generated with the alphabet
    in the actual characters."""

    def replace_unicode_values(
        array: jnp.ndarray,
        max_length: jnp.ndarray,  # single value
    ) -> jnp.ndarray:
        """Method that scans one example horizontally
        and replaces each value/pair of values by
        a unicode character. Max length corresponds
        to the maximal length in the generated
        (the converted that is returned will have
        max_length/2)"""

        def stop_condition(state: DecodeState) -> jax.Array:
            """Stops when either max_length
            or max_length/2 is reached depending
            on the array generated or decoded
            """
            return jnp.logical_and(
                jnp.greater(
                    jnp.array(state.max_length), state.index_generated
                ),
                jnp.greater(
                    jnp.array(state.max_length // 2), state.index_decoded
                ),
            )

        def body_condition(state: DecodeState) -> DecodeState:
            current_val = state.generated[state.index_generated]

            def smaller_optimal_split(
                state: DecodeState,
            ) -> t.Tuple[jax.Array, jax.Array, jax.Array]:
                """When the first value is smaller than optimal
                split, the value corresponds to the character"""
                return (
                    state.index_generated + 1,
                    state.index_decoded + 1,
                    state.decoded.at[state.index_decoded].set(
                        chars_sorted[state.generated[state.index_generated]]
                    ),
                )

            def larger_optimal_split(
                state: DecodeState,
            ) -> t.Tuple[jax.Array, jax.Array, jax.Array]:
                """When the first value is equal or larger than
                optimal split, we need two values to map a character
                character"""

                first = state.generated[state.index_generated] * (
                    alphabet_length - 2
                ) + optimal_split * (3 - alphabet_length)
                second = state.generated[state.index_generated + 1] - 1
                value = chars_sorted[first + second]

                return (
                    state.index_generated + 2,
                    state.index_decoded + 1,
                    state.decoded.at[state.index_decoded].set(value),
                )

            (
                index_generated,
                index_decoded,
                decoded,
            ) = jax.lax.cond(
                current_val < optimal_split,
                smaller_optimal_split,
                larger_optimal_split,
                state,
            )
            return DecodeState(
                max_length=state.max_length,
                index_generated=index_generated,
                index_decoded=index_decoded,
                decoded=decoded,
                generated=state.generated,
            )

        init_state = DecodeState(
            max_length=max_length,
            index_decoded=jnp.array(0),
            index_generated=jnp.array(0),
            decoded=jnp.zeros_like(array),
            generated=array,
        )
        final_state = jax.lax.while_loop(
            stop_condition,
            body_fun=body_condition,
            init_val=init_state,
        )
        return final_state.decoded

    return jax.vmap(replace_unicode_values, in_axes=(0, 0), out_axes=0)(
        array, 2 * lengths
    )