Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tvault-contego / home / tvault / .virtenv / lib / python2.7 / site-packages / dask / tests / test_threaded.py
Size: Mime:
import os
import sys
import signal
import threading
from multiprocessing.pool import ThreadPool
from time import time, sleep

import pytest

import dask
from dask.compatibility import PY2
from dask.threaded import get
from dask.utils_test import inc, add


def test_get():
    dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')}
    assert get(dsk, 'w') == 4
    assert get(dsk, ['w', 'z']) == (4, 2)


def test_nested_get():
    dsk = {'x': 1, 'y': 2, 'a': (add, 'x', 'y'), 'b': (sum, ['x', 'y'])}
    assert get(dsk, ['a', 'b']) == (3, 3)


def test_get_without_computation():
    dsk = {'x': 1}
    assert get(dsk, 'x') == 1


def test_broken_callback():
    from dask.callbacks import Callback

    def _f_ok(*args, **kwargs):
        pass

    def _f_broken(*args, **kwargs):
        raise ValueError('my_exception')

    dsk = {'x': 1}

    with Callback(start=_f_broken, finish=_f_ok):
        with Callback(start=_f_ok, finish=_f_ok):
            with pytest.raises(ValueError, match='my_exception'):
                get(dsk, 'x')


def bad(x):
    raise ValueError()


def test_exceptions_rise_to_top():
    dsk = {'x': 1, 'y': (bad, 'x')}
    pytest.raises(ValueError, lambda: get(dsk, 'y'))


def test_reuse_pool():
    pool = ThreadPool()
    with dask.config.set(pool=pool):
        assert get({'x': (inc, 1)}, 'x') == 2
        assert get({'x': (inc, 1)}, 'x') == 2


@pytest.mark.skipif(PY2, reason="threading API changed")
def test_pool_kwarg():
    def f():
        sleep(0.01)
        return threading.get_ident()

    dsk = {('x', i): (f,) for i in range(30)}
    dsk['x'] = (len, (set, [('x', i) for i in range(len(dsk))]))

    with ThreadPool(3) as pool:
        assert get(dsk, 'x', pool=pool) == 3


def test_threaded_within_thread():
    L = []

    def f(i):
        result = get({'x': (lambda: i,)}, 'x', num_workers=2)
        L.append(result)

    before = threading.active_count()

    for i in range(20):
        t = threading.Thread(target=f, args=(1,))
        t.daemon = True
        t.start()
        t.join()
        assert L == [1]
        del L[:]

    start = time()  # wait for most threads to join
    while threading.active_count() > before + 10:
        sleep(0.01)
        assert time() < start + 5


def test_dont_spawn_too_many_threads():
    before = threading.active_count()

    dsk = {('x', i): (lambda: i,) for i in range(10)}
    dsk['x'] = (sum, list(dsk))
    for i in range(20):
        get(dsk, 'x', num_workers=4)

    after = threading.active_count()

    assert after <= before + 8


def test_thread_safety():
    def f(x):
        return 1

    dsk = {'x': (sleep, 0.05), 'y': (f, 'x')}

    L = []

    def test_f():
        L.append(get(dsk, 'y'))

    threads = []
    for i in range(20):
        t = threading.Thread(target=test_f)
        t.daemon = True
        t.start()
        threads.append(t)

    for thread in threads:
        thread.join()

    assert L == [1] * 20


@pytest.mark.xfail('xdist' in sys.modules,
                   reason=("This test fails intermittently when using "
                           "pytest-xdist (maybe)"))
def test_interrupt():
    # Python 2 and windows 2 & 3 both implement `queue.get` using polling,
    # which means we can set an exception to interrupt the call to `get`.
    # Python 3 on other platforms requires sending SIGINT to the main thread.
    if PY2:
        from thread import interrupt_main
    elif os.name == 'nt':
        from _thread import interrupt_main
    else:
        main_thread = threading.get_ident()

        def interrupt_main():
            signal.pthread_kill(main_thread, signal.SIGINT)

    def long_task():
        sleep(5)

    dsk = {('x', i): (long_task,) for i in range(20)}
    dsk['x'] = (len, list(dsk.keys()))
    try:
        interrupter = threading.Timer(0.5, interrupt_main)
        interrupter.start()
        start = time()
        get(dsk, 'x')
    except KeyboardInterrupt:
        pass
    except Exception:
        assert False, "Failed to interrupt"
    stop = time()
    if stop - start > 4:
        assert False, "Failed to interrupt"