#pragma once
#include <c10/core/impl/InlineDeviceGuard.h>
namespace c10 {
/// RAII guard that sets a certain default device in its constructor, and
/// changes it back to the device that was originally active upon destruction.
///
/// The device is always reset to the one that was active at the time of
/// construction of the guard. Even if you `set_device` after construction, the
/// destructor will still reset the device to the one that was active at
/// construction time.
///
/// This device guard does NOT have an uninitialized state; it is guaranteed
/// to reset a device on exit. If you are in a situation where you *might*
/// want to setup a guard (i.e., are looking for the moral equivalent
/// of optional<DeviceGuard>), see OptionalDeviceGuard.
class DeviceGuard {
public:
/// No default constructor; see Note [Omitted default constructor from RAII]
explicit DeviceGuard() = delete;
/// Set the current device to the passed Device.
explicit DeviceGuard(Device device) : guard_(device) {}
/// This constructor is for testing only.
explicit DeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {}
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
DeviceGuard& operator=(const DeviceGuard&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
DeviceGuard(DeviceGuard&& other) = delete;
DeviceGuard& operator=(DeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
///
/// TODO: The consistency check here is inconsistent with StreamGuard's
/// behavior with set_stream, where a stream on a different device than
/// the original one isn't an error; we just reset the stream and then
/// switch devices.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// This method is for testing only.
void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Sets the device index to the given one. The device type is inferred
/// from the original device type the guard was constructed with.
void set_index(DeviceIndex index) {
guard_.set_index(index);
}
/// Returns the device that was set at the time the guard was constructed.
Device original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device.
Device current_device() const {
return guard_.current_device();
}
private:
impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
};
/**
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
* Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
* with extra constructors and methods as appropriate.
*
* Besides its obvious use (optionally applying a DeviceGuard), OptionalDeviceGuard
* is often also used for the following idiom:
*
* OptionalDeviceGuard g;
* for (const auto& t : tensors) {
* g.set_device(t.device());
* do_something_with(t);
* }
*
* This usage is marginally more efficient than constructing a DeviceGuard every
* iteration of the for loop, as it avoids an unnecessary device reset.
*
* Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs
* when you use the nullary constructor, or pass a nullopt to the constructor.
* Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
* original device was and they do not reset on destruction. This is why
* original_device() and current_device() return optional<Device> rather than
* Device (as they do in DeviceGuard), and also is why we didn't just
* provide OptionalDeviceGuard by default and hide DeviceGuard from users.
*
* The semantics of an OptionalDeviceGuard are exactly explained by thinking
* of it as an optional<DeviceGuard>. In particular, an initialized
* OptionalDeviceGuard doesn't restore device to its value at construction; it
* restores device to its value *at initialization*. So if you have the
* program:
*
* setDevice(1);
* OptionalDeviceGuard g;
* setDevice(2);
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
*
* On destruction, g will reset device to 2, rather than 1.
*
* An uninitialized OptionalDeviceGuard is distinct from a (initialized)
* DeviceGuard whose original_device_ and current_device_ match, since the
* DeviceGuard will still reset the device to original_device_.
*/
class OptionalDeviceGuard {
public:
/// Create an uninitialized guard. Set the guard later using reset_device.
explicit OptionalDeviceGuard() : guard_() {}
/// Initialize the guard, setting the current device to the passed Device.
explicit OptionalDeviceGuard(Device device) : guard_(device) {}
/// Initialize the guard if a Device is passed; otherwise leave the
/// guard uninitialized.
explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
/// Constructor for testing only.
explicit OptionalDeviceGuard(Device device, const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {}
/// Copy is disallowed
OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;
/// Move is disallowed
/// See Note [Explicit initialization of optional fields]
/// and // Note [Move construction for RAII guards is tricky]
/// for rationale.
OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// For testing only
void reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Returns the device that was set at the time the guard was constructed.
optional<Device> original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via reset_device.
optional<Device> current_device() const {
return guard_.current_device();
}
private:
impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_;
};
// Note [Whither the DeviceGuard boilerplate]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Design note: in principle, we could avoid these wrappers using:
//
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
// using OptionalDeviceGuard = impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
//
// But the error messages are worse, and our users can't just look at the
// header file to find out what's going on. Furthermore, for specializations
// like CUDAStreamGuard, it can be profitable to replace some interfaces with
// refined types (e.g., return CUDAStream instead of Stream). So, we eat
// the boilerplate and write out the API explicitly.
} // namespace c10