Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / pandas   python

Repository URL to install this package:

Version: 1.1.1 

/ tests / window / test_dtypes.py

from itertools import product

import numpy as np
import pytest

from pandas import DataFrame, Series
import pandas._testing as tm
from pandas.core.base import DataError

# gh-12373 : rolling functions error on float32 data
# make sure rolling functions works for different dtypes
#
# NOTE that these are yielded tests and so _create_data
# is explicitly called.
#
# further note that we are only checking rolling for fully dtype
# compliance (though both expanding and ewm inherit)


class Dtype:
    window = 2

    funcs = {
        "count": lambda v: v.count(),
        "max": lambda v: v.max(),
        "min": lambda v: v.min(),
        "sum": lambda v: v.sum(),
        "mean": lambda v: v.mean(),
        "std": lambda v: v.std(),
        "var": lambda v: v.var(),
        "median": lambda v: v.median(),
    }

    def get_expects(self):
        expects = {
            "sr1": {
                "count": Series([1, 2, 2, 2, 2], dtype="float64"),
                "max": Series([np.nan, 1, 2, 3, 4], dtype="float64"),
                "min": Series([np.nan, 0, 1, 2, 3], dtype="float64"),
                "sum": Series([np.nan, 1, 3, 5, 7], dtype="float64"),
                "mean": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"),
                "std": Series([np.nan] + [np.sqrt(0.5)] * 4, dtype="float64"),
                "var": Series([np.nan, 0.5, 0.5, 0.5, 0.5], dtype="float64"),
                "median": Series([np.nan, 0.5, 1.5, 2.5, 3.5], dtype="float64"),
            },
            "sr2": {
                "count": Series([1, 2, 2, 2, 2], dtype="float64"),
                "max": Series([np.nan, 10, 8, 6, 4], dtype="float64"),
                "min": Series([np.nan, 8, 6, 4, 2], dtype="float64"),
                "sum": Series([np.nan, 18, 14, 10, 6], dtype="float64"),
                "mean": Series([np.nan, 9, 7, 5, 3], dtype="float64"),
                "std": Series([np.nan] + [np.sqrt(2)] * 4, dtype="float64"),
                "var": Series([np.nan, 2, 2, 2, 2], dtype="float64"),
                "median": Series([np.nan, 9, 7, 5, 3], dtype="float64"),
            },
            "sr3": {
                "count": Series([1, 2, 2, 1, 1], dtype="float64"),
                "max": Series([np.nan, 1, 2, np.nan, np.nan], dtype="float64"),
                "min": Series([np.nan, 0, 1, np.nan, np.nan], dtype="float64"),
                "sum": Series([np.nan, 1, 3, np.nan, np.nan], dtype="float64"),
                "mean": Series([np.nan, 0.5, 1.5, np.nan, np.nan], dtype="float64"),
                "std": Series(
                    [np.nan] + [np.sqrt(0.5)] * 2 + [np.nan] * 2, dtype="float64"
                ),
                "var": Series([np.nan, 0.5, 0.5, np.nan, np.nan], dtype="float64"),
                "median": Series([np.nan, 0.5, 1.5, np.nan, np.nan], dtype="float64"),
            },
            "df": {
                "count": DataFrame(
                    {0: Series([1, 2, 2, 2, 2]), 1: Series([1, 2, 2, 2, 2])},
                    dtype="float64",
                ),
                "max": DataFrame(
                    {0: Series([np.nan, 2, 4, 6, 8]), 1: Series([np.nan, 3, 5, 7, 9])},
                    dtype="float64",
                ),
                "min": DataFrame(
                    {0: Series([np.nan, 0, 2, 4, 6]), 1: Series([np.nan, 1, 3, 5, 7])},
                    dtype="float64",
                ),
                "sum": DataFrame(
                    {
                        0: Series([np.nan, 2, 6, 10, 14]),
                        1: Series([np.nan, 4, 8, 12, 16]),
                    },
                    dtype="float64",
                ),
                "mean": DataFrame(
                    {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])},
                    dtype="float64",
                ),
                "std": DataFrame(
                    {
                        0: Series([np.nan] + [np.sqrt(2)] * 4),
                        1: Series([np.nan] + [np.sqrt(2)] * 4),
                    },
                    dtype="float64",
                ),
                "var": DataFrame(
                    {0: Series([np.nan, 2, 2, 2, 2]), 1: Series([np.nan, 2, 2, 2, 2])},
                    dtype="float64",
                ),
                "median": DataFrame(
                    {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])},
                    dtype="float64",
                ),
            },
        }
        return expects

    def _create_dtype_data(self, dtype):
        sr1 = Series(np.arange(5), dtype=dtype)
        sr2 = Series(np.arange(10, 0, -2), dtype=dtype)
        sr3 = sr1.copy()
        sr3[3] = np.NaN
        df = DataFrame(np.arange(10).reshape((5, 2)), dtype=dtype)

        data = {"sr1": sr1, "sr2": sr2, "sr3": sr3, "df": df}

        return data

    def _create_data(self):
        self.data = self._create_dtype_data(self.dtype)
        self.expects = self.get_expects()

    def test_dtypes(self):
        self._create_data()
        for f_name, d_name in product(self.funcs.keys(), self.data.keys()):

            f = self.funcs[f_name]
            d = self.data[d_name]
            exp = self.expects[d_name][f_name]
            self.check_dtypes(f, f_name, d, d_name, exp)

    def check_dtypes(self, f, f_name, d, d_name, exp):
        roll = d.rolling(window=self.window)
        result = f(roll)
        tm.assert_almost_equal(result, exp)


class TestDtype_object(Dtype):
    dtype = object


class Dtype_integer(Dtype):
    pass


class TestDtype_int8(Dtype_integer):
    dtype = np.int8


class TestDtype_int16(Dtype_integer):
    dtype = np.int16


class TestDtype_int32(Dtype_integer):
    dtype = np.int32


class TestDtype_int64(Dtype_integer):
    dtype = np.int64


class Dtype_uinteger(Dtype):
    pass


class TestDtype_uint8(Dtype_uinteger):
    dtype = np.uint8


class TestDtype_uint16(Dtype_uinteger):
    dtype = np.uint16


class TestDtype_uint32(Dtype_uinteger):
    dtype = np.uint32


class TestDtype_uint64(Dtype_uinteger):
    dtype = np.uint64


class Dtype_float(Dtype):
    pass


class TestDtype_float16(Dtype_float):
    dtype = np.float16


class TestDtype_float32(Dtype_float):
    dtype = np.float32


class TestDtype_float64(Dtype_float):
    dtype = np.float64


class TestDtype_category(Dtype):
    dtype = "category"
    include_df = False

    def _create_dtype_data(self, dtype):
        sr1 = Series(range(5), dtype=dtype)
        sr2 = Series(range(10, 0, -2), dtype=dtype)

        data = {"sr1": sr1, "sr2": sr2}

        return data


class DatetimeLike(Dtype):
    def check_dtypes(self, f, f_name, d, d_name, exp):

        roll = d.rolling(window=self.window)
        if f_name == "count":
            result = f(roll)
            tm.assert_almost_equal(result, exp)

        else:
            msg = "No numeric types to aggregate"
            with pytest.raises(DataError, match=msg):
                f(roll)


class TestDtype_timedelta(DatetimeLike):
    dtype = np.dtype("m8[ns]")


class TestDtype_datetime(DatetimeLike):
    dtype = np.dtype("M8[ns]")


class TestDtype_datetime64UTC(DatetimeLike):
    dtype = "datetime64[ns, UTC]"

    def _create_data(self):
        pytest.skip(
            "direct creation of extension dtype "
            "datetime64[ns, UTC] is not supported ATM"
        )