Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
import numpy as np
import pytest
import dask.array as da
from dask.array.reshape import contract_tuple, expand_tuple, reshape_rechunk
from dask.array.utils import assert_eq
@pytest.mark.parametrize(
"inshape,outshape,prechunks,inchunks,outchunks",
[
((4,), (4,), ((2, 2),), ((2, 2),), ((2, 2),)),
((4,), (2, 2), ((2, 2),), ((2, 2),), ((1, 1), (2,))),
((4,), (4, 1), ((2, 2),), ((2, 2),), ((2, 2), (1,))),
((4,), (1, 4), ((2, 2),), ((2, 2),), ((1,), (2, 2))),
((1, 4), (4,), ((1,), (2, 2)), ((1,), (2, 2)), ((2, 2),)),
((4, 1), (4,), ((2, 2), (1,)), ((2, 2), (1,)), ((2, 2),)),
(
(4, 1, 4),
(4, 4),
((2, 2), (1,), (2, 2)),
((2, 2), (1,), (2, 2)),
((2, 2), (2, 2)),
),
((4, 4), (4, 1, 4), ((2, 2), (2, 2)), ((2, 2), (2, 2)), ((2, 2), (1,), (2, 2))),
((2, 2), (4,), ((2,), (2,)), ((2,), (2,)), ((4,),)),
((2, 2), (4,), ((1, 1), (2,)), ((1, 1), (2,)), ((2, 2),)),
((2, 2), (4,), ((2,), (1, 1)), ((1, 1), (2,)), ((2, 2),)),
(
(64,),
(4, 4, 4),
((8, 8, 8, 8, 8, 8, 8, 8),),
((16, 16, 16, 16),),
((1, 1, 1, 1), (4,), (4,)),
),
((64,), (4, 4, 4), ((32, 32),), ((32, 32),), ((2, 2), (4,), (4,))),
((64,), (4, 4, 4), ((16, 48),), ((16, 48),), ((1, 3), (4,), (4,))),
((64,), (4, 4, 4), ((20, 44),), ((16, 48),), ((1, 3), (4,), (4,))),
(
(64, 4),
(8, 8, 4),
((16, 16, 16, 16), (2, 2)),
((16, 16, 16, 16), (2, 2)),
((2, 2, 2, 2), (8,), (2, 2)),
),
],
)
def test_reshape_rechunk(inshape, outshape, prechunks, inchunks, outchunks):
result_in, result_out = reshape_rechunk(inshape, outshape, prechunks)
assert result_in == inchunks
assert result_out == outchunks
assert np.prod(list(map(len, result_in))) == np.prod(list(map(len, result_out)))
def test_expand_tuple():
assert expand_tuple((2, 4), 2) == (1, 1, 2, 2)
assert expand_tuple((2, 4), 3) == (1, 1, 1, 1, 2)
assert expand_tuple((3, 4), 2) == (1, 2, 2, 2)
assert expand_tuple((7, 4), 3) == (2, 2, 3, 1, 1, 2)
def test_contract_tuple():
assert contract_tuple((1, 1, 2, 3, 1), 2) == (2, 2, 2, 2)
assert contract_tuple((1, 1, 2, 5, 1), 2) == (2, 2, 4, 2)
assert contract_tuple((2, 4), 2) == (2, 4)
assert contract_tuple((2, 4), 3) == (6,)
def test_reshape_unknown_sizes():
a = np.random.random((10, 6, 6))
A = da.from_array(a, chunks=(5, 2, 3))
a2 = a.reshape((60, -1))
A2 = A.reshape((60, -1))
assert A2.shape == (60, 6)
assert_eq(A2, a2)
with pytest.raises(ValueError):
a.reshape((60, -1, -1))
with pytest.raises(ValueError):
A.reshape((60, -1, -1))
@pytest.mark.parametrize(
"inshape, inchunks, outshape, outchunks",
[
# (2, 3, 4) -> (6, 4)
((2, 3, 4), ((1, 1), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (1, 2, 3, 4) -> (12, 4)
((1, 2, 3, 4), ((1,), (1, 1), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (2, 2, 3, 4) -> (12, 4)
(
(2, 2, 3, 4),
((1, 1), (1, 1), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
# (2, 2, 3, 4) -> (4, 3, 4)
(
(2, 2, 3, 4),
((1, 1), (1, 1), (1, 2), (2, 2)),
(4, 3, 4),
((1, 1, 1, 1), (1, 2), (2, 2)),
),
# (2, 2, 3, 4) -> (4, 3, 4)
((2, 2, 3, 4), ((1, 1), (2,), (1, 2), (4,)), (4, 3, 4), ((2, 2), (1, 2), (4,))),
# (2, 3, 4) -> (24,).
((2, 3, 4), ((1, 1), (1, 1, 1), (2, 2)), (24,), ((2,) * 12,)),
# (2, 3, 4) -> (2, 12)
((2, 3, 4), ((1, 1), (1, 1, 1), (4,)), (2, 12), ((1, 1), (4,) * 3)),
],
)
def test_reshape_all_chunked_no_merge(inshape, inchunks, outshape, outchunks):
# https://github.com/dask/dask/issues/5544#issuecomment-712280433
# When the early axes are completely chunked then we are just moving blocks
# and can avoid any rechunking. The result inchunks are the same as the
# input chunks.
base = np.arange(np.prod(inshape)).reshape(inshape)
a = da.from_array(base, chunks=inchunks)
# test directly
inchunks2, outchunks2 = reshape_rechunk(a.shape, outshape, inchunks)
assert inchunks2 == inchunks
assert outchunks2 == outchunks
# and via reshape
result = a.reshape(outshape)
assert result.chunks == outchunks
assert_eq(result, base.reshape(outshape))
@pytest.mark.parametrize(
"inshape, inchunks, expected_inchunks, outshape, outchunks",
[
# (2, 3, 4) -> (24,). This does merge, since the second dim isn't fully chunked!
((2, 3, 4), ((1, 1), (1, 2), (2, 2)), ((1, 1), (3,), (4,)), (24,), ((12, 12),)),
],
)
def test_reshape_all_not_chunked_merge(
inshape, inchunks, expected_inchunks, outshape, outchunks
):
base = np.arange(np.prod(inshape)).reshape(inshape)
a = da.from_array(base, chunks=inchunks)
# test directly
inchunks2, outchunks2 = reshape_rechunk(a.shape, outshape, inchunks)
assert inchunks2 == expected_inchunks
assert outchunks2 == outchunks
# and via reshape
result = a.reshape(outshape)
assert result.chunks == outchunks
assert_eq(result, base.reshape(outshape))
@pytest.mark.parametrize(
"inshape, inchunks, outshape, outchunks",
[
# (2, 3, 4) -> (6, 4)
((2, 3, 4), ((2,), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (1, 2, 3, 4) -> (12, 4)
((1, 2, 3, 4), ((1,), (2,), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (2, 2, 3, 4) -> (12, 4) (3 cases)
(
(2, 2, 3, 4),
((1, 1), (2,), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
(
(2, 2, 3, 4),
((2,), (1, 1), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
(
(2, 2, 3, 4),
((2,), (2,), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
# (2, 2, 3, 4) -> (4, 3, 4)
# TODO: I'm confused about the behavior in this case.
# (
# (2, 2, 3, 4),
# ((2,), (2,), (1, 2), (2, 2)),
# (4, 3, 4),
# ((1, 1, 1, 1), (1, 2), (2, 2)),
# ),
# (2, 2, 3, 4) -> (4, 3, 4)
((2, 2, 3, 4), ((2,), (2,), (1, 2), (4,)), (4, 3, 4), ((2, 2), (1, 2), (4,))),
],
)
def test_reshape_merge_chunks(inshape, inchunks, outshape, outchunks):
# https://github.com/dask/dask/issues/5544#issuecomment-712280433
# When the early axes are completely chunked then we are just moving blocks
# and can avoid any rechunking. The outchunks will always be ...
base = np.arange(np.prod(inshape)).reshape(inshape)
a = da.from_array(base, chunks=inchunks)
# and via reshape
result = a.reshape(outshape, merge_chunks=False)
assert result.chunks == outchunks
assert_eq(result, base.reshape(outshape))
assert result.chunks != a.reshape(outshape).chunks