Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / utils / python_arg_parser.h

#pragma once

// Parse arguments to Python functions implemented in C++
// This is similar to PyArg_ParseTupleAndKeywords(), but specifically handles
// the types relevant to PyTorch and distinguishes between overloaded function
// signatures.
//
// Example:
//
//   static PythonArgParser parser({
//     "norm(Scalar p, int64_t dim, bool keepdim=False)",
//     "norm(Scalar p=2)",
//   });
//   ParsedArgs<3> parsed_args;
//   auto r = parser.parse(args, kwargs, parsed_args);
//   if (r.idx == 0) {
//     norm(r.scalar(0), r.int64(1), r.bool(0));
//   } else {
//     norm(r.scalar(0));
//   }
//
// We auto-generate most uses of PythonArgParser; the generated files
// are torch/csrc/autograd/generated/python_*.cpp
//
// Some gotchas that you should watch out for:
//
//    - Note [Order of overloads matters]
//      Order of overloads matters.  A set of input arguments may
//      bind to multiple argument specs; we will always pick the
//      first one in PythonArgParser.  However, when you are writing
//      overloads in, e.g., native_functions.yaml, you don't have to
//      worry about what order you write them, because the code
//      generation logic always gives the overloads a canonical
//      order, where Tensor overloads come first, before Scalar overloads.
//      This logic is in sort_declarations in
//      tools/autograd/gen_python_functions.py
//
//    - Zero-dim tensors (e.g., torch.tensor(2)) bind to both
//      Scalar and Tensor, UNLESS they require grad (in which case
//      they only bind to Tensor).

#include <pybind11/pytypes.h>
#include <torch/csrc/python_headers.h>

#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/MemoryFormat.h>
#include <torch/csrc/QScheme.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/python_dimname.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/python_symnode.h>
#include <torch/csrc/utils/six.h>

#include <ATen/PythonTorchFunctionTLS.h>
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>

#include <c10/core/SymFloat.h>
#include <c10/core/SymNodeImpl.h>

#include <array>
#include <cstddef>
#include <memory>
#include <sstream>
#include <string>
#include <vector>

inline bool THPUtils_checkScalar(PyObject* obj) {
#ifdef USE_NUMPY
  if (torch::utils::is_numpy_scalar(obj)) {
    return true;
  }
#endif
  return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
      torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj));
}

namespace torch {

bool should_allow_numbers_as_tensors(const std::string& name);

enum class ParameterType {
  TENSOR,
  SCALAR,
  INT64,
  SYM_INT,
  DOUBLE,
  COMPLEX,
  TENSOR_LIST,
  INT_LIST,
  GENERATOR,
  BOOL,
  STORAGE,
  PYOBJECT,
  SCALARTYPE,
  LAYOUT,
  MEMORY_FORMAT,
  DEVICE,
  STREAM,
  STRING,
  DIMNAME,
  DIMNAME_LIST,
  QSCHEME,
  FLOAT_LIST,
  SCALAR_LIST,
  SYM_INT_LIST
};

struct FunctionParameter;
struct FunctionSignature;
struct PythonArgs;

// Contains bound Python arguments in declaration order
template <int N>
struct ParsedArgs {
  ParsedArgs() : args() {}
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  PyObject* args[N];
};

struct PythonArgParser {
  explicit PythonArgParser(
      std::vector<std::string> fmts,
      bool traceable = false);

  // meant only for `torch` functions.
  template <int N>
  inline PythonArgs parse(
      PyObject* self,
      PyObject* args,
      PyObject* kwargs,
      ParsedArgs<N>& dst);

  template <int N>
  inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst);

  inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst);

  // Formatted strings of non-hidden signatures
  std::vector<std::string> get_signatures() const;

 private:
  [[noreturn]]
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  void
  print_error(
      PyObject* self,
      PyObject* args,
      PyObject* kwargs,
      PyObject* parsed_args[]);
  void check_deprecated(const FunctionSignature& signature);
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  PythonArgs raw_parse(
      PyObject* self,
      PyObject* args,
      PyObject* kwargs,
      PyObject* parsed_args[]);

  std::vector<FunctionSignature> signatures_;
  std::string function_name;
  size_t max_args;
  bool traceable;
};

struct PYBIND11_EXPORT FunctionSignature {
  explicit FunctionSignature(const std::string& fmt, int index);

  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  bool parse(
      PyObject* self,
      PyObject* args,
      PyObject* kwargs,
      PyObject* dst[],
      bool raise_exception);

  std::string toString() const;

  std::string name;
  std::vector<FunctionParameter> params;
  std::vector<py::handle> overloaded_args;
  size_t min_args;
  size_t max_args;
  size_t max_pos_args;
  int index;
  bool hidden;
  bool deprecated;
  bool disable_torch_function;
};

struct PythonArgs {
  PythonArgs(
      bool traceable,
      const FunctionSignature& signature,
      PyObject** args)
      : idx(signature.index),
        traceable(traceable),
        signature(signature),
        args(args) {}

  int idx;
  bool traceable;
  const FunctionSignature& signature;
  PyObject** args;

  inline bool has_torch_function();
  inline std::string get_func_name();
  inline at::Tensor tensor(int i);
  inline c10::optional<at::Tensor> optionalTensor(int i);
  inline at::Scalar scalar(int i);
  inline at::Scalar scalarWithDefault(int i, const at::Scalar& default_scalar);
  inline std::vector<at::Scalar> scalarlist(int i);
  inline std::vector<at::Tensor> tensorlist(int i);
  inline torch::List<c10::optional<at::Tensor>> list_of_optional_tensors(int i);
  template <int N>
  inline std::array<at::Tensor, N> tensorlist_n(int i);
  inline std::vector<int64_t> intlist(int i);
  inline std::vector<c10::SymInt> symintlist(int i);
  inline c10::OptionalArray<int64_t> intlistOptional(int i);
  inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
  inline std::vector<int64_t> intlistWithDefault(
      int i,
      std::vector<int64_t> default_intlist);
  inline c10::optional<at::Generator> generator(int i);
  inline at::Storage storage(int i);
  inline at::Storage storage(
      int i,
      at::ScalarType& storage_scalar_type,
      bool& is_typed_storage);
  inline c10::Stream stream(int i);
  inline at::ScalarType scalartype(int i);
  inline at::ScalarType scalartypeWithDefault(
      int i,
      at::ScalarType default_scalartype);
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
  inline c10::optional<at::Scalar> scalarOptional(int i);
  inline c10::optional<int64_t> toInt64Optional(int i);
  inline c10::optional<c10::SymInt> toSymIntOptional(int i);
  inline c10::optional<bool> toBoolOptional(int i);
  inline c10::optional<double> toDoubleOptional(int i);
  inline c10::OptionalArray<double> doublelistOptional(int i);
  inline std::vector<double> doublelist(int i);
  inline std::vector<double> getDoublelist(int i);
  inline at::Layout layout(int i);
  inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
  inline c10::optional<at::Layout> layoutOptional(int i);
  inline at::Device device(int i);
  inline at::Device deviceWithDefault(int i, const at::Device& default_device);
  inline c10::optional<at::Device> deviceOptional(int i);
  inline at::Dimname dimname(int i);
  inline std::vector<at::Dimname> dimnamelist(int i);
  inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
  inline at::MemoryFormat memoryformat(int i);
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
  inline at::QScheme toQScheme(int i);
  inline std::string string(int i);
  inline std::string stringWithDefault(int i, const std::string& default_str);
  inline c10::optional<std::string> stringOptional(int i);
  inline c10::string_view stringView(int i);
  inline c10::string_view stringViewWithDefault(
      int i,
      const c10::string_view default_str);
  inline c10::optional<c10::string_view> stringViewOptional(int i);
  inline PyObject* pyobject(int i);
  inline int64_t toInt64(int i);
  inline c10::SymInt toSymInt(int i);
  inline int64_t toInt64WithDefault(int i, int64_t default_int);
  inline double toDouble(int i);
  inline double toDoubleWithDefault(int i, double default_double);
  inline c10::complex<double> toComplex(int i);
  inline c10::complex<double> toComplexWithDefault(
      int i,
      c10::complex<double> default_complex);
  inline bool toBool(int i);
  inline bool toBoolWithDefault(int i, bool default_bool);
  inline bool isNone(int i);

 private:
  at::Tensor tensor_slow(int i);
  at::Scalar scalar_slow(int i);
  at::Scalar scalar_slow(PyObject* arg);
};

struct FunctionParameter {
  FunctionParameter(const std::string& fmt, bool keyword_only);

  bool check(
      PyObject* obj,
      std::vector<py::handle>& overloaded_args,
      int argnum,
      int64_t* failed_idx = nullptr);

  void set_default_str(const std::string& str);
  std::string type_name() const;

  ParameterType type_;
  bool optional;
  bool allow_none;
  bool keyword_only;
  bool allow_numbers_as_tensors = false;
  int size;
  std::string name;
  // having this as a raw PyObject * will presumably leak it, but these are only
  // held by static objects anyway, and Py_Finalize can already be called when
  // this is destructed.
  PyObject* python_name;
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
  at::SmallVector<PyObject*, 5> numpy_python_names;
  at::Scalar default_scalar;
  std::vector<int64_t> default_intlist;
  std::string default_string;
  union {
    bool default_bool;
    int64_t default_int;
    double default_double;
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
    double default_complex[2]; // see Scalar
    at::ScalarType default_scalartype;
    at::Layout default_layout;
  };
};

template <int N>
inline PythonArgs PythonArgParser::parse(
    PyObject* self,
    PyObject* args,
    PyObject* kwargs,
    ParsedArgs<N>& dst) {
  if (N < max_args) {
    throw ValueError(
        "PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)",
        (int)max_args,
        N);
  }
  return raw_parse(self, args, kwargs, dst.args);
}

template <int N>
inline PythonArgs PythonArgParser::parse(
    PyObject* args,
    PyObject* kwargs,
    ParsedArgs<N>& dst) {
  return parse(nullptr, args, kwargs, dst);
}

inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) {
  return parse(self, nullptr, nullptr, dst);
}

inline bool PythonArgs::has_torch_function() {
  return !this->signature.overloaded_args.empty() ||
      at::impl::torch_function_mode_enabled();
}

inline std::string PythonArgs::get_func_name() {
  return signature.name;
}

// TODO: this can return MaybeOwned
inline at::Tensor PythonArgs::tensor(int i) {
  if (args[i] && THPVariable_CheckExact(args[i])) {
    return THPVariable_Unpack(args[i]);
  }
  return tensor_slow(i);
}

inline c10::optional<at::Tensor> PythonArgs::optionalTensor(int i) {
  at::Tensor t = tensor(i);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  if (t.defined()) {
    return t;
  } else {
    return c10::nullopt;
  }
}

inline at::Scalar PythonArgs::scalar(int i) {
  if (!args[i])
    return signature.params[i].default_scalar;
  return scalar_slow(i);
}

inline std::vector<at::Scalar> PythonArgs::scalarlist(int i) {
  if (!args[i])
    return std::vector<at::Scalar>();
  auto tuple = six::isTuple(args[i]);
  THPObjectPtr arg = six::maybeAsTuple(args[i]);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  std::vector<at::Scalar> res(size);
  for (const auto idx : c10::irange(size)) {
    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
                          : PyList_GET_ITEM(arg.get(), idx);
    res[idx] = scalar_slow(obj);
  }
  return res;
}

inline at::Scalar PythonArgs::scalarWithDefault(
    int i,
    const at::Scalar& default_scalar) {
  if (!args[i])
    return default_scalar;
  return scalar_slow(i);
}

inline c10::optional<at::Scalar> PythonArgs::scalarOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return scalar_slow(i);
}

inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
  if (!args[i])
    return std::vector<at::Tensor>();
  auto tuple = six::isTuple(args[i]);
  THPObjectPtr arg = six::maybeAsTuple(args[i]);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  std::vector<at::Tensor> res(size);
  for (const auto idx : c10::irange(size)) {
    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
                          : PyList_GET_ITEM(arg.get(), idx);
    // This is checked by the argument parser so it's safe to cast without
    // checking if this is a tensor first
    res[idx] = THPVariable_Unpack(obj);
  }
  return res;
}

inline torch::List<c10::optional<at::Tensor>> PythonArgs::
    list_of_optional_tensors(int i) {
  if (!args[i])
    return torch::List<c10::optional<at::Tensor>>();
  auto tuple = six::isTuple(args[i]);
  THPObjectPtr arg = six::maybeAsTuple(args[i]);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  torch::List<c10::optional<at::Tensor>> res;
  res.reserve(size);
  for (const auto idx : c10::irange(size)) {
    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
                          : PyList_GET_ITEM(arg.get(), idx);
    // This is checked by the argument parser so it's safe to cast without
    // checking if this is a tensor first
    res.push_back(THPVariable_Unpack(obj));
  }
  return res;
}

template <int N>
inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
  auto res = std::array<at::Tensor, N>();
  if (!args[i])
    return res;
  auto tuple = six::isTuple(args[i]);
  THPObjectPtr arg = six::maybeAsTuple(args[i]);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
  if (size != N) {
    throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
  }
  for (const auto idx : c10::irange(size)) {
    PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
                          : PyList_GET_ITEM(arg.get(), idx);
    // This is checked by the argument parser so it's safe to cast without
    // checking if this is a tensor first
    res[idx] = THPVariable_Unpack(obj);
  }
  return res;
}

inline std::vector<int64_t> PythonArgs::intlist(int i) {
  return intlistWithDefault(i, signature.params[i].default_intlist);
}

inline PyObject* toPyObject(c10::SymInt symint) {
  if (symint.is_symbolic()) {
    auto r = py::cast(symint).release().ptr();
    TORCH_INTERNAL_ASSERT(r);
    return r;
  } else {
    return THPUtils_packInt64(symint.as_int_unchecked());
  }
}

inline void throw_intlist_exception(
    const torch::PythonArgs* args,
    size_t i,
    PyObject* obj,
    size_t idx) {
  throw TypeError(
      "%s(): argument '%s' must be %s, but found element of type %s at pos %ld",
      args->signature.name.c_str(),
      args->signature.params[i].name.c_str(),
      args->signature.params[i].type_name().c_str(),
      Py_TYPE(obj)->tp_name,
      idx + 1);
}

inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
  if (!args[i]) {
    return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
      return c10::SymInt(di);
    });
  }

  const auto size1 = signature.params[i].size;
  if (size1 > 0 && THPUtils_checkLong(args[i])) {
    return std::vector<c10::SymInt>(
        size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
  }

  if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
    auto si = py::handle(args[i]).cast<c10::SymInt>();
    return std::vector<c10::SymInt>(size1, si);
  }

  PyObject* arg = args[i];
  auto tuple = PyTuple_Check(arg);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
  std::vector<c10::SymInt> res;
  res.reserve(size2);
  for (const auto idx : c10::irange(size2)) {
    PyObject* obj =
        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);

    // Elements of torch.Size are tensors during tracing, and we need to
    // record extra information before they are turned into an IntArrayRef
    if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
      auto& var = THPVariable_Unpack(obj);
      jit::tracer::ArgumentStash::stashIntArrayRefElem(
          signature.params[i].name, size2, idx, var);
      try {
        res.emplace_back(var.item<int64_t>());
        continue;
      } catch (std::exception& e) {
        throw_intlist_exception(this, i, obj, idx);
      }
      continue;
    } else {
      // convert tensor to scalar outside of try / catch,
      // so that Tensor subclass exceptions will not be caught.
      if (THPVariable_Check(obj)) {
        auto& var = THPVariable_Unpack(obj);
        if (var.numel() != 1 ||
            !at::isIntegralType(
                var.dtype().toScalarType(), /*include_bool*/ true)) {
          throw_intlist_exception(this, i, obj, idx);
        }
        auto scalar = var.item();
        TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
        res.push_back(scalar.toSymInt());
      } else {
        try {
          if (is_symint(py::handle(obj))) {
            res.push_back(py::handle(obj).cast<c10::SymInt>());
          } else {
            res.emplace_back(THPUtils_unpackIndex(obj));
          }
        } catch (std::exception& e) {
          throw_intlist_exception(this, i, obj, idx);
        }
      }
    }
  }

  return res;
}

inline std::vector<int64_t> PythonArgs::intlistWithDefault(
    int i,
    std::vector<int64_t> default_intlist) {
  if (!args[i])
    return default_intlist;
  PyObject* arg = args[i];
  const auto size1 = signature.params[i].size;
  if (size1 > 0 && THPUtils_checkLong(arg)) {
    return std::vector<int64_t>(size1, THPUtils_unpackIndex(arg));
  }
  auto tuple = PyTuple_Check(arg);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  std::vector<int64_t> res(size2);
  for (const auto idx : c10::irange(size2)) {
    PyObject* obj =
        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
    // Elements of torch.Size are tensors during tracing, and we need to
    // record extra information before they are turned into an IntArrayRef
    if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
      auto& var = THPVariable_Unpack(obj);
      jit::tracer::ArgumentStash::stashIntArrayRefElem(
          signature.params[i].name, size2, idx, var);
      try {
        res[idx] = var.item<int64_t>();
        continue;
      } catch (std::exception& e) {
        throw_intlist_exception(this, i, obj, idx);
      }
    } else {
      // convert tensor to scalar outside of try / catch,
      // so that Tensor subclass exceptions will not be caught.
      if (THPVariable_Check(obj)) {
        auto& var = THPVariable_Unpack(obj);
        if (var.numel() != 1 ||
            !at::isIntegralType(
                var.dtype().toScalarType(), /*include_bool*/ true)) {
          throw_intlist_exception(this, i, obj, idx);
        }
        res[idx] = var.item<int64_t>();
      } else {
        try {
          res[idx] = THPUtils_unpackIndex(obj);
        } catch (std::exception& e) {
          throw_intlist_exception(this, i, obj, idx);
        }
      }
    }
  }
  return res;
}

inline c10::OptionalArray<int64_t> PythonArgs::intlistOptional(int i) {
  if (!args[i]) {
    return {};
  }
  return intlist(i);
}

inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
  if (!args[i]) {
    return {};
  }
  return symintlist(i);
}

inline std::vector<double> PythonArgs::getDoublelist(int i) {
  PyObject* arg = args[i];
  auto tuple = PyTuple_Check(arg);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  std::vector<double> res(size);
  for (const auto idx : c10::irange(size)) {
    PyObject* obj =
        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
    try {
      res[idx] = THPUtils_unpackDouble(obj);
    } catch (const std::exception& e) {
      throw TypeError(
          "%s(): argument '%s' must be %s, but found element of type %s at pos %ld",
          signature.name.c_str(),
          signature.params[i].name.c_str(),
          signature.params[i].type_name().c_str(),
          Py_TYPE(obj)->tp_name,
          idx + 1);
    }
  }
  return res;
}

inline c10::OptionalArray<double> PythonArgs::doublelistOptional(int i) {
  if (!args[i]) {
    return {};
  }
  return this->getDoublelist(i);
}

inline std::vector<double> PythonArgs::doublelist(int i) {
  if (!args[i]) {
    return {};
  }
  return this->getDoublelist(i);
}

inline at::ScalarType PythonArgs::scalartypeWithDefault(
    int i,
    at::ScalarType default_scalartype) {
  if (!args[i])
    return default_scalartype;
  return scalartype(i);
}

inline at::ScalarType PythonArgs::scalartype(int i) {
  if (!args[i]) {
    auto scalartype = signature.params[i].default_scalartype;
    return (scalartype == at::ScalarType::Undefined)
        ? torch::tensors::get_default_scalar_type()
        : scalartype;
  }
  PyObject* obj = args[i];
  if (obj == (PyObject*)&PyFloat_Type) {
    return at::ScalarType::Double;
  }
  if (obj == (PyObject*)&PyBool_Type) {
    return at::ScalarType::Bool;
  }
  if (obj == (PyObject*)&PyLong_Type) {
    return at::ScalarType::Long;
  }
  return reinterpret_cast<THPDtype*>(obj)->scalar_type;
}

inline c10::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return scalartype(i);
}

inline at::Layout toLayout(PyObject* obj) {
  const auto layout = reinterpret_cast<THPLayout*>(obj);
  return layout->layout;
}

inline at::Layout PythonArgs::layout(int i) {
  if (!args[i])
    return signature.params[i].default_layout;
  return toLayout(args[i]);
}

inline at::Layout PythonArgs::layoutWithDefault(
    int i,
    at::Layout default_layout) {
  if (!args[i])
    return default_layout;
  return layout(i);
}

inline c10::optional<at::Layout> PythonArgs::layoutOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return layout(i);
}

inline at::Device toDevice(PyObject* obj) {
  if (THPDevice_Check(obj)) {
    const auto device = reinterpret_cast<THPDevice*>(obj);
    return device->device;
  }
  if (THPUtils_checkLong(obj)) {
    const auto device_index = THPUtils_unpackLong(obj);
    TORCH_CHECK(device_index >= 0, "Device index must not be negative");
    return at::Device(DeviceType::CUDA, device_index);
  }
  const std::string& device_str = THPUtils_unpackString(obj);
  return at::Device(device_str);
}

inline at::Device PythonArgs::device(int i) {
  if (!args[i]) {
    return torch::tensors::get_default_device();
  }
  return toDevice(args[i]);
}

inline at::Device PythonArgs::deviceWithDefault(
    int i,
    const at::Device& default_device) {
  if (!args[i])
    return default_device;
  return device(i);
}

inline c10::optional<at::Device> PythonArgs::deviceOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return device(i);
}

inline at::Dimname PythonArgs::dimname(int i) {
  TORCH_INTERNAL_ASSERT(args[i] != nullptr);
  return THPDimname_parse(args[i]);
}

inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
  auto tuple = PyTuple_Check(arg);
  // NOLINTNEXTLINE(bugprone-branch-clone)
  auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  std::vector<at::Dimname> res;
  res.reserve(size);
  for (const auto idx : c10::irange(size)) {
    PyObject* obj =
        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
    res.push_back(THPDimname_parse(obj));
  }
  return res;
}

inline c10::optional<std::vector<at::Dimname>> PythonArgs::
    toDimnameListOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return parseDimnameList(args[i]);
}

inline std::vector<at::Dimname> PythonArgs::dimnamelist(int i) {
  TORCH_INTERNAL_ASSERT(args[i]);
  PyObject* arg = args[i];
  auto size = signature.params[i].size;
  TORCH_INTERNAL_ASSERT(size == 0 || size == 1);
  if (size == 1 && THPUtils_checkDimname(arg)) {
    return {THPDimname_parse(arg)};
  }
  return parseDimnameList(arg);
}

inline at::MemoryFormat PythonArgs::memoryformat(int i) {
  if (!args[i])
    return at::MemoryFormat::Contiguous;
  TORCH_CHECK(
      THPMemoryFormat_Check(args[i]),
      "memory_format arg must be an instance of the torch.memory_format");
  const auto memory_format = reinterpret_cast<THPMemoryFormat*>(args[i]);
  return memory_format->memory_format;
}

inline c10::optional<at::MemoryFormat> PythonArgs::memoryformatOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return memoryformat(i);
}

inline at::QScheme PythonArgs::toQScheme(int i) {
  if (!args[i])
    return at::kPerTensorAffine;
  TORCH_CHECK(
      THPQScheme_Check(args[i]),
      "qscheme arg must be an instance of the torch.qscheme");
  const auto qscheme = reinterpret_cast<THPQScheme*>(args[i]);
  return qscheme->qscheme;
}

inline std::string PythonArgs::string(int i) {
  return stringWithDefault(i, signature.params[i].default_string);
}

inline std::string PythonArgs::stringWithDefault(
    int i,
    const std::string& default_str) {
  if (!args[i])
    return default_str;
  return THPUtils_unpackString(args[i]);
}

inline c10::optional<std::string> PythonArgs::stringOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return THPUtils_unpackString(args[i]);
}

inline c10::string_view PythonArgs::stringView(int i) {
  return stringViewWithDefault(i, signature.params[i].default_string);
}

inline c10::string_view PythonArgs::stringViewWithDefault(
    int i,
    const c10::string_view default_str) {
  if (!args[i])
    return default_str;
  return THPUtils_unpackStringView(args[i]);
}

inline c10::optional<c10::string_view> PythonArgs::stringViewOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return THPUtils_unpackStringView(args[i]);
}

inline int64_t PythonArgs::toInt64(int i) {
  if (!args[i])
    return signature.params[i].default_int;
  if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
    auto& var = THPVariable_Unpack(args[i]);
    jit::tracer::ArgumentStash::stashValue(
        signature.params[i].name, idx, var, c10::IntType::get());
  }
  return THPUtils_unpackLong(args[i]);
}

inline c10::SymInt PythonArgs::toSymInt(int i) {
  if (!args[i]) {
    return c10::SymInt(signature.params[i].default_int);
  }
  if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
    auto& var = THPVariable_Unpack(args[i]);
    jit::tracer::ArgumentStash::stashValue(
        signature.params[i].name, idx, var, c10::IntType::get());
  }

  return py::cast<c10::SymInt>(py::handle(args[i]));
}

inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) {
  if (!args[i])
    return default_int;
  return toInt64(i);
}

inline c10::optional<int64_t> PythonArgs::toInt64Optional(int i) {
  if (!args[i])
    return c10::nullopt;
  return toInt64(i);
}

inline c10::optional<c10::SymInt> PythonArgs::toSymIntOptional(int i) {
  if (!args[i])
    return c10::nullopt;
  return toSymInt(i);
}

inline c10::optional<bool> PythonArgs::toBoolOptional(int i) {
  if (!args[i]) {
    return c10::nullopt;
  }
  return toBool(i);
}

inline c10::optional<double> PythonArgs::toDoubleOptional(int i) {
  if (!args[i]) {
    return c10::nullopt;
  }
  return toDouble(i);
}

inline double PythonArgs::toDouble(int i) {
  if (!args[i])
    return signature.params[i].default_double;
  return THPUtils_unpackDouble(args[i]);
}

inline double PythonArgs::toDoubleWithDefault(int i, double default_double) {
  if (!args[i])
    return default_double;
  return toDouble(i);
}

inline c10::complex<double> PythonArgs::toComplex(int i) {
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
  c10::complex<double> default_value = *const_cast<c10::complex<double>*>(
      reinterpret_cast<const c10::complex<double>*>(
          signature.params[i].default_complex));
  if (!args[i])
    return default_value;
  return THPUtils_unpackComplexDouble(args[i]);
}

inline c10::complex<double> PythonArgs::toComplexWithDefault(
    int i,
    c10::complex<double> default_value) {
  if (!args[i])
    return default_value;
  return toComplex(i);
}

inline bool PythonArgs::toBool(int i) {
  if (!args[i])
    return signature.params[i].default_bool;
  return args[i] == Py_True;
}

inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) {
  if (!args[i])
    return default_bool;
  return toBool(i);
}

inline bool PythonArgs::isNone(int i) {
  return args[i] == nullptr;
}

inline c10::optional<at::Generator> PythonArgs::generator(int i) {
  if (!args[i])
    return c10::nullopt;
  return reinterpret_cast<THPGenerator*>(args[i])->cdata;
}

inline at::Storage PythonArgs::storage(int i) {
  if (!args[i])
    return at::Storage();
  return createStorage(args[i]);
}

inline at::Storage PythonArgs::storage(
    int i,
    at::ScalarType& storage_scalar_type,
    bool& is_typed_storage) {
  at::Storage storage;
  if (!args[i]) {
    storage = at::Storage();
    is_typed_storage = false;
    storage_scalar_type = at::ScalarType::Undefined;
  } else {
    storage =
        createStorageGetType(args[i], storage_scalar_type, is_typed_storage);
  }
  return storage;
}

inline c10::Stream PythonArgs::stream(int i) {
  if (!args[i])
    return c10::Stream(
        c10::Stream::Default::DEFAULT, c10::Device(DeviceType::CPU, -1));
  if (!THPStream_Check(args[i])) {
    throw TypeError(
        "expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name);
  }
  return c10::Stream::unpack3(
      ((THPStream*)args[i])->stream_id,
      ((THPStream*)args[i])->device_index,
      static_cast<DeviceType>(((THPStream*)args[i])->device_type));
}

inline PyObject* PythonArgs::pyobject(int i) {
  if (!args[i])
    return Py_None;
  return args[i];
}

/*
 *
 * Handle __torch_function__ overrides if we know that there are overloaded
 * arguments.  All objects stored in r.overloaded_args must have a
 * __torch_function__ implementation and the arguments must be ordered in order
 * of precedence. Precedence goes from left to right in the order of the
 * signature of the function the overloaded arguments were passed to, except
 * subclasses are always considered before superclasses.
 *
 * If the result of calling __torch_function__ is NotImplemented, the
 * next implementation in the precedence order is called. If all
 * arguments return NotImplemented from their __torch_function__
 * implementation, a TypeError is raised in Python.
 *
 * Assumes overloaded_args has at least one entry. All entries must have
 * a __torch_function__ attribute that resolves to a callable that
 * accepts a torch API function, a tuple of arguments, and a dict of
 * keyword arguments for the torch API function.
 *
 * It is sufficient to call PythonArgs::has_torch_function before
 * calling this function to verify that there are valid arguments
 * present. If that is not done then special care must be taken to
 * ensure there are arguments that are overloaded with
 * __torch_function__.
 *
 * See torch._overrides.handle_torch_function for the equivalent
 * code in the pure-python implementation.
 *
 * 'r' is a parsed PythonArgs instance, returned from
 * PythonArgParser::parse.
 *
 * 'args' is a reference to the python tuple of arguments to the torch
 * API function.
 *
 * 'kwargs' is a reference to the python dict of keyword arguments to
 * the torch API function.
 *
 * 'torch_api' is a reference to a python torch API namespace.
 *
 * 'torch_api_function' is the reference to the original torch method, usually,
 * we can use torch_api and func_name to get torch_api_function. In some cases,
 * e.g., torch custom op, we create the function in C++, if we still use
 * torch_api and func_name to fetch original api, a cyclic call will happen.
 *
 * 'overloaded_args' is the args which have overloaded __torch_function__.
 *
 * 'func_name' is the named of the original torch method.
 *
 * TODO: we could use different names for the following 'handle_torch_function'
 * instead of overloading.
 *
 */
// Used for Tensor methods with arguments.
auto handle_torch_function(
    PythonArgs& r,
    PyObject* self,
    PyObject* args,
    PyObject* kwargs,
    PyObject* torch_api,
    const char* module_name,
    const char* func_name_override = nullptr) -> PyObject*;

// Used for functions which needs to parse python args.
auto handle_torch_function(
    PythonArgs& r,
    PyObject* args,
    PyObject* kwargs,
    PyObject* torch_api,
    const char* module_name,
    const char* func_name_override = nullptr) -> PyObject*;

// Used for functions that have no argument parsing.
auto handle_torch_function(
    PyObject* self,
    const std::string& func_name,
    PyObject* args = nullptr,
    PyObject* kwargs = nullptr,
    PyObject* torch_api = THPVariableClass,
    const std::string& module_name = "torch.Tensor") -> PyObject*;

// Used for functions created in C++, e.g., C++ custom op, which doesn't use
// PythonArgParser to get overloaded_args.
enum class TorchFunctionName { TorchFunction, TorchDispatch };

auto TORCH_API handle_torch_function_no_python_arg_parser(
    at::ArrayRef<py::handle> overloaded_args,
    PyObject* args,
    PyObject* kwargs,
    const char* func_name,
    PyObject* torch_api_function,
    const char* module_name,
    TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction)
    -> PyObject*;

// Used for getters of Tensor properties
auto handle_torch_function_getter(
    THPVariable* self,
    const std::string& property_name) -> PyObject*;

// Used for setters of Tensor properties.
auto handle_torch_function_setter(
    THPVariable* self,
    const std::string& property_name,
    PyObject* value) -> int;

// Used for __getitem__ and __setitem__
auto handle_torch_function_indexing(
    PyObject* self,
    PyObject* index,
    PyObject* val = nullptr) -> PyObject*;

/*
 * Check if the input obj is Tensor type, including its subclass, or overloaded
 * type. If the type defines __torch_function__, it also returns true.
 * Otherwise returns flase. If the class is not torch.Tensor, and it defines
 * __torch_function__, we append obj to overloaded_args.
 *
 * 'obj': the input argument to be checked
 * 'overloaded_args': the vector to append the overloaded args.
 */
bool is_tensor_and_append_overloaded(
    PyObject* obj,
    std::vector<py::handle>* overloaded_args);

/*
 * Check if the input obj is Tensor List or Tensor Tuple type. First check
 * whether obj is Tuple or List type, if true, iterate over each element and
 * check whether it is Tensor type, including its subclass or overloaded type.
 * At the same time, the overloaded arg is appended to the overloaded_args.
 *
 * 'obj': the input argument to be checked
 * 'overloaded_args': the vector to append the overloaded args.
 * 'argnum': the number of total arguments of the function being checked.
 * 'throw_error': whether throw error if any element in the list or tuple is
 *                not tensor type or overloaded.
 */
bool is_tensor_list_and_append_overloaded(
    PyObject* obj,
    std::vector<py::handle>* overloaded_args,
    int argnum,
    bool throw_error);

/* Given an argument that is definitely a tensor and is definitely overloaded,
 * append it to the overloaded arguments list.  Use this instead of
 * is_tensor_and_append_overloaded in situations where you have a PyObject
 * and you know it definitely is a Tensor and it is definitely overloaded.
 *
 * 'overloaded_args': the vector to append the overloaded args
 * 'obj': the input tensor that is overloaded
 */
void append_overloaded_tensor(
    std::vector<py::handle>* overloaded_args,
    PyObject* obj);

/* Given an argument that is definitely a type and is definitely overloaded,
 * append it to the overloaded arguments list. Use this only with
 * __torch_dispatch__, where we operate on classes that have a
 * __torch_dispatch__ classmethod.
 *
 * 'overloaded_args': the vector to append the overloaded type
 * 'obj': the input class that has a __torch_dispatch__ classmethod.
 */
void append_overloaded_type(
    std::vector<py::handle>* overloaded_args,
    PyObject* obj);

} // namespace torch