Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
import datetime
import functools
import operator
import pickle
from array import array
import pytest
from tlz import curry
from dask import get
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import SubgraphCallable
from dask.utils import (
Dispatch,
M,
SerializableLock,
_deprecated,
asciitable,
cached_cumsum,
derived_from,
ensure_bytes,
ensure_dict,
ensure_set,
ensure_unicode,
extra_titles,
format_bytes,
format_time,
funcname,
getargspec,
has_keyword,
is_arraylike,
itemgetter,
iter_chunks,
memory_repr,
methodcaller,
ndeepmap,
parse_bytes,
parse_timedelta,
partial_by_order,
random_state_data,
skip_doctest,
stringify,
stringify_collection_keys,
takes_multiple_arguments,
tmpfile,
typename,
)
from dask.utils_test import inc
def test_ensure_bytes():
data = [b"1", "1", memoryview(b"1"), bytearray(b"1"), array("B", b"1")]
for d in data:
result = ensure_bytes(d)
assert isinstance(result, bytes)
assert result == b"1"
def test_ensure_bytes_ndarray():
np = pytest.importorskip("numpy")
result = ensure_bytes(np.arange(12))
assert isinstance(result, bytes)
def test_ensure_bytes_pyarrow_buffer():
pa = pytest.importorskip("pyarrow")
buf = pa.py_buffer(b"123")
result = ensure_bytes(buf)
assert isinstance(result, bytes)
def test_ensure_unicode():
data = [b"1", "1", memoryview(b"1"), bytearray(b"1"), array("B", b"1")]
for d in data:
result = ensure_unicode(d)
assert isinstance(result, str)
assert result == "1"
def test_ensure_unicode_ndarray():
np = pytest.importorskip("numpy")
a = np.frombuffer(b"123", dtype="u1")
result = ensure_unicode(a)
assert isinstance(result, str)
assert result == "123"
def test_ensure_unicode_pyarrow_buffer():
pa = pytest.importorskip("pyarrow")
buf = pa.py_buffer(b"123")
result = ensure_unicode(buf)
assert isinstance(result, str)
assert result == "123"
def test_getargspec():
def func(x, y):
pass
assert getargspec(func).args == ["x", "y"]
func2 = functools.partial(func, 2)
# this is a bit of a lie, but maybe close enough
assert getargspec(func2).args == ["x", "y"]
def wrapper(*args, **kwargs):
pass
wrapper.__wrapped__ = func
assert getargspec(wrapper).args == ["x", "y"]
class MyType:
def __init__(self, x, y):
pass
assert getargspec(MyType).args == ["self", "x", "y"]
def test_takes_multiple_arguments():
assert takes_multiple_arguments(map)
assert not takes_multiple_arguments(sum)
def multi(a, b, c):
return a, b, c
class Singular:
def __init__(self, a):
pass
class Multi:
def __init__(self, a, b):
pass
assert takes_multiple_arguments(multi)
assert not takes_multiple_arguments(Singular)
assert takes_multiple_arguments(Multi)
def f():
pass
assert not takes_multiple_arguments(f)
def vararg(*args):
pass
assert takes_multiple_arguments(vararg)
assert not takes_multiple_arguments(vararg, varargs=False)
def test_dispatch():
foo = Dispatch()
foo.register(int, lambda a: a + 1)
foo.register(float, lambda a: a - 1)
foo.register(tuple, lambda a: tuple(foo(i) for i in a))
def f(a):
"""My Docstring"""
return a
foo.register(object, f)
class Bar:
pass
b = Bar()
assert foo(1) == 2
assert foo.dispatch(int)(1) == 2
assert foo(1.0) == 0.0
assert foo(b) == b
assert foo((1, 2.0, b)) == (2, 1.0, b)
assert foo.__doc__ == f.__doc__
def test_dispatch_kwargs():
foo = Dispatch()
foo.register(int, lambda a, b=10: a + b)
assert foo(1, b=20) == 21
def test_dispatch_variadic_on_first_argument():
foo = Dispatch()
foo.register(int, lambda a, b: a + b)
foo.register(float, lambda a, b: a - b)
assert foo(1, 2) == 3
assert foo(1.0, 2.0) == -1
def test_dispatch_lazy():
# this tests the recursive component of dispatch
foo = Dispatch()
foo.register(int, lambda a: a)
import decimal
# keep it outside lazy dec for test
def foo_dec(a):
return a + 1
@foo.register_lazy("decimal")
def register_decimal():
import decimal
foo.register(decimal.Decimal, foo_dec)
# This test needs to be *before* any other calls
assert foo.dispatch(decimal.Decimal) == foo_dec
assert foo(decimal.Decimal(1)) == decimal.Decimal(2)
assert foo(1) == 1
def test_dispatch_lazy_walks_mro():
"""Check that subclasses of classes with lazily registered handlers still
use their parent class's handler by default"""
import decimal
class Lazy(decimal.Decimal):
pass
class Eager(Lazy):
pass
foo = Dispatch()
@foo.register(Eager)
def eager_handler(x):
return "eager"
def lazy_handler(a):
return "lazy"
@foo.register_lazy("decimal")
def register_decimal():
foo.register(decimal.Decimal, lazy_handler)
assert foo.dispatch(Lazy) == lazy_handler
assert foo(Lazy(1)) == "lazy"
assert foo.dispatch(decimal.Decimal) == lazy_handler
assert foo(decimal.Decimal(1)) == "lazy"
assert foo.dispatch(Eager) == eager_handler
assert foo(Eager(1)) == "eager"
def test_random_state_data():
np = pytest.importorskip("numpy")
seed = 37
state = np.random.RandomState(seed)
n = 10000
# Use an integer
states = random_state_data(n, seed)
assert len(states) == n
# Use RandomState object
states2 = random_state_data(n, state)
for s1, s2 in zip(states, states2):
assert s1.shape == (624,)
assert (s1 == s2).all()
# Consistent ordering
states = random_state_data(10, 1234)
states2 = random_state_data(20, 1234)[:10]
for s1, s2 in zip(states, states2):
assert (s1 == s2).all()
def test_memory_repr():
for power, mem_repr in enumerate(["1.0 bytes", "1.0 KB", "1.0 MB", "1.0 GB"]):
assert memory_repr(1024**power) == mem_repr
def test_method_caller():
a = [1, 2, 3, 3, 3]
f = methodcaller("count")
assert f(a, 3) == a.count(3)
assert methodcaller("count") is f
assert M.count is f
assert pickle.loads(pickle.dumps(f)) is f
assert "count" in dir(M)
assert "count" in str(methodcaller("count"))
assert "count" in repr(methodcaller("count"))
def test_skip_doctest():
example = """>>> xxx
>>>
>>> # comment
>>> xxx"""
res = skip_doctest(example)
assert (
res
== """>>> xxx # doctest: +SKIP
>>>
>>> # comment
>>> xxx # doctest: +SKIP"""
)
assert skip_doctest(None) == ""
example = """
>>> 1 + 2 # doctest: +ELLIPSES
3"""
expected = """
>>> 1 + 2 # doctest: +ELLIPSES, +SKIP
3"""
res = skip_doctest(example)
assert res == expected
def test_extra_titles():
example = """
Notes
-----
hello
Foo
---
Notes
-----
bar
"""
expected = """
Notes
-----
hello
Foo
---
Extra Notes
-----------
bar
"""
assert extra_titles(example) == expected
def test_asciitable():
res = asciitable(
["fruit", "color"],
[("apple", "red"), ("banana", "yellow"), ("tomato", "red"), ("pear", "green")],
)
assert res == (
"+--------+--------+\n"
"| fruit | color |\n"
"+--------+--------+\n"
"| apple | red |\n"
"| banana | yellow |\n"
"| tomato | red |\n"
"| pear | green |\n"
"+--------+--------+"
)
def test_SerializableLock():
a = SerializableLock()
b = SerializableLock()
with a:
pass
with a:
with b:
pass
with a:
assert not a.acquire(False)
a2 = pickle.loads(pickle.dumps(a))
a3 = pickle.loads(pickle.dumps(a))
a4 = pickle.loads(pickle.dumps(a2))
for x in [a, a2, a3, a4]:
for y in [a, a2, a3, a4]:
with x:
assert not y.acquire(False)
b2 = pickle.loads(pickle.dumps(b))
b3 = pickle.loads(pickle.dumps(b2))
for x in [a, a2, a3, a4]:
for y in [b, b2, b3]:
with x:
with y:
pass
with y:
with x:
pass
def test_SerializableLock_name_collision():
a = SerializableLock("a")
b = SerializableLock("b")
c = SerializableLock("a")
d = SerializableLock()
assert a.lock is not b.lock
assert a.lock is c.lock
assert d.lock not in (a.lock, b.lock, c.lock)
def test_SerializableLock_locked():
a = SerializableLock("a")
assert not a.locked()
with a:
assert a.locked()
assert not a.locked()
def test_SerializableLock_acquire_blocking():
a = SerializableLock("a")
assert a.acquire(blocking=True)
assert not a.acquire(blocking=False)
a.release()
def test_funcname():
def foo(a, b, c):
pass
assert funcname(foo) == "foo"
assert funcname(functools.partial(foo, a=1)) == "foo"
assert funcname(M.sum) == "sum"
assert funcname(lambda: 1) == "lambda"
class Foo:
pass
assert funcname(Foo) == "Foo"
assert "Foo" in funcname(Foo())
def test_funcname_long():
def a_long_function_name_11111111111111111111111111111111111111111111111():
pass
result = funcname(
a_long_function_name_11111111111111111111111111111111111111111111111
)
assert "a_long_function_name" in result
assert len(result) < 60
def test_funcname_toolz():
@curry
def foo(a, b, c):
pass
assert funcname(foo) == "foo"
assert funcname(foo(1)) == "foo"
def test_funcname_multipledispatch():
md = pytest.importorskip("multipledispatch")
@md.dispatch(int, int, int)
def foo(a, b, c):
pass
assert funcname(foo) == "foo"
assert funcname(functools.partial(foo, a=1)) == "foo"
def test_funcname_numpy_vectorize():
np = pytest.importorskip("numpy")
vfunc = np.vectorize(int)
assert funcname(vfunc) == "vectorize_int"
# Regression test for https://github.com/pydata/xarray/issues/3303
# Partial functions don't have a __name__ attribute
func = functools.partial(np.add, out=None)
vfunc = np.vectorize(func)
assert funcname(vfunc) == "vectorize_add"
def test_ndeepmap():
L = 1
assert ndeepmap(0, inc, L) == 2
L = [1]
assert ndeepmap(0, inc, L) == 2
L = [1, 2, 3]
assert ndeepmap(1, inc, L) == [2, 3, 4]
L = [[1, 2], [3, 4]]
assert ndeepmap(2, inc, L) == [[2, 3], [4, 5]]
L = [[[1, 2], [3, 4, 5]], [[6], []]]
assert ndeepmap(3, inc, L) == [[[2, 3], [4, 5, 6]], [[7], []]]
def test_ensure_dict():
d = {"x": 1}
assert ensure_dict(d) is d
class mydict(dict):
pass
d2 = ensure_dict(d, copy=True)
d3 = ensure_dict(HighLevelGraph.from_collections("x", d))
d4 = ensure_dict(mydict(d))
for di in (d2, d3, d4):
assert type(di) is dict
assert di is not d
assert di == d
def test_ensure_set():
s = {1}
assert ensure_set(s) is s
class myset(set):
pass
s2 = ensure_set(s, copy=True)
s3 = ensure_set(myset(s))
for si in (s2, s3):
assert type(si) is set
assert si is not s
assert si == s
def test_itemgetter():
data = [1, 2, 3]
g = itemgetter(1)
assert g(data) == 2
g2 = pickle.loads(pickle.dumps(g))
assert g2(data) == 2
assert g2.index == 1
assert itemgetter(1) == itemgetter(1)
assert itemgetter(1) != itemgetter(2)
assert itemgetter(1) != 123
def test_partial_by_order():
assert partial_by_order(5, function=operator.add, other=[(1, 20)]) == 25
def test_has_keyword():
def foo(a, b, c=None):
pass
assert has_keyword(foo, "a")
assert has_keyword(foo, "b")
assert has_keyword(foo, "c")
bar = functools.partial(foo, a=1)
assert has_keyword(bar, "b")
assert has_keyword(bar, "c")
def test_derived_from():
class Foo:
def f(a, b):
"""A super docstring
An explanation
Parameters
----------
a: int
an explanation of a
b: float
an explanation of b
"""
class Bar:
@derived_from(Foo)
def f(a, c):
pass
class Zap:
@derived_from(Foo)
def f(a, c):
"extra docstring"
pass
assert Bar.f.__doc__.strip().startswith("A super docstring")
assert "Foo.f" in Bar.f.__doc__
assert any("inconsistencies" in line for line in Bar.f.__doc__.split("\n")[:7])
[b_arg] = [line for line in Bar.f.__doc__.split("\n") if "b:" in line]
assert "not supported" in b_arg.lower()
assert "dask" in b_arg.lower()
assert " extra docstring\n\n" in Zap.f.__doc__
def test_derived_from_func():
import builtins
@derived_from(builtins)
def sum():
"extra docstring"
pass
assert "extra docstring\n\n" in sum.__doc__
assert "Return the sum of" in sum.__doc__
assert "This docstring was copied from builtins.sum" in sum.__doc__
def test_derived_from_dask_dataframe():
dd = pytest.importorskip("dask.dataframe")
assert "inconsistencies" in dd.DataFrame.dropna.__doc__
[axis_arg] = [
line for line in dd.DataFrame.dropna.__doc__.split("\n") if "axis :" in line
]
assert "not supported" in axis_arg.lower()
assert "dask" in axis_arg.lower()
assert "Object with missing values filled" in dd.DataFrame.ffill.__doc__
def test_parse_bytes():
assert parse_bytes("100") == 100
assert parse_bytes("100 MB") == 100000000
assert parse_bytes("100M") == 100000000
assert parse_bytes("5kB") == 5000
assert parse_bytes("5.4 kB") == 5400
assert parse_bytes("1kiB") == 1024
assert parse_bytes("1Mi") == 2**20
assert parse_bytes("1e6") == 1000000
assert parse_bytes("1e6 kB") == 1000000000
assert parse_bytes("MB") == 1000000
assert parse_bytes(123) == 123
assert parse_bytes(".5GB") == 500000000
def test_parse_timedelta():
for text, value in [
("1s", 1),
("100ms", 0.1),
("5S", 5),
("5.5s", 5.5),
("5.5 s", 5.5),
("1 second", 1),
("3.3 seconds", 3.3),
("3.3 milliseconds", 0.0033),
("3500 us", 0.0035),
("1 ns", 1e-9),
("2m", 120),
("5 days", 5 * 24 * 60 * 60),
("2 w", 2 * 7 * 24 * 60 * 60),
("2 minutes", 120),
(None, None),
(3, 3),
(datetime.timedelta(seconds=2), 2),
(datetime.timedelta(milliseconds=100), 0.1),
]:
result = parse_timedelta(text)
assert result == value or abs(result - value) < 1e-14
assert parse_timedelta("1ms", default="seconds") == 0.001
assert parse_timedelta("1", default="seconds") == 1
assert parse_timedelta("1", default="ms") == 0.001
assert parse_timedelta(1, default="ms") == 0.001
assert parse_timedelta("1ms", default=False) == 0.001
with pytest.raises(ValueError):
parse_timedelta(1, default=False)
with pytest.raises(ValueError):
parse_timedelta("1", default=False)
with pytest.raises(TypeError):
parse_timedelta("1", default=None)
def test_is_arraylike():
np = pytest.importorskip("numpy")
assert is_arraylike(0) is False
assert is_arraylike(()) is False
assert is_arraylike(0) is False
assert is_arraylike([]) is False
assert is_arraylike([0]) is False
assert is_arraylike(np.empty(())) is True
assert is_arraylike(np.empty((0,))) is True
assert is_arraylike(np.empty((0, 0))) is True
def test_iter_chunks():
sizes = [14, 8, 5, 9, 7, 9, 1, 19, 8, 19]
assert list(iter_chunks(sizes, 19)) == [
[14],
[8, 5],
[9, 7],
[9, 1],
[19],
[8],
[19],
]
assert list(iter_chunks(sizes, 28)) == [[14, 8, 5], [9, 7, 9, 1], [19, 8], [19]]
assert list(iter_chunks(sizes, 67)) == [[14, 8, 5, 9, 7, 9, 1], [19, 8, 19]]
def test_stringify():
obj = "Hello"
assert stringify(obj) is obj
obj = b"Hello"
assert stringify(obj) is obj
dsk = {"x": 1}
assert stringify(dsk) == str(dsk)
assert stringify(dsk, exclusive=()) == dsk
dsk = {("x", 1): (inc, 1)}
assert stringify(dsk) == str({("x", 1): (inc, 1)})
assert stringify(dsk, exclusive=()) == {("x", 1): (inc, 1)}
dsk = {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}
assert stringify(dsk, exclusive=dsk) == {
("x", 1): (inc, 1),
("x", 2): (inc, str(("x", 1))),
}
dsks = [
{"x": 1},
{("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))},
{("x", 1): (sum, [1, 2, 3]), ("x", 2): (sum, [("x", 1), ("x", 1)])},
]
for dsk in dsks:
sdsk = {stringify(k): stringify(v, exclusive=dsk) for k, v in dsk.items()}
keys = list(dsk)
skeys = [str(k) for k in keys]
assert all(isinstance(k, str) for k in sdsk)
assert get(dsk, keys) == get(sdsk, skeys)
dsk = {("y", 1): (SubgraphCallable({"x": ("y", 1)}, "x", (("y", 1),)), (("z", 1),))}
dsk = stringify(dsk, exclusive=set(dsk) | {("z", 1)})
assert dsk[("y", 1)][0].dsk["x"] == "('y', 1)"
assert dsk[("y", 1)][1][0] == "('z', 1)"
def test_stringify_collection_keys():
obj = "Hello"
assert stringify_collection_keys(obj) is obj
obj = [("a", 0), (b"a", 0), (1, 1)]
res = stringify_collection_keys(obj)
assert res[0] == str(obj[0])
assert res[1] == str(obj[1])
assert res[2] == obj[2]
@pytest.mark.parametrize(
"n,expect",
[
(0, "0 B"),
(920, "920 B"),
(930, "0.91 kiB"),
(921.23 * 2**10, "921.23 kiB"),
(931.23 * 2**10, "0.91 MiB"),
(921.23 * 2**20, "921.23 MiB"),
(931.23 * 2**20, "0.91 GiB"),
(921.23 * 2**30, "921.23 GiB"),
(931.23 * 2**30, "0.91 TiB"),
(921.23 * 2**40, "921.23 TiB"),
(931.23 * 2**40, "0.91 PiB"),
(2**60, "1024.00 PiB"),
],
)
def test_format_bytes(n, expect):
assert format_bytes(int(n)) == expect
def test_format_time():
assert format_time(1.4) == "1.40 s"
assert format_time(10.4) == "10.40 s"
assert format_time(100.4) == "100.40 s"
assert format_time(1000.4) == "16m 40s"
assert format_time(10000.4) == "2hr 46m"
assert format_time(1234.567) == "20m 34s"
assert format_time(12345.67) == "3hr 25m"
assert format_time(123456.78) == "34hr 17m"
assert format_time(1234567.8) == "14d 6hr"
def test_deprecated():
@_deprecated()
def foo():
return "bar"
with pytest.warns(FutureWarning) as record:
assert foo() == "bar"
assert len(record) == 1
msg = str(record[0].message)
assert "foo is deprecated" in msg
assert "removed in a future release" in msg
def test_deprecated_version():
@_deprecated(version="1.2.3")
def foo():
return "bar"
with pytest.warns(FutureWarning, match="deprecated in version 1.2.3"):
assert foo() == "bar"
def test_deprecated_after_version():
@_deprecated(after_version="1.2.3")
def foo():
return "bar"
with pytest.warns(FutureWarning, match="deprecated after version 1.2.3"):
assert foo() == "bar"
def test_deprecated_category():
@_deprecated(category=DeprecationWarning)
def foo():
return "bar"
with pytest.warns(DeprecationWarning):
assert foo() == "bar"
def test_deprecated_message():
@_deprecated(message="woohoo")
def foo():
return "bar"
with pytest.warns(FutureWarning) as record:
assert foo() == "bar"
assert len(record) == 1
assert str(record[0].message) == "woohoo"
def test_typename():
assert typename(HighLevelGraph) == "dask.highlevelgraph.HighLevelGraph"
assert typename(HighLevelGraph, short=True) == "dask.HighLevelGraph"
class MyType:
pass
def test_typename_on_instances():
instance = MyType()
assert typename(instance) == typename(MyType)
def test_cached_cumsum():
a = (1, 2, 3, 4)
x = cached_cumsum(a)
y = cached_cumsum(a, initial_zero=True)
assert x == (1, 3, 6, 10)
assert y == (0, 1, 3, 6, 10)
def test_cached_cumsum_nan():
np = pytest.importorskip("numpy")
a = (1, np.nan, 3)
x = cached_cumsum(a)
y = cached_cumsum(a, initial_zero=True)
np.testing.assert_equal(x, (1, np.nan, np.nan))
np.testing.assert_equal(y, (0, 1, np.nan, np.nan))
def test_cached_cumsum_non_tuple():
a = [1, 2, 3]
assert cached_cumsum(a) == (1, 3, 6)
a[1] = 4
assert cached_cumsum(a) == (1, 5, 8)
def test_tmpfile_naming():
with tmpfile() as fn:
# Do not end file or directory name with a period.
# This causes issues on Windows.
assert fn[-1] != "."
with tmpfile(extension="jpg") as fn:
assert fn[-4:] == ".jpg"
with tmpfile(extension=".jpg") as fn:
assert fn[-4:] == ".jpg"
assert fn[-5] != "."