Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_statistics / sarus_statistics / ops / text_marginal.py
Size: Mime:
import warnings
from typing import Optional, cast

import numpy as np
import pandas as pd

from sarus_statistics.ops.adaptive_quantiles.local import feature_quantiles

try:
    from transformers import GPT2TokenizerFast
except ModuleNotFoundError:
    warnings.warn(
        "transformers Module not found, max_length operations not available "
    )

MAX_LENGTH = 300
MIN_PROB = 0.9


# pylint: disable=too-many-arguments
def max_length(
    data: pd.DataFrame,
    noise: float,
    col_to_eval: str,
    user_col: str,
    private_col: str,
    weight_col: str,
    max_multiplicity: float,
    sampling_ratio: float,
    nb_quantiles: int = 20,
    random_generator: Optional[np.random.Generator] = None,
) -> int:
    """Compute the max length for sequences to bound sensitivity"""
    random_generator = (
        random_generator
        if random_generator is not None
        else np.random.default_rng(random_generator)
    )
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    data = cast(pd.DataFrame, data.sample(frac=sampling_ratio))
    output = tokenizer(
        data[col_to_eval].values.tolist(),
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
    )
    counts = np.sum(output["attention_mask"], axis=1)
    data["counts"] = counts
    quantiles = feature_quantiles(
        data,
        data_col="counts",
        user_col=user_col,
        private_col=private_col,
        weight_col=weight_col,
        noise=noise,
        sampling_ratio=None,
        nb_quantiles=nb_quantiles,
        bounds=(1, MAX_LENGTH),
        max_multiplicity=max_multiplicity,
        random_generator=random_generator,
    )
    prob_value = np.array(sorted(quantiles.keys()))[
        np.array(sorted(quantiles.keys())) > MIN_PROB
    ][0]
    result = min(MAX_LENGTH, int(quantiles[prob_value]))
    return result