Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / tensorexpr / expr.h

/**
 * This file implements the core classes for Tensor Expressions.
 *
 * The structure of the expressions is inspired by Halide/TVM IR.
 */
#pragma once

#include <c10/core/MemoryFormat.h>
#include <c10/util/Optional.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/types.h>

#include <utility>

namespace torch {
namespace jit {
namespace tensorexpr {

enum IRNodeType {
  kPrimitive,
  kAdd,
  kSub,
  kMul,
  kDiv,
  kMod,
  kMax,
  kMin,
  kAnd,
  kOr,
  kLshift,
  kRshift,
  kXor,
  kCompareSelect,
  kCast,
  kBitCast,
  kOther,
};

// The common base between all expression node.
class TORCH_API Expr : public std::enable_shared_from_this<Expr> {
 public:
  explicit Expr(Dtype dtype, IRNodeType expr_type = kOther)
      : dtype_(dtype), expr_type_(expr_type) {}
  virtual ~Expr() = default;
  Dtype dtype() const {
    return dtype_;
  }
  virtual void accept(IRVisitor* visitor) = 0;
  virtual ExprPtr accept_mutator(IRMutator* mutator) = 0;

  IRNodeType expr_type() const {
    return expr_type_;
  }
  // Is this a fixed (constant) immediate value.
  virtual bool isConstant() const {
    return false;
  }

  void set_dtype(Dtype dtype) {
    dtype_ = dtype;
  }

  /*
   * Make a deep copy of the given expression.
   *
   * All sub-expressions inside the given expressions are also cloned. Note
   * that the variables are not deep-copied since they are immutable.
   */
  static ExprPtr clone(ExprPtr s);

 protected:
  std::shared_ptr<Expr> getptr() {
    return shared_from_this();
  }

 private:
  Dtype dtype_;
  IRNodeType expr_type_;
};

// A CRTP pattern to accept visitors for children class,
// and dispatch back to the children.
template <class Op, class Base = Expr>
class ExprNode : public Base {
 public:
  using ExprNodeBase = ExprNode<Op>;
  void accept(IRVisitor* visitor) override {
    visitor->visit(static_to<Op>(Base::getptr()));
  }
  ExprPtr accept_mutator(IRMutator* mutator) override;
  // pass the constructor to the base class
  using Base::Base;
};

// A wrapper object to the underlying ExprNode.
// Also serves the primary way to build and operate on other expressions.
class TORCH_API ExprHandle {
 public:
  ExprHandle() = default;
  explicit ExprHandle(ExprPtr node) : base_expr_node_(std::move(node)) {}

  ExprPtr node() {
    return base_expr_node_;
  }

  ExprPtr node() const {
    return base_expr_node_;
  }

  bool empty() const {
    return base_expr_node_ == nullptr;
  }

#define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
#undef IMM_EXPR_DECLARE

  template <class Op>
  NodePtr<Op> AsNode() {
    return to<Op>(this->node());
  }

  template <class Op>
  NodePtr<Op> AsNode() const {
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
    return const_cast<ExprHandle*>(this)->AsNode<Op>();
  }

  Dtype dtype() const {
    return node()->dtype();
  }

  // Handling the math operators.
  ExprHandle operator+(const ExprHandle& other) const;
  ExprHandle operator-(const ExprHandle& other) const;
  ExprHandle operator*(const ExprHandle& other) const;
  ExprHandle operator/(const ExprHandle& other) const;
  ExprHandle operator%(const ExprHandle& other) const;
  ExprHandle operator==(const ExprHandle& other) const;
  ExprHandle operator!=(const ExprHandle& other) const;
  ExprHandle operator>(const ExprHandle& other) const;
  ExprHandle operator>=(const ExprHandle& other) const;
  ExprHandle operator<(const ExprHandle& other) const;
  ExprHandle operator<=(const ExprHandle& other) const;
  ExprHandle operator&(const ExprHandle& other) const;
  ExprHandle operator|(const ExprHandle& other) const;
  ExprHandle operator&&(const ExprHandle& other) const;
  ExprHandle operator||(const ExprHandle& other) const;
  ExprHandle operator^(const ExprHandle& other) const;
  ExprHandle operator<<(const ExprHandle& other) const;
  ExprHandle operator>>(const ExprHandle& other) const;

 private:
  ExprPtr base_expr_node_ = nullptr;
};

// The underlying representation node to a Var.
// Currently, each Var object represents a unique variable, even though the
// names might be the same. We should consider add a unique_name as well.
class TORCH_API Var : public ExprNode<Var> {
 public:
  static ExprHandle make(const std::string& name_hint, Dtype dtype) {
    return ExprHandle(alloc<Var>(name_hint, dtype));
  }
  static ExprHandle make(Dtype dtype) {
    return ExprHandle(alloc<Var>("", dtype));
  }

  // TODO: unique_name
  const std::string& name_hint() const {
    return name_hint_;
  }

  void set_name_hint(const std::string& name) {
    name_hint_ = name;
  }

  void set_name_hint(std::string&& name) {
    name_hint_ = std::move(name);
  }

  Var(std::string name_hint, Dtype dtype)
      : ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {}

 private:
  std::string name_hint_;
};

TORCH_API std::vector<ExprPtr> make_contiguous_strides(
    const std::vector<ExprHandle>& dims);
TORCH_API std::vector<ExprPtr> make_channels_last_strides(
    const std::vector<ExprHandle>& dims);

class TORCH_API Buf : public ExprNode<Buf> {
 public:
  static BufHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);

  static BufHandle make(
      const std::string& name_hint,
      const std::vector<ExprHandle>& dims,
      const std::vector<ExprHandle>& strides,
      Dtype dtype);

  static BufHandle make(
      const std::string& name_hint,
      const std::vector<ExprHandle>& dims,
      Dtype dtype,
      c10::optional<ExprHandle> initializer = c10::nullopt,
      c10::optional<std::vector<ExprHandle>> strides = c10::nullopt,
      c10::optional<ExprHandle> qscale = c10::nullopt,
      c10::optional<ExprHandle> qzero = c10::nullopt);

  // TODO: unique_name
  VarPtr base_handle() const {
    return base_handle_;
  }
  void set_base_handle(VarPtr base_handle) {
    base_handle_ = std::move(base_handle);
  }

  const std::string& name_hint() const {
    return base_handle_->name_hint();
  }
  void set_name_hint(const std::string& name_hint) {
    base_handle_->set_name_hint(name_hint);
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Buf(const std::string& name_hint,
      const std::vector<ExprPtr>& dims,
      Dtype dtype,
      ExprPtr initializer = nullptr,
      c10::optional<std::vector<ExprPtr>> strides = c10::nullopt,
      ExprPtr qscale = nullptr,
      ExprPtr qzero = nullptr)
      : Buf(alloc<Var>(name_hint, kHandle),
            dims,
            dtype,
            std::move(initializer),
            std::move(strides),
            std::move(qscale),
            std::move(qzero)) {}

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Buf(VarPtr var,
      std::vector<ExprPtr> dims,
      Dtype dtype,
      ExprPtr initializer = nullptr,
      c10::optional<std::vector<ExprPtr>> strides = c10::nullopt,
      ExprPtr qscale = nullptr,
      ExprPtr qzero = nullptr);

  size_t ndim() const {
    return dims_.size();
  }
  ExprPtr dim(size_t index) const {
    if (index >= ndim()) {
      throw out_of_range_index();
    }
    return dims_[index];
  }
  std::vector<ExprPtr> dims() const {
    return dims_;
  }
  void set_dims(std::vector<ExprPtr> dims) {
    dims_ = std::move(dims);
  }

  std::vector<ExprPtr> strides() const {
    return strides_;
  }

  void set_strides(std::vector<ExprPtr> strides) {
    strides_ = std::move(strides);
  }

  ExprPtr initializer() const {
    return initializer_;
  };

  ExprPtr qzero() const {
    return qzero_;
  }

  ExprPtr qscale() const {
    return qscale_;
  }

  void set_qzero(ExprPtr qzero) {
    qzero_ = std::move(qzero);
  }

  void set_qscale(ExprPtr qscale) {
    qscale_ = std::move(qscale);
  }

  bool hasConstantDims() const {
    for (const auto& d : dims_) {
      if (!d->isConstant()) {
        return false;
      }
    }
    return true;
  }

  bool is_contiguous(
      at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const;

  // The channels-last 1d can benefit the performance of some operators like
  // conv1d. But the MemoryFormat enum has not covered this layout yet. Hence,
  // we abstract a dedicated function to check channels-last 1d contiguous.
  //
  // Channels-last 1d:
  //   dims:              n   c    l
  //   strides(nlc):    c*l   1    c
  bool is_channels_last_1d_contiguous() const {
    if (dims_.size() != 3) {
      return false;
    }
    return is_stride_one(1) && is_cont_with(2, 1) && is_cont_with(0, 2);
  }

 private:
  bool is_cont_with(int cur_dim, int adjacent_dim) const;
  bool is_stride_one(int cur_dim) const;

  VarPtr base_handle_;
  std::vector<ExprPtr> dims_;
  std::vector<ExprPtr> strides_;
  ExprPtr initializer_;
  // qscale_ and qzero_ are used only for quantized dtypes Bufs: kQUInt8, kQInt8
  ExprPtr qscale_;
  ExprPtr qzero_;
};

class TORCH_API BufHandle : public ExprHandle {
 public:
  BufHandle(
      const std::string& name_hint,
      const std::vector<ExprHandle>& dims,
      Dtype dtype)
      : ExprHandle(Buf::make(name_hint, dims, dtype)) {}

  BufHandle(
      const std::string& name_hint,
      const std::vector<ExprHandle>& dims,
      const std::vector<ExprHandle>& strides,
      Dtype dtype)
      : ExprHandle(Buf::make(name_hint, dims, strides, dtype)) {}

  BufHandle(const std::vector<ExprHandle>& dims, Dtype dtype)
      : ExprHandle(Buf::make("_", dims, dtype)) {}

  explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}

  explicit BufHandle(BufPtr node) : ExprHandle(std::move(node)) {}
  BufPtr node() const {
    return static_to<Buf>(ExprHandle::node());
  }
  BufPtr node() {
    return static_to<Buf>(ExprHandle::node());
  }

  template <typename... Ts>
  inline ExprHandle load(const Ts&... ts) const;

  template <typename T>
  inline ExprHandle load(const std::vector<T>& args) const;

  inline ExprHandle load(const std::vector<ExprHandle>& args) const;

  StorePtr store(const std::vector<ExprHandle>& args, const ExprHandle& val)
      const;

  bool operator==(const BufHandle& other) const {
    return this->node() == other.node();
  }
  bool operator!=(const BufHandle& other) const {
    return !(*this == other);
  }

  const std::string& name_hint() const {
    return this->node()->name_hint();
  }

  bool empty() const {
    return (this->node() == nullptr);
  }

  size_t ndim() const {
    return node()->ndim();
  }

  std::vector<ExprHandle> dims() const;

  ExprHandle dim(size_t index) const {
    return ExprHandle(node()->dim(index));
  }

  bool is_contiguous(
      at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
    return node()->is_contiguous(memory_format);
  }

  bool is_channels_last_1d_contiguous() const {
    return node()->is_channels_last_1d_contiguous();
  }
};

// An expression to construct the underlying variable node.
// Note: do not store any info here, since it is often possible to slice this
// object. For example: VarHandle x('x'); ExprHandle x2 = x;
class TORCH_API VarHandle : public ExprHandle {
 public:
  // Creates an empty VarHandle whose base Var is set to nullptr.
  VarHandle() : ExprHandle() {}

  explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}

  VarHandle(const std::string& name_hint, Dtype dtype)
      : ExprHandle(Var::make(name_hint, dtype)) {}

  explicit VarHandle(VarPtr node) : ExprHandle(std::move(node)) {}

  VarPtr node() const {
    return static_to<Var>(ExprHandle::node());
  }
  bool operator==(const VarHandle& other) const {
    return this->node() == other.node();
  }
  bool operator!=(const VarHandle& other) const {
    return !(*this == other);
  }

  const std::string& name_hint() const {
    return this->node()->name_hint();
  }
  bool empty() const {
    return (this->node() == nullptr);
  }
};

template <class Op, class Base>
ExprPtr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
  return mutator->mutate(static_to<Op>(Base::getptr()));
}

inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
  return expr1.AsNode<Expr>() == expr2.AsNode<Expr>();
}

TORCH_API ExprHandle sin(const ExprHandle& v);
TORCH_API ExprHandle cos(const ExprHandle& v);
TORCH_API ExprHandle tan(const ExprHandle& v);
TORCH_API ExprHandle asin(const ExprHandle& v);
TORCH_API ExprHandle acos(const ExprHandle& v);
TORCH_API ExprHandle atan(const ExprHandle& v);
TORCH_API ExprHandle sinh(const ExprHandle& v);
TORCH_API ExprHandle cosh(const ExprHandle& v);
TORCH_API ExprHandle tanh(const ExprHandle& v);
TORCH_API ExprHandle sigmoid(const ExprHandle& v);
TORCH_API ExprHandle exp(const ExprHandle& v);
TORCH_API ExprHandle expm1(const ExprHandle& v);
TORCH_API ExprHandle abs(const ExprHandle& v);
TORCH_API ExprHandle log(const ExprHandle& v);
TORCH_API ExprHandle fast_tanh(const ExprHandle& v);
TORCH_API ExprHandle fast_sigmoid(const ExprHandle& v);
TORCH_API ExprHandle fast_log(const ExprHandle& v);
TORCH_API ExprHandle log_vml(const ExprHandle& v);
TORCH_API ExprHandle log2(const ExprHandle& v);
TORCH_API ExprHandle log10(const ExprHandle& v);
TORCH_API ExprHandle log1p(const ExprHandle& v);
TORCH_API ExprHandle erf(const ExprHandle& v);
TORCH_API ExprHandle erfc(const ExprHandle& v);
TORCH_API ExprHandle sqrt(const ExprHandle& v);
TORCH_API ExprHandle rsqrt(const ExprHandle& v);
TORCH_API ExprHandle ceil(const ExprHandle& v);
TORCH_API ExprHandle floor(const ExprHandle& v);
TORCH_API ExprHandle round(const ExprHandle& v);
TORCH_API ExprHandle trunc(const ExprHandle& v);
TORCH_API ExprHandle frac(const ExprHandle& v);
TORCH_API ExprHandle lgamma(const ExprHandle& v);
TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle isnan(const ExprHandle& v1);
TORCH_API ExprHandle Relu(const ExprHandle& v1);

TORCH_API ExprHandle
ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f);

TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes);

} // namespace tensorexpr
} // namespace jit
} // namespace torch