#pragma once
#include <c10/util/FunctionRef.h>
#include <c10/util/SmallVector.h>
#include <c10/util/TypeCast.h>
#include <ATen/core/Range.h>
#include <bitset>
#include <ATen/NamedTensorUtils.h>
#include <ATen/TensorMeta.h>
// TensorIterator is a helper class for element-wise operations, such as
// arithmetic, comparisons, and trigonometric functions. It handles
// broadcasting and type conversions of operands.
//
// This is inspired by NumPy's Array Iterator API (NpyIter).
//
// The files Loops.h and Loops.cuh provide functions to build kernels that
// use TensorIterator.
//
// Example:
//
// auto iter = TensorIteratorConfig()
// .add_output(output)
// .add_input(input)
// .build()
//
// [MyKernel.cpp / MyKernel.cu]
// cpu_kernel(iter, [](float a, float b) {
// return a + b;
// });
//
// gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float {
// return a + b;
// });
//
// Note [Common Dtype Computation]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Some operations have a natural notion of a "common dtype" or
// "computation dtype" where all inputs are cast to one dtype, the
// operation is performed, and then the results are cast to all outputs.
//
// TensorIterator infers a common dtype if all inputs have the same dtype,
// and it computes one using type promotion rules on its inputs if
// promote_inputs_to_common_dtype_ is true. Attempting to query
// a common dtype otherwise will throw an exception.
//
// Note that the outputs are not considered when computing a common dtype.
namespace at {
namespace internal {
// This parameter is heuristically chosen to determine the minimum number of
// work that warrants parallelism. For example, when summing an array, it is
// deemed inefficient to parallelise over arrays shorter than 32768. Further,
// no parallel algorithm (such as parallel_reduce) should split work into
// smaller than GRAIN_SIZE chunks.
constexpr int64_t GRAIN_SIZE = 32768;
} // namespace internal
struct DimCounter {
DimCounter(IntArrayRef shape, Range range);
void increment(const std::array<int64_t, 2>& step);
bool is_done() const;
std::array<int64_t, 2> max_2d_step() const;
IntArrayRef shape;
Range range;
DimVector values;
int64_t offset;
};
struct TORCH_API OperandInfo {
using StrideVector = SmallVector<int64_t, 6>;
OperandInfo() {}
explicit OperandInfo(Tensor t) : tensor(std::move(t)) {
if (tensor.defined()) {
device = tensor.device();
target_dtype = tensor.scalar_type();
current_dtype = target_dtype;
}
validate();
}
/// Stride after broadcasting. The stride is in bytes, not number of elements.
StrideVector stride_bytes;
/// The tensor operand. Note that the strides, data pointer, and
/// other attributes may differ due to dimension reordering and
/// coalescing.
Tensor tensor;
// Save the original tensor operand in cases when an output is modified
// (e.g. if dtype is changed)
Tensor original_tensor;
/// The desired device and type for the operand. For inputs, this specifies that
/// the input should be converted to this type if necessary. For outputs, this
/// specifies which type to allocate. target_dtype and device are initialized with the dtype and device of the tensor
/// but during type promotion target_dtype value can become different from tensor's dtype
/// also, during type promotion target_dtype and device can be set for an undefined tensor so that tensor can be properly
/// constructed later.
Device device = kCPU;
ScalarType target_dtype = ScalarType::Undefined;
// Caches dtype of the tensor, because scalar_type is an expensive operation
// If dtype of the tensor is changed (e.g. as a result of type promotion or in allocate_outputs), this
//value should be changed too.
ScalarType current_dtype = ScalarType::Undefined;
bool is_type_defined() const { return target_dtype != ScalarType::Undefined; }
TensorOptions options() const {
return TensorOptions(target_dtype).device(device);
}
/// The data pointer. This may be different from tensor.data_ptr() if the
/// iterator is split.
void* data = nullptr;
bool is_output = false;
bool will_resize = false;
bool is_read_write = false;
void validate() {
TORCH_CHECK(
!tensor.defined() || tensor.layout() == kStrided,
"unsupported tensor layout: ", tensor.layout());
}
};
struct SplitUntil32Bit;
enum class FastSetupType : uint8_t {
NONE,
CONTIGUOUS,
CHANNELS_LAST,
NON_OVERLAPPING_DENSE
};
class TensorIteratorConfig;
struct TensorIterator;
struct TORCH_API TensorIteratorBase : public impl::MetaBase {
using DimMask = std::bitset<64>;
using PtrVector = SmallVector<char*, 4>;
using StrideVector = SmallVector<int64_t, 6>;
TensorIteratorBase();
void build(TensorIteratorConfig&);
// The inner-loop function operates on the fastest moving dimension. It
// implements element-wise operations in terms of 1-d strided tensors.
//
// Arguments:
// data: data pointers for each operand (length `ntensors`)
// strides: stride for each operand (length `ntensors`)
// size: size of inner loop
//
// The `size` often matches shape[0], but may be smaller due to
// parallelization of the inner loop.
using loop_t = c10::function_ref<void(char** data, const int64_t* strides, int64_t size)>;
using loop2d_t = c10::function_ref<void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>;
using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>;
void foreach_reduced_elt(loop_subiter_t loop, bool parallelize=true);
int ndim() const { return shape_.size(); }
IntArrayRef shape() const { return shape_; }
int64_t numel() const;
int ntensors() const { return operands_.size(); }
int noutputs() const { return num_outputs_; }
int ninputs() const { return ntensors() - noutputs(); }
IntArrayRef view_offsets() const { return view_offsets_; }
/// number of elements in the output operand. this is the same as numel() for
/// operations that are not reductions.
int64_t num_output_elements() const;
/// number of reduced dimensions in a reduction operation
int num_reduce_dims() const;
/// 1-dimensional iteration and no buffering or type conversion
bool is_trivial_1d() const;
/// Reducible to 1-dimensional and all operands are contiguous
bool is_contiguous() const;
bool is_dim_reduced(int dim) const;
/// Accessors for each operand
IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; }
void* data_ptr(int arg) const;
ScalarType dtype(int arg=0) const { return operands_[arg].current_dtype; }
ScalarType common_dtype() const {
TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined, "Queried for invalid common dtype!");
return common_dtype_;
}
ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].current_dtype; }
Device device(int arg=0) const { return operands_[arg].device; }
DeviceType device_type(int arg=0) const { return device(arg).type(); }
int64_t element_size(int arg) const { return elementSize(dtype(arg)); }
bool is_scalar(int arg) const;
bool is_cpu_scalar(int arg) const;
const Tensor& tensor(int arg) const { return operands_[arg].tensor; }
Tensor& tensor(int arg) { return operands_[arg].tensor; }
Tensor output(int arg=0) const {
AT_ASSERT(arg < num_outputs_);
return operands_[arg].tensor;
}
// Copies from temporary outputs back to the original outputs
// NOTE: only used on CPU
void cast_outputs();
Tensor input(int arg=0) const {
AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
return operands_[num_outputs_ + arg].tensor;
}
/// Removes an operand from this iterator
void remove_operand(int arg);
/// Shrinks an iterated dimension
void narrow(int dim, int64_t start, int64_t size);
/// Narrows every dim after and including `start_dim` to size one.
void select_all_keeping_dim(int start_dim, IntArrayRef starts);
/// Replaces the data pointer for the operand at index `arg`.
/// The new pointer should have the same sizes, strides and dtype as the
/// original
void unsafe_replace_operand(int arg, void* data);
/// Splits this TensorIterator into two iterators. Together they iterate over
/// the entire operation. Used by `with_32bit_indexing()`.
std::unique_ptr<TensorIterator> split(int dim);
/// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim]
int get_dim_to_split() const;
template <typename T>
T scalar_value(int arg) {
auto& op = operands_[arg];
return c10::fetch_and_cast<T>(op.tensor.scalar_type(), op.data);
}
void for_each(loop_t loop, int64_t grain_size = at::internal::GRAIN_SIZE);
void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE);
void parallel_reduce(loop2d_t loop);
void serial_for_each(loop_t loop, Range range) const;
void serial_for_each(loop2d_t loop, Range range) const;
/// Create a strides array for a Tensor with shape of this iterator. The
/// parameter `element_size` specifies the size of Tensor's data type in
/// bytes (e.g. `4` for `float`)
StrideVector compatible_stride(int element_size) const;
/// Inverts the re-ordering done by reorder_dimensions. This can only be
/// called *before* coalesce_dimensions() is called.
DimVector invert_perm(IntArrayRef input) const;
/// Reapply same re-ordering as it is done by reorder_dimensions. This can
/// only be called *before* coalesce_dimensions() is called.
DimVector apply_perm_and_mul(IntArrayRef input, int mul) const;
/// Helper functions for CPU iteration
StrideVector get_dim_strides(int dim) const;
StrideVector get_strides() const;
StrideVector get_inner_strides() const { return get_dim_strides(0); }
PtrVector get_data_ptrs(ArrayRef<char*> base, IntArrayRef counter) const;
PtrVector get_base_ptrs() const;
/// true if the stride computation can use 32-bit arithmetic. Used by GPU kernels
bool can_use_32bit_indexing() const;
/// An "iteratable" object that recursively splits this iterator into sub-iterators
/// that can use 32-bit indexing.
SplitUntil32Bit with_32bit_indexing() const;
/// If the kernel should accumulate into the output. Only relevant for CUDA
/// reductions.
bool should_accumulate() const { return accumulate_; }
/// Whether this iterator produces the actual output,
/// as opposed to something that will be accumulated further. Only relevant for
/// CUDA reductions.
bool is_final_output() const { return final_output_; }
bool has_contiguous_first_dim() const {
int num_tensors = ntensors();
for (int i = 0; i < num_tensors; i++) {
if (strides(i)[0] != element_size(i)) {
return false;
}
}
return true;
}
void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override;
void build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b);
protected:
// Mutable reference as it moves tensors out of TensorIteratorConfig
void populate_operands(TensorIteratorConfig&);
void mark_outputs();
void mark_resize_outputs(const TensorIteratorConfig&);
void compute_mem_overlaps(const TensorIteratorConfig&);
void compute_shape(const TensorIteratorConfig&);
void compute_strides(const TensorIteratorConfig&);
void reorder_dimensions();
void permute_dimensions(IntArrayRef perm);
void compute_types(const TensorIteratorConfig&);
ScalarType compute_common_dtype();
void allocate_or_resize_outputs();
bool fast_set_up(const TensorIteratorConfig&);
FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
void compute_names(const TensorIteratorConfig&);
void propagate_names_to_outputs();
void coalesce_dimensions();
protected:
/// Records the "computation" shape of the output tensor. The computation
/// shape is different from the regular shape in a few ways:
///
/// - The shape may be permuted (via permute_dimensions) so that we
/// process the dimensions in the most computationally efficient order
/// (rather than the logical order given to us by the users.)
/// - The shape may have adjacent dimensions collapsed (via
/// coalesce_dimensions) so that we minimize the number of
/// dimensions we have to explicitly iterate over. For example,
/// a pointwise operation on a contiguous tensor "computationally"
/// consists of only a single dimension.
///
/// In other words, the computation shape is the output shape as it
/// actually matters for implementing the kernel, but not necessarily the
/// output shape that the user will see in the end.
///
/// The lifecycle of mutations to shape_ in TensorIterator:
/// - declare_static_shape() sets an initial shape explicitly
/// provided by user, otherwise
/// - compute_shape() computes the true (non-computational) shape
/// specified by the user.
Loading ...