Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / numpy   python

Repository URL to install this package:

Version: 1.19.1 

/ random / tests / test_randomstate.py

import hashlib
import pickle
import sys
import warnings

import numpy as np
import pytest
from numpy.testing import (
        assert_, assert_raises, assert_equal, assert_warns,
        assert_no_warnings, assert_array_equal, assert_array_almost_equal,
        suppress_warnings
        )

from numpy.random import MT19937, PCG64
from numpy import random

INT_FUNCS = {'binomial': (100.0, 0.6),
             'geometric': (.5,),
             'hypergeometric': (20, 20, 10),
             'logseries': (.5,),
             'multinomial': (20, np.ones(6) / 6.0),
             'negative_binomial': (100, .5),
             'poisson': (10.0,),
             'zipf': (2,),
             }

if np.iinfo(int).max < 2**32:
    # Windows and some 32-bit platforms, e.g., ARM
    INT_FUNC_HASHES = {'binomial': '670e1c04223ffdbab27e08fbbad7bdba',
                       'logseries': '6bd0183d2f8030c61b0d6e11aaa60caf',
                       'geometric': '6e9df886f3e1e15a643168568d5280c0',
                       'hypergeometric': '7964aa611b046aecd33063b90f4dec06',
                       'multinomial': '68a0b049c16411ed0aa4aff3572431e4',
                       'negative_binomial': 'dc265219eec62b4338d39f849cd36d09',
                       'poisson': '7b4dce8e43552fc82701c2fa8e94dc6e',
                       'zipf': 'fcd2a2095f34578723ac45e43aca48c5',
                       }
else:
    INT_FUNC_HASHES = {'binomial': 'b5f8dcd74f172836536deb3547257b14',
                       'geometric': '8814571f45c87c59699d62ccd3d6c350',
                       'hypergeometric': 'bc64ae5976eac452115a16dad2dcf642',
                       'logseries': '84be924b37485a27c4a98797bc88a7a4',
                       'multinomial': 'ec3c7f9cf9664044bb0c6fb106934200',
                       'negative_binomial': '210533b2234943591364d0117a552969',
                       'poisson': '0536a8850c79da0c78defd742dccc3e0',
                       'zipf': 'f2841f504dd2525cd67cdcad7561e532',
                       }


@pytest.fixture(scope='module', params=INT_FUNCS)
def int_func(request):
    return (request.param, INT_FUNCS[request.param],
            INT_FUNC_HASHES[request.param])


def assert_mt19937_state_equal(a, b):
    assert_equal(a['bit_generator'], b['bit_generator'])
    assert_array_equal(a['state']['key'], b['state']['key'])
    assert_array_equal(a['state']['pos'], b['state']['pos'])
    assert_equal(a['has_gauss'], b['has_gauss'])
    assert_equal(a['gauss'], b['gauss'])


class TestSeed:
    def test_scalar(self):
        s = random.RandomState(0)
        assert_equal(s.randint(1000), 684)
        s = random.RandomState(4294967295)
        assert_equal(s.randint(1000), 419)

    def test_array(self):
        s = random.RandomState(range(10))
        assert_equal(s.randint(1000), 468)
        s = random.RandomState(np.arange(10))
        assert_equal(s.randint(1000), 468)
        s = random.RandomState([0])
        assert_equal(s.randint(1000), 973)
        s = random.RandomState([4294967295])
        assert_equal(s.randint(1000), 265)

    def test_invalid_scalar(self):
        # seed must be an unsigned 32 bit integer
        assert_raises(TypeError, random.RandomState, -0.5)
        assert_raises(ValueError, random.RandomState, -1)

    def test_invalid_array(self):
        # seed must be an unsigned 32 bit integer
        assert_raises(TypeError, random.RandomState, [-0.5])
        assert_raises(ValueError, random.RandomState, [-1])
        assert_raises(ValueError, random.RandomState, [4294967296])
        assert_raises(ValueError, random.RandomState, [1, 2, 4294967296])
        assert_raises(ValueError, random.RandomState, [1, -2, 4294967296])

    def test_invalid_array_shape(self):
        # gh-9832
        assert_raises(ValueError, random.RandomState, np.array([],
                                                               dtype=np.int64))
        assert_raises(ValueError, random.RandomState, [[1, 2, 3]])
        assert_raises(ValueError, random.RandomState, [[1, 2, 3],
                                                       [4, 5, 6]])

    def test_cannot_seed(self):
        rs = random.RandomState(PCG64(0))
        with assert_raises(TypeError):
            rs.seed(1234)

    def test_invalid_initialization(self):
        assert_raises(ValueError, random.RandomState, MT19937)


class TestBinomial:
    def test_n_zero(self):
        # Tests the corner case of n == 0 for the binomial distribution.
        # binomial(0, p) should be zero for any p in [0, 1].
        # This test addresses issue #3480.
        zeros = np.zeros(2, dtype='int')
        for p in [0, .5, 1]:
            assert_(random.binomial(0, p) == 0)
            assert_array_equal(random.binomial(zeros, p), zeros)

    def test_p_is_nan(self):
        # Issue #4571.
        assert_raises(ValueError, random.binomial, 1, np.nan)


class TestMultinomial:
    def test_basic(self):
        random.multinomial(100, [0.2, 0.8])

    def test_zero_probability(self):
        random.multinomial(100, [0.2, 0.8, 0.0, 0.0, 0.0])

    def test_int_negative_interval(self):
        assert_(-5 <= random.randint(-5, -1) < -1)
        x = random.randint(-5, -1, 5)
        assert_(np.all(-5 <= x))
        assert_(np.all(x < -1))

    def test_size(self):
        # gh-3173
        p = [0.5, 0.5]
        assert_equal(random.multinomial(1, p, np.uint32(1)).shape, (1, 2))
        assert_equal(random.multinomial(1, p, np.uint32(1)).shape, (1, 2))
        assert_equal(random.multinomial(1, p, np.uint32(1)).shape, (1, 2))
        assert_equal(random.multinomial(1, p, [2, 2]).shape, (2, 2, 2))
        assert_equal(random.multinomial(1, p, (2, 2)).shape, (2, 2, 2))
        assert_equal(random.multinomial(1, p, np.array((2, 2))).shape,
                     (2, 2, 2))

        assert_raises(TypeError, random.multinomial, 1, p,
                      float(1))

    def test_invalid_prob(self):
        assert_raises(ValueError, random.multinomial, 100, [1.1, 0.2])
        assert_raises(ValueError, random.multinomial, 100, [-.1, 0.9])

    def test_invalid_n(self):
        assert_raises(ValueError, random.multinomial, -1, [0.8, 0.2])

    def test_p_non_contiguous(self):
        p = np.arange(15.)
        p /= np.sum(p[1::3])
        pvals = p[1::3]
        random.seed(1432985819)
        non_contig = random.multinomial(100, pvals=pvals)
        random.seed(1432985819)
        contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals))
        assert_array_equal(non_contig, contig)


class TestSetState:
    def setup(self):
        self.seed = 1234567890
        self.random_state = random.RandomState(self.seed)
        self.state = self.random_state.get_state()

    def test_basic(self):
        old = self.random_state.tomaxint(16)
        self.random_state.set_state(self.state)
        new = self.random_state.tomaxint(16)
        assert_(np.all(old == new))

    def test_gaussian_reset(self):
        # Make sure the cached every-other-Gaussian is reset.
        old = self.random_state.standard_normal(size=3)
        self.random_state.set_state(self.state)
        new = self.random_state.standard_normal(size=3)
        assert_(np.all(old == new))

    def test_gaussian_reset_in_media_res(self):
        # When the state is saved with a cached Gaussian, make sure the
        # cached Gaussian is restored.

        self.random_state.standard_normal()
        state = self.random_state.get_state()
        old = self.random_state.standard_normal(size=3)
        self.random_state.set_state(state)
        new = self.random_state.standard_normal(size=3)
        assert_(np.all(old == new))

    def test_backwards_compatibility(self):
        # Make sure we can accept old state tuples that do not have the
        # cached Gaussian value.
        old_state = self.state[:-2]
        x1 = self.random_state.standard_normal(size=16)
        self.random_state.set_state(old_state)
        x2 = self.random_state.standard_normal(size=16)
        self.random_state.set_state(self.state)
        x3 = self.random_state.standard_normal(size=16)
        assert_(np.all(x1 == x2))
        assert_(np.all(x1 == x3))

    def test_negative_binomial(self):
        # Ensure that the negative binomial results take floating point
        # arguments without truncation.
        self.random_state.negative_binomial(0.5, 0.5)

    def test_get_state_warning(self):
        rs = random.RandomState(PCG64())
        with suppress_warnings() as sup:
            w = sup.record(RuntimeWarning)
            state = rs.get_state()
            assert_(len(w) == 1)
            assert isinstance(state, dict)
            assert state['bit_generator'] == 'PCG64'

    def test_invalid_legacy_state_setting(self):
        state = self.random_state.get_state()
        new_state = ('Unknown', ) + state[1:]
        assert_raises(ValueError, self.random_state.set_state, new_state)
        assert_raises(TypeError, self.random_state.set_state,
                      np.array(new_state, dtype=object))
        state = self.random_state.get_state(legacy=False)
        del state['bit_generator']
        assert_raises(ValueError, self.random_state.set_state, state)

    def test_pickle(self):
        self.random_state.seed(0)
        self.random_state.random_sample(100)
        self.random_state.standard_normal()
        pickled = self.random_state.get_state(legacy=False)
        assert_equal(pickled['has_gauss'], 1)
        rs_unpick = pickle.loads(pickle.dumps(self.random_state))
        unpickled = rs_unpick.get_state(legacy=False)
        assert_mt19937_state_equal(pickled, unpickled)

    def test_state_setting(self):
        attr_state = self.random_state.__getstate__()
        self.random_state.standard_normal()
        self.random_state.__setstate__(attr_state)
        state = self.random_state.get_state(legacy=False)
        assert_mt19937_state_equal(attr_state, state)

    def test_repr(self):
        assert repr(self.random_state).startswith('RandomState(MT19937)')


class TestRandint:

    rfunc = random.randint

    # valid integer/boolean types
    itype = [np.bool_, np.int8, np.uint8, np.int16, np.uint16,
             np.int32, np.uint32, np.int64, np.uint64]

    def test_unsupported_type(self):
        assert_raises(TypeError, self.rfunc, 1, dtype=float)

    def test_bounds_checking(self):
        for dt in self.itype:
            lbnd = 0 if dt is np.bool_ else np.iinfo(dt).min
            ubnd = 2 if dt is np.bool_ else np.iinfo(dt).max + 1
            assert_raises(ValueError, self.rfunc, lbnd - 1, ubnd, dtype=dt)
            assert_raises(ValueError, self.rfunc, lbnd, ubnd + 1, dtype=dt)
            assert_raises(ValueError, self.rfunc, ubnd, lbnd, dtype=dt)
            assert_raises(ValueError, self.rfunc, 1, 0, dtype=dt)

    def test_rng_zero_and_extremes(self):
        for dt in self.itype:
            lbnd = 0 if dt is np.bool_ else np.iinfo(dt).min
            ubnd = 2 if dt is np.bool_ else np.iinfo(dt).max + 1

            tgt = ubnd - 1
            assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)

            tgt = lbnd
            assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)

            tgt = (lbnd + ubnd)//2
            assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)

    def test_full_range(self):
        # Test for ticket #1690

        for dt in self.itype:
            lbnd = 0 if dt is np.bool_ else np.iinfo(dt).min
            ubnd = 2 if dt is np.bool_ else np.iinfo(dt).max + 1

            try:
                self.rfunc(lbnd, ubnd, dtype=dt)
            except Exception as e:
                raise AssertionError("No error should have been raised, "
                                     "but one was with the following "
                                     "message:\n\n%s" % str(e))

    def test_in_bounds_fuzz(self):
        # Don't use fixed seed
        random.seed()

        for dt in self.itype[1:]:
            for ubnd in [4, 8, 16]:
                vals = self.rfunc(2, ubnd, size=2**16, dtype=dt)
                assert_(vals.max() < ubnd)
                assert_(vals.min() >= 2)

        vals = self.rfunc(0, 2, size=2**16, dtype=np.bool_)

        assert_(vals.max() < 2)
        assert_(vals.min() >= 0)

    def test_repeatability(self):
        # We use a md5 hash of generated sequences of 1000 samples
        # in the range [0, 6) for all but bool, where the range
        # is [0, 2). Hashes are for little endian numbers.
        tgt = {'bool': '7dd3170d7aa461d201a65f8bcf3944b0',
               'int16': '1b7741b80964bb190c50d541dca1cac1',
               'int32': '4dc9fcc2b395577ebb51793e58ed1a05',
               'int64': '17db902806f448331b5a758d7d2ee672',
               'int8': '27dd30c4e08a797063dffac2490b0be6',
               'uint16': '1b7741b80964bb190c50d541dca1cac1',
               'uint32': '4dc9fcc2b395577ebb51793e58ed1a05',
               'uint64': '17db902806f448331b5a758d7d2ee672',
               'uint8': '27dd30c4e08a797063dffac2490b0be6'}

        for dt in self.itype[1:]:
            random.seed(1234)

            # view as little endian for hash
            if sys.byteorder == 'little':
                val = self.rfunc(0, 6, size=1000, dtype=dt)
            else:
                val = self.rfunc(0, 6, size=1000, dtype=dt).byteswap()

            res = hashlib.md5(val.view(np.int8)).hexdigest()
            assert_(tgt[np.dtype(dt).name] == res)
Loading ...