Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

alkaline-ml / joblib   python

Repository URL to install this package:

/ test / test_parallel.py

"""
Test the parallel module.
"""

# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Copyright (c) 2010-2011 Gael Varoquaux
# License: BSD Style, 3 clauses.

import os
import sys
import time
import mmap
import threading
from traceback import format_exception
from math import sqrt
from time import sleep
from pickle import PicklingError
from multiprocessing import TimeoutError
import pytest

import joblib
from joblib import parallel
from joblib import dump, load
from joblib.externals.loky import get_reusable_executor

from joblib.test.common import np, with_numpy
from joblib.test.common import with_multiprocessing
from joblib.testing import (parametrize, raises, check_subprocess_call,
                            skipif, SkipTest, warns)
from joblib._compat import PY3_OR_LATER, PY27

try:
    import cPickle as pickle
except ImportError:
    import pickle

try:
    from queue import Queue
except ImportError:
    # Backward compat
    from Queue import Queue

try:
    import posix
except ImportError:
    posix = None

try:
    RecursionError
except NameError:
    RecursionError = RuntimeError

try:
    reload         # Python 2
except NameError:  # Python 3
    from importlib import reload

try:
    from ._openmp_test_helper.parallel_sum import parallel_sum
except ImportError:
    parallel_sum = None

try:
    import distributed
except ImportError:
    distributed = None

from joblib._parallel_backends import SequentialBackend
from joblib._parallel_backends import ThreadingBackend
from joblib._parallel_backends import MultiprocessingBackend
from joblib._parallel_backends import ParallelBackendBase
from joblib._parallel_backends import LokyBackend
from joblib._parallel_backends import SafeFunction

from joblib.parallel import Parallel, delayed
from joblib.parallel import register_parallel_backend, parallel_backend
from joblib.parallel import effective_n_jobs, cpu_count

from joblib.parallel import mp, BACKENDS, DEFAULT_BACKEND, EXTERNAL_BACKENDS
from joblib.my_exceptions import JoblibException
from joblib.my_exceptions import TransportableException
from joblib.my_exceptions import JoblibValueError
from joblib.my_exceptions import WorkerInterrupt


ALL_VALID_BACKENDS = [None] + sorted(BACKENDS.keys())
# Add instances of backend classes deriving from ParallelBackendBase
ALL_VALID_BACKENDS += [BACKENDS[backend_str]() for backend_str in BACKENDS]
PROCESS_BACKENDS = ['multiprocessing', 'loky']
PARALLEL_BACKENDS = PROCESS_BACKENDS + ['threading']

if hasattr(mp, 'get_context'):
    # Custom multiprocessing context in Python 3.4+
    ALL_VALID_BACKENDS.append(mp.get_context('spawn'))

DefaultBackend = BACKENDS[DEFAULT_BACKEND]


def get_workers(backend):
    return getattr(backend, '_pool', getattr(backend, '_workers', None))


def division(x, y):
    return x / y


def square(x):
    return x ** 2


class MyExceptionWithFinickyInit(Exception):
    """An exception class with non trivial __init__
    """
    def __init__(self, a, b, c, d):
        pass


def exception_raiser(x, custom_exception=False):
    if x == 7:
        raise (MyExceptionWithFinickyInit('a', 'b', 'c', 'd')
               if custom_exception else ValueError)
    return x


def interrupt_raiser(x):
    time.sleep(.05)
    raise KeyboardInterrupt


def f(x, y=0, z=0):
    """ A module-level function so that it can be spawn with
    multiprocessing.
    """
    return x ** 2 + y + z


def _active_backend_type():
    return type(parallel.get_active_backend()[0])


def parallel_func(inner_n_jobs, backend):
    return Parallel(n_jobs=inner_n_jobs, backend=backend)(
        delayed(square)(i) for i in range(3))


###############################################################################
def test_cpu_count():
    assert cpu_count() > 0


def test_effective_n_jobs():
    assert effective_n_jobs() > 0


###############################################################################
# Test parallel

@parametrize('backend', ALL_VALID_BACKENDS)
@parametrize('n_jobs', [1, 2, -1, -2])
@parametrize('verbose', [2, 11, 100])
def test_simple_parallel(backend, n_jobs, verbose):
    assert ([square(x) for x in range(5)] ==
            Parallel(n_jobs=n_jobs, backend=backend,
                     verbose=verbose)(
                delayed(square)(x) for x in range(5)))


@parametrize('backend', ALL_VALID_BACKENDS)
def test_main_thread_renamed_no_warning(backend, monkeypatch):
    # Check that no default backend relies on the name of the main thread:
    # https://github.com/joblib/joblib/issues/180#issuecomment-253266247
    # Some programs use a different name for the main thread. This is the case
    # for uWSGI apps for instance.
    monkeypatch.setattr(target=threading.current_thread(), name='name',
                        value='some_new_name_for_the_main_thread')

    with warns(None) as warninfo:
        results = Parallel(n_jobs=2, backend=backend)(
            delayed(square)(x) for x in range(3))
        assert results == [0, 1, 4]

    # Due to the default parameters of LokyBackend, there is a chance that
    # warninfo catches Warnings from worker timeouts. We remove it if it exists
    warninfo = [w for w in warninfo if "worker timeout" not in str(w.message)]

    # The multiprocessing backend will raise a warning when detecting that is
    # started from the non-main thread. Let's check that there is no false
    # positive because of the name change.
    assert len(warninfo) == 0


def _assert_warning_nested(backend, inner_n_jobs, expected):
    with warns(None) as records:
        parallel_func(backend=backend, inner_n_jobs=inner_n_jobs)

    if expected:
        # with threading, we might see more that one records
        if len(records) > 0:
            return 'backed parallel loops cannot' in records[0].message.args[0]
        return False
    else:
        assert len(records) == 0
        return True


@with_multiprocessing
@parametrize('parent_backend,child_backend,expected', [
    ('loky', 'multiprocessing', True), ('loky', 'loky', False),
    ('multiprocessing', 'multiprocessing', True),
    ('multiprocessing', 'loky', True),
    ('threading', 'multiprocessing', True),
    ('threading', 'loky', True),
])
def test_nested_parallel_warnings(parent_backend, child_backend, expected):

    # no warnings if inner_n_jobs=1
    Parallel(n_jobs=2, backend=parent_backend)(
        delayed(_assert_warning_nested)(
            backend=child_backend, inner_n_jobs=1,
            expected=False)
        for _ in range(5))

    #  warnings if inner_n_jobs != 1 and expected
    res = Parallel(n_jobs=2, backend=parent_backend)(
        delayed(_assert_warning_nested)(
            backend=child_backend, inner_n_jobs=2,
            expected=expected)
        for _ in range(5))

    # warning handling is not thread safe. One thread might see multiple
    # warning or no warning at all.
    if parent_backend == "threading":
        assert any(res)
    else:
        assert all(res)


@with_multiprocessing
@parametrize('backend', ['loky', 'multiprocessing', 'threading'])
def test_background_thread_parallelism(backend):
    is_run_parallel = [False]

    def background_thread(is_run_parallel):
        with warns(None) as records:
            Parallel(n_jobs=2)(
                delayed(sleep)(.1) for _ in range(4))
        print(len(records))
        is_run_parallel[0] = len(records) == 0

    t = threading.Thread(target=background_thread, args=(is_run_parallel,))
    t.start()
    t.join()
    assert is_run_parallel[0]


def nested_loop(backend):
    Parallel(n_jobs=2, backend=backend)(
        delayed(square)(.01) for _ in range(2))


@parametrize('child_backend', BACKENDS)
@parametrize('parent_backend', BACKENDS)
def test_nested_loop(parent_backend, child_backend):
    Parallel(n_jobs=2, backend=parent_backend)(
        delayed(nested_loop)(child_backend) for _ in range(2))


def raise_exception(backend):
    raise ValueError


def test_nested_loop_with_exception_with_loky():
    with raises(ValueError):
        with Parallel(n_jobs=2, backend="loky") as parallel:
            parallel([delayed(nested_loop)("loky"),
                      delayed(raise_exception)("loky")])


def test_mutate_input_with_threads():
    """Input is mutable when using the threading backend"""
    q = Queue(maxsize=5)
    Parallel(n_jobs=2, backend="threading")(
        delayed(q.put)(1) for _ in range(5))
    assert q.full()


@parametrize('n_jobs', [1, 2, 3])
def test_parallel_kwargs(n_jobs):
    """Check the keyword argument processing of pmap."""
    lst = range(10)
    assert ([f(x, y=1) for x in lst] ==
            Parallel(n_jobs=n_jobs)(delayed(f)(x, y=1) for x in lst))


@parametrize('backend', PARALLEL_BACKENDS)
def test_parallel_as_context_manager(backend):
    lst = range(10)
    expected = [f(x, y=1) for x in lst]

    with Parallel(n_jobs=4, backend=backend) as p:
        # Internally a pool instance has been eagerly created and is managed
        # via the context manager protocol
        managed_backend = p._backend

        # We make call with the managed parallel object several times inside
        # the managed block:
        assert expected == p(delayed(f)(x, y=1) for x in lst)
        assert expected == p(delayed(f)(x, y=1) for x in lst)

        # Those calls have all used the same pool instance:
        if mp is not None:
            assert get_workers(managed_backend) is get_workers(p._backend)

    # As soon as we exit the context manager block, the pool is terminated and
    # no longer referenced from the parallel object:
    if mp is not None:
        assert get_workers(p._backend) is None

    # It's still possible to use the parallel instance in non-managed mode:
    assert expected == p(delayed(f)(x, y=1) for x in lst)
    if mp is not None:
        assert get_workers(p._backend) is None


@with_multiprocessing
def test_parallel_pickling():
    """ Check that pmap captures the errors when it is passed an object
        that cannot be pickled.
    """
    class UnpicklableObject(object):
        def __reduce__(self):
            raise RuntimeError('123')

    with raises(PicklingError, match=r"the task to send"):
        Parallel(n_jobs=2)(delayed(id)(UnpicklableObject()) for _ in range(10))


@parametrize('backend', PARALLEL_BACKENDS)
def test_parallel_timeout_success(backend):
    # Check that timeout isn't thrown when function is fast enough
    assert len(Parallel(n_jobs=2, backend=backend, timeout=10)(
        delayed(sleep)(0.001) for x in range(10))) == 10


@with_multiprocessing
Loading ...