Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
dask / dask / array / cupy_entry_point.py
Size: Mime:
import dask.array as da
from dask import config
from dask.array.backends import ArrayBackendEntrypoint, register_cupy


def _cupy(strict=True):
    try:
        import cupy
    except ImportError:
        if strict:
            raise ImportError("Please install `cupy` to use `CupyBackendEntrypoint`")
        return None
    return cupy


def _da_with_cupy_meta(attr, *args, meta=None, **kwargs):
    # Call the dask.array api with cupy-based meta
    meta = _cupy().empty(()) if meta is None else meta
    with config.set({"array.backend": "numpy"}):
        return getattr(da, attr)(*args, meta=meta, **kwargs)


class CupyBackendEntrypoint(ArrayBackendEntrypoint):
    def __init__(self):
        """Register data-directed dispatch functions"""
        if _cupy(strict=False):
            register_cupy()

    @property
    def RandomState(self):
        return _cupy().random.RandomState

    @staticmethod
    def ones(*args, **kwargs):
        return _da_with_cupy_meta("ones", *args, **kwargs)

    @staticmethod
    def zeros(*args, **kwargs):
        return _da_with_cupy_meta("zeros", *args, **kwargs)

    @staticmethod
    def empty(*args, **kwargs):
        return _da_with_cupy_meta("empty", *args, **kwargs)

    @staticmethod
    def full(*args, **kwargs):
        return _da_with_cupy_meta("full", *args, **kwargs)

    @staticmethod
    def arange(*args, like=None, **kwargs):
        like = _cupy().empty(()) if like is None else like
        with config.set({"array.backend": "numpy"}):
            return da.arange(*args, like=like, **kwargs)