Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / pandas   python

Repository URL to install this package:

Version: 1.1.1 

/ tests / groupby / test_sample.py

import pytest

from pandas import DataFrame, Index, Series
import pandas._testing as tm


@pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)])
def test_groupby_sample_balanced_groups_shape(n, frac):
    values = [1] * 10 + [2] * 10
    df = DataFrame({"a": values, "b": values})

    result = df.groupby("a").sample(n=n, frac=frac)
    values = [1] * 2 + [2] * 2
    expected = DataFrame({"a": values, "b": values}, index=result.index)
    tm.assert_frame_equal(result, expected)

    result = df.groupby("a")["b"].sample(n=n, frac=frac)
    expected = Series(values, name="b", index=result.index)
    tm.assert_series_equal(result, expected)


def test_groupby_sample_unbalanced_groups_shape():
    values = [1] * 10 + [2] * 20
    df = DataFrame({"a": values, "b": values})

    result = df.groupby("a").sample(n=5)
    values = [1] * 5 + [2] * 5
    expected = DataFrame({"a": values, "b": values}, index=result.index)
    tm.assert_frame_equal(result, expected)

    result = df.groupby("a")["b"].sample(n=5)
    expected = Series(values, name="b", index=result.index)
    tm.assert_series_equal(result, expected)


def test_groupby_sample_index_value_spans_groups():
    values = [1] * 3 + [2] * 3
    df = DataFrame({"a": values, "b": values}, index=[1, 2, 2, 2, 2, 2])

    result = df.groupby("a").sample(n=2)
    values = [1] * 2 + [2] * 2
    expected = DataFrame({"a": values, "b": values}, index=result.index)
    tm.assert_frame_equal(result, expected)

    result = df.groupby("a")["b"].sample(n=2)
    expected = Series(values, name="b", index=result.index)
    tm.assert_series_equal(result, expected)


def test_groupby_sample_n_and_frac_raises():
    df = DataFrame({"a": [1, 2], "b": [1, 2]})
    msg = "Please enter a value for `frac` OR `n`, not both"

    with pytest.raises(ValueError, match=msg):
        df.groupby("a").sample(n=1, frac=1.0)

    with pytest.raises(ValueError, match=msg):
        df.groupby("a")["b"].sample(n=1, frac=1.0)


def test_groupby_sample_frac_gt_one_without_replacement_raises():
    df = DataFrame({"a": [1, 2], "b": [1, 2]})
    msg = "Replace has to be set to `True` when upsampling the population `frac` > 1."

    with pytest.raises(ValueError, match=msg):
        df.groupby("a").sample(frac=1.5, replace=False)

    with pytest.raises(ValueError, match=msg):
        df.groupby("a")["b"].sample(frac=1.5, replace=False)


@pytest.mark.parametrize("n", [-1, 1.5])
def test_groupby_sample_invalid_n_raises(n):
    df = DataFrame({"a": [1, 2], "b": [1, 2]})

    if n < 0:
        msg = "Please provide positive value"
    else:
        msg = "Only integers accepted as `n` values"

    with pytest.raises(ValueError, match=msg):
        df.groupby("a").sample(n=n)

    with pytest.raises(ValueError, match=msg):
        df.groupby("a")["b"].sample(n=n)


def test_groupby_sample_oversample():
    values = [1] * 10 + [2] * 10
    df = DataFrame({"a": values, "b": values})

    result = df.groupby("a").sample(frac=2.0, replace=True)
    values = [1] * 20 + [2] * 20
    expected = DataFrame({"a": values, "b": values}, index=result.index)
    tm.assert_frame_equal(result, expected)

    result = df.groupby("a")["b"].sample(frac=2.0, replace=True)
    expected = Series(values, name="b", index=result.index)
    tm.assert_series_equal(result, expected)


def test_groupby_sample_without_n_or_frac():
    values = [1] * 10 + [2] * 10
    df = DataFrame({"a": values, "b": values})

    result = df.groupby("a").sample(n=None, frac=None)
    expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index)
    tm.assert_frame_equal(result, expected)

    result = df.groupby("a")["b"].sample(n=None, frac=None)
    expected = Series([1, 2], name="b", index=result.index)
    tm.assert_series_equal(result, expected)


def test_groupby_sample_with_weights():
    values = [1] * 2 + [2] * 2
    df = DataFrame({"a": values, "b": values}, index=Index(["w", "x", "y", "z"]))

    result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0])
    expected = DataFrame({"a": values, "b": values}, index=Index(["w", "w", "y", "y"]))
    tm.assert_frame_equal(result, expected)

    result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0])
    expected = Series(values, name="b", index=Index(["w", "w", "y", "y"]))
    tm.assert_series_equal(result, expected)