Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_statistics / tests / unit / test_ops / test_quantiles_local.py
Size: Mime:
import numpy as np
import pandas as pd
import pytest

from sarus_statistics.ops.adaptive_quantiles.local import feature_quantiles

np.random.seed(0)
NOISE = 1e-5
BOUNDS = (0.0, 10.0)
PU_COLUMN = "sarus_protected_entity"
PRIVATE_COLUMN = "sarus_is_public"
WEIGHTS = "sarus_weights"
COL_TO_EVAL = "age"
N_QUANTILES = 1


@pytest.fixture()
def data():
    return pd.DataFrame(
        data={
            PU_COLUMN: ["A", "B", "C"],
            WEIGHTS: [
                4.0,
                2.0,
                1.0,
            ],
            COL_TO_EVAL: [0.0, 3.0, 10.0],
            PRIVATE_COLUMN: [False, False, False],
        }
    )


@pytest.mark.parametrize(
    "max_multiplicity,expected_proba",
    [(1, [0.0, 2 / 3, 1.0]), (2, [0.0, 0.8, 1.0]), (4, [0.0, 6 / 7, 1.0])],
)
def test_quantiles(data, expected_proba, max_multiplicity):

    # test with little noise, should get almost as expected real
    quantiles = feature_quantiles(
        data=data,
        noise=NOISE,
        sampling_ratio=None,
        max_multiplicity=max_multiplicity,
        user_col=PU_COLUMN,
        private_col=PRIVATE_COLUMN,
        data_col=COL_TO_EVAL,
        nb_quantiles=N_QUANTILES,
        bounds=BOUNDS,
        weight_col=WEIGHTS,
    )
    assert np.allclose(list(quantiles.keys()), expected_proba, rtol=0.01)