from __future__ import division, absolute_import, print_function
import warnings
import sys
import os
import itertools
import textwrap
import pytest
import weakref
import numpy as np
from numpy.testing import (
assert_equal, assert_array_equal, assert_almost_equal,
assert_array_almost_equal, assert_array_less, build_err_msg, raises,
assert_raises, assert_warns, assert_no_warnings, assert_allclose,
assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp,
clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
)
from numpy.core.overrides import ARRAY_FUNCTION_ENABLED
class _GenericTest(object):
def _test_equal(self, a, b):
self._assert_func(a, b)
def _test_not_equal(self, a, b):
with assert_raises(AssertionError):
self._assert_func(a, b)
def test_array_rank1_eq(self):
"""Test two equal array of rank 1 are found equal."""
a = np.array([1, 2])
b = np.array([1, 2])
self._test_equal(a, b)
def test_array_rank1_noteq(self):
"""Test two different array of rank 1 are found not equal."""
a = np.array([1, 2])
b = np.array([2, 2])
self._test_not_equal(a, b)
def test_array_rank2_eq(self):
"""Test two equal array of rank 2 are found equal."""
a = np.array([[1, 2], [3, 4]])
b = np.array([[1, 2], [3, 4]])
self._test_equal(a, b)
def test_array_diffshape(self):
"""Test two arrays with different shapes are found not equal."""
a = np.array([1, 2])
b = np.array([[1, 2], [1, 2]])
self._test_not_equal(a, b)
def test_objarray(self):
"""Test object arrays."""
a = np.array([1, 1], dtype=object)
self._test_equal(a, 1)
def test_array_likes(self):
self._test_equal([1, 2, 3], (1, 2, 3))
class TestArrayEqual(_GenericTest):
def setup(self):
self._assert_func = assert_array_equal
def test_generic_rank1(self):
"""Test rank 1 array for all dtypes."""
def foo(t):
a = np.empty(2, t)
a.fill(1)
b = a.copy()
c = a.copy()
c.fill(0)
self._test_equal(a, b)
self._test_not_equal(c, b)
# Test numeric types and object
for t in '?bhilqpBHILQPfdgFDG':
foo(t)
# Test strings
for t in ['S1', 'U1']:
foo(t)
def test_generic_rank3(self):
"""Test rank 3 array for all dtypes."""
def foo(t):
a = np.empty((4, 2, 3), t)
a.fill(1)
b = a.copy()
c = a.copy()
c.fill(0)
self._test_equal(a, b)
self._test_not_equal(c, b)
# Test numeric types and object
for t in '?bhilqpBHILQPfdgFDG':
foo(t)
# Test strings
for t in ['S1', 'U1']:
foo(t)
def test_nan_array(self):
"""Test arrays with nan values in them."""
a = np.array([1, 2, np.nan])
b = np.array([1, 2, np.nan])
self._test_equal(a, b)
c = np.array([1, 2, 3])
self._test_not_equal(c, b)
def test_string_arrays(self):
"""Test two arrays with different shapes are found not equal."""
a = np.array(['floupi', 'floupa'])
b = np.array(['floupi', 'floupa'])
self._test_equal(a, b)
c = np.array(['floupipi', 'floupa'])
self._test_not_equal(c, b)
def test_recarrays(self):
"""Test record arrays."""
a = np.empty(2, [('floupi', float), ('floupa', float)])
a['floupi'] = [1, 2]
a['floupa'] = [1, 2]
b = a.copy()
self._test_equal(a, b)
c = np.empty(2, [('floupipi', float), ('floupa', float)])
c['floupipi'] = a['floupi'].copy()
c['floupa'] = a['floupa'].copy()
with suppress_warnings() as sup:
l = sup.record(FutureWarning, message="elementwise == ")
self._test_not_equal(c, b)
assert_equal(len(l), 1)
def test_masked_nan_inf(self):
# Regression test for gh-11121
a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
b = np.array([3., np.nan, 6.5])
self._test_equal(a, b)
self._test_equal(b, a)
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
b = np.array([np.inf, 4., 6.5])
self._test_equal(a, b)
self._test_equal(b, a)
def test_subclass_that_overrides_eq(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having
# comparison operators, not on them being able to store booleans
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
class MyArray(np.ndarray):
def __eq__(self, other):
return bool(np.equal(self, other).all())
def __ne__(self, other):
return not self == other
a = np.array([1., 2.]).view(MyArray)
b = np.array([2., 3.]).view(MyArray)
assert_(type(a == a), bool)
assert_(a == a)
assert_(a != b)
self._test_equal(a, a)
self._test_not_equal(a, b)
self._test_not_equal(b, a)
@pytest.mark.skipif(
not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
def test_subclass_that_does_not_implement_npall(self):
class MyArray(np.ndarray):
def __array_function__(self, *args, **kwargs):
return NotImplemented
a = np.array([1., 2.]).view(MyArray)
b = np.array([2., 3.]).view(MyArray)
with assert_raises(TypeError):
np.all(a)
self._test_equal(a, a)
self._test_not_equal(a, b)
self._test_not_equal(b, a)
class TestBuildErrorMessage(object):
def test_build_err_msg_defaults(self):
x = np.array([1.00001, 2.00002, 3.00003])
y = np.array([1.00002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'
a = build_err_msg([x, y], err_msg)
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
'1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
'2.00003, 3.00004])')
assert_equal(a, b)
def test_build_err_msg_no_verbose(self):
x = np.array([1.00001, 2.00002, 3.00003])
y = np.array([1.00002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'
a = build_err_msg([x, y], err_msg, verbose=False)
b = '\nItems are not equal: There is a mismatch'
assert_equal(a, b)
def test_build_err_msg_custom_names(self):
x = np.array([1.00001, 2.00002, 3.00003])
y = np.array([1.00002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'
a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
'1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
'3.00004])')
assert_equal(a, b)
def test_build_err_msg_custom_precision(self):
x = np.array([1.000000001, 2.00002, 3.00003])
y = np.array([1.000000002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'
a = build_err_msg([x, y], err_msg, precision=10)
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
'1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
'1.000000002, 2.00003 , 3.00004 ])')
assert_equal(a, b)
class TestEqual(TestArrayEqual):
def setup(self):
self._assert_func = assert_equal
def test_nan_items(self):
self._assert_func(np.nan, np.nan)
self._assert_func([np.nan], [np.nan])
self._test_not_equal(np.nan, [np.nan])
self._test_not_equal(np.nan, 1)
def test_inf_items(self):
self._assert_func(np.inf, np.inf)
self._assert_func([np.inf], [np.inf])
self._test_not_equal(np.inf, [np.inf])
def test_datetime(self):
self._test_equal(
np.datetime64("2017-01-01", "s"),
np.datetime64("2017-01-01", "s")
)
self._test_equal(
np.datetime64("2017-01-01", "s"),
np.datetime64("2017-01-01", "m")
)
# gh-10081
self._test_not_equal(
np.datetime64("2017-01-01", "s"),
np.datetime64("2017-01-02", "s")
)
self._test_not_equal(
np.datetime64("2017-01-01", "s"),
np.datetime64("2017-01-02", "m")
)
def test_nat_items(self):
# not a datetime
nadt_no_unit = np.datetime64("NaT")
nadt_s = np.datetime64("NaT", "s")
nadt_d = np.datetime64("NaT", "ns")
# not a timedelta
natd_no_unit = np.timedelta64("NaT")
natd_s = np.timedelta64("NaT", "s")
natd_d = np.timedelta64("NaT", "ns")
dts = [nadt_no_unit, nadt_s, nadt_d]
tds = [natd_no_unit, natd_s, natd_d]
for a, b in itertools.product(dts, dts):
self._assert_func(a, b)
self._assert_func([a], [b])
self._test_not_equal([a], b)
for a, b in itertools.product(tds, tds):
self._assert_func(a, b)
self._assert_func([a], [b])
self._test_not_equal([a], b)
for a, b in itertools.product(tds, dts):
self._test_not_equal(a, b)
self._test_not_equal(a, [b])
self._test_not_equal([a], [b])
self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
self._test_not_equal([a], np.timedelta64(123, "s"))
self._test_not_equal([b], np.timedelta64(123, "s"))
def test_non_numeric(self):
self._assert_func('ab', 'ab')
self._test_not_equal('ab', 'abb')
def test_complex_item(self):
self._assert_func(complex(1, 2), complex(1, 2))
self._assert_func(complex(1, np.nan), complex(1, np.nan))
self._test_not_equal(complex(1, np.nan), complex(1, 2))
self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
def test_negative_zero(self):
self._test_not_equal(np.PZERO, np.NZERO)
def test_complex(self):
x = np.array([complex(1, 2), complex(1, np.nan)])
y = np.array([complex(1, 2), complex(1, 2)])
self._assert_func(x, x)
self._test_not_equal(x, y)
def test_error_message(self):
with pytest.raises(AssertionError) as exc_info:
self._assert_func(np.array([1, 2]), np.array([[1, 2]]))
msg = str(exc_info.value)
msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
msg_reference = textwrap.dedent("""\
Arrays are not equal
(shapes (2,), (1, 2) mismatch)
x: array([1, 2])
y: array([[1, 2]])""")
try:
assert_equal(msg, msg_reference)
Loading ...