Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / tensorexpr / codegen.h

#pragma once

#include <ATen/ATen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>

#include <utility>

namespace torch {
namespace jit {
namespace tensorexpr {

template <typename T>
class PaddedBuffer;

class TORCH_API CodeGen {
 public:
  class BufferArg;
  class CallArg;

  template <typename... Ts>
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  CodeGen(StmtPtr stmt, Ts... ts)
      : stmt_(std::move(stmt)), buffer_args_({BufferArg(ts)...}) {}

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  CodeGen(
      StmtPtr stmt,
      std::vector<BufferArg> buffer_args,
      at::Device device = at::kCPU,
      std::string kernel_func_name = "func");

  virtual ~CodeGen() = default;

  StmtPtr stmt() const {
    return stmt_;
  }

  void set_stmt(StmtPtr s) {
    stmt_ = s;
  }

  void apply_mutator(IRMutator* mutator) {
    stmt_ = stmt_->accept_mutator(mutator);
  }

  void apply_visitor(IRVisitor* visitor) {
    stmt_->accept(visitor);
  }

  std::vector<BufferArg>& buffer_args() {
    return buffer_args_;
  }

  const std::vector<BufferArg>& buffer_args() const {
    return buffer_args_;
  }

  at::Device device() {
    return device_;
  }

  // This function returns the generated code as
  // a string.
  virtual std::string getCodeText(const std::string& attr = "") {
    return ("");
  }

  // TODO: Figure out how to unify these call interfaces.

  /// Call a function with a vector of CallArgs, which are tagged
  /// unions that properly type the arguments.
  virtual void call(const std::vector<CallArg>& args) = 0;

  /// Call a function faster than a regular `call` by assuming that
  /// the generated kernel already knows the type of the arguments, so
  /// they can be type-punned with `void*`s.
  virtual void call_raw(const std::vector<void*>& args) = 0;

  /// Call a function even faster than a regular call, by assuming
  /// that the number of thread blocks can be derived from `numel` via
  /// a simple division, rather than evaluating an expression.
  virtual void call_with_numel(void** args, int64_t numel);

  virtual at::Tensor empty_strided(
      c10::IntArrayRef size,
      c10::IntArrayRef stride,
      c10::optional<c10::ScalarType> dtype_opt,
      c10::optional<c10::Layout> layout_opt,
      c10::optional<c10::Device> device_opt,
      c10::optional<bool> pin_memory_opt) {
    return at::empty_strided(
        size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
  }

  const std::string& kernel_func_name() const {
    return kernel_func_name_;
  }

  void allocIntermediateBufs();

 protected:
  static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg);

 private:
  StmtPtr stmt_;
  std::vector<BufferArg> buffer_args_;
  at::Device device_ = at::kCPU;
  std::string kernel_func_name_ = "func";
};

class TORCH_API ExtCallMemoryReuse : public IRMutator {
  static std::unordered_map<std::string, std::string> makeExtCallFuncNameMap();
  static const std::unordered_map<std::string, std::string> extCallFuncNameMap_;

 public:
  explicit ExtCallMemoryReuse(
      const std::vector<CodeGen::BufferArg>& bufferArgs);
  ~ExtCallMemoryReuse() override = default;
  StmtPtr mutate(ExternalCallPtr v) override;

 private:
  std::unordered_set<BufPtr> bufferArgs_;
};

class CodeGen::BufferArg {
 public:
  BufferArg(const Tensor& tensor) : buf_(tensor.buf()) {}
  BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
  BufferArg(const BufHandle& buf) : buf_(buf.node()) {}
  BufferArg(BufPtr buf) : buf_(std::move(buf)) {}

  VarPtr var() const {
    return isVar_ ? var_ : buf_->base_handle();
  }

  BufPtr buf() const {
    return buf_;
  }

  bool isVar() const {
    return isVar_;
  }

  Dtype dtype() const {
    return isVar_ ? var_->dtype() : buf_->dtype();
  }

 private:
  VarPtr var_ = nullptr;
  BufPtr buf_ = nullptr;
  bool isVar_ = false;
};

class CodeGen::CallArg {
 public:
  template <typename T>
  CallArg(const PaddedBuffer<T>& buffer);

  template <typename T>
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-pro-type-const-cast)
  CallArg(const std::vector<T>& buffer)
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
      : data_(const_cast<T*>(buffer.data())) {}

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  CallArg(void* ptr) : data_(ptr) {}

#define ARG_TYPE_CTOR(Type, Name)     \
  CallArg(Type v) {                   \
    memcpy(&data_, &v, sizeof(Type)); \
  }
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR);
#undef ARG_TYPE_CTOR

  void* data() const {
    return data_;
  }

#define ARG_PTR_DEFINE(Type, Name) \
  Type* Name##Ptr() const {        \
    return (Type*)&data_;          \
  }
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE);
#undef ARG_PTR_DEFINE

 private:
  void* data_;
};

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class RegisterCodeGenList {
 public:
  TORCH_API static RegisterCodeGenList& GetInstance() {
    static RegisterCodeGenList codegen_list;
    return codegen_list;
  }

  using StmtFactoryMethod = std::function<std::unique_ptr<CodeGen>(
      StmtPtr stmt,
      const std::vector<CodeGen::BufferArg>&,
      at::Device device,
      const std::string& kernel_func_name)>;

  TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name);
  RegisterCodeGenList(const RegisterCodeGenList&) = delete;
  RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete;

 private:
  template <class CodeGenType>
  friend class RegisterCodeGen;
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  RegisterCodeGenList() = default;
  TORCH_API void AddStmtFactoryMethod(
      const std::string& name,
      const StmtFactoryMethod& stmt_factory_method);

  std::unordered_map<std::string, StmtFactoryMethod> stmt_factory_methods_;
};

template <class CodeGenType>
class RegisterCodeGen {
 public:
  explicit RegisterCodeGen(const std::string& name) {
    RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
    codegen_list.AddStmtFactoryMethod(
        name,
        [](StmtPtr stmt,
           const std::vector<CodeGen::BufferArg>& params,
           at::Device device,
           const std::string& kernel_func_name) {
          // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
          std::unique_ptr<CodeGen> method(
              new CodeGenType(stmt, params, device, kernel_func_name));
          return method;
        });
  }
};

TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
    const std::string& name,
    StmtPtr stmt,
    const std::vector<CodeGen::BufferArg>& params,
    at::Device device = at::kCPU,
    const std::string& kernel_func_name = "func");

class TORCH_API GenericIntrinsicsExpander : public IRMutator {
 protected:
  ExprPtr mutate(IntrinsicsPtr v) override;
};

} // namespace tensorexpr
} // namespace jit
} // namespace torch