from functools import reduce

import numpy as np
import numpy.core.umath as umath
import numpy.core.fromnumeric as fromnumeric
from numpy.testing import (
    assert_, assert_raises, assert_equal,
from numpy.ma import (
    MaskType, MaskedArray, absolute, add, all, allclose, allequal, alltrue,
    arange, arccos, arcsin, arctan, arctan2, array, average, choose,
    concatenate, conjugate, cos, cosh, count, divide, equal, exp, filled,
    getmask, greater, greater_equal, inner, isMaskedArray, less,
    less_equal, log, log10, make_mask, masked, masked_array, masked_equal,
    masked_greater, masked_greater_equal, masked_inside, masked_less,
    masked_less_equal, masked_not_equal, masked_outside,
    masked_print_option, masked_values, masked_where, maximum, minimum,
    multiply, nomask, nonzero, not_equal, ones, outer, product, put, ravel,
    repeat, resize, shape, sin, sinh, sometrue, sort, sqrt, subtract, sum,
    take, tan, tanh, transpose, where, zeros,
from numpy.compat import pickle

pi = np.pi

def eq(v, w, msg=''):
    result = allclose(v, w)
    if not result:
        print("Not eq:%s\n%s\n----%s" % (msg, str(v), str(w)))
    return result

class TestMa:

    def setup(self):
        x = np.array([1., 1., 1., -2., pi/2.0, 4., 5., -10., 10., 1., 2., 3.])
        y = np.array([5., 0., 3., 2., -1., -4., 0., -10., 10., 1., 0., 3.])
        a10 = 10.
        m1 = [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
        m2 = [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1]
        xm = array(x, mask=m1)
        ym = array(y, mask=m2)
        z = np.array([-.5, 0., .5, .8])
        zm = array(z, mask=[0, 1, 0, 0])
        xf = np.where(m1, 1e+20, x)
        s = x.shape
        self.d = (x, y, a10, m1, m2, xm, ym, z, zm, xf, s)

    def test_testBasic1d(self):
        # Test of basic array creation and properties in 1 dimension.
        (x, y, a10, m1, m2, xm, ym, z, zm, xf, s) = self.d
        assert_(not isMaskedArray(x))
        assert_equal(shape(xm), s)
        assert_equal(xm.shape, s)
        assert_equal(xm.dtype, x.dtype)
        assert_equal(xm.size, reduce(lambda x, y:x * y, s))
        assert_equal(count(xm), len(m1) - reduce(lambda x, y:x + y, m1))
        assert_(eq(xm, xf))
        assert_(eq(filled(xm, 1.e20), xf))
        assert_(eq(x, xm))

    def test_testBasic2d(self):
        # Test of basic array creation and properties in 2 dimensions.
        for s in [(4, 3), (6, 2)]:
            (x, y, a10, m1, m2, xm, ym, z, zm, xf, s) = self.d
            x.shape = s
            y.shape = s
            xm.shape = s
            ym.shape = s
            xf.shape = s

            assert_(not isMaskedArray(x))
            assert_equal(shape(xm), s)
            assert_equal(xm.shape, s)
            assert_equal(xm.size, reduce(lambda x, y:x * y, s))
                             len(m1) - reduce(lambda x, y:x + y, m1))
            assert_(eq(xm, xf))
            assert_(eq(filled(xm, 1.e20), xf))
            assert_(eq(x, xm))

    def test_testArithmetic(self):
        # Test of basic arithmetic.
        (x, y, a10, m1, m2, xm, ym, z, zm, xf, s) = self.d
        a2d = array([[1, 2], [0, 4]])
        a2dm = masked_array(a2d, [[0, 0], [1, 0]])
        assert_(eq(a2d * a2d, a2d * a2dm))
        assert_(eq(a2d + a2d, a2d + a2dm))
        assert_(eq(a2d - a2d, a2d - a2dm))
        for s in [(12,), (4, 3), (2, 6)]:
            x = x.reshape(s)
            y = y.reshape(s)
            xm = xm.reshape(s)
            ym = ym.reshape(s)
            xf = xf.reshape(s)
            assert_(eq(-x, -xm))
            assert_(eq(x + y, xm + ym))
            assert_(eq(x - y, xm - ym))
            assert_(eq(x * y, xm * ym))
            with np.errstate(divide='ignore', invalid='ignore'):
                assert_(eq(x / y, xm / ym))
            assert_(eq(a10 + y, a10 + ym))
            assert_(eq(a10 - y, a10 - ym))
            assert_(eq(a10 * y, a10 * ym))
            with np.errstate(divide='ignore', invalid='ignore'):
                assert_(eq(a10 / y, a10 / ym))
            assert_(eq(x + a10, xm + a10))
            assert_(eq(x - a10, xm - a10))
            assert_(eq(x * a10, xm * a10))
            assert_(eq(x / a10, xm / a10))
            assert_(eq(x ** 2, xm ** 2))
            assert_(eq(abs(x) ** 2.5, abs(xm) ** 2.5))
            assert_(eq(x ** y, xm ** ym))
            assert_(eq(np.add(x, y), add(xm, ym)))
            assert_(eq(np.subtract(x, y), subtract(xm, ym)))
            assert_(eq(np.multiply(x, y), multiply(xm, ym)))
            with np.errstate(divide='ignore', invalid='ignore'):
                assert_(eq(np.divide(x, y), divide(xm, ym)))

    def test_testMixedArithmetic(self):
        na = np.array([1])
        ma = array([1])
        assert_(isinstance(na + ma, MaskedArray))
        assert_(isinstance(ma + na, MaskedArray))

    def test_testUfuncs1(self):
        # Test various functions such as sin, cos.
        (x, y, a10, m1, m2, xm, ym, z, zm, xf, s) = self.d
        assert_(eq(np.cos(x), cos(xm)))
        assert_(eq(np.cosh(x), cosh(xm)))
        assert_(eq(np.sin(x), sin(xm)))
        assert_(eq(np.sinh(x), sinh(xm)))
        assert_(eq(np.tan(x), tan(xm)))
        assert_(eq(np.tanh(x), tanh(xm)))
        with np.errstate(divide='ignore', invalid='ignore'):
            assert_(eq(np.sqrt(abs(x)), sqrt(xm)))
            assert_(eq(np.log(abs(x)), log(xm)))
            assert_(eq(np.log10(abs(x)), log10(xm)))
        assert_(eq(np.exp(x), exp(xm)))
        assert_(eq(np.arcsin(z), arcsin(zm)))
        assert_(eq(np.arccos(z), arccos(zm)))
        assert_(eq(np.arctan(z), arctan(zm)))
        assert_(eq(np.arctan2(x, y), arctan2(xm, ym)))
        assert_(eq(np.absolute(x), absolute(xm)))
        assert_(eq(np.equal(x, y), equal(xm, ym)))
        assert_(eq(np.not_equal(x, y), not_equal(xm, ym)))
        assert_(eq(np.less(x, y), less(xm, ym)))
        assert_(eq(np.greater(x, y), greater(xm, ym)))
        assert_(eq(np.less_equal(x, y), less_equal(xm, ym)))
        assert_(eq(np.greater_equal(x, y), greater_equal(xm, ym)))
        assert_(eq(np.conjugate(x), conjugate(xm)))
        assert_(eq(np.concatenate((x, y)), concatenate((xm, ym))))
        assert_(eq(np.concatenate((x, y)), concatenate((x, y))))
        assert_(eq(np.concatenate((x, y)), concatenate((xm, y))))
        assert_(eq(np.concatenate((x, y, x)), concatenate((x, ym, x))))

    def test_xtestCount(self):
        # Test count
        ott = array([0., 1., 2., 3.], mask=[1, 0, 0, 0])
        assert_(count(ott).dtype.type is np.intp)
        assert_equal(3, count(ott))
        assert_equal(1, count(1))
        assert_(eq(0, array(1, mask=[1])))
        ott = ott.reshape((2, 2))
        assert_(count(ott).dtype.type is np.intp)
        assert_(isinstance(count(ott, 0), np.ndarray))
        assert_(count(ott).dtype.type is np.intp)
        assert_(eq(3, count(ott)))
        assert_(getmask(count(ott, 0)) is nomask)
        assert_(eq([1, 2], count(ott, 0)))

    def test_testMinMax(self):
        # Test minimum and maximum.
        (x, y, a10, m1, m2, xm, ym, z, zm, xf, s) = self.d
        xr = np.ravel(x)  # max doesn't work if shaped
        xmr = ravel(xm)

        # true because of careful selection of data
        assert_(eq(max(xr), maximum.reduce(xmr)))
        assert_(eq(min(xr), minimum.reduce(xmr)))

    def test_testAddSumProd(self):
        # Test add, sum, product.
        (x, y, a10, m1, m2, xm, ym, z, zm, xf, s) = self.d
        assert_(eq(np.add.reduce(x), add.reduce(x)))
        assert_(eq(np.add.accumulate(x), add.accumulate(x)))
        assert_(eq(4, sum(array(4), axis=0)))
        assert_(eq(4, sum(array(4), axis=0)))
        assert_(eq(np.sum(x, axis=0), sum(x, axis=0)))
        assert_(eq(np.sum(filled(xm, 0), axis=0), sum(xm, axis=0)))
        assert_(eq(np.sum(x, 0), sum(x, 0)))
        assert_(eq(np.product(x, axis=0), product(x, axis=0)))
        assert_(eq(np.product(x, 0), product(x, 0)))
        assert_(eq(np.product(filled(xm, 1), axis=0),
                           product(xm, axis=0)))
        if len(s) > 1:
            assert_(eq(np.concatenate((x, y), 1),
                               concatenate((xm, ym), 1)))
            assert_(eq(np.add.reduce(x, 1), add.reduce(x, 1)))
            assert_(eq(np.sum(x, 1), sum(x, 1)))
            assert_(eq(np.product(x, 1), product(x, 1)))

    def test_testCI(self):
        # Test of conversions and indexing
        x1 = np.array([1, 2, 4, 3])
        x2 = array(x1, mask=[1, 0, 0, 0])
        x3 = array(x1, mask=[0, 1, 0, 1])
        x4 = array(x1)
        # test conversion to strings
        str(x2)  # raises?
        repr(x2)  # raises?
        assert_(eq(np.sort(x1), sort(x2, fill_value=0)))
        # tests of indexing
        assert_(type(x2[1]) is type(x1[1]))
        assert_(x1[1] == x2[1])
        assert_(x2[0] is masked)
        assert_(eq(x1[2], x2[2]))
        assert_(eq(x1[2:5], x2[2:5]))
        assert_(eq(x1[:], x2[:]))
        assert_(eq(x1[1:], x3[1:]))
        x1[2] = 9
        x2[2] = 9
        assert_(eq(x1, x2))
        x1[1:3] = 99
        x2[1:3] = 99
        assert_(eq(x1, x2))
        x2[1] = masked
        assert_(eq(x1, x2))
        x2[1:3] = masked
        assert_(eq(x1, x2))
        x2[:] = x1
        x2[1] = masked
        assert_(allequal(getmask(x2), array([0, 1, 0, 0])))
        x3[:] = masked_array([1, 2, 3, 4], [0, 1, 1, 0])
        assert_(allequal(getmask(x3), array([0, 1, 1, 0])))
        x4[:] = masked_array([1, 2, 3, 4], [0, 1, 1, 0])
        assert_(allequal(getmask(x4), array([0, 1, 1, 0])))
        assert_(allequal(x4, array([1, 2, 3, 4])))
        x1 = np.arange(5) * 1.0
        x2 = masked_values(x1, 3.0)
        assert_(eq(x1, x2))
        assert_(allequal(array([0, 0, 0, 1, 0], MaskType), x2.mask))
        assert_(eq(3.0, x2.fill_value))
        x1 = array([1, 'hello', 2, 3], object)
        x2 = np.array([1, 'hello', 2, 3], object)
        s1 = x1[1]
        s2 = x2[1]
        assert_equal(type(s2), str)
        assert_equal(type(s1), str)
        assert_equal(s1, s2)
        assert_(x1[1:1].shape == (0,))

    def test_testCopySize(self):
        # Tests of some subtle points of copying and sizing.
        n = [0, 0, 1, 0, 0]
        m = make_mask(n)
        m2 = make_mask(m)
        assert_(m is m2)
        m3 = make_mask(m, copy=True)
        assert_(m is not m3)

        x1 = np.arange(5)
        y1 = array(x1, mask=m)
        assert_(y1._data is not x1)
        assert_(allequal(x1, y1._data))
        assert_(y1._mask is m)

        y1a = array(y1, copy=0)
        # For copy=False, one might expect that the array would just
        # passed on, i.e., that it would be "is" instead of "==".
        # See gh-4043 for discussion.
        assert_(y1a._mask.__array_interface__ ==

        y2 = array(x1, mask=m3, copy=0)
        assert_(y2._mask is m3)
        assert_(y2[2] is masked)
        y2[2] = 9
        assert_(y2[2] is not masked)
        assert_(y2._mask is m3)
        assert_(allequal(y2.mask, 0))

        y2a = array(x1, mask=m, copy=1)
        assert_(y2a._mask is not m)
        assert_(y2a[2] is masked)
        y2a[2] = 9
        assert_(y2a[2] is not masked)
        assert_(y2a._mask is not m)
        assert_(allequal(y2a.mask, 0))

        y3 = array(x1 * 1.0, mask=m)
        assert_(filled(y3).dtype is (x1 * 1.0).dtype)

        x4 = arange(4)
        x4[2] = masked
        y4 = resize(x4, (8,))
        assert_(eq(concatenate([x4, x4]), y4))
        assert_(eq(getmask(y4), [0, 0, 1, 0, 0, 0, 1, 0]))
        y5 = repeat(x4, (2, 2, 2, 2), axis=0)
        assert_(eq(y5, [0, 0, 1, 1, 2, 2, 3, 3]))
        y6 = repeat(x4, 2, axis=0)
        assert_(eq(y5, y6))

    def test_testPut(self):
        # Test of put
        d = arange(5)
        n = [0, 0, 0, 1, 1]
        m = make_mask(n)
        m2 = m.copy()
        x = array(d, mask=m)
        assert_(x[3] is masked)
        assert_(x[4] is masked)
        x[[1, 4]] = [10, 40]
        assert_(x._mask is m)
        assert_(x[3] is masked)
        assert_(x[4] is not masked)
        assert_(eq(x, [0, 10, 2, -1, 40]))

        x = array(d, mask=m2, copy=True)
        x.put([0, 1, 2], [-1, 100, 200])
        assert_(x._mask is not m2)
        assert_(x[3] is masked)
        assert_(x[4] is masked)
        assert_(eq(x, [-1, 100, 200, 0, 0]))

    def test_testPut2(self):
        # Test of put
        d = arange(5)
        x = array(d, mask=[0, 0, 0, 0, 0])
        z = array([10, 40], mask=[1, 0])
        assert_(x[2] is not masked)
        assert_(x[3] is not masked)
        x[2:4] = z
        assert_(x[2] is masked)
        assert_(x[3] is not masked)
        assert_(eq(x, [0, 1, 10, 40, 4]))

        d = arange(5)
        x = array(d, mask=[0, 0, 0, 0, 0])
        y = x[2:4]
