import sys
import hashlib

import pytest

import numpy as np
from numpy.linalg import LinAlgError
from numpy.testing import (
    assert_, assert_raises, assert_equal, assert_allclose,
    assert_warns, assert_no_warnings, assert_array_equal,
    assert_array_almost_equal, suppress_warnings)

from numpy.random import Generator, MT19937, SeedSequence

random = Generator(MT19937())

        "seed": 0,
        "steps": 10,
        "initial": {"key_md5": "64eaf265d2203179fb5ffb73380cd589", "pos": 9},
        "jumped": {"key_md5": "8cb7b061136efceef5217a9ce2cc9a5a", "pos": 598},
        "initial": {"key_md5": "e99708a47b82ff51a2c7b0625b81afb5", "pos": 311},
        "jumped": {"key_md5": "2ecdbfc47a895b253e6e19ccb2e74b90", "pos": 276},
        "seed": [839438204, 980239840, 859048019, 821],
        "steps": 511,
        "initial": {"key_md5": "9fcd6280df9199785e17e93162ce283c", "pos": 510},
        "jumped": {"key_md5": "433b85229f2ed853cde06cd872818305", "pos": 475},

@pytest.fixture(scope='module', params=[True, False])
def endpoint(request):
    return request.param

class TestSeed:
    def test_scalar(self):
        s = Generator(MT19937(0))
        assert_equal(s.integers(1000), 479)
        s = Generator(MT19937(4294967295))
        assert_equal(s.integers(1000), 324)

    def test_array(self):
        s = Generator(MT19937(range(10)))
        assert_equal(s.integers(1000), 465)
        s = Generator(MT19937(np.arange(10)))
        assert_equal(s.integers(1000), 465)
        s = Generator(MT19937([0]))
        assert_equal(s.integers(1000), 479)
        s = Generator(MT19937([4294967295]))
        assert_equal(s.integers(1000), 324)

    def test_seedsequence(self):
        s = MT19937(SeedSequence(0))
        assert_equal(s.random_raw(1), 2058676884)

    def test_invalid_scalar(self):
        # seed must be an unsigned 32 bit integer
        assert_raises(TypeError, MT19937, -0.5)
        assert_raises(ValueError, MT19937, -1)

    def test_invalid_array(self):
        # seed must be an unsigned integer
        assert_raises(TypeError, MT19937, [-0.5])
        assert_raises(ValueError, MT19937, [-1])
        assert_raises(ValueError, MT19937, [1, -2, 4294967296])

    def test_noninstantized_bitgen(self):
        assert_raises(ValueError, Generator, MT19937)

class TestBinomial:
    def test_n_zero(self):
        # Tests the corner case of n == 0 for the binomial distribution.
        # binomial(0, p) should be zero for any p in [0, 1].
        # This test addresses issue #3480.
        zeros = np.zeros(2, dtype='int')
        for p in [0, .5, 1]:
            assert_(random.binomial(0, p) == 0)
            assert_array_equal(random.binomial(zeros, p), zeros)

    def test_p_is_nan(self):
        # Issue #4571.
        assert_raises(ValueError, random.binomial, 1, np.nan)

class TestMultinomial:
    def test_basic(self):
        random.multinomial(100, [0.2, 0.8])

    def test_zero_probability(self):
        random.multinomial(100, [0.2, 0.8, 0.0, 0.0, 0.0])

    def test_int_negative_interval(self):
        assert_(-5 <= random.integers(-5, -1) < -1)
        x = random.integers(-5, -1, 5)
        assert_(np.all(-5 <= x))
        assert_(np.all(x < -1))

    def test_size(self):
        # gh-3173
        p = [0.5, 0.5]
        assert_equal(random.multinomial(1, p, np.uint32(1)).shape, (1, 2))
        assert_equal(random.multinomial(1, p, np.uint32(1)).shape, (1, 2))
        assert_equal(random.multinomial(1, p, np.uint32(1)).shape, (1, 2))
        assert_equal(random.multinomial(1, p, [2, 2]).shape, (2, 2, 2))
        assert_equal(random.multinomial(1, p, (2, 2)).shape, (2, 2, 2))
        assert_equal(random.multinomial(1, p, np.array((2, 2))).shape,
                     (2, 2, 2))

        assert_raises(TypeError, random.multinomial, 1, p,

    def test_invalid_prob(self):
        assert_raises(ValueError, random.multinomial, 100, [1.1, 0.2])
        assert_raises(ValueError, random.multinomial, 100, [-.1, 0.9])

    def test_invalid_n(self):
        assert_raises(ValueError, random.multinomial, -1, [0.8, 0.2])
        assert_raises(ValueError, random.multinomial, [-1] * 10, [0.8, 0.2])

    def test_p_non_contiguous(self):
        p = np.arange(15.)
        p /= np.sum(p[1::3])
        pvals = p[1::3]
        random = Generator(MT19937(1432985819))
        non_contig = random.multinomial(100, pvals=pvals)
        random = Generator(MT19937(1432985819))
        contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals))
        assert_array_equal(non_contig, contig)

    def test_multidimensional_pvals(self):
        assert_raises(ValueError, random.multinomial, 10, [[0, 1]])
        assert_raises(ValueError, random.multinomial, 10, [[0], [1]])
        assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]])
        assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]]))

class TestMultivariateHypergeometric:

    def setup(self):
        self.seed = 8675309

    def test_argument_validation(self):
        # Error cases...

        # `colors` must be a 1-d sequence
        assert_raises(ValueError, random.multivariate_hypergeometric,
                      10, 4)

        # Negative nsample
        assert_raises(ValueError, random.multivariate_hypergeometric,
                      [2, 3, 4], -1)

        # Negative color
        assert_raises(ValueError, random.multivariate_hypergeometric,
                      [-1, 2, 3], 2)

        # nsample exceeds sum(colors)
        assert_raises(ValueError, random.multivariate_hypergeometric,
                      [2, 3, 4], 10)

        # nsample exceeds sum(colors) (edge case of empty colors)
        assert_raises(ValueError, random.multivariate_hypergeometric,
                      [], 1)

        # Validation errors associated with very large values in colors.
        assert_raises(ValueError, random.multivariate_hypergeometric,
                      [999999999, 101], 5, 1, 'marginals')

        int64_info = np.iinfo(np.int64)
        max_int64 = int64_info.max
        max_int64_index = max_int64 // int64_info.dtype.itemsize
        assert_raises(ValueError, random.multivariate_hypergeometric,
                      [max_int64_index - 100, 101], 5, 1, 'count')

    @pytest.mark.parametrize('method', ['count', 'marginals'])
    def test_edge_cases(self, method):
        # Set the seed, but in fact, all the results in this test are
        # deterministic, so we don't really need this.
        random = Generator(MT19937(self.seed))

        x = random.multivariate_hypergeometric([0, 0, 0], 0, method=method)
        assert_array_equal(x, [0, 0, 0])

        x = random.multivariate_hypergeometric([], 0, method=method)
        assert_array_equal(x, [])

        x = random.multivariate_hypergeometric([], 0, size=1, method=method)
        assert_array_equal(x, np.empty((1, 0), dtype=np.int64))

        x = random.multivariate_hypergeometric([1, 2, 3], 0, method=method)
        assert_array_equal(x, [0, 0, 0])

        x = random.multivariate_hypergeometric([9, 0, 0], 3, method=method)
        assert_array_equal(x, [3, 0, 0])

        colors = [1, 1, 0, 1, 1]
        x = random.multivariate_hypergeometric(colors, sum(colors),
        assert_array_equal(x, colors)

        x = random.multivariate_hypergeometric([3, 4, 5], 12, size=3,
        assert_array_equal(x, [[3, 4, 5]]*3)

    # Cases for nsample:
    #     nsample < 10
    #     10 <= nsample < colors.sum()/2
    #     colors.sum()/2 < nsample < colors.sum() - 10
    #     colors.sum() - 10 < nsample < colors.sum()
    @pytest.mark.parametrize('nsample', [8, 25, 45, 55])
    @pytest.mark.parametrize('method', ['count', 'marginals'])
    @pytest.mark.parametrize('size', [5, (2, 3), 150000])
    def test_typical_cases(self, nsample, method, size):
        random = Generator(MT19937(self.seed))

        colors = np.array([10, 5, 20, 25])
        sample = random.multivariate_hypergeometric(colors, nsample, size,
        if isinstance(size, int):
            expected_shape = (size,) + colors.shape
            expected_shape = size + colors.shape
        assert_equal(sample.shape, expected_shape)
        assert_((sample >= 0).all())
        assert_((sample <= colors).all())
                           np.full(size, fill_value=nsample, dtype=int))
        if isinstance(size, int) and size >= 100000:
            # This sample is large enough to compare its mean to
            # the expected values.
                            nsample * colors / colors.sum(),
                            rtol=1e-3, atol=0.005)

    def test_repeatability1(self):
        random = Generator(MT19937(self.seed))
        sample = random.multivariate_hypergeometric([3, 4, 5], 5, size=5,
        expected = np.array([[2, 1, 2],
                             [2, 1, 2],
                             [1, 1, 3],
                             [2, 0, 3],
                             [2, 1, 2]])
        assert_array_equal(sample, expected)

    def test_repeatability2(self):
        random = Generator(MT19937(self.seed))
        sample = random.multivariate_hypergeometric([20, 30, 50], 50,
        expected = np.array([[ 9, 17, 24],
                             [ 7, 13, 30],
                             [ 9, 15, 26],
                             [ 9, 17, 24],
                             [12, 14, 24]])
        assert_array_equal(sample, expected)

    def test_repeatability3(self):
        random = Generator(MT19937(self.seed))
        sample = random.multivariate_hypergeometric([20, 30, 50], 12,
        expected = np.array([[2, 3, 7],
                             [5, 3, 4],
                             [2, 5, 5],
                             [5, 3, 4],
                             [1, 5, 6]])
        assert_array_equal(sample, expected)

class TestSetState:
    def setup(self):
        self.seed = 1234567890
        self.rg = Generator(MT19937(self.seed))
        self.bit_generator = self.rg.bit_generator
        self.state = self.bit_generator.state
        self.legacy_state = (self.state['bit_generator'],

    def test_gaussian_reset(self):
        # Make sure the cached every-other-Gaussian is reset.
        old = self.rg.standard_normal(size=3)
        self.bit_generator.state = self.state
        new = self.rg.standard_normal(size=3)
        assert_(np.all(old == new))

    def test_gaussian_reset_in_media_res(self):
        # When the state is saved with a cached Gaussian, make sure the
        # cached Gaussian is restored.

        state = self.bit_generator.state
        old = self.rg.standard_normal(size=3)
        self.bit_generator.state = state
        new = self.rg.standard_normal(size=3)
        assert_(np.all(old == new))

    def test_negative_binomial(self):
        # Ensure that the negative binomial results take floating point
        # arguments without truncation.
        self.rg.negative_binomial(0.5, 0.5)

class TestIntegers:
    rfunc = random.integers

    # valid integer/boolean types
    itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
             np.int32, np.uint32, np.int64, np.uint64]

    def test_unsupported_type(self, endpoint):
        assert_raises(TypeError, self.rfunc, 1, endpoint=endpoint, dtype=float)

    def test_bounds_checking(self, endpoint):
        for dt in self.itype:
            lbnd = 0 if dt is bool else np.iinfo(dt).min
            ubnd = 2 if dt is bool else np.iinfo(dt).max + 1
            ubnd = ubnd - 1 if endpoint else ubnd
            assert_raises(ValueError, self.rfunc, lbnd - 1, ubnd,
                          endpoint=endpoint, dtype=dt)
            assert_raises(ValueError, self.rfunc, lbnd, ubnd + 1,
                          endpoint=endpoint, dtype=dt)
            assert_raises(ValueError, self.rfunc, ubnd, lbnd,
                          endpoint=endpoint, dtype=dt)
            assert_raises(ValueError, self.rfunc, 1, 0, endpoint=endpoint,

            assert_raises(ValueError, self.rfunc, [lbnd - 1], ubnd,
                          endpoint=endpoint, dtype=dt)
            assert_raises(ValueError, self.rfunc, [lbnd], [ubnd + 1],
                          endpoint=endpoint, dtype=dt)
            assert_raises(ValueError, self.rfunc, [ubnd], [lbnd],
                          endpoint=endpoint, dtype=dt)
            assert_raises(ValueError, self.rfunc, 1, [0],
                          endpoint=endpoint, dtype=dt)
