Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
import pickle
from collections import namedtuple
import pytest
from dask.core import (
flatten,
get,
get_dependencies,
get_deps,
getcycle,
has_tasks,
istask,
literal,
preorder_traversal,
quote,
subs,
)
from dask.utils_test import GetFunctionTestMixin, add, inc
def contains(a, b):
"""
>>> contains({'x': 1, 'y': 2}, {'x': 1})
True
>>> contains({'x': 1, 'y': 2}, {'z': 3})
False
"""
return all(a.get(k) == v for k, v in b.items())
def test_istask():
assert istask((inc, 1))
assert not istask(1)
assert not istask((1, 2))
f = namedtuple("f", ["x", "y"])
assert not istask(f(sum, 2))
def test_has_tasks():
dsk = {
"a": [1, 2, 3],
"b": "a",
"c": [1, (inc, 1)],
"d": [(sum, "a")],
"e": ["a", "b"],
"f": [["a", "b"], 2, 3],
}
assert not has_tasks(dsk, dsk["a"])
assert has_tasks(dsk, dsk["b"])
assert has_tasks(dsk, dsk["c"])
assert has_tasks(dsk, dsk["d"])
assert has_tasks(dsk, dsk["e"])
assert has_tasks(dsk, dsk["f"])
def test_preorder_traversal():
t = (add, 1, 2)
assert list(preorder_traversal(t)) == [add, 1, 2]
t = (add, (add, 1, 2), (add, 3, 4))
assert list(preorder_traversal(t)) == [add, add, 1, 2, add, 3, 4]
t = (add, (sum, [1, 2]), 3)
assert list(preorder_traversal(t)) == [add, sum, list, 1, 2, 3]
class TestGet(GetFunctionTestMixin):
get = staticmethod(get)
def test_GetFunctionTestMixin_class():
class TestCustomGetFail(GetFunctionTestMixin):
get = staticmethod(lambda x, y: 1)
custom_testget = TestCustomGetFail()
pytest.raises(AssertionError, custom_testget.test_get)
class TestCustomGetPass(GetFunctionTestMixin):
get = staticmethod(get)
custom_testget = TestCustomGetPass()
custom_testget.test_get()
def test_get_dependencies_nested():
dsk = {"x": 1, "y": 2, "z": (add, (inc, [["x"]]), "y")}
assert get_dependencies(dsk, "z") == {"x", "y"}
assert sorted(get_dependencies(dsk, "z", as_list=True)) == ["x", "y"]
def test_get_dependencies_empty():
dsk = {"x": (inc,)}
assert get_dependencies(dsk, "x") == set()
assert get_dependencies(dsk, "x", as_list=True) == []
def test_get_dependencies_list():
dsk = {"x": 1, "y": 2, "z": ["x", [(inc, "y")]]}
assert get_dependencies(dsk, "z") == {"x", "y"}
assert sorted(get_dependencies(dsk, "z", as_list=True)) == ["x", "y"]
def test_get_dependencies_task():
dsk = {"x": 1, "y": 2, "z": ["x", [(inc, "y")]]}
assert get_dependencies(dsk, task=(inc, "x")) == {"x"}
assert get_dependencies(dsk, task=(inc, "x"), as_list=True) == ["x"]
def test_get_dependencies_nothing():
with pytest.raises(ValueError):
get_dependencies({})
def test_get_dependencies_many():
dsk = {
"a": [1, 2, 3],
"b": "a",
"c": [1, (inc, 1)],
"d": [(sum, "c")],
"e": ["a", "b", "zzz"],
"f": [["a", "b"], 2, 3],
}
tasks = [dsk[k] for k in ("d", "f")]
s = get_dependencies(dsk, task=tasks)
assert s == {"a", "b", "c"}
s = get_dependencies(dsk, task=tasks, as_list=True)
assert sorted(s) == ["a", "b", "c"]
s = get_dependencies(dsk, task=[])
assert s == set()
s = get_dependencies(dsk, task=[], as_list=True)
assert s == []
def test_get_dependencies_task_none():
# Regression test for https://github.com/dask/distributed/issues/2756
dsk = {"foo": None}
assert get_dependencies(dsk, task=dsk["foo"]) == set()
def test_get_deps():
dsk = {
"a": [1, 2, 3],
"b": "a",
"c": [1, (inc, 1)],
"d": [(sum, "c")],
"e": ["b", "zzz", "b"],
"f": [["a", "b"], 2, 3],
}
dependencies, dependents = get_deps(dsk)
assert dependencies == {
"a": set(),
"b": {"a"},
"c": set(),
"d": {"c"},
"e": {"b"},
"f": {"a", "b"},
}
assert dependents == {
"a": {"b", "f"},
"b": {"e", "f"},
"c": {"d"},
"d": set(),
"e": set(),
"f": set(),
}
def test_flatten():
assert list(flatten(())) == []
assert list(flatten("foo")) == ["foo"]
def test_subs():
assert subs((sum, [1, "x"]), "x", 2) == (sum, [1, 2])
assert subs((sum, [1, ["x"]]), "x", 2) == (sum, [1, [2]])
class MutateOnEq:
hit_eq = 0
def __eq__(self, other):
self.hit_eq += 1
return False
def test_subs_no_key_data_eq():
# Numpy throws a deprecation warning on bool(array == scalar), which
# pollutes the terminal. This test checks that `subs` never tries to
# compare keys (scalars) with values (which could be arrays)`subs` never
# tries to compare keys (scalars) with values (which could be arrays).
a = MutateOnEq()
subs(a, "x", 1)
assert a.hit_eq == 0
subs((add, a, "x"), "x", 1)
assert a.hit_eq == 0
def test_subs_with_unfriendly_eq():
try:
import numpy as np
except ImportError:
return
else:
task = (np.sum, np.array([1, 2]))
assert (subs(task, (4, 5), 1) == task) is True
class MyException(Exception):
pass
class F:
def __eq__(self, other):
raise MyException()
task = F()
assert subs(task, 1, 2) is task
def test_subs_with_surprisingly_friendly_eq():
try:
import pandas as pd
except ImportError:
return
else:
df = pd.DataFrame()
assert subs(df, "x", 1) is df
def test_subs_unexpected_hashable_key():
class UnexpectedButHashable:
def __init__(self):
self.name = "a"
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
return isinstance(other, UnexpectedButHashable)
assert subs((id, UnexpectedButHashable()), UnexpectedButHashable(), 1) == (id, 1)
def test_quote():
literals = [[1, 2, 3], (add, 1, 2), [1, [2, 3]], (add, 1, (add, 2, 3)), {"x": "x"}]
for l in literals:
assert get({"x": quote(l)}, "x") == l
def test_literal_serializable():
l = literal((add, 1, 2))
assert pickle.loads(pickle.dumps(l)).data == (add, 1, 2)
def test_getcycle():
dsk = {
0: [7, 13, 7, 9, 13, 3, 9, 18, 18, 17],
1: [14, 14, 12, 1, 9, 16, 4, 5, 9, 8],
2: [3, 1, 7, 7, 2, 0, 0, 6, 3, 2],
3: [4, 8, 3, 14, 15, 19, 14, 1, 9, 1],
4: [9, 13, 19, 4, 16, 8, 11, 1, 16, 1],
5: [9, 8, 12, 13, 10, 13, 19, 3, 18, 18],
6: [10, 2, 13, 16, 3, 12, 7, 16, 5, 17],
7: [16, 8, 6, 4, 4, 10, 3, 1, 13, 10],
8: [11, 4, 12, 10, 14, 6, 18, 15, 16, 12],
9: [17, 18, 5, 16, 19, 16, 3, 6, 16, 18],
}
assert len(getcycle(dsk, list(dsk))) <= 3 # 7->6->7
dsk = {
0: [1, 27, 6, 25, 30, 16, 1, 1, 35, 17],
1: [35, 22, 21, 31, 2, 28, 36, 0, 29, 29],
2: [14, 27, 10, 1, 38, 18, 28, 28, 6, 0],
3: [0, 37, 7, 16, 38, 28, 34, 13, 30, 10],
4: [22, 22, 27, 13, 29, 36, 22, 9, 39, 19],
5: [38, 7, 18, 17, 33, 5, 29, 11, 23, 30],
6: [3, 30, 28, 38, 4, 13, 14, 27, 29, 38],
7: [22, 27, 12, 2, 22, 16, 34, 15, 18, 16],
8: [36, 21, 24, 22, 28, 38, 21, 2, 4, 24],
9: [38, 32, 38, 7, 31, 34, 39, 20, 30, 18],
}
assert len(getcycle(dsk, list(dsk))) <= 4 # 0->1->2->0