Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
import dask.array as da
from dask import config
from dask.array.backends import ArrayBackendEntrypoint, register_cupy
def _cupy(strict=True):
try:
import cupy
except ImportError:
if strict:
raise ImportError("Please install `cupy` to use `CupyBackendEntrypoint`")
return None
return cupy
def _da_with_cupy_meta(attr, *args, meta=None, **kwargs):
# Call the dask.array api with cupy-based meta
meta = _cupy().empty(()) if meta is None else meta
with config.set({"array.backend": "numpy"}):
return getattr(da, attr)(*args, meta=meta, **kwargs)
class CupyBackendEntrypoint(ArrayBackendEntrypoint):
def __init__(self):
"""Register data-directed dispatch functions"""
if _cupy(strict=False):
register_cupy()
@property
def RandomState(self):
return _cupy().random.RandomState
@staticmethod
def ones(*args, **kwargs):
return _da_with_cupy_meta("ones", *args, **kwargs)
@staticmethod
def zeros(*args, **kwargs):
return _da_with_cupy_meta("zeros", *args, **kwargs)
@staticmethod
def empty(*args, **kwargs):
return _da_with_cupy_meta("empty", *args, **kwargs)
@staticmethod
def full(*args, **kwargs):
return _da_with_cupy_meta("full", *args, **kwargs)
@staticmethod
def arange(*args, like=None, **kwargs):
like = _cupy().empty(()) if like is None else like
with config.set({"array.backend": "numpy"}):
return da.arange(*args, like=like, **kwargs)