Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / tensorexpr / ir.h

#pragma once

#include <string>
#include <utility>
#include <vector>

#include <c10/util/string_utils.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/tensorexpr/stmt.h>

#include <ATen/core/ivalue.h>

namespace torch {
namespace jit {
namespace tensorexpr {

enum CompareSelectOperation {
  kEQ = 0,
  kGT,
  kGE,
  kLT,
  kLE,
  kNE,
};

enum CompareSelectBias {
  kUnbiased,
  kLikely,
  kUnlikely,
};

inline int getPrecedence(IRNodeType ty) {
  // Match C++ operator precedence rules, since some pretty-print expressions to
  // C++. SEE: https://en.cppreference.com/w/cpp/language/operator_precedence
  switch (ty) {
    case kPrimitive:
      return 0;
    case kCast:
    case kBitCast:
      return 2;
    case kAdd:
    case kSub:
      return 6;
    case kMul:
    case kDiv:
    case kMod:
      return 5;
    case kMax:
    case kMin:
      return 99;
    case kAnd:
      return 11;
    case kOr:
      return 13;
    case kLshift:
    case kRshift:
      return 7;
    case kXor:
      return 12;
    case kCompareSelect:
      return 16;
    default:
      return 99;
  }
}

class TORCH_API Cast : public ExprNode<Cast> {
 public:
  ExprPtr src_value() const {
    return src_value_;
  }

  void set_src_value(ExprPtr src_value) {
    src_value_ = std::move(src_value);
  }

  static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
    return ExprHandle(alloc<Cast>(dtype, src_value.node()));
  }
  Cast(Dtype dtype, ExprPtr src_value)
      : ExprNodeBase(dtype, kCast), src_value_(std::move(src_value)) {}

  bool isConstant() const override {
    return src_value_->isConstant();
  }

 private:
  ExprPtr src_value_;
};

template <typename T>
ExprHandle cast(const ExprHandle& src_value) {
  return Cast::make(Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
}

// This is a bitwise cast, akin to bitcast in LLVM
class TORCH_API BitCast : public ExprNode<BitCast> {
 public:
  ExprPtr src_value() const {
    return src_value_;
  }

  void set_src_value(ExprPtr src_value) {
    src_value_ = std::move(src_value);
  }

  static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
    return ExprHandle(alloc<BitCast>(dtype, src_value.node()));
  }
  BitCast(Dtype dtype, ExprPtr src_value)
      : ExprNodeBase(dtype, kBitCast), src_value_(std::move(src_value)) {
    TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size());
  }

  bool isConstant() const override {
    return src_value_->isConstant();
  }

 private:
  ExprPtr src_value_;
};

template <typename T>
ExprHandle bitcast(const ExprHandle& src_value) {
  return BitCast::make(
      Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
}

// Represent the expression node for binary operators.
// A CRTP pattern to share common code among the operators.
template <typename Op>
class BinaryOpNode : public ExprNode<Op> {
 public:
  ExprPtr lhs() const {
    return this->lhs_;
  }
  ExprPtr rhs() const {
    return this->rhs_;
  }

  void set_lhs(ExprPtr lhs) {
    lhs_ = std::move(lhs);
  }

  void set_rhs(ExprPtr rhs) {
    rhs_ = std::move(rhs);
  }

  static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
    return ExprHandle(alloc<Op>(lhs.node(), rhs.node()));
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  BinaryOpNode(
      ExprPtr lhs_v,
      ExprPtr rhs_v,
      IRNodeType expr_type,
      ScalarType ret_type = ScalarType::Undefined)
      : ExprNode<Op>(
            // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
            BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type),
            expr_type),
        lhs_(CastIfNeeded(std::move(lhs_v), ExprNode<Op>::dtype())),
        rhs_(CastIfNeeded(std::move(rhs_v), ExprNode<Op>::dtype())) {}

 private:
  static ExprPtr CastIfNeeded(ExprPtr expr, Dtype dst_dtype) {
    if (expr->dtype() == dst_dtype) {
      return expr;
    }
    return Cast::make(dst_dtype, ExprHandle(std::move(expr))).node();
  }

  ExprPtr lhs_;
  ExprPtr rhs_;
};

namespace detail {
template <typename T>
void bin_op_deducer(BinaryOpNode<T>);
bool bin_op_deducer(...);
} // namespace detail

class TORCH_API Add : public BinaryOpNode<Add> {
 public:
  Add(ExprPtr lhs, ExprPtr rhs)
      : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAdd) {}
};

class TORCH_API Sub : public BinaryOpNode<Sub> {
 public:
  Sub(ExprPtr lhs, ExprPtr rhs)
      : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kSub) {}
};

class TORCH_API Mul : public BinaryOpNode<Mul> {
 public:
  Mul(ExprPtr lhs, ExprPtr rhs)
      : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMul) {}
};

class TORCH_API Div : public BinaryOpNode<Div> {
 public:
  Div(ExprPtr lhs, ExprPtr rhs)
      : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kDiv) {}
};

class TORCH_API Mod : public BinaryOpNode<Mod> {
 public:
  Mod(ExprPtr lhs, ExprPtr rhs)
      : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMod) {}
};

template <typename Op>
class BitwiseOpNode : public BinaryOpNode<Op> {
 public:
  BitwiseOpNode(ExprPtr lhs, ExprPtr rhs, IRNodeType type)
      : BinaryOpNode<Op>(std::move(lhs), std::move(rhs), type) {}

  static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
    if (!lhs.dtype().is_integral()) {
      throw unsupported_dtype();
    }
    if (lhs.dtype() != rhs.dtype()) {
      throw malformed_input("lhs/rhs dtype mismatch");
    }
    return BinaryOpNode<Op>::make(lhs, rhs);
  }
};

class TORCH_API And : public BitwiseOpNode<And> {
 public:
  And(ExprPtr lhs, ExprPtr rhs)
      : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAnd) {}
};

class TORCH_API Or : public BitwiseOpNode<Or> {
 public:
  Or(ExprPtr lhs, ExprPtr rhs)
      : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kOr) {}
};

class TORCH_API Xor : public BitwiseOpNode<Xor> {
 public:
  Xor(ExprPtr lhs, ExprPtr rhs)
      : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kXor) {}
};

class TORCH_API Lshift : public BitwiseOpNode<Lshift> {
 public:
  Lshift(ExprPtr lhs, ExprPtr rhs)
      : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kLshift) {}
};

class TORCH_API Rshift : public BitwiseOpNode<Rshift> {
 public:
  Rshift(ExprPtr lhs, ExprPtr rhs)
      : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kRshift) {}
};

// TODO: add TORCH_API
// Currently adding it results in a compilation error on Windows
class Max : public BinaryOpNode<Max> {
 private:
  bool propagate_nans_;

 public:
  Max(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
      : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMax),
        propagate_nans_(propagate_nans) {}

  bool propagate_nans() const {
    return propagate_nans_;
  }

  static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
  static ExprHandle make(
      const ExprHandle& lhs,
      const ExprHandle& rhs,
      bool propagate_nans) {
    return ExprHandle(alloc<Max>(lhs.node(), rhs.node(), propagate_nans));
  }
};

// TODO: add TORCH_API
// Currently adding it results in a compilation error on Windows
class Min : public BinaryOpNode<Min> {
 private:
  bool propagate_nans_;

 public:
  Min(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
      : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMin),
        propagate_nans_(propagate_nans) {}

  bool propagate_nans() const {
    return propagate_nans_;
  }

  static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
  static ExprHandle make(
      const ExprHandle& lhs,
      const ExprHandle& rhs,
      bool propagate_nans) {
    return ExprHandle(alloc<Min>(lhs.node(), rhs.node(), propagate_nans));
  }
};

// Encode typed immediate values e.g. IntImm, FloatImm.
#define IMM_DECLARE(Type, Name)                               \
  class TORCH_API Name##Imm : public ExprNode<Name##Imm> {    \
   public:                                                    \
    Name##Imm(Type value)                                     \
        : ExprNodeBase(k##Name, kPrimitive), value_(value) {} \
    bool isConstant() const override {                        \
      return true;                                            \
    }                                                         \
    Type value() const {                                      \
      return value_;                                          \
    }                                                         \
    static ExprHandle make(Type value) {                      \
      return ExprHandle(alloc<Name##Imm>(value));             \
    }                                                         \
                                                              \
   private:                                                   \
    Type value_;                                              \
  };
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE);
#undef IMM_DECLARE

// Get immediate by ScalarType.
template <typename T>
ExprPtr getImmediateByType(ScalarType immType, T initialVal) {
  switch (immType) {
#define TYPE_CASE(Type, Name) \
  case ScalarType::Name:      \
    return alloc<Name##Imm>(Type(initialVal));
    // NOLINTNEXTLINE(bugprone-branch-clone)
    AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
    default:
      throw unsupported_dtype();
  }
  return nullptr;
}

template <typename T>
ExprPtr getImmediateByType(Dtype dtype, T initialVal) {
  return getImmediateByType<T>(dtype.scalar_type(), initialVal);
}

template <typename T>
ExprPtr immLike(const ExprPtr& e, T v) {
  return getImmediateByType<T>(e->dtype(), v);
}

template <typename T>
ExprPtr immLike(const ExprHandle& e, T v) {
  return immLike(e.node(), v);
}

inline c10::optional<int64_t> intValue(const ExprPtr& e) {
#define TYPE_CASE(Type, Name)      \
  if (auto v = to<Name##Imm>(e)) { \
    return v->value();             \
  }
  AT_FORALL_INT_TYPES(TYPE_CASE);
#undef TYPE_CASE
  return c10::nullopt;
}

inline c10::optional<int64_t> intValue(const ExprHandle& e) {
  return intValue(e.node());
}

template <typename T>
T immediateAs(const ExprPtr& e) {
#define TYPE_CASE(Type, Name)                \
  if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
    return imm->value();                     \
  }
  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
  throw unsupported_dtype();
  return 0;
}

template <typename T>
T immediateAs(const ExprHandle& e) {
  return immediateAs<T>(e.node());
}

template <typename T>
bool immediateEquals(const ExprPtr& e, T val) {
#define TYPE_CASE(Type, Name)                \
  if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
    return imm->value() == val;              \
  }
  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
  throw unsupported_dtype();
  return false;
}

TORCH_API bool immediateIsNegative(const ExprPtr& e);

TORCH_API bool immediateIsPositive(const ExprPtr& e);

TORCH_API bool immediateIsZero(const ExprPtr& e);

// Represents a ramp vector node:
//     [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
class TORCH_API Ramp : public ExprNode<Ramp> {
 public:
  ExprPtr base() const {
    return base_;
  }
  ExprPtr stride() const {
    return stride_;
  }

  void set_base(ExprPtr base) {
    base_ = std::move(base);
  }

  void set_stride(ExprPtr stride) {
    stride_ = std::move(stride);
  }

  static ExprHandle make(
      const ExprHandle& base,
      const ExprHandle& stride,
      int lanes) {
    if (stride.dtype() != base.dtype()) {
      throw malformed_input("Bad stride in Ramp");
    }
    return ExprHandle(alloc<Ramp>(base.node(), stride.node(), lanes));
  }
  int lanes() const {
    return lanes_;
  }

  Ramp(ExprPtr base, ExprPtr stride, int lanes)
      : ExprNodeBase(Dtype(base->dtype(), lanes)),
        base_(std::move(base)),
        stride_(std::move(stride)),
        lanes_(lanes) {}

 private:
  ExprPtr base_;
  ExprPtr stride_;
  int lanes_;
};

class TORCH_API Load : public ExprNode<Load> {
 public:
  VarPtr base_handle() const {
    return buf_->base_handle();
  }
  std::vector<ExprPtr> indices() const {
    return indices_;
  }
  ExprPtr flat_index() const {
    TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
    return indices_[0];
  }
  BufPtr buf() const {
    return buf_;
  }

  void set_buf(BufPtr buf) {
    buf_ = std::move(buf);
  }

  void set_indices(std::vector<ExprPtr> indices) {
    indices_ = std::move(indices);
  }

  static ExprHandle make(
      Dtype dtype,
      const BufHandle& buf,
      const std::vector<ExprHandle>& indices);
  static ExprHandle make(
      const BufHandle& buf,
      const std::vector<ExprHandle>& indices);

  Load(Dtype dtype, BufPtr base_handle, std::vector<ExprPtr> indices);
  Load(BufPtr base_handle, const std::vector<ExprPtr>& indices);

 private:
  BufPtr buf_;
  std::vector<ExprPtr> indices_;
};

class TORCH_API Broadcast : public ExprNode<Broadcast> {
 public:
  ExprPtr value() const {
    return value_;
  }

  void set_value(ExprPtr value) {
    value_ = std::move(value);
  }

  int lanes() const {
    return lanes_;
  }
  static ExprHandle make(const ExprHandle& value, int lanes) {
    return ExprHandle(alloc<Broadcast>(value.node(), lanes));
  }
  Broadcast(ExprPtr value, int lanes)
      : ExprNodeBase(Dtype(value->dtype(), lanes)),
        value_(std::move(value)),
        lanes_(lanes) {}

 private:
  ExprPtr value_;
  int lanes_;
};

class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
 public:
  ExprPtr condition() const {
    return condition_;
  }

  // Lazily evaluated only if condition is true
  ExprPtr true_value() const {
    return true_;
  }

  // Lazily evaluated only if condition is false
  ExprPtr false_value() const {
    return false_;
  }

  void set_condition(ExprPtr condition) {
    condition_ = std::move(condition);
  }

  void set_true_value(ExprPtr true_value) {
    true_ = std::move(true_value);
  }

  void set_false_value(ExprPtr false_value) {
    false_ = std::move(false_value);
  }

  static ExprHandle make(
      const ExprHandle& c,
      const ExprHandle& t,
      const ExprHandle& f) {
    if (!c.dtype().is_integral()) {
      throw unsupported_dtype();
    }
    if (c.dtype().lanes() != 1) {
      throw unsupported_dtype();
    }
    if (t.dtype() != f.dtype()) {
      throw malformed_input("Bad dtype in IfThenElse");
    }
    return ExprHandle(alloc<IfThenElse>(c.node(), t.node(), f.node()));
  }

  IfThenElse(ExprPtr c, ExprPtr t, ExprPtr f)
      : ExprNodeBase(t->dtype()),
        condition_(std::move(c)),
        true_(std::move(t)),
        false_(std::move(f)) {}

 private:
  ExprPtr condition_;
  ExprPtr true_;
  ExprPtr false_;
};

class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
 public:
  CompareSelectOperation compare_select_op() const {
    return compare_op_;
  }
  ExprPtr lhs() const {
    return this->lhs_;
  }
  ExprPtr rhs() const {
    return this->rhs_;
  }
  ExprPtr ret_val1() const {
    return this->ret_val1_;
  }
  ExprPtr ret_val2() const {
    return this->ret_val2_;
  }

  void set_lhs(ExprPtr lhs) {
    lhs_ = std::move(lhs);
  }

  void set_rhs(ExprPtr rhs) {
    rhs_ = std::move(rhs);
  }

  void set_ret_val1(ExprPtr ret_val1) {
    ret_val1_ = std::move(ret_val1);
  }

  void set_ret_val2(ExprPtr ret_val2) {
    ret_val2_ = std::move(ret_val2);
  }

  CompareSelectBias bias() const {
    return bias_;
  }

  static ExprHandle make(
      const ExprHandle& lhs,
      const ExprHandle& rhs,
      CompareSelectOperation cmp_op,
      CompareSelectBias bias = kUnbiased) {
    if (lhs.dtype() != rhs.dtype()) {
      throw malformed_input("bad dtype in CompareSelect");
    }
    return ExprHandle(alloc<CompareSelect>(
        lhs.node(),
        rhs.node(),
        IntImm::make(1).node(),
        IntImm::make(0).node(),
        cmp_op,
        bias));
  }

  static ExprHandle make(
      const ExprHandle& lhs,
      const ExprHandle& rhs,
      const ExprHandle& ret_val1,
      const ExprHandle& ret_val2,
      CompareSelectOperation cmp_op,
      CompareSelectBias bias = kUnbiased) {
    if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
      throw malformed_input("bad dtype in CompareSelect");
    }
    return ExprHandle(alloc<CompareSelect>(
        lhs.node(),
        rhs.node(),
        ret_val1.node(),
        ret_val2.node(),
        cmp_op,
        bias));
  }

  CompareSelect(
      ExprPtr lhs,
      ExprPtr rhs,
      ExprPtr ret_val1,
      ExprPtr ret_val2,
      CompareSelectOperation cmp_op,
      CompareSelectBias bias = kUnbiased)
      : ExprNodeBase(ret_val1->dtype()),
        lhs_(std::move(lhs)),
        rhs_(std::move(rhs)),
        ret_val1_(std::move(ret_val1)),
        ret_val2_(std::move(ret_val2)),
        compare_op_(cmp_op),
        bias_(bias) {}

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  CompareSelect(
      ExprPtr lhs,
      ExprPtr rhs,
      CompareSelectOperation cmp_op,
      CompareSelectBias bias = kUnbiased)
      : ExprNodeBase(kInt),
        lhs_(std::move(lhs)),
        rhs_(std::move(rhs)),
        ret_val1_(alloc<IntImm>(1)),
        ret_val2_(alloc<IntImm>(0)),
        compare_op_(cmp_op),
        bias_(bias) {}

 private:
  ExprPtr lhs_;
  ExprPtr rhs_;
  ExprPtr ret_val1_;
  ExprPtr ret_val2_;
  CompareSelectOperation compare_op_;
  CompareSelectBias bias_;
};

enum IntrinsicsOp {
  kSin,
  kCos,
  kTan,
  kAsin,
  kAcos,
  kAtan,
  kAtan2,
  kSinh,
  kCosh,
  kTanh,
  kSigmoid,
  kExp,
  kExpm1,
  kAbs,
  kLog,
  kLog2,
  kLog10,
  kLog1p,
  kErf,
  kErfc,
  kSqrt,
  kRsqrt,
  kPow,
  kCeil,
  kFloor,
  kRound,
  kTrunc,
  kFmod,
  kRemainder,
  kLgamma,
  kFrac,
  kIsNan,
  kRand, // We need more discussions on this. Should we consider stateful?
  kMaxIntrinsicsOp,
};

class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
 public:
  static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) {
    return ExprHandle(alloc<Intrinsics>(op_type, v1.node()));
  }

  static ExprHandle make(
      IntrinsicsOp op_type,
      const ExprHandle& v1,
      const ExprHandle& v2) {
    return ExprHandle(alloc<Intrinsics>(op_type, v1.node(), v2.node()));
  }

  static ExprHandle make(
      IntrinsicsOp op_type,
      const std::vector<ExprHandle>& params) {
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    std::vector<ExprPtr> params_nodes(params.size());
    for (size_t i = 0; i < params.size(); i++) {
      params_nodes[i] = params[i].node();
    }
    return ExprHandle(alloc<Intrinsics>(op_type, params_nodes));
  }

  static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) {
    return ExprHandle(alloc<Intrinsics>(op_type, dtype));
  }

  IntrinsicsOp op_type() const {
    return op_type_;
  }

  std::string func_name() const {
    switch (op_type()) {
      case kSin:
        return "sin";
      case kCos:
        return "cos";
      case kTan:
        return "tan";
      case kAsin:
        return "asin";
      case kAcos:
        return "acos";
      case kAtan:
        return "atan";
      case kAtan2:
        return "atan2";
      case kSinh:
        return "sinh";
      case kCosh:
        return "cosh";
      case kTanh:
        return "tanh";
      case kSigmoid:
        return "sigmoid";
      case kExp:
        return "exp";
      case kAbs:
        return "abs";
      case kLog:
        return "log";
      case kLog2:
        return "log2";
      case kLog10:
        return "log10";
      case kLog1p:
        return "log1p";
      case kErf:
        return "erf";
      case kSqrt:
        return "sqrt";
      case kRsqrt:
        return "rsqrt";
      case kPow:
        return "pow";
      case kCeil:
        return "ceil";
      case kFloor:
        return "floor";
      case kRound:
        return "round";
      case kTrunc:
        return "trunc";
      case kRand:
        return "rand";
      case kFmod:
        return "fmod";
      case kRemainder:
        return "remainder";
      case kLgamma:
        return "lgamma";
      case kExpm1:
        return "expm1";
      case kErfc:
        return "erfc";
      case kFrac:
        return "frac";
      case kIsNan:
        return "isnan";
      default:
        throw std::runtime_error(
            "invalid op_type: " + c10::to_string(op_type()));
    }
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Intrinsics(IntrinsicsOp op_type, Dtype dtype)
      : ExprNodeBase(IntrinsicsDtype(op_type, dtype)),
        params_({}),
        op_type_(op_type) {
    if (OpArgCount(op_type) != 0) {
      throw malformed_input("bad arg count in Intrinsics");
    }
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Intrinsics(IntrinsicsOp op_type, ExprPtr v1)
      : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())),
        params_({std::move(v1)}),
        op_type_(op_type) {
    if (OpArgCount(op_type) != 1) {
      throw malformed_input("bad arg count in Intrinsics");
    }
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Intrinsics(IntrinsicsOp op_type, ExprPtr v1, ExprPtr v2)
      : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())),
        params_({std::move(v1), std::move(v2)}),
        op_type_(op_type) {
    if (OpArgCount(op_type) != 2) {
      throw malformed_input("bad arg count in Intrinsics");
    }
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Intrinsics(IntrinsicsOp op_type, const std::vector<ExprPtr>& params)
      : ExprNodeBase(IntrinsicsDtype(op_type, params)),
        params_(params),
        op_type_(op_type) {
    if (OpArgCount(op_type) != nparams()) {
      throw malformed_input("bad arg count in Intrinsics");
    }
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Intrinsics(
      IntrinsicsOp op_type,
      Dtype dtype,
      const std::vector<ExprPtr>& params)
      : ExprNodeBase(IntrinsicsDtype(op_type, dtype)),
        params_(params),
        op_type_(op_type) {
    if (OpArgCount(op_type) != nparams()) {
      throw malformed_input("bad arg count in Intrinsics");
    }
  }

  bool isPure() const {
    return op_type_ != kRand;
  }

  int nparams() const {
    return params_.size();
  }

  ExprPtr param(int index) const {
    return params_[index];
  }
  const std::vector<ExprPtr>& params() const {
    return params_;
  }

  void set_params(std::vector<ExprPtr> params) {
    params_ = std::move(params);
  }

  static int OpArgCount(IntrinsicsOp op_type);

 private:
  static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1);
  static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2);
  static Dtype IntrinsicsDtype(
      IntrinsicsOp op_type,
      const std::vector<ExprPtr>& params);

  std::vector<ExprPtr> params_;
  IntrinsicsOp op_type_;
};

TORCH_API std::vector<ExprPtr> ExprHandleVectorToExprVector(
    const std::vector<ExprHandle>&);
TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
    const std::vector<ExprPtr>&);
TORCH_API std::vector<VarPtr> VarHandleVectorToVarVector(
    const std::vector<VarHandle>&);
TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
    const std::vector<VarPtr>&);
TORCH_API ExprPtr flatten_index(
    const std::vector<ExprPtr>& dims,
    const std::vector<ExprPtr>& indices,
    const std::vector<ExprPtr>& strides);

} // namespace tensorexpr
} // namespace jit
} // namespace torch