Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ jit / cuda.py

# mypy: ignore-errors

r"""
This package adds support for JIT compilation for CUDA Streams and events,
This is similar to API's available in the eager mode
:ref:`cuda-semantics` has more details about working with CUDA.
"""

import torch
from typing import Optional, Any
from torch import device as _device

def get_current_device_index() -> int:
    r"""Checks if there are CUDA devices available and
    returns the device index of the current default CUDA device.
    Returns -1 in case there are no CUDA devices available.

    Arguments: ``None``
    """
    if torch.cuda.device_count() > 0:
        return torch.cuda._current_device()
    return -1

def get_device_index(device: Optional[_device] = None, optional: bool = False, allow_cpu: bool = False) -> int:
    r"""Gets the device index from :attr:`device`, which can be a torch.device
    object, a Python integer, or ``None``.

    If :attr:`device` is a torch.device object, returns the device index if it
    is a CUDA device. Note that for a CUDA device without a specified index,
    , this will return the current default CUDA device if :attr:`optional` is ``True``.
    If :attr:`allow_cpu` is ``True``,CPU devices will be accepted and ``-1`` will be
    returned in this case.

    If :attr:`device` is a Python integer, it is returned as is.

    If :attr:`device` is ``None``, this will return the current default CUDA
    device if :attr:`optional` is ``True``.
    """
    if device is None:
        if optional:
            return get_current_device_index()
        else:
            raise ValueError('Expected a torch.device with a specified index '
                             f'or an integer, but got: {device}')
    device_index = -1
    if isinstance(device, str):
        device = torch.device(device)

    if isinstance(device, torch.device):
        if not allow_cpu and device.type == 'cpu':
            raise ValueError(f'Expected a non cpu device, but got: {device}')
        device_index = -1 if device.type == 'cpu' else torch.cuda.device_index(device)

    if isinstance(device, int):
        device_index = device

    return device_index

class device(object):
    r"""Context-manager that changes the selected device.
    This is similar to device (torch.device or int), but has been
    introduced for JIT compatibility.
    Arguments:
        device (torch.device or int): device index to select. It's a no-op if
            this argument is a negative integer or ``None``.
    """
    def __init__(self, device: Optional[_device]):
        self.idx = -1
        self.prev_idx = -1
        self.device = device

    def __enter__(self):
        self.idx = get_device_index(self.device, optional=True)

        if self.idx == -1:
            return
        self.prev_idx = torch.cuda._current_device()

        if self.prev_idx != self.idx:
            torch.cuda._set_device(self.idx)

    def __exit__(self, type: Any, value: Any, traceback: Any):
        if self.prev_idx != self.idx:
            torch.cuda._set_device(self.prev_idx)

class StreamContext(object):
    r"""Context-manager that selects a given stream.
    All CUDA kernels queued within its context will be enqueued on a selected
    stream.
    Arguments:
        StreamContext (Stream): selected stream. This manager is a no-op if it's
            ``None``.
    .. note:: Streams are per-device. If the selected stream is not on the
        current device, this function will also change the current device to
        match the stream.
    """
    cur_stream : Optional['torch.classes.cuda.Stream']

    def __init__(self, stream: Optional['torch.classes.cuda.Stream']):
        self.idx = -1
        self.stream = stream
        # Initialize the below streams to default stream on the current device
        self.device_index = get_current_device_index()
        self.src_prev_stream = torch.cuda.default_stream(self.device_index)
        self.dst_prev_stream = torch.cuda.default_stream(self.device_index)

    def __enter__(self):
        self.idx = get_device_index(device=None, optional=True)
        # If there is no CUDA device available, return
        if self.idx == -1:
            return

        # Local cur_stream variable for type refinement
        cur_stream = self.stream
        # Return if stream is None
        if cur_stream is None:
            return
        self.src_prev_stream = torch.cuda.current_stream(self.idx)
        # If the stream is not on the current device, then change the device
        # and set the current stream on the device
        if self.src_prev_stream.device_index() != cur_stream.device_index():
            with device(cur_stream.device()):
                self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device_index())
            torch.cuda._set_device(cur_stream.device_index())
        torch.cuda.set_stream(cur_stream)

    def __exit__(self, type: Any, value: Any, traceback: Any):
        # Local cur_stream variable for type refinement
        cur_stream = self.stream
        # If stream is None or no CUDA device available, return
        if cur_stream is None or self.idx == -1:
            return
        # If the stream was not on the current device, restore the previous stream on
        # the destination device and also reset the current device to the previous device.
        # Set the current stream on the device to the src_prev_stream
        if self.src_prev_stream.device_index() != cur_stream.device_index():
            torch.cuda.set_stream(self.dst_prev_stream)
            torch.cuda._set_device(self.idx)
        torch.cuda.set_stream(self.src_prev_stream)

def stream(stream: Optional['torch.classes.cuda.Stream']) -> StreamContext:
    r"""Wrapper around the Context-manager that selects a given stream.
    All CUDA kernels queued within its context will be enqueued on a selected
    stream.
    Arguments:
        stream (Stream): selected stream. This manager is a no-op if it's
            ``None``.
    """
    return StreamContext(stream)

def Stream(device: int = -1, priority: int = 0) -> 'torch.classes.cuda.Stream':
    r"""Wrapper around a CUDA stream.
    A CUDA stream is a linear sequence of execution that belongs to a specific
    device, independent from other streams.  See :ref:`cuda-semantics` for
    details.
    Arguments:
        device(int, optional): a device on which to allocate
            the stream. If :attr:`device` is ``None`` (default) or a negative
            integer, this will use the current device.
        priority(int, optional): priority of the stream. Can be either
            -1 (high priority) or 0 (low priority). By default, streams have
            priority 0.
    .. note:: Although CUDA versions >= 11 support more than two levels of
        priorities, in PyTorch, we only support two levels of priorities.
    """
    return torch.classes.cuda.Stream(device, priority)

def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False) -> 'torch.classes.cuda.Event':
    r"""Wrapper around a CUDA event.
    CUDA events are synchronization markers that can be used to monitor the
    device's progress, to accurately measure timing, and to synchronize CUDA
    streams.
    Arguments:
        enable_timing (bool, optional): indicates if the event should measure time
            (default: ``False``)
        blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
        interprocess (bool): if ``True``, the event can be shared between processes
            (default: ``False``)
    .. _CUDA Event Documentation:
       https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
    """
    return torch.classes.cuda.Event(enable_timing, blocking, interprocess)