Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / tensorexpr / llvm_codegen.h

#pragma once

#ifdef TORCH_ENABLE_LLVM
#include <torch/csrc/Export.h>

#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>

#include <c10/util/Optional.h>

#include <unordered_map>
#include <vector>

namespace torch {
namespace jit {
namespace tensorexpr {

class LLVMCodeGenImpl;
class LLVMCodeGenCallee;

class TORCH_API LLVMCodeGen : public CodeGen {
 public:
  explicit LLVMCodeGen(
      StmtPtr stmt,
      const std::vector<BufferArg>& args,
      at::Device device = at::kCPU,
      const std::string& kernel_func_name = "func",
      Dtype dtype = kInt,
      c10::optional<std::string> triple = c10::nullopt,
      c10::optional<std::string> cpu = c10::nullopt,
      c10::optional<std::string> attrs = c10::nullopt);
  explicit LLVMCodeGen(StmtPtr stmt);

  LLVMCodeGen() = delete;
  ~LLVMCodeGen() override;

  // Cleans up all the memory used during LLVM code generation pass except
  // the generated kernel. After calling this method, users should not call
  // methods like `getCodeText` that require the LLVMCodeGenImpl data. However,
  // users can continue to call this kernel using `call` and `call_raw`.
  void cleanup_memory();

  TORCH_API void call(const std::vector<CallArg>& args) override;
  TORCH_API void call_raw(const std::vector<void*>& args) override;
  TORCH_API void call_with_numel(void** args, int64_t numel) override;

  at::Tensor empty_strided(
      c10::IntArrayRef size,
      c10::IntArrayRef stride,
      c10::optional<c10::ScalarType> dtype_opt,
      c10::optional<c10::Layout> layout_opt,
      c10::optional<c10::Device> device_opt,
      c10::optional<bool> pin_memory_opt) override;

  template <typename T>
  T value() {
    return value<T>(nullptr);
  }

  template <typename T>
  T value(std::vector<void*>& args) {
    return value<T>(args.data());
  }

  template <typename T>
  T value(void** args) {
    T (*fp)(void**) = (T(*)(void**))getKernelAddress(callee_.get());
    T rv = fp(args);
    return rv;
  }

  std::string getCodeText(const std::string& attr = "") override;

 private:
  void* getKernelAddress(LLVMCodeGenCallee* callee);

  std::unique_ptr<LLVMCodeGenCallee> callee_;
  std::unique_ptr<LLVMCodeGenImpl> impl_;
};

struct TORCH_API LLVMCodeGenBuilder {
  using BufferArg = CodeGen::BufferArg;

  LLVMCodeGenBuilder(StmtPtr stmt, std::vector<BufferArg> args)
      : stmt_(stmt), args_(std::move(args)) {}

  LLVMCodeGenBuilder& device(at::Device device) {
    device_ = device;
    return *this;
  }

  LLVMCodeGenBuilder& kernelFuncName(std::string name) {
    kernelFuncName_ = std::move(name);
    return *this;
  }

  LLVMCodeGenBuilder& dtype(Dtype d) {
    dtype_ = d;
    return *this;
  }

  LLVMCodeGenBuilder& triple(std::string triple) {
    triple_ = std::move(triple);
    return *this;
  }

  LLVMCodeGenBuilder& cpu(std::string cpu) {
    cpu_ = std::move(cpu);
    return *this;
  }

  LLVMCodeGenBuilder& attrs(std::string attrs) {
    attrs_ = std::move(attrs);
    return *this;
  }

  std::unique_ptr<LLVMCodeGen> build() {
    return std::make_unique<LLVMCodeGen>(
        stmt_, args_, device_, kernelFuncName_, dtype_, triple_, cpu_, attrs_);
  }

 private:
  StmtPtr stmt_;
  std::vector<BufferArg> args_;
  at::Device device_ = at::kCPU;
  std::string kernelFuncName_ = "func";
  Dtype dtype_ = kInt;
  c10::optional<std::string> triple_ = c10::nullopt;
  c10::optional<std::string> cpu_ = c10::nullopt;
  c10::optional<std::string> attrs_ = c10::nullopt;
};

TORCH_API c10::optional<std::string>& LLVMTargetTriple();
TORCH_API c10::optional<std::string>& LLVMTargetCPU();
TORCH_API c10::optional<std::string>& LLVMTargetAttrs();
TORCH_API bool& LLVMAOTWorkflow();

} // namespace tensorexpr
} // namespace jit
} // namespace torch

#endif // TORCH_ENABLE_LLVM