Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / nanoarrow   python

Repository URL to install this package:

/ src / nanoarrow / _device.pyx

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# cython: language_level = 3

from libc.stdint cimport uintptr_t, int64_t

from nanoarrow_device_c cimport (
    ARROW_DEVICE_CPU,
    ARROW_DEVICE_CUDA,
    ARROW_DEVICE_CUDA_HOST,
    ARROW_DEVICE_OPENCL,
    ARROW_DEVICE_VULKAN,
    ARROW_DEVICE_METAL,
    ARROW_DEVICE_VPI,
    ARROW_DEVICE_ROCM,
    ARROW_DEVICE_ROCM_HOST,
    ARROW_DEVICE_EXT_DEV,
    ARROW_DEVICE_CUDA_MANAGED,
    ARROW_DEVICE_ONEAPI,
    ARROW_DEVICE_WEBGPU,
    ARROW_DEVICE_HEXAGON,
    ArrowDevice,
    ArrowDeviceCpu,
    ArrowDeviceResolve
)

from nanoarrow._utils cimport Error

from enum import Enum

from nanoarrow import _repr_utils


class DeviceType(Enum):
    """
    An enumerator providing access to the device constant values
    defined in the Arrow C Device interface. Unlike the other enum
    accessors, this Python Enum is defined in Cython so that we can use
    the bulit-in functionality to do better printing of device identifiers
    for classes defined in Cython. Unlike the other enums, users don't
    typically need to specify these (but would probably like them printed
    nicely).
    """

    CPU = ARROW_DEVICE_CPU
    CUDA = ARROW_DEVICE_CUDA
    CUDA_HOST = ARROW_DEVICE_CUDA_HOST
    OPENCL = ARROW_DEVICE_OPENCL
    VULKAN =  ARROW_DEVICE_VULKAN
    METAL = ARROW_DEVICE_METAL
    VPI = ARROW_DEVICE_VPI
    ROCM = ARROW_DEVICE_ROCM
    ROCM_HOST = ARROW_DEVICE_ROCM_HOST
    EXT_DEV = ARROW_DEVICE_EXT_DEV
    CUDA_MANAGED = ARROW_DEVICE_CUDA_MANAGED
    ONEAPI = ARROW_DEVICE_ONEAPI
    WEBGPU = ARROW_DEVICE_WEBGPU
    HEXAGON = ARROW_DEVICE_HEXAGON


cdef class Device:
    """ArrowDevice wrapper

    The ArrowDevice structure is a nanoarrow internal struct (i.e.,
    not ABI stable) that contains callbacks for device operations
    beyond its type and identifier (e.g., copy buffers to or from
    a device).
    """

    def __cinit__(self, object base, uintptr_t addr):
        self._base = base,
        self._ptr = <ArrowDevice*>addr

    def __eq__(self, other) -> bool:
        return (
            isinstance(other, Device) and
            other.device_type == self.device_type and
            other.device_id == self.device_id
        )

    def __repr__(self):
        return _repr_utils.device_repr(self)

    @property
    def device_type(self):
        return DeviceType(self._ptr.device_type)

    @property
    def device_type_id(self):
        return self._ptr.device_type

    @property
    def device_id(self):
        return self._ptr.device_id

    @staticmethod
    def resolve(device_type, int64_t device_id):
        if int(device_type) == ARROW_DEVICE_CPU:
            return DEVICE_CPU

        cdef ArrowDevice* c_device = ArrowDeviceResolve(device_type, device_id)
        if c_device == NULL:
            raise ValueError(f"Device not found for type {device_type}/{device_id}")

        return Device(None, <uintptr_t>c_device)


# Cache the CPU device
# The CPU device is statically allocated (so base is None)
DEVICE_CPU = Device(None, <uintptr_t>ArrowDeviceCpu())


cdef class CSharedSyncEvent:

    def __cinit__(self, Device device, uintptr_t sync_event=0):
        self.device = device
        self.sync_event = <void*>sync_event

    cdef synchronize(self):
        if self.sync_event == NULL:
            return

        cdef Error error = Error()
        cdef ArrowDevice* c_device = self.device._ptr
        cdef int code = c_device.synchronize_event(c_device, self.sync_event, NULL, &error.c_error)
        error.raise_message_not_ok("ArrowDevice::synchronize_event", code)

        self.sync_event = NULL

    cdef synchronize_stream(self, uintptr_t stream):
        cdef Error error = Error()
        cdef ArrowDevice* c_device = self.device._ptr
        cdef int code = c_device.synchronize_event(c_device, self.sync_event, <void*>stream, &error.c_error)
        error.raise_message_not_ok("ArrowDevice::synchronize_event with stream", code)