/ include / ATen / TensorIndexing.h

#pragma once

#include <c10/util/Optional.h>
#include <ATen/core/TensorBody.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Functions.h>
#include <ATen/ScalarOps.h>

// TODO: try to remove this
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>

#include <ATen/core/List.h>

namespace at {
namespace indexing {

const int64_t INDEX_MAX = std::numeric_limits<int64_t>::max();
const int64_t INDEX_MIN = std::numeric_limits<int64_t>::min();

enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };

constexpr c10::nullopt_t None = c10::nullopt;

struct TORCH_API EllipsisIndexType final { EllipsisIndexType() {} };
TORCH_API extern const EllipsisIndexType Ellipsis;

struct TORCH_API Slice final {
  // This mirrors `__PySlice_Unpack` in torch/csrc/utils/python_compat.h
    c10::optional<int64_t> start_index = c10::nullopt,
    c10::optional<int64_t> stop_index = c10::nullopt,
    c10::optional<int64_t> step_index = c10::nullopt) {
    if (!step_index.has_value()) {
      step_ = 1;
    } else {
      step_ = step_index.value();
      TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");

      // Here step might be -INDEX_MAX-1; in this case we replace it
      // with -INDEX_MAX.  This doesn't affect the semantics, and it
      // guards against later undefined behaviour resulting from code that
      // does "step = -step" as part of a slice reversal.
      if (step_ < -INDEX_MAX)
        step_ = -INDEX_MAX;
    if (!start_index.has_value()) {
      start_ = step_ < 0 ? INDEX_MAX : 0;
    } else {
      start_ = start_index.value();
    if (!stop_index.has_value()) {
      stop_ = step_ < 0 ? INDEX_MIN : INDEX_MAX;
    } else {
      stop_ = stop_index.value();

  inline int64_t start() const {
    return start_;

  inline int64_t stop() const {
    return stop_;

  inline int64_t step() const {
    return step_;

  int64_t start_;
  int64_t stop_;
  int64_t step_;

TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);

// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
// `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
// into its equivalent `std::vector<TensorIndex>`, so that further tensor indexing
// operations can be performed using the supplied indices.
// There is one-to-one correspondence between Python and C++ tensor index types:
// Python                  | C++
// -----------------------------------------------------
// `None`                  | `at::indexing::None`
// `Ellipsis`              | `at::indexing::Ellipsis`
// `...`                   | `"..."`
// `123`                   | `123`
// `True` / `False`        | `true` / `false`
// `:`                     | `Slice()` / `Slice(None, None)`
// `::`                    | `Slice()` / `Slice(None, None, None)`
// `1:`                    | `Slice(1, None)`
// `1::`                   | `Slice(1, None, None)`
// `:3`                    | `Slice(None, 3)`
// `:3:`                   | `Slice(None, 3, None)`
// `::2`                   | `Slice(None, None, 2)`
// `1:3`                   | `Slice(1, 3)`
// `1::2`                  | `Slice(1, None, 2)`
// `:3:2`                  | `Slice(None, 3, 2)`
// `1:3:2`                 | `Slice(1, 3, 2)`
// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
struct TORCH_API TensorIndex final {
  // Case 1: `at::indexing::None`
  TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}

  // Case 2: "..." / `at::indexing::Ellipsis`
  TensorIndex(at::indexing::EllipsisIndexType) : type_(TensorIndexType::Ellipsis) {}
  TensorIndex(const char *str) : TensorIndex(at::indexing::Ellipsis) {
      strcmp(str, "...") == 0,
      "Expected \"...\" to represent an ellipsis index, but got \"", str, "\"");

  // Case 3: Integer value
  TensorIndex(int64_t integer) : integer_(integer), type_(TensorIndexType::Integer) {}
  TensorIndex(int integer) : TensorIndex((int64_t)integer) {}

  // Case 4: Boolean value
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type >
  TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}

  // Case 5: Slice represented in `at::indexing::Slice` form
  TensorIndex(Slice slice) : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}

  // Case 6: Tensor value
  TensorIndex(Tensor tensor) : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}

  inline bool is_none() const {
    return type_ == TensorIndexType::None;

  inline bool is_ellipsis() const {
    return type_ == TensorIndexType::Ellipsis;

  inline bool is_integer() const {
    return type_ == TensorIndexType::Integer;

  inline int64_t integer() const {
    return integer_;

  inline bool is_boolean() const {
    return type_ == TensorIndexType::Boolean;

  inline bool boolean() const {
    return boolean_;

  inline bool is_slice() const {
    return type_ == TensorIndexType::Slice;

  inline const Slice& slice() const {
    return slice_;

  inline bool is_tensor() const {
    return type_ == TensorIndexType::Tensor;

  inline const Tensor& tensor() const {
    return tensor_;

  int64_t integer_;
  bool boolean_;
  Slice slice_;
  Tensor tensor_;
  TensorIndexType type_;

TORCH_API std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index);
TORCH_API std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices);

namespace impl {
static inline Tensor applySlice(
    const Tensor& self,
    int64_t dim,
    int64_t start,
    int64_t stop,
    int64_t step,
    bool disable_slice_optimization,
    const at::Device& self_device,
    const IntArrayRef& self_sizes) {
  // TODO: implement negative step
  TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");

  // Skip this optimization if we are tracing, as the trace may be polymorphic
  // over the shape of the `self` tensor, and we still want to record
  // the slice.
  int64_t length = (self_device == at::kCPU || self_device == at::kCUDA) ? self_sizes[dim] : self.size(dim);
  if (!disable_slice_optimization && start == 0 && stop == length && step == 1) {
    return self;
  return self.slice(dim, start, stop, step);

static inline Tensor applySelect(
    const Tensor& self,
    int64_t dim,
    int64_t index,
    int64_t real_dim,
    const at::Device& self_device,
    const IntArrayRef& self_sizes) {
    !(index == 0 && dim == 0 && self_sizes.size() == 0),
    "invalid index of a 0-dim tensor. ",
    "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");

  int64_t size = self_sizes[dim];
    index >= -size && index < size,
    "index ", index, " is out of bounds for dimension ", real_dim, " with size ", size);

  // if the index is negative, do not normalize it because that would fix the index
  // on the current tensor size in the tracer.
  // aten::select also works on negative indices
  return self.select(dim, index);

static inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
  // booleans add a dimension of size 1. true indexes this dimension as if 0:, false as empty.
  if (value) {
    return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);
  } else {
    return at::empty({0}, {}, self.options().dtype(kLong));

static inline Tensor boolToIndexingTensorNonNativeDeviceType(const Tensor& self, bool value) {
  // booleans add a dimension of size 1. true indexes this dimension as if 0:, false as empty.
  if (value) {
    return at::zeros({1}, {}, self.options().dtype(kLong));
  } else {
    return at::empty({0}, {}, self.options().dtype(kLong));

static inline Tensor boolToIndexingTensor(const Tensor& self, bool value, const at::Device& self_device) {
  if (self_device == at::kCPU || self_device == at::kCUDA) {
    return boolToIndexingTensorCPUOrCUDA(self, value);
  } else {
    return boolToIndexingTensorNonNativeDeviceType(self, value);

static inline Tensor scalarToTensorNonNativeDeviceType(Scalar v, const TensorOptions& options) {
  return at::scalar_tensor(v, options);

static inline void recordTensorIndex(const Tensor& tensor, std::vector<Tensor>& outIndices, int64_t* dim_ptr) {
  // TODO: check scalarType
  outIndices.resize(*dim_ptr + 1);
  outIndices[*dim_ptr] = tensor;

static inline c10::List<c10::optional<Tensor>> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
  c10::List<c10::optional<Tensor>> converted_inds;
  for (size_t i = 0; i < indices.size(); ++i) {
    const auto &ind = indices[i];
    if (ind.defined()) {
    } else {
  return converted_inds;

// NOTE: Why do we mirror instead of replace the `count_specified_dimensions` function
// in torch/csrc/autograd/python_variable_indexing.cpp? It's because
// `count_specified_dimensions` is on the hot path of Python tensor multi-dim indexing
// (i.e. it's called by `applySlicing` which is called by `THPVariable_getitem` /
// `THPVariable_setitem` when handling indexing of more than one dimension). If we were
// to merge the Python/C++ `count_specified_dimensions` function, on the Python side
// we would have to construct a `std::vector` container to be consumed by the C++
// `count_specified_dimensions` function, which adds 100s of nanoseconds overhead and
// is undesirable.
static inline int64_t count_specified_dimensions(const ArrayRef<TensorIndex>& indices) {
  // Count the number of indexed dimensions (everything but ellipsis and None)
  int64_t count = 0;
  for (auto& obj : indices) {
    if (obj.is_tensor()) {
      auto& tensor = obj.tensor();
      if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
        count += tensor.dim();
      } else {
    } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
  return count;
} // namespace impl

// NOTE: Many functions below are only for consumption from Python indexing
// implementation, they include:
// - `Tensor scalarToTensor(...)`
// - `IntArrayRef slicePrefix1sSize(...)`
// - `void copy_to(...)`
// - `Tensor handleDimInMultiDimIndexing(...)`
// - `Tensor dispatch_index(...)`
// - `Tensor dispatch_index_put_(...)`
// - `Tensor get_item(...)`
// - `void set_item(...)`
// The rest of the functions are in `at::indexing::impl` namespace, signifying
// that they shouldn't be used from Python indexing implementation.
static inline Tensor scalarToTensor(Scalar v, const TensorOptions& options, const at::Device& self_device) {
  if (self_device == at::kCPU) {
    return at::detail::scalar_tensor_static(v, options.dtype_opt()->toScalarType(), self_device);
  } else {
    return impl::scalarToTensorNonNativeDeviceType(v, options);

// To match numpy semantics:
// As a special case for backwards compatibility,
// strip away unit dimensions from the left of 'src'
static inline IntArrayRef slicePrefix1sSize(const IntArrayRef& sizes) {
  size_t first_non1_src = sizes.size();
  for (size_t i = 0; i < sizes.size(); ++i) {
    if (sizes[i] != 1) {
      first_non1_src = i;

  return sizes.slice(first_non1_src);

static inline void copy_to(const Tensor& dst, const Tensor& src) {
