Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / c10 / core / DeviceGuard.h

#pragma once

#include <c10/core/impl/InlineDeviceGuard.h>

namespace c10 {

/// RAII guard that sets a certain default device in its constructor, and
/// changes it back to the device that was originally active upon destruction.
///
/// The device is always reset to the one that was active at the time of
/// construction of the guard. Even if you `set_device` after construction, the
/// destructor will still reset the device to the one that was active at
/// construction time.
///
/// This device guard does NOT have an uninitialized state; it is guaranteed
/// to reset a device on exit.  If you are in a situation where you *might*
/// want to setup a guard (i.e., are looking for the moral equivalent
/// of optional<DeviceGuard>), see OptionalDeviceGuard.
class DeviceGuard {
 public:
  /// No default constructor; see Note [Omitted default constructor from RAII]
  explicit DeviceGuard() = delete;

  /// Set the current device to the passed Device.
  explicit DeviceGuard(Device device) : guard_(device) {}

  /// This constructor is for testing only.
  explicit DeviceGuard(
      Device device,
      const impl::DeviceGuardImplInterface* impl)
      : guard_(device, impl) {}

  /// Copy is disallowed
  DeviceGuard(const DeviceGuard&) = delete;
  DeviceGuard& operator=(const DeviceGuard&) = delete;

  /// Move is disallowed, as DeviceGuard does not have an uninitialized state,
  /// which is required for moves on types with nontrivial destructors.
  DeviceGuard(DeviceGuard&& other) = delete;
  DeviceGuard& operator=(DeviceGuard&& other) = delete;

  /// Sets the device to the given one.  The specified device must be consistent
  /// with the device type originally specified during guard construction.
  ///
  /// TODO: The consistency check here is inconsistent with StreamGuard's
  /// behavior with set_stream, where a stream on a different device than
  /// the original one isn't an error; we just reset the stream and then
  /// switch devices.
  void reset_device(at::Device device) {
    guard_.reset_device(device);
  }

  /// This method is for testing only.
  void reset_device(
      at::Device device,
      const impl::DeviceGuardImplInterface* impl) {
    guard_.reset_device(device, impl);
  }

  /// Sets the device index to the given one.  The device type is inferred
  /// from the original device type the guard was constructed with.
  void set_index(DeviceIndex index) {
    guard_.set_index(index);
  }

  /// Returns the device that was set at the time the guard was constructed.
  Device original_device() const {
    return guard_.original_device();
  }

  /// Returns the most recent device that was set using this device guard,
  /// either from construction, or via set_device.
  Device current_device() const {
    return guard_.current_device();
  }

 private:
  impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
};

/**
 * A OptionalDeviceGuard is an RAII class that sets a device to some value on
 * initialization, and resets the device to its original value on destruction.
 * Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
 * with extra constructors and methods as appropriate.
 *
 * Besides its obvious use (optionally applying a DeviceGuard),
 * OptionalDeviceGuard is often also used for the following idiom:
 *
 *    OptionalDeviceGuard g;
 *    for (const auto& t : tensors) {
 *      g.set_device(t.device());
 *      do_something_with(t);
 *    }
 *
 * This usage is marginally more efficient than constructing a DeviceGuard every
 * iteration of the for loop, as it avoids an unnecessary device reset.
 *
 * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized.  This occurs
 * when you use the nullary constructor, or pass a nullopt to the constructor.
 * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
 * original device was and they do not reset on destruction.  This is why
 * original_device() and current_device() return optional<Device> rather than
 * Device (as they do in DeviceGuard), and also is why we didn't just
 * provide OptionalDeviceGuard by default and hide DeviceGuard from users.
 *
 * The semantics of an OptionalDeviceGuard are exactly explained by thinking
 * of it as an optional<DeviceGuard>.  In particular, an initialized
 * OptionalDeviceGuard doesn't restore device to its value at construction; it
 * restores device to its value *at initialization*.  So if you have the
 * program:
 *
 *     setDevice(1);
 *     OptionalDeviceGuard g;
 *     setDevice(2);
 *     g.reset_device(Device(DeviceType::CUDA, 3));  // initializes!
 *
 * On destruction, g will reset device to 2, rather than 1.
 *
 * An uninitialized OptionalDeviceGuard is distinct from a (initialized)
 * DeviceGuard whose original_device_ and current_device_ match, since the
 * DeviceGuard will still reset the device to original_device_.
 */
class OptionalDeviceGuard {
 public:
  /// Create an uninitialized guard.  Set the guard later using reset_device.
  explicit OptionalDeviceGuard() = default;

  /// Initialize the guard, setting the current device to the passed Device.
  explicit OptionalDeviceGuard(Device device) : guard_(device) {}

  /// Initialize the guard if a Device is passed; otherwise leave the
  /// guard uninitialized.
  explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}

  /// Constructor for testing only.
  explicit OptionalDeviceGuard(
      Device device,
      const impl::DeviceGuardImplInterface* impl)
      : guard_(device, impl) {}

  /// Copy is disallowed
  OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
  OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;

  /// Move is disallowed
  /// See Note [Explicit initialization of optional fields]
  /// and // Note [Move construction for RAII guards is tricky]
  /// for rationale.
  OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
  OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;

  /// Sets the device to the given one.  The specified device must be consistent
  /// with the device type originally specified during guard construction.
  void reset_device(at::Device device) {
    guard_.reset_device(device);
  }

  /// For testing only
  void reset_device(
      at::Device device,
      const impl::DeviceGuardImplInterface* impl) {
    guard_.reset_device(device, impl);
  }

  /// Returns the device that was set at the time the guard was constructed.
  optional<Device> original_device() const {
    return guard_.original_device();
  }

  /// Returns the most recent device that was set using this device guard,
  /// either from construction, or via reset_device.
  optional<Device> current_device() const {
    return guard_.current_device();
  }

 private:
  impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{};
};

// Note [Whither the DeviceGuard boilerplate]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Design note: in principle, we could avoid these wrappers using:
//
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
// using OptionalDeviceGuard =
// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
//
// But the error messages are worse, and our users can't just look at the
// header file to find out what's going on.  Furthermore, for specializations
// like CUDAStreamGuard, it can be profitable to replace some interfaces with
// refined types (e.g., return CUDAStream instead of Stream).  So, we eat
// the boilerplate and write out the API explicitly.

} // namespace c10