Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import typing as t

import numpy as np


def gen_from_cumulative(
    probabilities: t.List[float],
    quantile_values: t.Union[t.List[float], t.List[int]],
    size: int,
    random_gen: np.random.Generator,
) -> np.ndarray:
    """Generate from the input cumulative distribution a batch
    of 'size' examples"""
    random_cumulative = random_gen.uniform(size=(size, 1))
    probability = np.array(probabilities)[None, :]
    to_eval = np.where(
        random_cumulative - probability < 0,
        np.inf,
        random_cumulative - probability,
    )
    lower_bound_index = np.argmin(
        to_eval,
        axis=1,
    )

    lower_quant = np.take_along_axis(
        np.tile(np.array(quantile_values)[None, :], reps=[size, 1]),
        indices=lower_bound_index[:, None],
        axis=1,
    )
    lower_prob = np.take_along_axis(
        np.tile(probability, reps=[size, 1]),
        indices=lower_bound_index[:, None],
        axis=1,
    )

    upper_quant = np.take_along_axis(
        np.tile(np.array(quantile_values)[None, :], reps=[size, 1]),
        indices=lower_bound_index[:, None] + 1,
        axis=1,
    )
    upper_prob = np.take_along_axis(
        np.tile(probability, reps=[size, 1]),
        indices=lower_bound_index[:, None] + 1,
        axis=1,
    )
    dp = upper_prob - lower_prob
    dq = upper_quant - lower_quant
    return dq / dp * (random_cumulative - lower_prob) + lower_quant  # type:ignore[no-any-return]