Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / c10 / core / impl / InlineStreamGuard.h

#pragma once

#include <c10/core/impl/InlineDeviceGuard.h>

namespace c10 {
namespace impl {

/**
 * A StreamGuard is an RAII class that changes the current device
 * to the device corresponding to some stream, and changes the
 * default stream on that device to be this stream.
 *
 * InlineStreamGuard is a helper class for implementing StreamGuards.
 * See InlineDeviceGuard for guidance on how to use this class.
 */
template <typename T>
class InlineStreamGuard : private InlineDeviceGuard<T> {
public:
  /// No default constructor, see Note [Omitted default constructor from RAII]
  explicit InlineStreamGuard() = delete;

  /// Set the current device to the device associated with the passed stream,
  /// and set the current stream on that device to the passed stream.
  explicit InlineStreamGuard(Stream stream)
    : InlineDeviceGuard<T>(stream.device())
    , original_stream_of_original_device_(this->impl_.getStream(original_device()))
    , original_stream_of_current_device_(this->impl_.exchangeStream(stream))
    , current_stream_(stream)
    {}

  /// This constructor exists purely for testing
  template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
  explicit InlineStreamGuard(Stream stream, const DeviceGuardImplInterface* impl)
    : InlineDeviceGuard<T>(stream.device(), impl ? impl : getDeviceGuardImpl(stream.device_type()))
    , original_stream_of_original_device_(this->impl_.getStream(original_device()))
    , original_stream_of_current_device_(this->impl_.exchangeStream(stream))
    , current_stream_(stream)
    {}

  /// Copy is disallowed
  InlineStreamGuard(const InlineStreamGuard<T>&) = delete;
  InlineStreamGuard<T>& operator=(const InlineStreamGuard<T>&) = delete;

  /// Move is disallowed, as StreamGuard does not have an uninitialized state,
  /// which is required for moves on types with nontrivial destructors.
  InlineStreamGuard(InlineStreamGuard<T>&& other) = delete;
  InlineStreamGuard& operator=(InlineStreamGuard<T>&& other) = delete;

  ~InlineStreamGuard() {
    this->impl_.exchangeStream(original_stream_of_current_device_);
  }

  /// Resets the currently set stream to the original stream and
  /// the currently set device to the original device.  Then,
  /// set the current device to the device associated with the passed stream,
  /// and set the current stream on that device to the passed stream.
  ///
  /// NOTE: this implementation may skip some stream/device setting if
  /// it can prove that it is unnecessary.
  ///
  /// WARNING: reset_stream does NOT preserve previously set streams on
  /// different devices.  If you need to set streams on multiple devices
  /// on CUDA, use CUDAMultiStreamGuard instead.
  void reset_stream(Stream stream) {
    // TODO: make a version that takes an impl argument.  Unfortunately,
    // that will require SFINAE because impl is only valid for the
    // VirtualGuardImpl specialization.
    if (stream.device() == this->current_device()) {
      this->impl_.exchangeStream(stream);
      current_stream_ = stream;
    } else {
      // Destruct and reconstruct the StreamGuard in-place
      this->impl_.exchangeStream(original_stream_of_current_device_);
      this->reset_device(stream.device());
      original_stream_of_current_device_ = this->impl_.exchangeStream(stream);
      current_stream_ = stream;
    }
  }

  // It's not clear if set_device should also reset the current stream
  // if the device is unchanged; therefore, we don't provide it.
  // The situation is somewhat clearer with reset_device, but it's still
  // a pretty weird thing to do, so haven't added this either.

  /// Returns the stream of the original device prior to this guard.  Subtly,
  /// the stream returned here is the original stream of the *original*
  /// device; i.e., it's the stream that your computation *would* have
  /// been put on, if it hadn't been for this meddling stream guard.
  /// This is usually what you want.
  Stream original_stream() const {
    return original_stream_of_original_device_;
  }

  /// Returns the most recent stream that was set using this device guard,
  /// either from construction, or via set_stream.
  Stream current_stream() const {
    return current_stream_;
  }

  /// Returns the most recent device that was set using this device guard,
  /// either from construction, or via set_device/reset_device/set_index.
  Device current_device() const {
    return InlineDeviceGuard<T>::current_device();
  }

  /// Returns the device that was set at the most recent reset_stream(),
  /// or otherwise the device at construction time.
  Device original_device() const {
    return InlineDeviceGuard<T>::original_device();
  }

private:
  Stream original_stream_of_original_device_; // what the user probably cares about
  Stream original_stream_of_current_device_; // what we need to restore
  Stream current_stream_;
};

/**
 * An OptionalStreamGuard is an RAII class that sets a device to some value on
 * initialization, and resets the device to its original value on destruction.
 * See InlineOptionalDeviceGuard for more guidance on how to use this class.
 */
template <typename T>
class InlineOptionalStreamGuard {
public:
  /// Creates an uninitialized stream guard.
  explicit InlineOptionalStreamGuard()
    : guard_() // See Note [Explicit initialization of optional fields]
    {}

  /// Set the current device to the device associated with the passed stream,
  /// and set the current stream on that device to the passed stream,
  /// if the passed stream is not nullopt.
  explicit InlineOptionalStreamGuard(optional<Stream> stream_opt)
    : guard_() {
    if (stream_opt.has_value()) {
      guard_.emplace(stream_opt.value());
    }
  }

  /// All constructors of StreamGuard are valid for OptionalStreamGuard
  template <typename... Args>
  explicit InlineOptionalStreamGuard(Args&&... args)
    : guard_(in_place, std::forward<Args>(args)...) {}

  // See Note [Move construction for RAII guards is tricky]
  InlineOptionalStreamGuard(InlineOptionalStreamGuard<T>&& other) = delete;

  // See Note [Move assignment for RAII guards is tricky]
  InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = delete;

  /// Resets the currently set stream to the original stream and
  /// the currently set device to the original device.  Then,
  /// set the current device to the device associated with the passed stream,
  /// and set the current stream on that device to the passed stream.
  /// Initializes the OptionalStreamGuard if it was not previously initialized.
  void reset_stream(Stream stream) {
    if (guard_.has_value()) {
      guard_->reset_stream(stream);
    } else {
      guard_.emplace(stream);
    }
  }

  /// Returns the stream that was set at the time the guard was most recently
  /// initialized, or nullopt if the guard is uninitialized.
  optional<Stream> original_stream() const {
    return guard_.has_value() ? make_optional(guard_->original_stream()) : nullopt;
  }

  /// Returns the most recent stream that was set using this stream guard,
  /// either from construction, or via reset_stream, if the guard is initialized,
  /// or nullopt if the guard is uninitialized.
  optional<Stream> current_stream() const {
    return guard_.has_value() ? make_optional(guard_->current_stream()) : nullopt;
  }

  /// Restore the original device and stream, resetting this guard to uninitialized state.
  void reset() {
    guard_.reset();
  }

private:
  optional<InlineStreamGuard<T>> guard_;
};

}} // namespace c10::impl