Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aaronreidsmith / numpy   python

Repository URL to install this package:

Version: 1.17.4 

/ testing / tests / test_utils.py

from __future__ import division, absolute_import, print_function

import warnings
import sys
import os
import itertools
import textwrap
import pytest
import weakref

import numpy as np
from numpy.testing import (
    assert_equal, assert_array_equal, assert_almost_equal,
    assert_array_almost_equal, assert_array_less, build_err_msg, raises,
    assert_raises, assert_warns, assert_no_warnings, assert_allclose,
    assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp,
    clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
    tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
    )
from numpy.core.overrides import ARRAY_FUNCTION_ENABLED


class _GenericTest(object):

    def _test_equal(self, a, b):
        self._assert_func(a, b)

    def _test_not_equal(self, a, b):
        with assert_raises(AssertionError):
            self._assert_func(a, b)

    def test_array_rank1_eq(self):
        """Test two equal array of rank 1 are found equal."""
        a = np.array([1, 2])
        b = np.array([1, 2])

        self._test_equal(a, b)

    def test_array_rank1_noteq(self):
        """Test two different array of rank 1 are found not equal."""
        a = np.array([1, 2])
        b = np.array([2, 2])

        self._test_not_equal(a, b)

    def test_array_rank2_eq(self):
        """Test two equal array of rank 2 are found equal."""
        a = np.array([[1, 2], [3, 4]])
        b = np.array([[1, 2], [3, 4]])

        self._test_equal(a, b)

    def test_array_diffshape(self):
        """Test two arrays with different shapes are found not equal."""
        a = np.array([1, 2])
        b = np.array([[1, 2], [1, 2]])

        self._test_not_equal(a, b)

    def test_objarray(self):
        """Test object arrays."""
        a = np.array([1, 1], dtype=object)
        self._test_equal(a, 1)

    def test_array_likes(self):
        self._test_equal([1, 2, 3], (1, 2, 3))


class TestArrayEqual(_GenericTest):

    def setup(self):
        self._assert_func = assert_array_equal

    def test_generic_rank1(self):
        """Test rank 1 array for all dtypes."""
        def foo(t):
            a = np.empty(2, t)
            a.fill(1)
            b = a.copy()
            c = a.copy()
            c.fill(0)
            self._test_equal(a, b)
            self._test_not_equal(c, b)

        # Test numeric types and object
        for t in '?bhilqpBHILQPfdgFDG':
            foo(t)

        # Test strings
        for t in ['S1', 'U1']:
            foo(t)

    def test_generic_rank3(self):
        """Test rank 3 array for all dtypes."""
        def foo(t):
            a = np.empty((4, 2, 3), t)
            a.fill(1)
            b = a.copy()
            c = a.copy()
            c.fill(0)
            self._test_equal(a, b)
            self._test_not_equal(c, b)

        # Test numeric types and object
        for t in '?bhilqpBHILQPfdgFDG':
            foo(t)

        # Test strings
        for t in ['S1', 'U1']:
            foo(t)

    def test_nan_array(self):
        """Test arrays with nan values in them."""
        a = np.array([1, 2, np.nan])
        b = np.array([1, 2, np.nan])

        self._test_equal(a, b)

        c = np.array([1, 2, 3])
        self._test_not_equal(c, b)

    def test_string_arrays(self):
        """Test two arrays with different shapes are found not equal."""
        a = np.array(['floupi', 'floupa'])
        b = np.array(['floupi', 'floupa'])

        self._test_equal(a, b)

        c = np.array(['floupipi', 'floupa'])

        self._test_not_equal(c, b)

    def test_recarrays(self):
        """Test record arrays."""
        a = np.empty(2, [('floupi', float), ('floupa', float)])
        a['floupi'] = [1, 2]
        a['floupa'] = [1, 2]
        b = a.copy()

        self._test_equal(a, b)

        c = np.empty(2, [('floupipi', float), ('floupa', float)])
        c['floupipi'] = a['floupi'].copy()
        c['floupa'] = a['floupa'].copy()

        with suppress_warnings() as sup:
            l = sup.record(FutureWarning, message="elementwise == ")
            self._test_not_equal(c, b)
            assert_equal(len(l), 1)

    def test_masked_nan_inf(self):
        # Regression test for gh-11121
        a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
        b = np.array([3., np.nan, 6.5])
        self._test_equal(a, b)
        self._test_equal(b, a)
        a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
        b = np.array([np.inf, 4., 6.5])
        self._test_equal(a, b)
        self._test_equal(b, a)

    def test_subclass_that_overrides_eq(self):
        # While we cannot guarantee testing functions will always work for
        # subclasses, the tests should ideally rely only on subclasses having
        # comparison operators, not on them being able to store booleans
        # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
        class MyArray(np.ndarray):
            def __eq__(self, other):
                return bool(np.equal(self, other).all())

            def __ne__(self, other):
                return not self == other

        a = np.array([1., 2.]).view(MyArray)
        b = np.array([2., 3.]).view(MyArray)
        assert_(type(a == a), bool)
        assert_(a == a)
        assert_(a != b)
        self._test_equal(a, a)
        self._test_not_equal(a, b)
        self._test_not_equal(b, a)

    @pytest.mark.skipif(
        not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
    def test_subclass_that_does_not_implement_npall(self):
        class MyArray(np.ndarray):
            def __array_function__(self, *args, **kwargs):
                return NotImplemented

        a = np.array([1., 2.]).view(MyArray)
        b = np.array([2., 3.]).view(MyArray)
        with assert_raises(TypeError):
            np.all(a)
        self._test_equal(a, a)
        self._test_not_equal(a, b)
        self._test_not_equal(b, a)


class TestBuildErrorMessage(object):

    def test_build_err_msg_defaults(self):
        x = np.array([1.00001, 2.00002, 3.00003])
        y = np.array([1.00002, 2.00003, 3.00004])
        err_msg = 'There is a mismatch'

        a = build_err_msg([x, y], err_msg)
        b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
             '1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
             '2.00003, 3.00004])')
        assert_equal(a, b)

    def test_build_err_msg_no_verbose(self):
        x = np.array([1.00001, 2.00002, 3.00003])
        y = np.array([1.00002, 2.00003, 3.00004])
        err_msg = 'There is a mismatch'

        a = build_err_msg([x, y], err_msg, verbose=False)
        b = '\nItems are not equal: There is a mismatch'
        assert_equal(a, b)

    def test_build_err_msg_custom_names(self):
        x = np.array([1.00001, 2.00002, 3.00003])
        y = np.array([1.00002, 2.00003, 3.00004])
        err_msg = 'There is a mismatch'

        a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
        b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
             '1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
             '3.00004])')
        assert_equal(a, b)

    def test_build_err_msg_custom_precision(self):
        x = np.array([1.000000001, 2.00002, 3.00003])
        y = np.array([1.000000002, 2.00003, 3.00004])
        err_msg = 'There is a mismatch'

        a = build_err_msg([x, y], err_msg, precision=10)
        b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
             '1.000000001, 2.00002    , 3.00003    ])\n DESIRED: array(['
             '1.000000002, 2.00003    , 3.00004    ])')
        assert_equal(a, b)


class TestEqual(TestArrayEqual):

    def setup(self):
        self._assert_func = assert_equal

    def test_nan_items(self):
        self._assert_func(np.nan, np.nan)
        self._assert_func([np.nan], [np.nan])
        self._test_not_equal(np.nan, [np.nan])
        self._test_not_equal(np.nan, 1)

    def test_inf_items(self):
        self._assert_func(np.inf, np.inf)
        self._assert_func([np.inf], [np.inf])
        self._test_not_equal(np.inf, [np.inf])

    def test_datetime(self):
        self._test_equal(
            np.datetime64("2017-01-01", "s"),
            np.datetime64("2017-01-01", "s")
        )
        self._test_equal(
            np.datetime64("2017-01-01", "s"),
            np.datetime64("2017-01-01", "m")
        )

        # gh-10081
        self._test_not_equal(
            np.datetime64("2017-01-01", "s"),
            np.datetime64("2017-01-02", "s")
        )
        self._test_not_equal(
            np.datetime64("2017-01-01", "s"),
            np.datetime64("2017-01-02", "m")
        )

    def test_nat_items(self):
        # not a datetime
        nadt_no_unit = np.datetime64("NaT")
        nadt_s = np.datetime64("NaT", "s")
        nadt_d = np.datetime64("NaT", "ns")
        # not a timedelta
        natd_no_unit = np.timedelta64("NaT")
        natd_s = np.timedelta64("NaT", "s")
        natd_d = np.timedelta64("NaT", "ns")

        dts = [nadt_no_unit, nadt_s, nadt_d]
        tds = [natd_no_unit, natd_s, natd_d]
        for a, b in itertools.product(dts, dts):
            self._assert_func(a, b)
            self._assert_func([a], [b])
            self._test_not_equal([a], b)

        for a, b in itertools.product(tds, tds):
            self._assert_func(a, b)
            self._assert_func([a], [b])
            self._test_not_equal([a], b)

        for a, b in itertools.product(tds, dts):
            self._test_not_equal(a, b)
            self._test_not_equal(a, [b])
            self._test_not_equal([a], [b])
            self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
            self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
            self._test_not_equal([a], np.timedelta64(123, "s"))
            self._test_not_equal([b], np.timedelta64(123, "s"))

    def test_non_numeric(self):
        self._assert_func('ab', 'ab')
        self._test_not_equal('ab', 'abb')

    def test_complex_item(self):
        self._assert_func(complex(1, 2), complex(1, 2))
        self._assert_func(complex(1, np.nan), complex(1, np.nan))
        self._test_not_equal(complex(1, np.nan), complex(1, 2))
        self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
        self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))

    def test_negative_zero(self):
        self._test_not_equal(np.PZERO, np.NZERO)

    def test_complex(self):
        x = np.array([complex(1, 2), complex(1, np.nan)])
        y = np.array([complex(1, 2), complex(1, 2)])
        self._assert_func(x, x)
        self._test_not_equal(x, y)

    def test_error_message(self):
        with pytest.raises(AssertionError) as exc_info:
            self._assert_func(np.array([1, 2]), np.array([[1, 2]]))
        msg = str(exc_info.value)
        msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
        msg_reference = textwrap.dedent("""\

        Arrays are not equal

        (shapes (2,), (1, 2) mismatch)
         x: array([1, 2])
         y: array([[1, 2]])""")

        try:
            assert_equal(msg, msg_reference)
Loading ...