Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / c10 / cuda / impl / CUDAGuardImpl.h

#pragma once

#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAFunctions.h>

#include <cuda_runtime_api.h>

namespace c10 {
namespace cuda {
namespace impl {

struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  static constexpr DeviceType static_type = DeviceType::CUDA;

  CUDAGuardImpl() {}
  explicit CUDAGuardImpl(DeviceType t) {
    TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
  }
  DeviceType type() const override {
    return DeviceType::CUDA;
  }
  Device exchangeDevice(Device d) const override {
    TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA);
    Device old_device = getDevice();
    if (old_device.index() != d.index()) {
      C10_CUDA_CHECK(cudaSetDevice(d.index()));
    }
    return old_device;
  }
  Device getDevice() const override {
    int device;
    C10_CUDA_CHECK(cudaGetDevice(&device));
    return Device(DeviceType::CUDA, device);
  }
  c10::optional<Device> uncheckedGetDevice() const noexcept {
    int device;
    auto err = cudaGetDevice(&device);
    C10_CUDA_CHECK_WARN(err);
    if (err != cudaSuccess) {
      return c10::nullopt;
    }
    return Device(DeviceType::CUDA, device);
  }
  void setDevice(Device d) const override {
    TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA);
    Device current_device = getDevice();
    if (current_device != d) {
      C10_CUDA_CHECK(cudaSetDevice(d.index()));
    }
  }
  void uncheckedSetDevice(Device d) const noexcept override {
    auto current_device = uncheckedGetDevice();
    if (!current_device.has_value() || current_device.value() != d) {
      C10_CUDA_CHECK_WARN(cudaSetDevice(d.index()));
    }
  }
  Stream getStream(Device d) const noexcept override {
    return getCurrentCUDAStream(d.index()).unwrap();
  }
  Stream getDefaultStream(Device d) const override {
    return getDefaultCUDAStream(d.index());
  }
  // NB: These do NOT set the current device
  Stream exchangeStream(Stream s) const noexcept override {
    CUDAStream cs(s);
    auto old_stream = getCurrentCUDAStream(s.device().index());
    setCurrentCUDAStream(cs);
    return old_stream.unwrap();
  }
  DeviceIndex deviceCount() const noexcept override {
    return device_count();
  }

  // Event-related functions
  void createEvent(
    cudaEvent_t* cuda_event,
    const EventFlag flag) const {
    // Maps PyTorch's Event::Flag to CUDA flag
    auto cuda_flag = cudaEventDefault;
    switch (flag) {
      case EventFlag::PYTORCH_DEFAULT:
      case EventFlag::CUDA_EVENT_DISABLE_TIMING:
        cuda_flag = cudaEventDisableTiming;
        break;
      case EventFlag::BACKEND_DEFAULT:
      case EventFlag::CUDA_EVENT_DEFAULT:
        cuda_flag = cudaEventDefault;
        break;
      default:
        TORCH_CHECK(false, "CUDA event received unknown flag");
    }

    C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
  }

  void destroyEvent(
    void* event,
    const DeviceIndex device_index) const noexcept override {
    if (!event) return;
    auto cuda_event = static_cast<cudaEvent_t>(event);
    int orig_device;
    C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device));
    C10_CUDA_CHECK_WARN(cudaSetDevice(device_index));
    C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
    C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device));
  }

  void record(
    void** event,
    const Stream& stream,
    const DeviceIndex device_index,
    const EventFlag flag) const override {
    TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
      "Event device index ",
      device_index,
      " does not match recording stream's device index ",
      stream.device_index(),
      ".");

    cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
    CUDAStream cuda_stream{stream};

    // Moves to stream's device to record
    const auto orig_device = getDevice();
    setDevice(stream.device());

    // Creates the event (lazily)
    if (!cuda_event) createEvent(&cuda_event, flag);
    C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
    // Makes the void* point to the (possibly just allocated) CUDA event
    *event = cuda_event;

    // Resets device
    setDevice(orig_device);
  }

  void block(
    void* event,
    const Stream& stream) const override {
    if (!event) return;
    cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
    CUDAStream cuda_stream{stream};
    const auto orig_device = getDevice();
    setDevice(stream.device());
    C10_CUDA_CHECK(cudaStreamWaitEvent(
      cuda_stream,
      cuda_event,
      /*flags (must be zero)=*/ 0));
    setDevice(orig_device);
  }

  // May be called from any device
  bool queryEvent(void* event) const override {
    if (!event) return true;
    cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
    const cudaError_t err = cudaEventQuery(cuda_event);
    if (err != cudaErrorNotReady) {
      C10_CUDA_CHECK(err);
    }
    return (err == cudaSuccess);
  }
};

}}} // namespace c10::cuda::impl