Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / c10 / core / impl / InlineDeviceGuard.h

#pragma once

// This file provides implementations of InlineDeviceGuard and InlineOptionalDeviceGuard.

#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/VirtualGuardImpl.h>
#include <c10/util/Optional.h>
#include <c10/util/C++17.h>

namespace c10 {
namespace impl {



/**
 * A DeviceGuard is an RAII class that sets a device to some value
 * on construction, and resets the device to its original value on
 * destruction.
 *
 * InlineDeviceGuard is a helper class for implementing DeviceGuards.
 * It is templated over a DeviceGuardImpl (anything that implements
 * DeviceGuardImplInterface).  There are two primary ways to instantiate
 * InlineDeviceGuard:
 *
 *  - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl.
 *    This is the best way to use InlineDeviceGuard, as all calls are
 *    devirtualized, giving you code as efficient as straight line
 *    calls to cudaGetDevice/cudaSetDevice.
 *
 *  - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl
 *    retrieved from a DeviceType registry.  We have explicitly instantiated
 *    InlineDeviceGuard this way as c10::DeviceGuard.
 *
 * If you are in a hurry, you can use InlineDeviceGuard directly:
 *
 *    using CUDAGuard = impl::InlineDeviceGuard<CUDAGuardImpl>;
 *
 * However, you can provide a better user experience if you explicitly write a
 * wrapper class that itself contains the template instantiation:
 *
 *    class CUDAGuard {
 *    public:
 *      // ... the API ...
 *    private:
 *      impl::InlineDeviceGuard<CUDAGuardImpl> guard_;
 *    }
 *
 * The wrapper class provides a good place to write documentation, and helps
 * avoid weird template instantiation errors when a user incorrectly uses the
 * class.
 *
 * If you need to test this class, consider instantiating it with FakeGuardImpl.
 */
template <typename T>
class InlineDeviceGuard {
public:
  // Note [Omitted default constructor from RAII]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // In principle, we could add a default constructor to
  // DeviceGuard which reads the current device and promises to
  // restore to that device on exit.  However, most cases where you
  // would have written this, you probably meant to actually just
  // use OptionalDeviceGuard (since you don't actually need the
  // restore to happen if you don't ever actually set the device).
  // We remove the constructor here to encourage you to think about
  // what you actually want to happen.
  explicit InlineDeviceGuard() = delete;

  /// Set the current device to the passed Device.
  explicit InlineDeviceGuard(Device device)
    : impl_(device.type())
    , original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
    , current_device_(device.index() == -1 ? original_device_ : device)
    {}

  /// Set the current device index to the passed DeviceIndex.  (The
  /// device type is inferred from the template parameter T).
  template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
  explicit InlineDeviceGuard(DeviceIndex device_index)
    : InlineDeviceGuard(Device(U::static_type, device_index)) {}

  /// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit
  /// DeviceGuardImplInterface pointer.
  template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
  explicit InlineDeviceGuard(Device device, const DeviceGuardImplInterface* impl)
    : impl_(VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type())))
    , original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
    , current_device_(device.index() == -1 ? original_device_ : device)
    {}

  /// Copy is disallowed
  InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete;
  InlineDeviceGuard<T>& operator=(const InlineDeviceGuard<T>&) = delete;

  /// Move is disallowed, as DeviceGuard does not have an uninitialized state,
  /// which is required for moves on types with nontrivial destructors.
  InlineDeviceGuard(InlineDeviceGuard<T>&& other) = delete;
  InlineDeviceGuard& operator=(InlineDeviceGuard<T>&& other) = delete;

  ~InlineDeviceGuard() {
    impl_.uncheckedSetDevice(original_device_);
  }

  /// Sets the device to the given one.
  template <typename U=T, typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value, int>::type = 0>
  void set_device(at::Device device) {
    AT_ASSERT((U::static_type == DeviceType::HIP && device.type() == DeviceType::CUDA) ||
              device.type() == U::static_type);
    auto index = device.index();
    if (index == -1) return;
    impl_.setDevice(device);
    current_device_ = device;
  }

  /// Resets the currently set device to its original device, and then sets the
  /// current device to the passed device.  This is effectively equivalent to
  /// set_device when a guard supports only a single device type.
  template <typename U=T>
  typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type
  reset_device(at::Device device) {
    set_device(device);
  }

  /// Resets the currently set device to its original device, and then sets the
  /// current device to the passed device (for a possibly different device
  /// type).
  ///
  /// This method is named reset_device to highlight the fact that previous
  /// device settings from this guard are NOT preserved, even if the device
  /// has a different device type.  For example:
  ///
  ///   // CUDA device is 0
  ///   DeviceGuard g(Device(kCUDA, 1));
  ///   g.reset_device(Device(kHIP, 2));
  ///   // CUDA device is 0 (!!)
  ///
  /// NOTE: this implementation may skip some device setting if it can prove
  /// that it is unnecessary.
  ///
  /// Optional argument is for testing only.
  template <typename U=T>
  typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value >::type
  reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl = nullptr) {
    auto index = device.index();
    if (index == -1) return;
    if (device.type() == original_device_.type()) {
      AT_ASSERT(impl == nullptr || impl->type() == device.type());
      impl_.setDevice(device);
      current_device_ = device;
    } else {
      // Destruct and reconstruct the DeviceGuard in place
      impl_.setDevice(original_device_);
      impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl);
      original_device_ = impl_.exchangeDevice(device);
      current_device_ = device;
    }
  }

  /// Sets the device index to the given one.  The device type is inferred
  /// from the original device type.
  void set_index(DeviceIndex index) {
    reset_device(Device(original_device_.type(), index));
  }

  /// Returns the device that was set at the time the most recent
  /// reset_device(), or otherwise the device at construction time.
  Device original_device() const {
    return original_device_;
  }

  /// Returns the most recent device that was set using this device guard,
  /// either from construction, or via set_device/reset_device/set_index.
  Device current_device() const {
    return current_device_;
  }

protected:
  T impl_;

private:
  Device original_device_;
  Device current_device_;
};

/**
 * A OptionalDeviceGuard is an RAII class that sets a device to some value on
 * initialization, and resets the device to its original value on destruction.
 *
 * InlineOptionalDeviceGuard is a helper class for implementing
 * OptionalDeviceGuards.  See guidance in InlineDeviceGuard on how to
 * use this.  See OptionalDeviceGuard for user-oriented usage notes.
 */
template <typename T>
class InlineOptionalDeviceGuard {
public:
  // Note [Explicit initialization of optional fields]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Explicit initialization of optional fields
  // required to workaround an nvcc bug; see https://github.com/pytorch/pytorch/issues/12117

  /// Creates an uninitialized OptionalDeviceGuard.
  explicit InlineOptionalDeviceGuard()
    : guard_() // See Note [Explicit initialization of optional fields]
    {}

  /// Set the current device to the passed Device, if it is not nullopt.
  explicit InlineOptionalDeviceGuard(optional<Device> device_opt)
    : guard_() { // See Note [Explicit initialization of optional fields]
    if (device_opt.has_value()) {
      guard_.emplace(device_opt.value());
    }
  }

  /// Set the current device to the passed DeviceIndex, if it is not nullopt.
  template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
  explicit InlineOptionalDeviceGuard(optional<DeviceIndex> device_index_opt)
    : guard_() { // See Note [Explicit initialization of optional fields]
    if (device_index_opt.has_value()) {
      guard_.emplace(device_index_opt.value());
    }
  }

  /// All constructors of DeviceGuard are valid for OptionalDeviceGuard
  /// and result in initialized OptionalDeviceGuard.
  template <typename... Args>
  explicit InlineOptionalDeviceGuard(Args&&... args)
    : guard_(in_place, std::forward<Args>(args)...) {}

  // TODO: Consider readding Tensor and TensorList constructors here, when
  // Tensor moves to c10.  (These are only valid on OptionalDeviceGuard,
  // because a Tensor may be undefined, in which case we need an uninitialized
  // tensor guard.)

  // Note [Move construction for RAII guards is tricky]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // In principle, move construction is useful for terminating
  // the lifetime of a `OptionalDeviceGuard` early; for example:
  //
  //     // current device is d0
  //     OptionalDeviceGuard g1(d1);
  //     // current device is d1
  //     {
  //       OptionalDeviceGuard g2(std::move(g1));
  //     }
  //     // current device is d0!!
  //
  // However, it's difficult to implement the move constructor
  // in a way that works in all situations.  For example, consider
  // the following example:
  //
  //     OptionalDeviceGuard g1(d1);
  //     {
  //       OptionalDeviceGuard g2(d2);
  //       {
  //         OptionalDeviceGuard g3(std::move(g1)); // !!!
  //       }
  //     }
  //
  // What should the current device be while g3 in scope... and what
  // should it be after it goes out of scope?  What about g2?
  // There don't seem to be satisfactory answers for these questions.
  //
  // It's in principle possible to raise an error when this occurs
  // by doing some extra thread-local bookkeeping.  But why bother?
  // Just don't provide the constructor.
  InlineOptionalDeviceGuard(InlineOptionalDeviceGuard<T>&& other) = delete;

  // Note [Move assignment for RAII guards is tricky]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Move assignment is deleted, because you need to know which guard was
  // defined "first", as that guard's original_device_ wins--with the current
  // representation, we have no way of telling which is the case.  (Move
  // construction does not have this problem, as one guard is always
  // uninitialized.)
  //
  // We can make this clear by way of a pair of examples:
  //
  // Example 1:
  //
  //  // initial device is n0
  //  {
  //    CUDAGuard g1(n1);
  //    {
  //      CUDAGuard g2(n2);
  //      // current device should be n2
  //      g1 = std::move(g2);
  //      // current device should still be n2
  //    }
  //    // current device should still be n2
  //  }
  //  // current device should be n0
  //
  //  Example 2 (flip the order of the two guards):
  //
  //  // initial device is n0
  //  {
  //    CUDAGuard g2(n2);
  //    {
  //      CUDAGuard g1(n1);
  //      // current device should be n1
  //      g1 = std::move(g2);
  //      // current device should be n2
  //    }
  //    // current device should be n0 (since g2 has been vacated)
  //  }
  //
  // In both examples, we need g1 to restore to n0 after move assignment.
  // However, in example 1, this is determined by the restore value of g1
  // (prior to the move). In example 2, however, it is determined by the the
  // restore value of g2(!!). We don't know which one should win, without having
  // a way of telling which guard was allocated first.
  //
  // We could solve this with an extra thread-local variable.  But no one is
  // actually using move-assignment.  So just get rid of it.
  InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = delete;

  /// Sets the device to the given one.  Initializes OptionalDeviceGuard if it
  /// is not already initialized.
  template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
  void set_device(at::Device device) {
    if (!guard_.has_value()) {
      guard_.emplace(device);
    } else {
      guard_->set_device(device);
    }
  }

  /// Resets the currently set device to its original device, and then sets the
  /// current device to the passed device (for a possibly different device
  /// type).  Initializes OptionalDeviceGuard if it is not already initialized.
  ///
  /// See notes on why this is called reset_device on InlineDeviceGuard.
  ///
  /// Optional argument is for testing only.
  template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
  void reset_device(at::Device device, const DeviceGuardImplInterface* impl = nullptr) {
    if (!guard_.has_value()) {
      guard_.emplace(device, impl);
    } else {
      guard_->reset_device(device, impl);
    }
  }

  /// Resets the currently set device to its original device, and then sets the
  /// current device to the passed device.  Initializes the guard if it is
  /// not already initialized.  This is effectively equivalent to set_device
  /// when a guard supports only a single device type.
  template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
  void reset_device(at::Device device) {
    if (!guard_.has_value()) {
      guard_.emplace(device);
    } else {
      guard_->reset_device(device);
    }
  }

  /// Sets the device index to the given one.  The device type is statically
  /// known.
  template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type>
  void set_index(DeviceIndex index) {
    if (!guard_.has_value()) {
      guard_.emplace(index);
    } else {
      guard_->set_index(index);
    }
  }

  /// Returns the device that was set immediately prior to initialization of the,
  /// guard, or nullopt if the guard is uninitialized.
  optional<Device> original_device() const {
    return guard_.has_value() ? make_optional(guard_->original_device()) : nullopt;
  }

  /// Returns the most recent device that was set using this device guard,
  /// either from construction, or via set_device, if the guard is initialized,
  /// or nullopt if the guard is uninitialized.
  optional<Device> current_device() const {
    return guard_.has_value() ? make_optional(guard_->current_device()) : nullopt;
  }

  /// Restore the original device, resetting this guard to uninitialized state.
  void reset() {
    guard_.reset();
  }

private:
  optional<InlineDeviceGuard<T>> guard_;
};

}} // namespace c10::impl