Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / tensorexpr / eval.h

#pragma once

#include <cmath>
#include <cstring>
#include <type_traits>
#include <unordered_map>
#include <vector>

#include <c10/macros/Macros.h>
#include <c10/util/Logging.h>
#include <c10/util/math_compat.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/csrc/jit/tensorexpr/types.h>
#include <torch/csrc/jit/tensorexpr/var_substitutor.h>

namespace torch {
namespace jit {
namespace tensorexpr {

DECLARE_TRIGGER(simple_ir_eval_executed);

class Value {
 public:
  Value() : dtype_(kInt) {
    Intvalues.push_back(0);
  }

#define VALUE_CTOR(Type, Name)      \
  Value(Type v) : dtype_(k##Name) { \
    Name##values.push_back(v);      \
  }
  AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_CTOR);
#undef VALUE_CTOR

#define VALUE_VEC_CTOR(Type, Name)  \
  Value(const std::vector<Type>& v) \
      : dtype_(Dtype(k##Name, v.size())), Name##values(v) {}
  AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_VEC_CTOR);
#undef VALUE_VEC_CTOR

  template <typename T>
  T as() const;

  template <typename T>
  const std::vector<T>& as_vec() const;

  Dtype dtype() const {
    return dtype_;
  }

 private:
  Dtype dtype_;

#define VALUE_STORAGE(Type, Name) std::vector<Type> Name##values;
  AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_STORAGE);
#undef VALUE_STORAGE
  void* ptr;
};

#define VALUE_AS_DISPATCH(Type, Name)   \
  template <>                           \
  inline Type Value::as<Type>() const { \
    if (dtype_ != k##Name) {            \
      throw unsupported_dtype();        \
    }                                   \
    return Name##values[0];             \
  }
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH);
#undef VALUE_AS_DISPATCH

#define VALUE_AS_VEC_DISPATCH(Type, Name)                       \
  template <>                                                   \
  inline const std::vector<Type>& Value::as_vec<Type>() const { \
    if (dtype_.scalar_type() != ScalarType::Name) {             \
      throw unsupported_dtype();                                \
    }                                                           \
    return Name##values;                                        \
  }
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH);
#undef VALUE_AS_VEC_DISPATCH

template <typename To, typename From>
To raw_bitcast(const From& src) {
  TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation");
  To storage;
  std::memcpy(&storage, &src, sizeof(From));
  return reinterpret_cast<To&>(storage);
}

class SimpleIREvaluatorImpl;
class TORCH_API SimpleIREvaluator : public CodeGen {
 public:
  SimpleIREvaluator(
      Stmt* stmt,
      const std::vector<BufferArg>& buffer_args,
      at::Device device = at::kCPU,
      const std::string& kernel_func_name = "func");

  ~SimpleIREvaluator() override;

  void call(const std::vector<CallArg>& args) override;

  template <typename... Ts>
  void operator()(const Ts&... ts) {
    std::vector<CallArg> args({CallArg(ts)...});
    call(args);
  }

  void bindVar(const Var* v, const Expr* e);
  Value value() const;

 private:
  void bindArg(const BufferArg& buf, const CallArg& data);
  void expand_intrinsics() {
    GenericIntrinsicsExpander intrinsics_expander;
    apply_mutator(&intrinsics_expander);
  }

  std::unique_ptr<SimpleIREvaluatorImpl> impl_;
};

template <class CodeGenType>
class ExprEval {
 public:
  using BufferArg = CodeGen::BufferArg;
  using CallArg = CodeGen::CallArg;

  template <typename... Ts>
  ExprEval(const ExprHandle& expr, Ts... ts)
      : ExprEval(expr, {BufferArg(ts)...}) {}

  ExprEval(const ExprHandle& expr, const std::vector<BufferArg>& buffer_args)
      : dtype_(expr.dtype()) {
    std::vector<BufferArg> buffer_args_extended = buffer_args;
    Placeholder ret_buf("ret_val", dtype_, {1});
    std::vector<const Expr*> indices;
    const Expr* zero = new IntImm(0);
    for (size_t i = 0; i < ret_buf.data()->ndim(); i++) {
      indices.push_back(zero);
    }
    Stmt* store_stmt =
        new Store(ret_buf.data(), indices, expr.node(), new IntImm(1));
    buffer_args_extended.push_back(ret_buf);
    codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
  }

  template <typename... Ts>
  void operator()(Ts... ts) {
    call(ts...);
  }

  void operator()(const std::vector<CallArg>& call_args) {
    call(call_args);
  }

  void bindVar(const Var* v, const Expr* e) {
    codegen_->bindVar(v, e);
  }

  void bindVar(const VarHandle& v, const ExprHandle& e) {
    codegen_->bindVar(v.node(), e.node());
  }

  template <typename... Ts>
  void call(Ts... ts) {
    call({CallArg(ts)...});
  }

  void call(const std::vector<CallArg>& call_args) {
    std::vector<CallArg> call_args_extended = call_args;
    switch (dtype_.scalar_type()) {
#define TYPE_CASE(Type, Name)                           \
  case ScalarType::Name: {                              \
    std::vector<Type> ret_val_arg(1);                   \
    call_args_extended.push_back(CallArg(ret_val_arg)); \
    codegen_->call(call_args_extended);                 \
    ret_value_ = Value(ret_val_arg[0]);                 \
  } break;
      AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
      case ScalarType::Bool: {
        std::vector<unsigned char> ret_val_arg(1);
        call_args_extended.push_back(CallArg(ret_val_arg.data()));
        codegen_->call(call_args_extended);
        ret_value_ = Value((bool)ret_val_arg[0]);
      } break;
      default:
        throw unsupported_dtype();
    }
  }

  template <typename T>
  T value(std::vector<void*>& args) {
    call(args);
    return ret_value_.as<T>();
  }

  template <typename T, typename... Ts>
  T value(Ts... ts) {
    call(std::forward<Ts>(ts)...);
    return ret_value_.as<T>();
  }

  Dtype dtype() {
    return dtype_;
  }

 private:
  Dtype dtype_;
  std::unique_ptr<CodeGenType> codegen_;
  Value ret_value_;
};

inline const Expr* Substitute(const Expr* expr, const VarMapping& var_mapping) {
  VarSubMutator var_sub(var_mapping);
  return expr->accept_mutator(&var_sub);
}

inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) {
  VarSubMutator var_sub(var_mapping);
  return stmt->accept_mutator(&var_sub);
}

} // namespace tensorexpr
} // namespace jit
} // namespace torch