Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / TensorIterator.h

#pragma once

#include <ATen/TensorMeta.h>
#include <ATen/core/Dimname.h>
#include <ATen/core/Range.h>
#include <ATen/core/TensorBase.h>
#include <c10/core/DynamicCast.h>
#include <c10/util/FunctionRef.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/SmallVector.h>
#include <c10/util/TypeCast.h>
#include <c10/util/irange.h>

#include <array>
#include <bitset>

C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
#endif
#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor")
#endif

namespace at {
class Tensor;
class OptionalTensorRef;
using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
} // namespace at

// TensorIterator is a helper class for element-wise operations, such as
// arithmetic, comparisons, and trigonometric functions. It handles
// broadcasting and type conversions of operands.
//
// This is inspired by NumPy's Array Iterator API (NpyIter).
//
// The files Loops.h and Loops.cuh provide functions to build kernels that
// use TensorIterator.
//
// Example:
//
//   auto iter = TensorIteratorConfig()
//     .add_output(output)
//     .add_input(input)
//     .build()
//
// [MyKernel.cpp / MyKernel.cu]
//   cpu_kernel(iter, [](float a, float b) {
//     return a + b;
//   });
//
//   gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float {
//     return a + b;
//   });
//
// Note [Order of Construction]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// When setting up the tensor iterator configuration, the output Tensors
// have to be added first via
// TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs,
// the inputs can be added via
// TensorIteratorConfig::add_owned_input(at::Tensor).
// Adding another output after inputs have been added will rise an exception.
//
// Note [Common Dtype Computation]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Some operations have a natural notion of a "common dtype" or
//   "computation dtype" where all inputs are cast to one dtype, the
//   operation is performed, and then the results are cast to all outputs.
//
// TensorIterator infers a common dtype if all inputs have the same dtype,
//   and it computes one using type promotion rules on its inputs if
//   promote_inputs_to_common_dtype_ is true. Attempting to query
//   a common dtype otherwise will throw an exception.
//
// Note that the outputs are not considered when computing a common dtype.

namespace at {

namespace internal {
// This parameter is heuristically chosen to determine the minimum number of
// work that warrants parallelism. For example, when summing an array, it is
// deemed inefficient to parallelise over arrays shorter than 32768. Further,
// no parallel algorithm (such as parallel_reduce) should split work into
// smaller than GRAIN_SIZE chunks.
constexpr int64_t GRAIN_SIZE = 32768;

// Storage for a non-owning Tensor, without needing to include Tensor.h
class TORCH_API OpaqueOptionalTensorRef {
  alignas(alignof(TensorBase)) std::array<char, sizeof(TensorBase)> data_;

 public:
  OpaqueOptionalTensorRef();
  ~OpaqueOptionalTensorRef();

  OptionalTensorRef* get() {
    return reinterpret_cast<OptionalTensorRef*>(data_.data());
  }
  const OptionalTensorRef* get() const {
    return reinterpret_cast<const OptionalTensorRef*>(data_.data());
  }

  OptionalTensorRef& operator*() {
    return *get();
  }
  const OptionalTensorRef& operator*() const {
    return *get();
  }
  OptionalTensorRef* operator->() {
    return get();
  }
  const OptionalTensorRef* operator->() const {
    return get();
  }

  const Tensor& getTensor() const;
};
} // namespace internal

struct TORCH_API OperandInfo {
  using StrideVector = SmallVector<int64_t, 6>;
  OperandInfo() = default;
  C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned<TensorBase>&& t) {
    if (t->defined()) {
      device = t->device();
      target_dtype = t->scalar_type();
      current_dtype = target_dtype;
    }
    tensor(std::move(t));
    validate();
  }

  C10_ALWAYS_INLINE ~OperandInfo() = default;

  /// Stride after broadcasting. The stride is in bytes, not number of elements.
  StrideVector stride_bytes;

  /// The desired device and type for the operand. For inputs, this specifies
  /// that the input should be converted to this type if necessary. For outputs,
  /// this specifies which type to allocate. target_dtype and device are
  /// initialized with the dtype and device of the tensor but during type
  /// promotion target_dtype value can become different from tensor's dtype
  /// also, during type promotion target_dtype and device can be set for an
  /// undefined tensor so that tensor can be properly constructed later.
  c10::optional<Device> device = c10::nullopt;
  ScalarType target_dtype = ScalarType::Undefined;
  // Caches dtype of the tensor, because scalar_type is an expensive operation
  // If dtype of the tensor is changed (e.g. as a result of type promotion or in
  // allocate_outputs), this
  // value should be changed too.
  ScalarType current_dtype = ScalarType::Undefined;

  bool is_device_defined() const {
    return device.has_value();
  }
  bool is_type_defined() const {
    return target_dtype != ScalarType::Undefined;
  }
  TensorOptions options() const {
    return TensorOptions(target_dtype).device(device);
  }

  /// The data pointer. This may be different from tensor->data_ptr() if the
  /// iterator is split.
  void* data = nullptr;

  bool is_output = false;

  bool will_resize = false;

  bool is_read_write = false;

  void validate() {
    TORCH_CHECK(
        !tensor_base_->defined() || tensor_base_->layout() == kStrided,
        "unsupported tensor layout: ",
        tensor_base_->layout());
  }

  /// The tensor operand. Note that the strides, data pointer, and
  /// other attributes may differ due to dimension reordering and
  /// coalescing.
  const Tensor& tensor() const {
    return tensor_storage_.getTensor();
  }
  const TensorBase& tensor_base() const {
    return *tensor_base_;
  }
  void tensor(c10::MaybeOwned<TensorBase>&& tensor);

  // Save the original tensor operand in cases when an output is modified
  // (e.g. if dtype is changed)
  const Tensor& original_tensor() const {
    return original_tensor_storage_.getTensor();
  }
  const TensorBase& original_tensor_base() const {
    return *original_tensor_base_;
  }

  // Set tensor to a new value, and store the old tensor value in
  // original_tensor Should only ever be called once for the lifetime of an
  // operand
  void exchange_tensor(c10::MaybeOwned<TensorBase>&& new_tensor);

  // Move original_tensor back into tensor, exchange_tensor must have been
  // called before
  void restore_original_tensor();

 private:
  c10::MaybeOwned<TensorBase> tensor_base_;
  c10::MaybeOwned<TensorBase> original_tensor_base_ =
      c10::MaybeOwned<TensorBase>::owned(c10::in_place);

  // We store TensorBase visibly in the header to allow inline access.
  // However, we sometimes need a genuine `const Tensor &` for the
  // TensorIterator API. So, we also store a non-owning `Tensor`
  // object in these `_storage_` variables.
  internal::OpaqueOptionalTensorRef tensor_storage_;
  internal::OpaqueOptionalTensorRef original_tensor_storage_;
};

struct SplitUntil32Bit;

enum class FastSetupType : uint8_t {
  NONE,
  CONTIGUOUS,
  CHANNELS_LAST,
  NON_OVERLAPPING_DENSE
};

class TensorIteratorConfig;
struct TensorIterator;

struct TORCH_API TensorIteratorBase : public impl::MetaBase {
  using DimMask = std::bitset<64>;
  using PtrVector = SmallVector<char*, 4>;
  using StrideVector = SmallVector<int64_t, 6>;

  TensorIteratorBase();
  void build(TensorIteratorConfig&);

  // The inner-loop function operates on the fastest moving dimension. It
  // implements element-wise operations in terms of 1-d strided tensors.
  //
  // Arguments:
  //  data: data pointers for each operand (length `ntensors`)
  //  strides: stride for each operand (length `ntensors`)
  //  size: size of inner loop
  //
  // The `size` often matches shape[0], but may be smaller due to
  // parallelization of the inner loop.
  using loop2d_t = c10::function_ref<
      void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>;

  using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>;

  void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true);

  int ndim() const {
    return shape_.size();
  }
  IntArrayRef shape() const {
    return shape_;
  }
  int64_t numel() const;
  int ntensors() const {
    return operands_.size();
  }
  int noutputs() const {
    return num_outputs_;
  }
  int ninputs() const {
    return ntensors() - noutputs();
  }
  IntArrayRef view_offsets() const {
    return view_offsets_;
  }

  /// number of elements in the output operand. this is the same as numel() for
  /// operations that are not reductions.
  int64_t num_output_elements() const;

  /// number of reduced dimensions in a reduction operation
  int num_reduce_dims() const;

  /// 1-dimensional iteration and no buffering or type conversion
  bool is_trivial_1d() const;
  /// Reducible to 1-dimensional and all operands are contiguous
  bool is_contiguous() const;
  bool is_dim_reduced(int dim) const;

  /// Accessors for each operand
  IntArrayRef strides(int arg) const {
    return operands_[arg].stride_bytes;
  }
  void* data_ptr(int arg) const;
  ScalarType dtype(int arg = 0) const {
    return operands_[arg].current_dtype;
  }
  ScalarType common_dtype() const {
    TORCH_INTERNAL_ASSERT(
        common_dtype_ != ScalarType::Undefined,
        "Queried for invalid common dtype!");
    return common_dtype_;
  }
  ScalarType input_dtype(int arg = 0) const {
    return operands_[num_outputs_ + arg].current_dtype;
  }
  Device device(int arg = 0) const {
    return operands_[arg].device.value();
  }
  DeviceType device_type(int arg = 0) const {
    return device(arg).type();
  }
  int64_t element_size(int arg) const {
    return elementSize(dtype(arg));
  }
  bool is_scalar(int arg) const;
  bool is_cpu_scalar(int arg) const;

  const TensorBase& tensor_base(int arg) const {
    return operands_[arg].tensor_base();
  }
  const Tensor& tensor(int arg) const {
    return operands_[arg].tensor();
  }

  const TensorBase& output_base(int arg = 0) const {
    AT_ASSERT(arg < num_outputs_);
    return tensor_base(arg);
  }

  const Tensor& output(int arg = 0) const {
    AT_ASSERT(arg < num_outputs_);
    return tensor(arg);
  }

  const TensorBase& input_base(int arg = 0) const {
    AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
    return tensor_base(num_outputs_ + arg);
  }
  const Tensor& input(int arg = 0) const {
    AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
    return tensor(num_outputs_ + arg);
  }

  // Copies from temporary outputs back to the original outputs
  // NOTE: only used on CPU
  void cast_outputs();

  /// Removes an operand from this iterator
  void remove_operand(int arg);
  /// Shrinks an iterated dimension
  void narrow(int dim, int64_t start, int64_t size);
  /// Narrows every dim after and including `start_dim` to size one.
  void select_all_keeping_dim(int start_dim, IntArrayRef starts);
  /// Replaces the data pointer for the operand at index `arg`.
  /// The new pointer should have the same sizes, strides and dtype as the
  /// original
  void unsafe_replace_operand(int arg, void* data);

  /// Splits this TensorIterator into two iterators. Together they iterate over
  /// the entire operation. Used by `with_32bit_indexing()`.
  std::unique_ptr<TensorIterator> split(int dim);

  /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim]
  int get_dim_to_split() const;

  template <typename T>
  T scalar_value(int arg) {
    auto& op = operands_[arg];
    return c10::fetch_and_cast<T>(op.tensor_base().scalar_type(), op.data);
  }

 private:
  template <typename loop1d_t>
  auto loop_2d_from_1d(const loop1d_t& loop) {
    return
        [loop, ntensor = ntensors()](
            char** base, const int64_t* strides, int64_t size0, int64_t size1) {
          PtrVector data(base, base + ntensor);
          const int64_t* outer_strides = &strides[ntensor];
          for (const auto i : c10::irange(size1)) {
            if (i > 0) {
              for (const auto arg : c10::irange(ntensor)) {
                data[arg] += outer_strides[arg];
              }
            }
            loop(data.data(), strides, size0);
          }
        };
  }

 public:
  template <
      typename loop1d_t,
      std::enable_if_t<
          std::is_convertible<
              loop1d_t,
              c10::function_ref<
                  void(char**, const int64_t* strides, int64_t size)>>::value,
          int> = 0>
  void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) {
    for_each(loop_2d_from_1d(loop), grain_size);
  }

  void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE);

  void parallel_reduce(loop2d_t loop);

  template <
      typename loop1d_t,
      std::enable_if_t<
          std::is_convertible<
              loop1d_t,
              c10::function_ref<
                  void(char**, const int64_t* strides, int64_t size)>>::value,
          int> = 0>
  void serial_for_each(loop1d_t loop, Range range) {
    serial_for_each(loop_2d_from_1d(loop), range);
  }

  void serial_for_each(loop2d_t loop, Range range) const;

  /// Create a strides array for a Tensor with shape of this iterator. The
  /// parameter `element_size` specifies the size of Tensor's data type in
  /// bytes (e.g. `4` for `float`)
  StrideVector compatible_stride(int element_size) const;

  /// Inverts the re-ordering done by reorder_dimensions. This can only be
  /// called *before* coalesce_dimensions() is called.
  DimVector invert_perm(IntArrayRef input) const;

  /// Reapply same re-ordering as it is done by reorder_dimensions. This can
  /// only be called *before* coalesce_dimensions() is called.
  DimVector apply_perm_and_mul(IntArrayRef input, int mul) const;

  /// Helper functions for CPU iteration
  StrideVector get_dim_strides(int dim) const;
  StrideVector get_strides() const;
  StrideVector get_inner_strides() const {
    return get_dim_strides(0);
  }
  PtrVector get_base_ptrs() const;

  // Helper functions for advanced stride manipulations (e.g. torch.flip)
  void _unsafe_set_arg_strides(const int arg, IntArrayRef strides) {
    operands_[arg].stride_bytes = std::move(strides);
  }
  void _unsafe_set_arg_data(const int arg, void* data) {
    operands_[arg].data = data;
  }

  /// true if the stride computation can use 32-bit arithmetic. Used by GPU
  /// kernels
  bool can_use_32bit_indexing() const;

  /// An "iteratable" object that recursively splits this iterator into
  /// sub-iterators that can use 32-bit indexing.
  SplitUntil32Bit with_32bit_indexing() const;

  /// If the kernel should accumulate into the output. Only relevant for CUDA
  /// reductions.
  bool should_accumulate() const {
    return accumulate_;
  }

  /// Whether this iterator produces the actual output,
  /// as opposed to something that will be accumulated further. Only relevant
  /// for CUDA reductions.
  bool is_final_output() const {
    return final_output_;
  }

  bool has_contiguous_first_dim() const {
    if (ndim() == 0) {
      return true;
    }

    int num_tensors = ntensors();
    for (const auto i : c10::irange(num_tensors)) {
      if (strides(i)[0] != element_size(i)) {
        return false;
      }
    }
    return true;
  }

  void set_output_raw_strided(
      int64_t output_idx,
      IntArrayRef sizes,
      IntArrayRef strides,
      TensorOptions options,
      DimnameList names) override;

#define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic)            \
  maybestatic void methodname(                                              \
      TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \
  maybestatic void methodname(                                              \
      const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \
  maybestatic void methodname(                                              \
      const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \
  maybestatic void methodname(                                              \
      TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete;      \
  maybestatic void methodname(                                              \
      TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete;      \
  maybestatic void methodname(                                              \
      const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete;      \
  maybestatic void methodname(                                              \
      TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete;

#define TORCH_DISALLOW_TEMPORARIES(methodname) \
  TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, )

  void build_binary_float_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  void build_borrowing_binary_float_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op)
  void build_binary_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  void build_borrowing_binary_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
  void build_unary_float_op(const TensorBase& out, const TensorBase& a);
  void build_borrowing_unary_float_op(
      const TensorBase& out,
      const TensorBase& a);
  TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
  void build_unary_op(const TensorBase& out, const TensorBase& a);
  // Odd special case needed for pow. Has to borrow the output because
  // it's a structured kernel, but the argument is potentially a copy.
  void build_output_borrowing_argument_owning_unary_op(
      const TensorBase& out,
      const TensorBase& a);
  void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
  TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
  void build_borrowing_unary_force_boolean_op(
      const TensorBase& out,
      const TensorBase& a);
  TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
  void build_comparison_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  void build_borrowing_comparison_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
  // Another special case: we need to own the second argument for comparison
  // ops.
  void build_borrowing_except_last_argument_comparison_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  void build_ternary_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b,
      const TensorBase& c);

#undef TORCH_DISALLOW_TEMPORARIES
 protected:
  // Mutable reference as it moves tensors out of TensorIteratorConfig
  void populate_operands(TensorIteratorConfig&);
  void mark_outputs();
  void mark_resize_outputs(const TensorIteratorConfig&);
  void compute_mem_overlaps(const TensorIteratorConfig&);
  void compute_shape(const TensorIteratorConfig&);
  void compute_strides(const TensorIteratorConfig&);
  void reorder_dimensions();
  void permute_dimensions(IntArrayRef perm);
  void compute_types(const TensorIteratorConfig&);
  ScalarType compute_common_dtype();
  void allocate_or_resize_outputs();
  bool fast_set_up(const TensorIteratorConfig&);
  FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
  void compute_names(const TensorIteratorConfig&);
  void propagate_names_to_outputs();
  void coalesce_dimensions();

 protected:
  /// Records the "computation" shape of the output tensor. The computation
  /// shape is different from the regular shape in a few ways:
  ///
  ///   - The shape may be permuted (via permute_dimensions) so that we
  ///     process the dimensions in the most computationally efficient order
  ///     (rather than the logical order given to us by the users.)
  ///   - The shape may have adjacent dimensions collapsed (via
  ///     coalesce_dimensions) so that we minimize the number of
  ///     dimensions we have to explicitly iterate over.  For example,
  ///     a pointwise operation on a contiguous tensor "computationally"
  ///     consists of only a single dimension.
  ///
  /// In other words, the computation shape is the output shape as it
  /// actually matters for implementing the kernel, but not necessarily the
  /// output shape that the user will see in the end.
  ///
  /// The lifecycle of mutations to shape_ in TensorIterator:
  ///   - declare_static_shape() sets an initial shape explicitly
  ///     provided by user, otherwise
  ///   - compute_shape() computes the true (non-computational) shape
  ///     specified by the user.
  ///   - reorder_dimensions() reorders dimensions to improve coalescing.
  ///   - coalesce_dimensions() then coalesces adjacent dimensions when
  ///     possible.
  ///
  /// The shape may also be further modified if we create sub-TensorIterators,
  /// e.g., via narrow or select_all_keeping_dim.
  DimVector shape_;

  /// Temporarily records the permutation computed by reorder_dimensions.
  /// This permutation maps the computation output dimension (dim) to
  /// the original true output dimension (perm_[dim]).  It is used by
  /// invert_perm to undo the permutation.  After coalesce_dimensions is
  /// called, the permutation is no longer valid (as, in general, there
  /// is no permutation that will make computation dimensions to
  /// output dimensions); methods that manipulate perm_ are obligated
  /// to test that !has_coalesced_dimensions
  DimVector perm_;

  /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build())
  /// been called?  This is SOLELY used to check validity of perm_.
  bool has_coalesced_dimensions_ = false;

  /// Whether iteration must be fixed. This disables dimension permuting and
  /// also changes how for_each divides work among threads.
  bool enforce_linear_iteration_ = false;

  /// The index offsets into the original tensors for each dimension.
  /// This is only non-zero when you narrow() a TensorIterator (e.g.,
  /// when you make sub-TensorIterators).
  DimVector view_offsets_;

  /// The computed names of the output tensor.  Computed by compute_names()
  NameVector names_;

  /// The operands of the TensorIterator: both the inputs and outputs.  The
  /// outputs MUST come first in the operands_ list.  There is always an
  /// operand for each output of the TensorIterator, even if TensorIterator
  /// will ultimately be responsible for allocating the output; in those
  /// cases, tensor is simply undefined (and will be populated later
  /// during build()).
  ///
  /// This list is initially populated prior to build(), but build() mutates
  /// OperandInfo to populate more information.
  SmallVector<OperandInfo, 4> operands_;

  /// Number of outputs in operands_ (the length of the outputs prefix
  /// in operands_).
  int num_outputs_ = 0;

  /// Whether or not all operands have the same shape and are 1d+. Having all
  /// the same shape affects whether or not the iterator is eligible for fast
  /// setup.
  bool all_ops_same_shape_ = false;
  /// Whether or not all operands are 0d, this affects type promotion
  bool all_ops_are_scalars_ = false;

  /// The "computation" dtype of TensorIterator, specifying what the dtype
  /// we will do the internal computation in TensorIterator.  Typically,
  /// this matches the dtype of the output tensors, but not always!
  ScalarType common_dtype_ = ScalarType::Undefined;

  /// This is currently defined as kCPU, or the device of the first non-CPU
  /// tensor argument. See TensorIteratorBase::compute_types for details.
  Device common_device_ = kCPU;

  /// Set by split(), see should_accumulate() and is_final_output()
  bool accumulate_ = false;
  bool final_output_ = true;

  // From TensorIteratorConfig
  bool is_reduction_ = false;

  /// Set by populate_operands(), says if we're handling meta tensors
  bool is_meta_ = false;
};

struct TORCH_API TensorIterator final : public TensorIteratorBase {
  TensorIterator() : TensorIteratorBase() {}
  // Slicing is OK, TensorIterator guaranteed NOT to have any fields
  TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {}

#define TORCH_DISALLOW_TEMPORARIES(methodname) \
  TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static)

  static TensorIterator binary_float_op(
      TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  static TensorIterator binary_op(
      TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  static TensorIterator borrowing_binary_op(
      const TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op)
  static TensorIterator comparison_op(
      TensorBase& out,
      const TensorBase& a,
      const TensorBase& b);
  static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
  static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
  static TensorIterator nullary_op(TensorBase& out);
  static TensorIterator borrowing_nullary_op(const TensorBase& out);
  static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
  static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
  static TensorIterator reduce_op(
      TensorBase& out1,
      TensorBase& out2,
      const TensorBase& a);
#undef TORCH_DISALLOW_TEMPORARIES
#undef TORCH_DISALLOW_TEMPORARIES_IMPL

  const Tensor& maybe_get_output(int64_t output_idx) override;
  void set_output_raw_strided(
      int64_t output_idx,
      IntArrayRef sizes,
      IntArrayRef strides,
      TensorOptions options,
      DimnameList names) override;
};

class TORCH_API TensorIteratorConfig final {
 public:
  friend struct TensorIteratorBase;
  friend struct TensorIterator;

  TensorIteratorConfig() = default;

  C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig);

  /// Construction
  // Stores input/output Tensors without incrementing the reference count.
  // Important: the outputs have to be added before the inputs.
  TensorIteratorConfig& add_output(const TensorBase& output) {
    return add_borrowed_output(output);
  }
  TensorIteratorConfig& add_input(const TensorBase& input) {
    return add_borrowed_input(input);
  }

  // Borrowing from temporaries is unlikely to go well.
  TensorIteratorConfig& add_output(TensorBase&& output) = delete;
  TensorIteratorConfig& add_input(TensorBase&& input) = delete;

  // Stores input/output Tensors while incrementing the reference count.
  // Note that add_{in,out}put are nearly always what you
  // want, and the exception (adding an unnamed temporary) won't
  // compile.
  TensorIteratorConfig& add_owned_output(const TensorBase& output);
  TensorIteratorConfig& add_owned_input(const TensorBase& input);

  // Advanced API: stores input/output Tensors without incrementing
  // the reference count. The caller must ensure that these Tensors
  // live at least as long as this TensorIteratorConfig and any
  // TensorIteratorBase built from this TensorIteratorConfig.
  // Important: the outputs have to be added before the inputs.
  TensorIteratorConfig& add_borrowed_output(const TensorBase& output);
  TensorIteratorConfig& add_borrowed_input(const TensorBase& input);

  // Borrowing from temporaries is unlikely to go well.
  TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete;
  TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete;

  // Sets the check_mem_overlap_ flag, which is true by default.
  // If true, inputs are checked for partial overlap with the outputs and
  // outputs are checked for internal overlap (e.g. broadcasted views). An error
  // is raised if unacceptable overlap is detected.
  // If you're migrating an existing operator to using TensorIterator, please
  // consider if the previous implementation checked memory overlap. If it did
  // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then
  // checking memory overlap is BC-breaking. Please don't check memory overlap
  // in that case.
  TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) {
    check_mem_overlap_ = check_mem_overlap;
    return *this;
  }

  // Sets the check_all_same_dtype_ flag, which is true by default
  // If true, checks that all inputs and defined outputs have the same dtype
  // Setting either of promote_inputs_to_common_dtype_
  //   or cast_common_dtype_to_outputs_ to true will set
  //   check_all_same_dtype_ to false.
  TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) {
    check_all_same_dtype_ = _check_all_same_dtype;
    return *this;
  }

  // Sets the check_all_same_device_ flag, which is true by default
  // If true, all operands must be on the same device, with the possible
  //   exception of CPU scalars, which can be passed to some CUDA kernels
  //   as kernel arguments.
  TensorIteratorConfig& check_all_same_device(
      const bool _check_all_same_device) {
    check_all_same_device_ = _check_all_same_device;
    return *this;
  }

  // Sets the enforce_safe_casting_to_output_ flag, which is false by default
  // If true, the iterator's "common dtype" must be computable
  //   (see the [Common Dtype Computation] note) and
  //   canCast(common dtype, output dtype) must be true for all outputs.
  TensorIteratorConfig& enforce_safe_casting_to_output(
      const bool _enforce_safe_casting_to_output) {
    enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output;
    return *this;
  }

  // Sets the enforce_linear_iteration_ flag, which is false by default.
  // If true, iteration goes in the same order as a C-contiguous tensor
  // is layed out in memory. i.e. last dimension iterates fastest.
  //
  // This iteration order can be less efficient and may even prevent
  // vectorization. So only use if the correctness of your kernel depends on it.
  TensorIteratorConfig& enforce_linear_iteration(
      const bool _enforce_linear_iteration = true) {
    enforce_linear_iteration_ = _enforce_linear_iteration;
    return *this;
  }

  // Sets the promote_inputs_to_common_dtype_ flag, which is false by default
  // If true, the iterator's "common dtype" is always computed (see the
  //   [Common Dtype Computation] note) and, on the CPU, temporary copies of
  //   the inputs in the common dtype are passed as the actual inputs to
  //   the operation.
  // Setting this flag to true sets check_all_same_dtype_ to false.
  TensorIteratorConfig& promote_inputs_to_common_dtype(
      const bool _promote_inputs_to_common_dtype) {
    promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype;
    if (_promote_inputs_to_common_dtype) {
      check_all_same_dtype_ = false;
    }
    return *this;
  }

  // Sets the promote_integer_inputs_to_float_ flag, which is false by default
  // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be
  // true. If true, if the iterator's "common dtype" is an integral type
  // (including bool)
  //   then it is changed to the default float scalar type.
  TensorIteratorConfig& promote_integer_inputs_to_float(
      const bool _promote_integer_inputs_to_float) {
    promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float;
    TORCH_INTERNAL_ASSERT(
        !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_);
    return *this;
  }

  TensorIteratorConfig& is_reduction(const bool _is_reduction) {
    is_reduction_ = _is_reduction;
    return *this;
  }

  TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) {
    allow_cpu_scalars_ = _allow_cpu_scalars;
    return *this;
  }

  // Sets the cast_common_dtype_to_outputs_ flag, which is false by default
  // If true, the iterator's "common dtype" must be computatable
  //   (see the [Common Dtype Computation] note) and, on the CPU, temporary
  //   copies of the outputs are passed as the actual output to the operation.
  //   These temporaries are then copied to the original outputs after
  //   the operation is performed (see cast_outputs()).
  // Setting this flag to true sets check_all_same_dtype_ to false.
  TensorIteratorConfig& cast_common_dtype_to_outputs(
      const bool _cast_common_dtype_to_outputs) {
    cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs;
    if (_cast_common_dtype_to_outputs) {
      check_all_same_dtype_ = false;
    }
    return *this;
  }

  TensorIteratorConfig& resize_outputs(bool resize_outputs) {
    resize_outputs_ = resize_outputs;
    return *this;
  }

  // Bypass output dtype/device computation and fix the dtype/device as
  // specified here.
  TensorIteratorConfig& declare_static_dtype_and_device(
      ScalarType dtype,
      Device device);
  TensorIteratorConfig& declare_static_dtype(ScalarType dtype);
  TensorIteratorConfig& declare_static_device(Device device);
  TensorIteratorConfig& declare_static_shape(IntArrayRef shape);
  TensorIteratorConfig& declare_static_shape(
      IntArrayRef shape,
      IntArrayRef squash_dims);

  // It would be better if this was && qualified, but this would be at the cost
  // of a lot of boilerplate above
  TensorIterator build() {
    TensorIterator iter;
    iter.build(*this);
    return iter;
  }

 private:
  SmallVector<c10::MaybeOwned<TensorBase>, 4> tensors_;
  int num_outputs_ = 0;
  int num_inputs_ = 0;

  c10::optional<DimVector> static_shape_ = c10::nullopt;
  c10::optional<ScalarType> static_dtype_ = c10::nullopt;
  c10::optional<Device> static_device_ = c10::nullopt;
  bool check_mem_overlap_ = true;
  bool allow_cpu_scalars_ = false;
  bool is_reduction_ = false;
  bool resize_outputs_ = true;
  bool check_all_same_dtype_ = true;
  bool check_all_same_device_ = true;
  bool enforce_safe_casting_to_output_ = false;
  bool enforce_linear_iteration_ = false;
  bool promote_inputs_to_common_dtype_ = false;
  bool promote_integer_inputs_to_float_ = false;
  bool cast_common_dtype_to_outputs_ = false;
};

/// A container-like struct that acts as if it contains splits of a
/// TensorIterator that can use 32-bit indexing. Taken together the splits cover
/// the original TensorIterator.
struct TORCH_API SplitUntil32Bit {
  struct TORCH_API iterator {
    iterator() = default;
    iterator(const TensorIteratorBase& iter);
    iterator(iterator&&) = default;

    // Guaranteed to be a TensorIterator proper!
    TensorIterator& operator*() const;
    iterator& operator++();
    bool operator==(const iterator& other) const {
      // two iterators are equal if they are the same object or they're both
      // empty
      return this == &other || (vec.empty() && other.vec.empty());
    }
    // needed for C++11 range-based for loop
    bool operator!=(const iterator& other) const {
      return !(*this == other);
    }

    /// stack of TensorIterators to be split
    std::vector<std::unique_ptr<TensorIterator>> vec;
  };

  SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {}

  iterator begin() const;
  iterator end() const;

 private:
  const TensorIteratorBase& iter;
};

} // namespace at

C10_CLANG_DIAGNOSTIC_POP()