Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / joblib   python

Repository URL to install this package:

/ test / test_numpy_pickle.py

"""Test the numpy pickler as a replacement of the standard pickler."""

import copy
import os
import random
import sys
import re
import io
import warnings
import gzip
import zlib
import bz2
import pickle
import socket
from contextlib import closing
import mmap

from joblib.test.common import np, with_numpy, with_lz4, without_lz4
from joblib.test.common import with_memory_profiler, memory_used
from joblib.testing import parametrize, raises, SkipTest, warns

# numpy_pickle is not a drop-in replacement of pickle, as it takes
# filenames instead of open files as arguments.
from joblib import numpy_pickle, register_compressor
from joblib.test import data

from joblib._compat import PY3_OR_LATER
from joblib.numpy_pickle_utils import _IO_BUFFER_SIZE
from joblib.numpy_pickle_utils import _detect_compressor
from joblib.compressor import (_COMPRESSORS, _LZ4_PREFIX, CompressorWrapper,
                               LZ4_NOT_INSTALLED_ERROR, BinaryZlibFile)

###############################################################################
# Define a list of standard types.
# Borrowed from dill, initial author: Micheal McKerns:
# http://dev.danse.us/trac/pathos/browser/dill/dill_test2.py

typelist = []

# testing types
_none = None
typelist.append(_none)
_type = type
typelist.append(_type)
_bool = bool(1)
typelist.append(_bool)
_int = int(1)
typelist.append(_int)
try:
    _long = long(1)
    typelist.append(_long)
except NameError:
    # long is not defined in python 3
    pass
_float = float(1)
typelist.append(_float)
_complex = complex(1)
typelist.append(_complex)
_string = str(1)
typelist.append(_string)
try:
    _unicode = unicode(1)
    typelist.append(_unicode)
except NameError:
    # unicode is not defined in python 3
    pass
_tuple = ()
typelist.append(_tuple)
_list = []
typelist.append(_list)
_dict = {}
typelist.append(_dict)
try:
    _file = file
    typelist.append(_file)
except NameError:
    pass  # file does not exists in Python 3
try:
    _buffer = buffer
    typelist.append(_buffer)
except NameError:
    # buffer does not exists in Python 3
    pass
_builtin = len
typelist.append(_builtin)


def _function(x):
    yield x


class _class:
    def _method(self):
        pass


class _newclass(object):
    def _method(self):
        pass


typelist.append(_function)
typelist.append(_class)
typelist.append(_newclass)  # <type 'type'>
_instance = _class()
typelist.append(_instance)
_object = _newclass()
typelist.append(_object)  # <type 'class'>


###############################################################################
# Tests

@parametrize('compress', [0, 1])
@parametrize('member', typelist)
def test_standard_types(tmpdir, compress, member):
    # Test pickling and saving with standard types.
    filename = tmpdir.join('test.pkl').strpath
    numpy_pickle.dump(member, filename, compress=compress)
    _member = numpy_pickle.load(filename)
    # We compare the pickled instance to the reloaded one only if it
    # can be compared to a copied one
    if member == copy.deepcopy(member):
        assert member == _member


def test_value_error():
    # Test inverting the input arguments to dump
    with raises(ValueError):
        numpy_pickle.dump('foo', dict())


@parametrize('wrong_compress', [-1, 10, dict()])
def test_compress_level_error(wrong_compress):
    # Verify that passing an invalid compress argument raises an error.
    exception_msg = ('Non valid compress level given: '
                     '"{0}"'.format(wrong_compress))
    with raises(ValueError) as excinfo:
        numpy_pickle.dump('dummy', 'foo', compress=wrong_compress)
    excinfo.match(exception_msg)


@with_numpy
@parametrize('compress', [False, True, 0, 3, 'zlib'])
def test_numpy_persistence(tmpdir, compress):
    filename = tmpdir.join('test.pkl').strpath
    rnd = np.random.RandomState(0)
    a = rnd.random_sample((10, 2))
    # We use 'a.T' to have a non C-contiguous array.
    for index, obj in enumerate(((a,), (a.T,), (a, a), [a, a, a])):
        filenames = numpy_pickle.dump(obj, filename, compress=compress)

        # All is cached in one file
        assert len(filenames) == 1
        # Check that only one file was created
        assert filenames[0] == filename
        # Check that this file does exist
        assert os.path.exists(filenames[0])

        # Unpickle the object
        obj_ = numpy_pickle.load(filename)
        # Check that the items are indeed arrays
        for item in obj_:
            assert isinstance(item, np.ndarray)
        # And finally, check that all the values are equal.
        np.testing.assert_array_equal(np.array(obj), np.array(obj_))

    # Now test with array subclasses
    for obj in (np.matrix(np.zeros(10)),
                np.memmap(filename + 'mmap',
                          mode='w+', shape=4, dtype=np.float)):
        filenames = numpy_pickle.dump(obj, filename, compress=compress)
        # All is cached in one file
        assert len(filenames) == 1

        obj_ = numpy_pickle.load(filename)
        if (type(obj) is not np.memmap and
                hasattr(obj, '__array_prepare__')):
            # We don't reconstruct memmaps
            assert isinstance(obj_, type(obj))

        np.testing.assert_array_equal(obj_, obj)

    # Test with an object containing multiple numpy arrays
    obj = ComplexTestObject()
    filenames = numpy_pickle.dump(obj, filename, compress=compress)
    # All is cached in one file
    assert len(filenames) == 1

    obj_loaded = numpy_pickle.load(filename)
    assert isinstance(obj_loaded, type(obj))
    np.testing.assert_array_equal(obj_loaded.array_float, obj.array_float)
    np.testing.assert_array_equal(obj_loaded.array_int, obj.array_int)
    np.testing.assert_array_equal(obj_loaded.array_obj, obj.array_obj)


@with_numpy
def test_numpy_persistence_bufferred_array_compression(tmpdir):
    big_array = np.ones((_IO_BUFFER_SIZE + 100), dtype=np.uint8)
    filename = tmpdir.join('test.pkl').strpath
    numpy_pickle.dump(big_array, filename, compress=True)
    arr_reloaded = numpy_pickle.load(filename)

    np.testing.assert_array_equal(big_array, arr_reloaded)


@with_numpy
def test_memmap_persistence(tmpdir):
    rnd = np.random.RandomState(0)
    a = rnd.random_sample(10)
    filename = tmpdir.join('test1.pkl').strpath
    numpy_pickle.dump(a, filename)
    b = numpy_pickle.load(filename, mmap_mode='r')

    assert isinstance(b, np.memmap)

    # Test with an object containing multiple numpy arrays
    filename = tmpdir.join('test2.pkl').strpath
    obj = ComplexTestObject()
    numpy_pickle.dump(obj, filename)
    obj_loaded = numpy_pickle.load(filename, mmap_mode='r')
    assert isinstance(obj_loaded, type(obj))
    assert isinstance(obj_loaded.array_float, np.memmap)
    assert not obj_loaded.array_float.flags.writeable
    assert isinstance(obj_loaded.array_int, np.memmap)
    assert not obj_loaded.array_int.flags.writeable
    # Memory map not allowed for numpy object arrays
    assert not isinstance(obj_loaded.array_obj, np.memmap)
    np.testing.assert_array_equal(obj_loaded.array_float,
                                  obj.array_float)
    np.testing.assert_array_equal(obj_loaded.array_int,
                                  obj.array_int)
    np.testing.assert_array_equal(obj_loaded.array_obj,
                                  obj.array_obj)

    # Test we can write in memmapped arrays
    obj_loaded = numpy_pickle.load(filename, mmap_mode='r+')
    assert obj_loaded.array_float.flags.writeable
    obj_loaded.array_float[0:10] = 10.0
    assert obj_loaded.array_int.flags.writeable
    obj_loaded.array_int[0:10] = 10

    obj_reloaded = numpy_pickle.load(filename, mmap_mode='r')
    np.testing.assert_array_equal(obj_reloaded.array_float,
                                  obj_loaded.array_float)
    np.testing.assert_array_equal(obj_reloaded.array_int,
                                  obj_loaded.array_int)

    # Test w+ mode is caught and the mode has switched to r+
    numpy_pickle.load(filename, mmap_mode='w+')
    assert obj_loaded.array_int.flags.writeable
    assert obj_loaded.array_int.mode == 'r+'
    assert obj_loaded.array_float.flags.writeable
    assert obj_loaded.array_float.mode == 'r+'


@with_numpy
def test_memmap_persistence_mixed_dtypes(tmpdir):
    # loading datastructures that have sub-arrays with dtype=object
    # should not prevent memmapping on fixed size dtype sub-arrays.
    rnd = np.random.RandomState(0)
    a = rnd.random_sample(10)
    b = np.array([1, 'b'], dtype=object)
    construct = (a, b)
    filename = tmpdir.join('test.pkl').strpath
    numpy_pickle.dump(construct, filename)
    a_clone, b_clone = numpy_pickle.load(filename, mmap_mode='r')

    # the floating point array has been memory mapped
    assert isinstance(a_clone, np.memmap)

    # the object-dtype array has been loaded in memory
    assert not isinstance(b_clone, np.memmap)


@with_numpy
def test_masked_array_persistence(tmpdir):
    # The special-case picker fails, because saving masked_array
    # not implemented, but it just delegates to the standard pickler.
    rnd = np.random.RandomState(0)
    a = rnd.random_sample(10)
    a = np.ma.masked_greater(a, 0.5)
    filename = tmpdir.join('test.pkl').strpath
    numpy_pickle.dump(a, filename)
    b = numpy_pickle.load(filename, mmap_mode='r')
    assert isinstance(b, np.ma.masked_array)


@with_numpy
def test_compress_mmap_mode_warning(tmpdir):
    # Test the warning in case of compress + mmap_mode
    rnd = np.random.RandomState(0)
    a = rnd.random_sample(10)
    this_filename = tmpdir.join('test.pkl').strpath
    numpy_pickle.dump(a, this_filename, compress=1)
    with warns(UserWarning) as warninfo:
        numpy_pickle.load(this_filename, mmap_mode='r+')
    assert len(warninfo) == 1
    assert (str(warninfo[0].message) ==
            'mmap_mode "%(mmap_mode)s" is not compatible with compressed '
            'file %(filename)s. "%(mmap_mode)s" flag will be ignored.' %
            {'filename': this_filename, 'mmap_mode': 'r+'})


@with_numpy
@parametrize('cache_size', [None, 0, 10])
def test_cache_size_warning(tmpdir, cache_size):
    # Check deprecation warning raised when cache size is not None
    filename = tmpdir.join('test.pkl').strpath
    rnd = np.random.RandomState(0)
    a = rnd.random_sample((10, 2))

    warnings.simplefilter("always")
    with warns(None) as warninfo:
        numpy_pickle.dump(a, filename, cache_size=cache_size)
    expected_nb_warnings = 1 if cache_size is not None else 0
    assert len(warninfo) == expected_nb_warnings
    for w in warninfo:
        assert w.category == DeprecationWarning
        assert (str(w.message) ==
                "Please do not set 'cache_size' in joblib.dump, this "
                "parameter has no effect and will be removed. You "
                "used 'cache_size={0}'".format(cache_size))


@with_numpy
@with_memory_profiler
@parametrize('compress', [True, False])
def test_memory_usage(tmpdir, compress):
    # Verify memory stays within expected bounds.
    filename = tmpdir.join('test.pkl').strpath
    small_array = np.ones((10, 10))
    big_array = np.ones(shape=100 * int(1e6), dtype=np.uint8)
    small_matrix = np.matrix(small_array)
    big_matrix = np.matrix(big_array)

    for obj in (small_array, big_array, small_matrix, big_matrix):
        size = obj.nbytes / 1e6
        obj_filename = filename + str(np.random.randint(0, 1000))
        mem_used = memory_used(numpy_pickle.dump,
                               obj, obj_filename, compress=compress)

        # The memory used to dump the object shouldn't exceed the buffer
        # size used to write array chunks (16MB).
        write_buf_size = _IO_BUFFER_SIZE + 16 * 1024 ** 2 / 1e6
Loading ...