Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
from __future__ import annotations
from itertools import zip_longest
from numbers import Integral
from typing import Any, Callable
import numpy as np
from dask import config
from dask.array.chunk import getitem
from dask.array.core import getter, getter_inline, getter_nofancy
from dask.blockwise import fuse_roots, optimize_blockwise
from dask.core import flatten, reverse_dict
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import SubgraphCallable, fuse, inline_functions
from dask.utils import ensure_dict
# All get* functions the optimizations know about
GETTERS = (getter, getter_nofancy, getter_inline, getitem)
# These get* functions aren't ever completely removed from the graph,
# even if the index should be a no-op by numpy semantics. Some array-like's
# don't completely follow semantics, making indexing always necessary.
GETNOREMOVE = (getter, getter_nofancy)
def optimize(
dsk,
keys,
fuse_keys=None,
fast_functions=None,
inline_functions_fast_functions=(getter_inline,),
rename_fused_keys=True,
**kwargs,
):
"""Optimize dask for array computation
1. Cull tasks not necessary to evaluate keys
2. Remove full slicing, e.g. x[:]
3. Inline fast functions like getitem and np.transpose
"""
if not isinstance(keys, (list, set)):
keys = [keys]
keys = list(flatten(keys))
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
dsk = optimize_blockwise(dsk, keys=keys)
dsk = fuse_roots(dsk, keys=keys)
dsk = dsk.cull(set(keys))
# Perform low-level fusion unless the user has
# specified False explicitly.
if config.get("optimization.fuse.active") is False:
return dsk
dependencies = dsk.get_all_dependencies()
dsk = ensure_dict(dsk)
# Low level task optimizations
if fast_functions is not None:
inline_functions_fast_functions = fast_functions
hold = hold_keys(dsk, dependencies)
dsk, dependencies = fuse(
dsk,
hold + keys + (fuse_keys or []),
dependencies,
rename_keys=rename_fused_keys,
)
if inline_functions_fast_functions:
dsk = inline_functions(
dsk,
keys,
dependencies=dependencies,
fast_functions=inline_functions_fast_functions,
)
return optimize_slices(dsk)
def hold_keys(dsk, dependencies):
"""Find keys to avoid fusion
We don't want to fuse data present in the graph because it is easier to
serialize as a raw value.
We don't want to fuse chains after getitem/GETTERS because we want to
move around only small pieces of data, rather than the underlying arrays.
"""
dependents = reverse_dict(dependencies)
data = {k for k, v in dsk.items() if type(v) not in (tuple, str)}
hold_keys = list(data)
for dat in data:
deps = dependents[dat]
for dep in deps:
task = dsk[dep]
# If the task is a get* function, we walk up the chain, and stop
# when there's either more than one dependent, or the dependent is
# no longer a get* function or an alias. We then add the final
# key to the list of keys not to fuse.
if _is_getter_task(task):
try:
while len(dependents[dep]) == 1:
new_dep = next(iter(dependents[dep]))
new_task = dsk[new_dep]
# If the task is a get* or an alias, continue up the
# linear chain
if _is_getter_task(new_task) or new_task in dsk:
dep = new_dep
else:
break
except (IndexError, TypeError):
pass
hold_keys.append(dep)
return hold_keys
def _is_getter_task(
value,
) -> tuple[Callable, Any, Any, bool, bool | None] | None:
"""Check if a value in a Dask graph looks like a getter.
1. Is it a tuple with the first element a known getter.
2. Is it a SubgraphCallable with a single element in its
dsk which is a known getter.
If a getter is found, it returns a tuple with (getter, array, index, asarray, lock).
Otherwise it returns ``None``.
TODO: the second check is a hack to allow for slice fusion between tasks produced
from blockwise layers and slicing operations. Once slicing operations have
HighLevelGraph layers which can talk to Blockwise layers this check *should* be
removed, and we should not have to introspect SubgraphCallables.
"""
if type(value) is not tuple:
return None
first = value[0]
get: Callable | None = None
if first in GETTERS:
get = first
# We only accept SubgraphCallables with a single sub-task right now as it's
# not clear which task to inspect if there is more than one, or how to resolve
# conflicts if they occur.
elif isinstance(first, SubgraphCallable) and len(first.dsk) == 1:
v = next(iter(first.dsk.values()))
if type(v) is tuple and len(v) > 1 and v[0] in GETTERS:
get = v[0]
if get is None: # Didn't find a getter
return None
length = len(value)
if length == 3:
# getter defaults to asarray=True, getitem is semantically False
return get, value[1], value[2], get is not getitem, None
elif length == 5:
return get, *value[1:] # type: ignore
return None
def optimize_slices(dsk):
"""Optimize slices
1. Fuse repeated slices, like x[5:][2:6] -> x[7:11]
2. Remove full slices, like x[:] -> x
See also:
fuse_slice_dict
"""
fancy_ind_types = (list, np.ndarray)
dsk = dsk.copy()
for k, v in dsk.items():
if a_task := _is_getter_task(v):
get, a, a_index, a_asarray, a_lock = a_task
while b_task := _is_getter_task(a):
f2, b, b_index, b_asarray, b_lock = b_task
if a_lock and a_lock is not b_lock:
break
if (type(a_index) is tuple) != (type(b_index) is tuple):
break
if type(a_index) is tuple:
indices = b_index + a_index
if len(a_index) != len(b_index) and any(i is None for i in indices):
break
if f2 is getter_nofancy and any(
isinstance(i, fancy_ind_types) for i in indices
):
break
elif f2 is getter_nofancy and (
type(a_index) in fancy_ind_types or type(b_index) in fancy_ind_types
):
break
try:
c_index = fuse_slice(b_index, a_index)
# rely on fact that nested gets never decrease in
# strictness e.g. `(getter_nofancy, (getter, ...))` never
# happens
get = getter if f2 is getter_inline else f2
except NotImplementedError:
break
a, a_index, a_lock = b, c_index, b_lock
a_asarray |= b_asarray
# Skip the get call if not from from_array and nothing to do
if get not in GETNOREMOVE and (
(
type(a_index) is slice
and not a_index.start
and a_index.stop is None
and a_index.step is None
)
or (
type(a_index) is tuple
and all(
type(s) is slice
and not s.start
and s.stop is None
and s.step is None
for s in a_index
)
)
):
dsk[k] = a
elif get is getitem or (a_asarray and not a_lock):
# default settings are fine, drop the extra parameters Since we
# always fallback to inner `get` functions, `get is getitem`
# can only occur if all gets are getitem, meaning all
# parameters must be getitem defaults.
dsk[k] = (get, a, a_index)
else:
dsk[k] = (get, a, a_index, a_asarray, a_lock)
return dsk
def normalize_slice(s):
"""Replace Nones in slices with integers
>>> normalize_slice(slice(None, None, None))
slice(0, None, 1)
"""
start, stop, step = s.start, s.stop, s.step
if start is None:
start = 0
if step is None:
step = 1
if start < 0 or step < 0 or stop is not None and stop < 0:
raise NotImplementedError()
return slice(start, stop, step)
def check_for_nonfusible_fancy_indexing(fancy, normal):
# Check for fancy indexing and normal indexing, where the fancy
# indexed dimensions != normal indexed dimensions with integers. E.g.:
# disallow things like:
# x[:, [1, 2], :][0, :, :] -> x[0, [1, 2], :] or
# x[0, :, :][:, [1, 2], :] -> x[0, [1, 2], :]
for f, n in zip_longest(fancy, normal, fillvalue=slice(None)):
if type(f) is not list and isinstance(n, Integral):
raise NotImplementedError(
"Can't handle normal indexing with "
"integers and fancy indexing if the "
"integers and fancy indices don't "
"align with the same dimensions."
)
def fuse_slice(a, b):
"""Fuse stacked slices together
Fuse a pair of repeated slices into a single slice:
>>> fuse_slice(slice(1000, 2000), slice(10, 15))
slice(1010, 1015, None)
This also works for tuples of slices
>>> fuse_slice((slice(100, 200), slice(100, 200, 10)),
... (slice(10, 15), [5, 2]))
(slice(110, 115, None), [150, 120])
And a variety of other interesting cases
>>> fuse_slice(slice(1000, 2000), 10) # integers
1010
>>> fuse_slice(slice(1000, 2000, 5), slice(10, 20, 2))
slice(1050, 1100, 10)
>>> fuse_slice(slice(1000, 2000, 5), [1, 2, 3]) # lists
[1005, 1010, 1015]
>>> fuse_slice(None, slice(None, None)) # doctest: +SKIP
None
"""
# None only works if the second side is a full slice
if a is None and isinstance(b, slice) and b == slice(None, None):
return None
# Replace None with 0 and one in start and step
if isinstance(a, slice):
a = normalize_slice(a)
if isinstance(b, slice):
b = normalize_slice(b)
if isinstance(a, slice) and isinstance(b, Integral):
if b < 0:
raise NotImplementedError()
return a.start + b * a.step
if isinstance(a, slice) and isinstance(b, slice):
start = a.start + a.step * b.start
if b.stop is not None:
stop = a.start + a.step * b.stop
else:
stop = None
if a.stop is not None:
if stop is not None:
stop = min(a.stop, stop)
else:
stop = a.stop
step = a.step * b.step
if step == 1:
step = None
return slice(start, stop, step)
if isinstance(b, list):
return [fuse_slice(a, bb) for bb in b]
if isinstance(a, list) and isinstance(b, (Integral, slice)):
return a[b]
if isinstance(a, tuple) and not isinstance(b, tuple):
b = (b,)
# If given two tuples walk through both, being mindful of uneven sizes
# and newaxes
if isinstance(a, tuple) and isinstance(b, tuple):
# Check for non-fusible cases with fancy-indexing
a_has_lists = any(isinstance(item, list) for item in a)
b_has_lists = any(isinstance(item, list) for item in b)
if a_has_lists and b_has_lists:
raise NotImplementedError("Can't handle multiple list indexing")
elif a_has_lists:
check_for_nonfusible_fancy_indexing(a, b)
elif b_has_lists:
check_for_nonfusible_fancy_indexing(b, a)
j = 0
result = list()
for i in range(len(a)):
# axis ceased to exist or we're out of b
if isinstance(a[i], Integral) or j == len(b):
result.append(a[i])
continue
while b[j] is None: # insert any Nones on the rhs
result.append(None)
j += 1
result.append(fuse_slice(a[i], b[j])) # Common case
j += 1
while j < len(b): # anything leftover on the right?
result.append(b[j])
j += 1
return tuple(result)
raise NotImplementedError()