Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / runtime / register_ops_utils.h

#pragma once

#include <ATen/Context.h>
#include <c10/core/DeviceType.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/logging.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <torch/csrc/jit/serialization/pickle.h>

#include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/Dict.h>
#include <ATen/core/ivalue.h>
#include <c10/core/thread_pool.h>
#include <c10/util/SmallVector.h>
#include <c10/util/irange.h>
#include <c10/util/math_compat.h>
#include <c10/util/string_utils.h>

namespace torch {
namespace jit {
constexpr inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
  return c10::AliasAnalysisKind::FROM_SCHEMA;
}

constexpr inline c10::AliasAnalysisKind aliasAnalysisConservative() {
  return c10::AliasAnalysisKind::CONSERVATIVE;
}

constexpr inline c10::AliasAnalysisKind aliasAnalysisSpecialCase() {
  return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
}

template <class T>
c10::List<T> make_result_list(const TypePtr& elemType) {
  return c10::List<T>();
}

template <>
c10::impl::GenericList make_result_list<IValue>(const TypePtr& elemType);

// As described in https://docs.python.org/3/library/functions.html#round
// When a number is exactly halfway between two integers, python builtin round
// function will round to even number. We use round(x/2)*2 to handle the
// special halfway case. For positive 'x', round(x/2)*2 =
// round((x_e + x_r)/2)*2 = x_e + round(x_r/2)*2, where x_e is an even integer,
// x_r is either 0.5 of 1.5, round(x_r/2)*2 results a 0 or 2, so the final
// result will always be a even number. Due to symmetricity, it also applies to
// negative cases.
inline double round_to_even(double a) {
  return a - std::floor(a) == 0.5 ? (std::round(a * 0.5) * 2.0) : std::round(a);
}

// using the rules from python_arg_parser FunctionParameter::check
// tensor cannot have grad set, tensor must be 0 dim,
// and if the dest is an int the source must be integral type
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt);

static C10_UNUSED int64_t floordiv(int64_t a, int64_t b) {
  if (b == 0) {
    throw std::runtime_error("division by 0");
  }
  if ((a > 0) == (b > 0)) {
    // simple case, both have same sign
    return a / b;
  } else {
    // in python division rounds down, it doesn't not truncate like in c++
    auto r = lldiv(a, b);
    return (r.rem) ? r.quot - 1 : r.quot;
  }
}
TORCH_API void checkDoubleInRange(double a);
static C10_UNUSED int64_t floor(double a) {
  checkDoubleInRange(a);
  return std::floor(a);
}
static C10_UNUSED int64_t ceil(double a) {
  checkDoubleInRange(a);
  return std::ceil(a);
}

static C10_UNUSED int64_t gcd(int64_t a, int64_t b) {
  while (b != 0) {
    int64_t r = a % b;
    a = b;
    b = r;
  }
  // in python gcd returns non-negative values
  return std::abs(a);
}

int64_t partProduct(int n, int m);

void loop(int n, int64_t& p, int64_t& r);

int nminussumofbits(int v);

int64_t factorial(int n);
static const double degToRad = std::acos(-1.0) / 180.0;
static const double radToDeg = 180.0 / std::acos(-1.0);
double degrees(double x);
double radians(double x);

// Convert an python index (which may be negative) into an index usable for a
// C++ container

// Equivalent to list.at(idx)
template <typename T>
T getItem(const c10::List<T>& list, int64_t idx) {
  const int64_t list_size = list.size();
  const int64_t normalized_idx = normalizeIndex(idx, list_size);
  if (normalized_idx < 0 || normalized_idx >= list_size) {
    throw std::out_of_range("list index out of range");
  }
  return list.get(normalized_idx);
}

template <typename T>
void setItem(const c10::List<T>& list, int64_t idx, T&& value) {
  const int64_t list_size = list.size();
  const int64_t normalized_idx = normalizeIndex(idx, list_size);
  if (normalized_idx < 0 || normalized_idx >= list_size) {
    throw std::out_of_range("list index out of range");
  }
  list.set(normalized_idx, std::forward<T>(value));
}

void listAppend(Stack& stack);

void listReverse(Stack& stack);

template <typename T>
void minList(Stack& stack) {
  c10::List<T> a = pop(stack).to<c10::List<T>>();
  c10::List<T> b = pop(stack).to<c10::List<T>>();

  size_t min_size = std::min(a.size(), b.size());
  for (const auto i : c10::irange(min_size)) {
    if (a[i] == b[i]) {
      continue;
    }

    push(stack, a[i] < b[i] ? a : b);
    return;
  }

  push(stack, b.size() < a.size() ? b : a);
}

template <typename T>
void maxList(Stack& stack) {
  c10::List<T> a = pop(stack).to<c10::List<T>>();
  c10::List<T> b = pop(stack).to<c10::List<T>>();

  size_t min_size = std::min(a.size(), b.size());
  for (const auto i : c10::irange(min_size)) {
    if (a[i] == b[i]) {
      continue;
    }

    push(stack, a[i] > b[i] ? a : b);
    return;
  }

  push(stack, b.size() > a.size() ? b : a);
}

void listPopImpl(Stack& stack, const char* empty_message);

void listPop(Stack& stack);

void listClear(Stack& stack);

void listDelete(Stack& stack);

void listInsert(Stack& stack);

template <typename T>
void listRemove(Stack& stack) {
  T elem = pop(stack).to<T>();
  c10::List<T> list = pop(stack).to<c10::List<T>>();

  auto pos = std::find(list.begin(), list.end(), elem);

  if (pos != list.end()) {
    list.erase(pos);
  } else {
    AT_ERROR("list.remove(x): x not in list");
  }
}

template <typename T>
void listMin(Stack& stack) {
  c10::List<T> list = pop(stack).to<c10::List<T>>();
  size_t list_size = list.size();
  if (list_size == 0) {
    throw std::runtime_error("min() arg is an empty sequence");
  }

  T min_elem = list[0];
  for (const auto i : c10::irange(1, list_size)) {
    T elem = list[i];
    min_elem = elem < min_elem ? elem : min_elem;
  }

  stack.push_back(min_elem);
}

template <typename T>
void listMax(Stack& stack) {
  c10::List<T> list = pop(stack).to<c10::List<T>>();
  size_t list_size = list.size();
  if (list_size == 0) {
    throw std::runtime_error("max() arg is an empty sequence");
  }

  T max_elem = list[0];
  for (const auto i : c10::irange(1, list_size)) {
    T elem = list[i];
    max_elem = elem > max_elem ? elem : max_elem;
  }

  stack.push_back(max_elem);
}

template <>
void listRemove<at::Tensor>(Stack& stack);

template <typename T>
void listIndex(Stack& stack) {
  T elem = pop(stack).to<T>();
  c10::List<T> list = pop(stack).to<c10::List<T>>();

  auto pos = std::find(list.begin(), list.end(), elem);

  if (pos != list.end()) {
    push(stack, static_cast<int64_t>(std::distance(list.begin(), pos)));
  } else {
    AT_ERROR("'", elem, "' is not in list");
  }
}

template <>
void listIndex<at::Tensor>(Stack& stack);

template <typename T>
void listCount(Stack& stack) {
  T elem = pop(stack).to<T>();
  c10::List<T> list = pop(stack).to<c10::List<T>>();

  const int64_t count = std::count(list.begin(), list.end(), elem);
  push(stack, count);
}

template <>
void listCount<at::Tensor>(Stack& stack);

void listExtend(Stack& stack);

void listCopy(Stack& stack);

void listSelect(Stack& stack);

void listLen(Stack& stack);

template <typename T>
void listEq(Stack& stack) {
  c10::List<T> b = pop(stack).to<c10::List<T>>();
  c10::List<T> a = pop(stack).to<c10::List<T>>();
  push(stack, a == b);
}

template <typename T>
void listNe(Stack& stack) {
  c10::List<T> b = pop(stack).to<c10::List<T>>();
  c10::List<T> a = pop(stack).to<c10::List<T>>();
  push(stack, a != b);
}

inline bool tensor_list_equal(
    const c10::List<at::Tensor>& a,
    const c10::List<at::Tensor>& b) {
  if (a.size() != b.size()) {
    return false;
  }

  for (const auto i : c10::irange(a.size())) {
    const at::Tensor& a_element = a[i];
    const at::Tensor& b_element = b[i];
    // This preserves Python's semantics, which uses eq() to compare two
    // elements, then passes the result to bool().
    // see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
    const auto cmp_result = a_element.eq(b_element);
    if (!at::native::is_nonzero(cmp_result)) {
      return false;
    }
  }

  return true;
}

// Specialization for at::Tensor, since it doesn't define operator==
template <>
void listEq<at::Tensor>(Stack& stack);

// Specialization for at::Tensor, since it doesn't define operator==
template <>
void listNe<at::Tensor>(Stack& stack);

void listList(Stack& stack);

template <typename T>
void listContains(Stack& stack) {
  auto key = pop(stack).to<T>();
  auto list = pop(stack).to<c10::List<T>>();
  // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
  for (const T& item : list) {
    if (item == key) {
      push(stack, true);
      return;
    }
  }
  push(stack, false);
}

void listAdd(Stack& stack);

void listInplaceAdd(Stack& stack);

void listMulIntLeftInPlace(Stack& stack);

void listMulIntLeft(Stack& stack);

void listMulIntRight(Stack& stack);

void listSlice(Stack& stack);

template <typename T>
void listSort(Stack& stack) {
  bool reverse = pop(stack).toBool();
  c10::List<T> list = pop(stack).to<c10::List<T>>();
  std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) {
    // FBCode errors without this check - "strict weak ordering"
    // TODO: remove when possible, since it just slows down
    // sorting and doesn't do anything useful
    if (a == b) {
      return false;
    }
    return (a < b) != reverse;
  });
}

// Specialization for at::Tensor
template <>
void listSort<at::Tensor>(Stack& stack);

template <typename T>
void listCopyAndSort(Stack& stack) {
  c10::List<T> list = pop(stack).to<c10::List<T>>();
  auto list_copied = list.copy();
  std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) {
    // "strict weak ordering" issue - see other sort
    if (a == b) {
      return false;
    }
    return a < b;
  });
  push(stack, list_copied);
}

// Specialization for at::Tensor
template <>
void listCopyAndSort<at::Tensor>(Stack& stack);

void listSetItem(Stack& stack);

struct OperatorGeneratorArgs {
  const char* schema_str;
  bool isOperationCreator;
  union {
    void (*operation)(Stack&);
    OperationCreator operationCreator;
  };
  AliasAnalysisKind aliasAnalysis;

  explicit constexpr OperatorGeneratorArgs(
      torch::detail::SelectiveStr<true> schema_str,
      void (*op)(Stack&),
      AliasAnalysisKind aa)
      : schema_str(schema_str),
        isOperationCreator(false),
        operation(op),
        aliasAnalysis(aa) {}

  explicit constexpr OperatorGeneratorArgs(
      torch::detail::SelectiveStr<true> schema_str,
      OperationCreator opCreator,
      AliasAnalysisKind aa)
      : schema_str(schema_str),
        isOperationCreator(true),
        operationCreator(opCreator),
        aliasAnalysis(aa) {}

  template <typename... Args>
  explicit constexpr OperatorGeneratorArgs(
      torch::detail::SelectiveStr<false>,
      Args...)
      : schema_str(nullptr),
        isOperationCreator(false),
        operation(nullptr),
        aliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE) {}
};

#define DEFINE_GENERIC_BINARY_OP(                                             \
    aten_op, op, int_float_result, complex_result)                            \
  OperatorGeneratorArgs(                                                      \
      TORCH_SELECTIVE_SCHEMA(#aten_op                                         \
                             ".int_int(int a, int b) -> " #int_float_result), \
      [](Stack& stack) {                                                      \
        int64_t a, b;                                                         \
        pop(stack, a, b);                                                     \
        push(stack, op);                                                      \
      },                                                                      \
      aliasAnalysisFromSchema()),                                             \
      OperatorGeneratorArgs(                                                  \
          TORCH_SELECTIVE_SCHEMA(                                             \
              #aten_op                                                        \
              ".float_float(float a, float b) -> " #int_float_result),        \
          [](Stack& stack) {                                                  \
            double a, b;                                                      \
            pop(stack, a, b);                                                 \
            push(stack, op);                                                  \
          },                                                                  \
          aliasAnalysisFromSchema()),                                         \
      OperatorGeneratorArgs(                                                  \
          TORCH_SELECTIVE_SCHEMA(                                             \
              #aten_op                                                        \
              ".complex_complex(complex a, complex b) -> " #complex_result),  \
          [](Stack& stack) {                                                  \
            c10::complex<double> a, b;                                        \
            pop(stack, a, b);                                                 \
            push(stack, op);                                                  \
          },                                                                  \
          aliasAnalysisFromSchema())

// define implementations for primitive number ops
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
  OperatorGeneratorArgs(                                                       \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result),   \
      [](Stack& stack) {                                                       \
        int64_t a, b;                                                          \
        pop(stack, a, b);                                                      \
        push(stack, int_op);                                                   \
      },                                                                       \
      aliasAnalysisFromSchema()),                                              \
      OperatorGeneratorArgs(                                                   \
          TORCH_SELECTIVE_SCHEMA(                                              \
              #aten_op ".float(float a, float b) -> " #float_result),          \
          [](Stack& stack) {                                                   \
            double a, b;                                                       \
            pop(stack, a, b);                                                  \
            push(stack, float_op);                                             \
          },                                                                   \
          aliasAnalysisFromSchema())

#define DEFINE_INT_FLOAT_OP(aten_op, op, result)                            \
  OperatorGeneratorArgs(                                                    \
      TORCH_SELECTIVE_SCHEMA(#aten_op                                       \
                             ".int_float(int a, float b) -> " #result),     \
      [](Stack& stack) {                                                    \
        int64_t a;                                                          \
        double b;                                                           \
        pop(stack, a, b);                                                   \
        push(stack, op);                                                    \
      },                                                                    \
      aliasAnalysisFromSchema()),                                           \
      OperatorGeneratorArgs(                                                \
          TORCH_SELECTIVE_SCHEMA(#aten_op                                   \
                                 ".float_int(float a, int b) -> " #result), \
          [](Stack& stack) {                                                \
            double a;                                                       \
            int64_t b;                                                      \
            pop(stack, a, b);                                               \
            push(stack, op);                                                \
          },                                                                \
          aliasAnalysisFromSchema())

#define DEFINE_INT_OP(aten_op, op)                                  \
  OperatorGeneratorArgs(                                            \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> int"), \
      [](Stack& stack) {                                            \
        int64_t a, b;                                               \
        pop(stack, a, b);                                           \
        push(stack, op); /* NOLINT(hicpp-signed-bitwise) */         \
      },                                                            \
      aliasAnalysisFromSchema())

#define DEFINE_STR_CMP_OP(aten_op, op)                               \
  OperatorGeneratorArgs(                                             \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".str(str a, str b) -> bool"), \
      [](Stack& stack) {                                             \
        auto b = pop(stack).toStringRef();                           \
        auto a = pop(stack).toStringRef();                           \
        push(stack, op);                                             \
      },                                                             \
      aliasAnalysisFromSchema())

// define a primitive op over Scalar operands.
// it's necessary to register this overload following
// int/float variations to avoid trapping Scalar args
// in unintended implicit conversions
#define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC(          \
    aten_op, int_op, float_op, result, string_val)                \
  OperatorGeneratorArgs(                                          \
      TORCH_SELECTIVE_SCHEMA(#aten_op string_val                  \
                             "(Scalar a, Scalar b) -> " #result), \
      [](Stack& stack) {                                          \
        IValue x, y;                                              \
        pop(stack, x, y);                                         \
        if (x.isDouble()) {                                       \
          if (y.isDouble()) {                                     \
            double a = x.toDouble();                              \
            double b = y.toDouble();                              \
            push(stack, float_op);                                \
          } else {                                                \
            double a = x.toDouble();                              \
            int64_t b = y.toInt();                                \
            push(stack, float_op);                                \
          }                                                       \
        } else {                                                  \
          if (y.isDouble()) {                                     \
            int64_t a = x.toInt();                                \
            double b = y.toDouble();                              \
            push(stack, float_op);                                \
          } else {                                                \
            int64_t a = x.toInt();                                \
            int64_t b = y.toInt();                                \
            push(stack, int_op);                                  \
          }                                                       \
        }                                                         \
      },                                                          \
      aliasAnalysisFromSchema())

#define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \
  DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC(                 \
      aten_op, int_op, float_op, result, "")

#define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION(   \
    aten_op, int_op, float_op, result)             \
  DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
      aten_op, int_op, float_op, result, ".Scalar_Scalar")

#define DEFINE_BINARY_OP(aten_op, op)             \
  DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
      DEFINE_INT_FLOAT_OP(aten_op, op, float),    \
      DEFINE_SCALAR_BINARY_OP(aten_op, op, op, Scalar)

#define DEFINE_BINARY_FLOAT_OP(aten_op, op)         \
  DEFINE_GENERIC_OP(aten_op, op, op, float, float), \
      DEFINE_INT_FLOAT_OP(aten_op, op, float),      \
      DEFINE_SCALAR_BINARY_OP(aten_op, op, op, float)

#define DEFINE_COMPARISON_OP(aten_op, op)             \
  DEFINE_GENERIC_OP(aten_op, op, op, bool, bool),     \
      DEFINE_INT_FLOAT_OP(aten_op, op, bool),         \
      DEFINE_SCALAR_BINARY_OP(aten_op, op, op, bool), \
      DEFINE_STR_CMP_OP(aten_op, op)

#define DEFINE_UNARY_INT_OP(aten_op, op, result)                  \
  OperatorGeneratorArgs(                                          \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a) -> " #result), \
      [](Stack& stack) {                                          \
        int64_t a;                                                \
        pop(stack, a);                                            \
        push(stack, op);                                          \
      },                                                          \
      aliasAnalysisFromSchema())

#define DEFINE_UNARY_FLOAT_OP(aten_op, op, result)                    \
  OperatorGeneratorArgs(                                              \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".float(float a) -> " #result), \
      [](Stack& stack) {                                              \
        double a;                                                     \
        pop(stack, a);                                                \
        push(stack, op);                                              \
      },                                                              \
      aliasAnalysisFromSchema())

#define DEFINE_UNARY_OP(aten_op, op, int_result, float_result)            \
  DEFINE_UNARY_INT_OP(aten_op, op, int_result),                           \
      DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result),                   \
      OperatorGeneratorArgs(                                              \
          TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
          [](Stack& stack) {                                              \
            IValue x;                                                     \
            pop(stack, x);                                                \
            if (x.isDouble()) {                                           \
              double a = x.toDouble();                                    \
              push(stack, static_cast<float_result>(op));                 \
            } else {                                                      \
              int64_t a = x.toInt();                                      \
              push(stack, static_cast<int_result>(op));                   \
            }                                                             \
          },                                                              \
          aliasAnalysisFromSchema())
#define DEFINE_BOOL_OP(aten_op, op)                                     \
  OperatorGeneratorArgs(                                                \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".bool(bool a, bool b) -> bool"), \
      [](Stack& stack) {                                                \
        bool a, b;                                                      \
        pop(stack, a, b);                                               \
        push(stack, op);                                                \
      },                                                                \
      aliasAnalysisFromSchema())
#define DEFINE_STRING_OP(op_name, string_op, result)                    \
  OperatorGeneratorArgs(                                                \
      TORCH_SELECTIVE_SCHEMA(#op_name ".str(str a, str b) ->" #result), \
      [](Stack& stack) {                                                \
        auto b = pop(stack).toStringRef();                              \
        auto a = pop(stack).toStringRef();                              \
        push(stack, string_op);                                         \
      },                                                                \
      aliasAnalysisFromSchema())

//-----------------------------------------------------------------------------
//-----------------------------------------------------------------------------
//-----------------------------------------------------------------------------
//-----------------------------------------------------------------------------
#define DEFINE_UNARY_COMPLEX_OP(aten_op, op, result)                      \
  OperatorGeneratorArgs(                                                  \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".complex(complex a) -> " #result), \
      [](Stack& stack) {                                                  \
        c10::complex<double> a;                                           \
        pop(stack, a);                                                    \
        push(stack, op);                                                  \
      },                                                                  \
      aliasAnalysisFromSchema())

// Some complex unary ops (like abs, angle) return real valued output, but most
// other unary ops return complex valued output. So, this macro is used in the
// former case where we can explicitly pass complex_result_cast argument, which
// is set to c10::complex<float> in the macro `DEFINE_UNARY_OP_WITH_COMPLEX`
// defined below.
#define DEFINE_UNARY_OP_WITH_COMPLEX_CAST(                                \
    aten_op,                                                              \
    op,                                                                   \
    int_result,                                                           \
    float_result,                                                         \
    complex_result,                                                       \
    complex_result_cast)                                                  \
  DEFINE_UNARY_INT_OP(aten_op, op, int_result),                           \
      DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result),                   \
      DEFINE_UNARY_COMPLEX_OP(aten_op, op, complex_result),               \
      OperatorGeneratorArgs(                                              \
          TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
          [](Stack& stack) {                                              \
            IValue x;                                                     \
            pop(stack, x);                                                \
            if (x.isDouble()) {                                           \
              double a = x.toDouble();                                    \
              push(stack, static_cast<float_result>(op));                 \
            } else if (x.isComplexDouble()) {                             \
              c10::complex<double> a = x.toComplexDouble();               \
              push(stack, static_cast<complex_result_cast>(op));          \
            } else {                                                      \
              int64_t a = x.toInt();                                      \
              push(stack, static_cast<int_result>(op));                   \
            }                                                             \
          },                                                              \
          aliasAnalysisFromSchema())

#define DEFINE_UNARY_OP_WITH_COMPLEX(aten_op, op, int_result, float_result) \
  DEFINE_UNARY_OP_WITH_COMPLEX_CAST(                                        \
      aten_op, op, int_result, float_result, complex, c10::complex<double>)

#define DEFINE_GENERIC_OP_WITH_COMPLEX(                                       \
    aten_op,                                                                  \
    int_op,                                                                   \
    float_op,                                                                 \
    complex_op,                                                               \
    int_result,                                                               \
    float_result,                                                             \
    complex_result)                                                           \
  OperatorGeneratorArgs(                                                      \
      TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result),  \
      [](Stack& stack) {                                                      \
        int64_t a, b;                                                         \
        pop(stack, a, b);                                                     \
        push(stack, int_op);                                                  \
      },                                                                      \
      aliasAnalysisFromSchema()),                                             \
      OperatorGeneratorArgs(                                                  \
          TORCH_SELECTIVE_SCHEMA(                                             \
              #aten_op ".complex(complex a, complex b) -> " #complex_result), \
          [](Stack& stack) {                                                  \
            c10::complex<double> a, b;                                        \
            pop(stack, a, b);                                                 \
            push(stack, complex_op);                                          \
          },                                                                  \
          aliasAnalysisFromSchema()),                                         \
      OperatorGeneratorArgs(                                                  \
          TORCH_SELECTIVE_SCHEMA(                                             \
              #aten_op ".float(float a, float b) -> " #float_result),         \
          [](Stack& stack) {                                                  \
            double a, b;                                                      \
            pop(stack, a, b);                                                 \
            push(stack, float_op);                                            \
          },                                                                  \
          aliasAnalysisFromSchema())

#define DEFINE_INT_COMPLEX_OP(aten_op, op, result)                          \
  OperatorGeneratorArgs(                                                    \
      TORCH_SELECTIVE_SCHEMA(#aten_op                                       \
                             ".int_complex(int a, complex b) -> " #result), \
      [](Stack& stack) {                                                    \
        int64_t a;                                                          \
        c10::complex<double> b;                                             \
        pop(stack, a, b);                                                   \
        push(stack, op);                                                    \
      },                                                                    \
      aliasAnalysisFromSchema()),                                           \
      OperatorGeneratorArgs(                                                \
          TORCH_SELECTIVE_SCHEMA(                                           \
              #aten_op ".complex_int(complex a, int b) -> " #result),       \
          [](Stack& stack) {                                                \
            c10::complex<double> a;                                         \
            int64_t b;                                                      \
            pop(stack, a, b);                                               \
            push(stack, op);                                                \
          },                                                                \
          aliasAnalysisFromSchema())

#define DEFINE_FLOAT_COMPLEX_OP(aten_op, op, result)                      \
  OperatorGeneratorArgs(                                                  \
      TORCH_SELECTIVE_SCHEMA(                                             \
          #aten_op ".float_complex(float a, complex b) -> " #result),     \
      [](Stack& stack) {                                                  \
        double a;                                                         \
        c10::complex<double> b;                                           \
        pop(stack, a, b);                                                 \
        push(stack, op);                                                  \
      },                                                                  \
      aliasAnalysisFromSchema()),                                         \
      OperatorGeneratorArgs(                                              \
          TORCH_SELECTIVE_SCHEMA(                                         \
              #aten_op ".complex_float(complex a, float b) -> " #result), \
          [](Stack& stack) {                                              \
            c10::complex<double> a;                                       \
            double b;                                                     \
            pop(stack, a, b);                                             \
            push(stack, op);                                              \
          },                                                              \
          aliasAnalysisFromSchema())

#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
    aten_op, int_op, float_op, complex_op, result, string_val)        \
  OperatorGeneratorArgs(                                              \
      TORCH_SELECTIVE_SCHEMA(#aten_op string_val                      \
                             "(Scalar a, Scalar b) -> " #result),     \
      [](Stack& stack) {                                              \
        IValue x, y;                                                  \
        pop(stack, x, y);                                             \
        if (x.isComplexDouble()) {                                    \
          c10::complex<double> a = x.toComplexDouble();               \
          if (y.isComplexDouble()) {                                  \
            c10::complex<double> b = y.toComplexDouble();             \
            push(stack, complex_op);                                  \
          } else if (y.isDouble()) {                                  \
            double b = y.toDouble();                                  \
            push(stack, complex_op);                                  \
          } else {                                                    \
            int64_t b = y.toInt();                                    \
            push(stack, complex_op);                                  \
          }                                                           \
        } else if (x.isDouble()) {                                    \
          double a = x.toDouble();                                    \
          if (y.isComplexDouble()) {                                  \
            c10::complex<double> b = y.toComplexDouble();             \
            push(stack, complex_op);                                  \
          } else if (y.isDouble()) {                                  \
            double b = y.toDouble();                                  \
            push(stack, float_op);                                    \
          } else {                                                    \
            int64_t b = y.toInt();                                    \
            push(stack, float_op);                                    \
          }                                                           \
        } else {                                                      \
          int64_t a = x.toInt();                                      \
          if (y.isComplexDouble()) {                                  \
            c10::complex<double> b = y.toComplexDouble();             \
            push(stack, complex_op);                                  \
          } else if (y.isDouble()) {                                  \
            double b = y.toDouble();                                  \
            push(stack, float_op);                                    \
          } else {                                                    \
            int64_t b = y.toInt();                                    \
            push(stack, int_op);                                      \
          }                                                           \
        }                                                             \
      },                                                              \
      aliasAnalysisFromSchema())

#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR(     \
    aten_op, int_op, float_op, complex_op, result)                         \
  OperatorGeneratorArgs(                                                   \
      TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \
      [](Stack& stack) {                                                   \
        IValue x, y;                                                       \
        pop(stack, x, y);                                                  \
        if (x.isComplexDouble()) {                                         \
          c10::complex<double> a = x.toComplexDouble();                    \
          if (y.isComplexDouble()) {                                       \
            c10::complex<double> b = y.toComplexDouble();                  \
            push(stack, complex_op);                                       \
          } else if (y.isDouble()) {                                       \
            double b = y.toDouble();                                       \
            push(stack, complex_op);                                       \
          }                                                                \
        } else if (x.isDouble()) {                                         \
          double a = x.toDouble();                                         \
          if (y.isComplexDouble()) {                                       \
            c10::complex<double> b = y.toComplexDouble();                  \
            push(stack, complex_op);                                       \
          } else if (y.isDouble()) {                                       \
            double b = y.toDouble();                                       \
            push(stack, float_op);                                         \
          } else {                                                         \
            int64_t b = y.toInt();                                         \
            push(stack, float_op);                                         \
          }                                                                \
        } else {                                                           \
          int64_t a = x.toInt();                                           \
          if (y.isDouble()) {                                              \
            double b = y.toDouble();                                       \
            push(stack, float_op);                                         \
          } else if (y.isInt()) {                                          \
            int64_t b = y.toInt();                                         \
            push(stack, int_op);                                           \
          }                                                                \
        }                                                                  \
      },                                                                   \
      aliasAnalysisFromSchema())

#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX(                   \
    aten_op, int_op, float_op, complex_op, result)              \
  DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
      aten_op, int_op, float_op, complex_op, result, "")

#define DEFINE_BINARY_OP_WITH_COMPLEX(aten_op, op)                          \
  DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, int, float, complex), \
      DEFINE_INT_COMPLEX_OP(aten_op, op, complex),                          \
      DEFINE_FLOAT_COMPLEX_OP(aten_op, op, complex),                        \
      DEFINE_INT_FLOAT_OP(aten_op, op, float),                              \
      DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX(aten_op, op, op, op, Scalar)

#define DEFINE_COMPARISON_OP_WITH_COMPLEX(aten_op, op)                   \
  DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, bool, bool, bool), \
      DEFINE_INT_FLOAT_OP(aten_op, op, bool),                            \
      DEFINE_FLOAT_COMPLEX_OP(aten_op, op, bool),                        \
      DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR(     \
          aten_op, op, op, op, bool),                                    \
      DEFINE_STR_CMP_OP(aten_op, op)

} // namespace jit
} // namespace torch