Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / pandas   python

Repository URL to install this package:

Version: 0.24.2 

/ tests / extension / decimal / test_decimal.py

import decimal
import math
import operator

import numpy as np
import pytest

import pandas as pd
from pandas import compat
from pandas.tests.extension import base
import pandas.util.testing as tm

from .array import DecimalArray, DecimalDtype, make_data, to_decimal


@pytest.fixture
def dtype():
    return DecimalDtype()


@pytest.fixture
def data():
    return DecimalArray(make_data())


@pytest.fixture
def data_missing():
    return DecimalArray([decimal.Decimal('NaN'), decimal.Decimal(1)])


@pytest.fixture
def data_for_sorting():
    return DecimalArray([decimal.Decimal('1'),
                         decimal.Decimal('2'),
                         decimal.Decimal('0')])


@pytest.fixture
def data_missing_for_sorting():
    return DecimalArray([decimal.Decimal('1'),
                         decimal.Decimal('NaN'),
                         decimal.Decimal('0')])


@pytest.fixture
def na_cmp():
    return lambda x, y: x.is_nan() and y.is_nan()


@pytest.fixture
def na_value():
    return decimal.Decimal("NaN")


@pytest.fixture
def data_for_grouping():
    b = decimal.Decimal('1.0')
    a = decimal.Decimal('0.0')
    c = decimal.Decimal('2.0')
    na = decimal.Decimal('NaN')
    return DecimalArray([b, b, na, na, a, a, b, c])


class BaseDecimal(object):

    def assert_series_equal(self, left, right, *args, **kwargs):
        def convert(x):
            # need to convert array([Decimal(NaN)], dtype='object') to np.NaN
            # because Series[object].isnan doesn't recognize decimal(NaN) as
            # NA.
            try:
                return math.isnan(x)
            except TypeError:
                return False

        if left.dtype == 'object':
            left_na = left.apply(convert)
        else:
            left_na = left.isna()
        if right.dtype == 'object':
            right_na = right.apply(convert)
        else:
            right_na = right.isna()

        tm.assert_series_equal(left_na, right_na)
        return tm.assert_series_equal(left[~left_na],
                                      right[~right_na],
                                      *args, **kwargs)

    def assert_frame_equal(self, left, right, *args, **kwargs):
        # TODO(EA): select_dtypes
        tm.assert_index_equal(
            left.columns, right.columns,
            exact=kwargs.get('check_column_type', 'equiv'),
            check_names=kwargs.get('check_names', True),
            check_exact=kwargs.get('check_exact', False),
            check_categorical=kwargs.get('check_categorical', True),
            obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame')))

        decimals = (left.dtypes == 'decimal').index

        for col in decimals:
            self.assert_series_equal(left[col], right[col],
                                     *args, **kwargs)

        left = left.drop(columns=decimals)
        right = right.drop(columns=decimals)
        tm.assert_frame_equal(left, right, *args, **kwargs)


class TestDtype(BaseDecimal, base.BaseDtypeTests):
    @pytest.mark.skipif(compat.PY2, reason="Context not hashable.")
    def test_hashable(self, dtype):
        pass


class TestInterface(BaseDecimal, base.BaseInterfaceTests):

    pytestmark = pytest.mark.skipif(compat.PY2,
                                    reason="Unhashble dtype in Py2.")


class TestConstructors(BaseDecimal, base.BaseConstructorsTests):

    @pytest.mark.skip(reason="not implemented constructor from dtype")
    def test_from_dtype(self, data):
        # construct from our dtype & string dtype
        pass


class TestReshaping(BaseDecimal, base.BaseReshapingTests):
    pytestmark = pytest.mark.skipif(compat.PY2,
                                    reason="Unhashble dtype in Py2.")


class TestGetitem(BaseDecimal, base.BaseGetitemTests):

    def test_take_na_value_other_decimal(self):
        arr = DecimalArray([decimal.Decimal('1.0'),
                            decimal.Decimal('2.0')])
        result = arr.take([0, -1], allow_fill=True,
                          fill_value=decimal.Decimal('-1.0'))
        expected = DecimalArray([decimal.Decimal('1.0'),
                                 decimal.Decimal('-1.0')])
        self.assert_extension_array_equal(result, expected)


class TestMissing(BaseDecimal, base.BaseMissingTests):
    pass


class Reduce(object):

    def check_reduce(self, s, op_name, skipna):

        if skipna or op_name in ['median', 'skew', 'kurt']:
            with pytest.raises(NotImplementedError):
                getattr(s, op_name)(skipna=skipna)

        else:
            result = getattr(s, op_name)(skipna=skipna)
            expected = getattr(np.asarray(s), op_name)()
            tm.assert_almost_equal(result, expected)


class TestNumericReduce(Reduce, base.BaseNumericReduceTests):
    pass


class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
    pass


class TestMethods(BaseDecimal, base.BaseMethodsTests):
    @pytest.mark.parametrize('dropna', [True, False])
    @pytest.mark.xfail(reason="value_counts not implemented yet.")
    def test_value_counts(self, all_data, dropna):
        all_data = all_data[:10]
        if dropna:
            other = np.array(all_data[~all_data.isna()])
        else:
            other = all_data

        result = pd.Series(all_data).value_counts(dropna=dropna).sort_index()
        expected = pd.Series(other).value_counts(dropna=dropna).sort_index()

        tm.assert_series_equal(result, expected)


class TestCasting(BaseDecimal, base.BaseCastingTests):
    pytestmark = pytest.mark.skipif(compat.PY2,
                                    reason="Unhashble dtype in Py2.")


class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
    pytestmark = pytest.mark.skipif(compat.PY2,
                                    reason="Unhashble dtype in Py2.")


class TestSetitem(BaseDecimal, base.BaseSetitemTests):
    pass


class TestPrinting(BaseDecimal, base.BasePrintingTests):
    pytestmark = pytest.mark.skipif(compat.PY2,
                                    reason="Unhashble dtype in Py2.")


# TODO(extension)
@pytest.mark.xfail(reason=(
    "raising AssertionError as this is not implemented, "
    "though easy enough to do"))
def test_series_constructor_coerce_data_to_extension_dtype_raises():
    xpr = ("Cannot cast data to extension dtype 'decimal'. Pass the "
           "extension array directly.")
    with pytest.raises(ValueError, match=xpr):
        pd.Series([0, 1, 2], dtype=DecimalDtype())


def test_series_constructor_with_dtype():
    arr = DecimalArray([decimal.Decimal('10.0')])
    result = pd.Series(arr, dtype=DecimalDtype())
    expected = pd.Series(arr)
    tm.assert_series_equal(result, expected)

    result = pd.Series(arr, dtype='int64')
    expected = pd.Series([10])
    tm.assert_series_equal(result, expected)


def test_dataframe_constructor_with_dtype():
    arr = DecimalArray([decimal.Decimal('10.0')])

    result = pd.DataFrame({"A": arr}, dtype=DecimalDtype())
    expected = pd.DataFrame({"A": arr})
    tm.assert_frame_equal(result, expected)

    arr = DecimalArray([decimal.Decimal('10.0')])
    result = pd.DataFrame({"A": arr}, dtype='int64')
    expected = pd.DataFrame({"A": [10]})
    tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("frame", [True, False])
def test_astype_dispatches(frame):
    # This is a dtype-specific test that ensures Series[decimal].astype
    # gets all the way through to ExtensionArray.astype
    # Designing a reliable smoke test that works for arbitrary data types
    # is difficult.
    data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a')
    ctx = decimal.Context()
    ctx.prec = 5

    if frame:
        data = data.to_frame()

    result = data.astype(DecimalDtype(ctx))

    if frame:
        result = result['a']

    assert result.dtype.context.prec == ctx.prec


class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests):

    def check_opname(self, s, op_name, other, exc=None):
        super(TestArithmeticOps, self).check_opname(s, op_name,
                                                    other, exc=None)

    def test_arith_series_with_array(self, data, all_arithmetic_operators):
        op_name = all_arithmetic_operators
        s = pd.Series(data)

        context = decimal.getcontext()
        divbyzerotrap = context.traps[decimal.DivisionByZero]
        invalidoptrap = context.traps[decimal.InvalidOperation]
        context.traps[decimal.DivisionByZero] = 0
        context.traps[decimal.InvalidOperation] = 0

        # Decimal supports ops with int, but not float
        other = pd.Series([int(d * 100) for d in data])
        self.check_opname(s, op_name, other)

        if "mod" not in op_name:
            self.check_opname(s, op_name, s * 2)

        self.check_opname(s, op_name, 0)
        self.check_opname(s, op_name, 5)
        context.traps[decimal.DivisionByZero] = divbyzerotrap
        context.traps[decimal.InvalidOperation] = invalidoptrap

    def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
        # We implement divmod
        super(TestArithmeticOps, self)._check_divmod_op(
            s, op, other, exc=None
        )

    def test_error(self):
        pass


class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests):

    def check_opname(self, s, op_name, other, exc=None):
        super(TestComparisonOps, self).check_opname(s, op_name,
                                                    other, exc=None)

    def _compare_other(self, s, data, op_name, other):
        self.check_opname(s, op_name, other)

    def test_compare_scalar(self, data, all_compare_operators):
        op_name = all_compare_operators
        s = pd.Series(data)
        self._compare_other(s, data, op_name, 0.5)

    def test_compare_array(self, data, all_compare_operators):
        op_name = all_compare_operators
        s = pd.Series(data)

        alter = np.random.choice([-1, 0, 1], len(data))
        # Randomly double, halve or keep same value
        other = pd.Series(data) * [decimal.Decimal(pow(2.0, i))
                                   for i in alter]
        self._compare_other(s, data, op_name, other)


class DecimalArrayWithoutFromSequence(DecimalArray):
    """Helper class for testing error handling in _from_sequence."""
    def _from_sequence(cls, scalars, dtype=None, copy=False):
        raise KeyError("For the test")


class DecimalArrayWithoutCoercion(DecimalArrayWithoutFromSequence):
    @classmethod
    def _create_arithmetic_method(cls, op):
        return cls._create_method(op, coerce_to_dtype=False)


DecimalArrayWithoutCoercion._add_arithmetic_ops()


def test_combine_from_sequence_raises():
    # https://github.com/pandas-dev/pandas/issues/22850
    ser = pd.Series(DecimalArrayWithoutFromSequence([
Loading ...