Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / tensorexpr / block_codegen.h

#pragma once

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <ATen/ATen.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>

namespace torch {
namespace jit {
namespace tensorexpr {

// A class that analyzes the given program relevant for Block backend.
class BlockAnalysis : public IRVisitor {
 public:
  bool is_buf_store_target(BufPtr buf) const {
    return store_targets_.count(buf) > 0;
  }

  const std::unordered_set<BufPtr>& loads() const {
    return loads_;
  }

  const std::unordered_set<BufPtr>& stores() const {
    return store_targets_;
  }

  int block_size() const {
    return block_size_;
  }

  bool areBufsInMap(const std::unordered_set<BufPtr>& bufs) const;

  BufPtr getMultiDimBuf(BufPtr buf) const;

  std::string getInputName(BufPtr buf) const;

  std::string getFlatInputName(BufPtr buf) const {
    return getInputName(std::move(buf)) + "_flat";
  }

  std::unordered_map<std::string, BufPtr> getBufferMap() const {
    return map_input_to_tensor_bufs_;
  }

 private:
  void visit(StorePtr v) override;
  void visit(LoadPtr v) override;
  void visit(ForPtr v) override;

  std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
  std::unordered_set<BufPtr> store_targets_;
  std::unordered_set<BufPtr> loads_;
  int block_size_ = 32;
};

// A class that overrides the underlying IRPrinter to produce Block.
class BlockPrinter : public IRPrinter {
 public:
  BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis)
      : IRPrinter(*os), block_analysis_(block_analysis) {}

  using IRPrinter::name_manager;
  using IRPrinter::visit;

 private:
  BlockAnalysis* block_analysis_;
  std::unordered_map<std::string, int> dim_values_map;
  std::vector<std::string> dim_names = {"N", "H", "W", "C"};
  std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
  void PrintTensorInfo(const std::unordered_set<BufPtr>& bufs);
  void PrintArguments(const std::unordered_set<BufPtr>& bufs);
  void PrintBufferInfo(const std::unordered_set<BufPtr>& bufs);
  void PrintDistribution(const std::unordered_set<BufPtr>& bufs);
  void PrintLoop(const std::unordered_set<BufPtr>& bufs, bool block_idx = true);
  void PrintReshapeInfo(
      const std::unordered_set<BufPtr>& bufs,
      bool reverse = false);
  void PrintDMAs(const std::unordered_set<BufPtr>& bufs);
  void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs);

  void visit(ForPtr v) override;
  void visit(LoadPtr v) override;
  void visit(StorePtr v) override;
  void visit(BlockPtr v) override;
  void visit(AddPtr v) override;
  void visit(MulPtr v) override;
};

class TORCH_API BlockCodeGen : public CodeGen {
 public:
  template <typename... Ts>
  /* implicit */
  BlockCodeGen(StmtPtr stmt, Ts... ts)
      : CodeGen(
            stmt,
            std::vector<BufferArg>({BufferArg(ts)...}),
            at::Device(at::kCPU)) {
    Initialize();
  }

  BlockCodeGen(
      StmtPtr stmt,
      const std::vector<BufferArg>& buffer_args,
      at::Device device = at::Device(at::kCPU),
      const std::string& kernel_func_name = "func")
      : CodeGen(stmt, buffer_args, device, kernel_func_name) {
    Initialize();
  }

  ~BlockCodeGen() override;

  void call(const std::vector<CallArg>& args) override;
  void call_raw(const std::vector<void*>& args) override;

  void Initialize();

  std::string getCodeText(const std::string& attr = "") override {
    return oss_.str();
  }

 private:
  UniqueNameManager* name_manager() {
    if (!printer_) {
      throw std::runtime_error("Null IRPrinter is not expected");
    }
    return printer_->name_manager();
  }

  std::ostream& os() {
    return printer_->os();
  }

  std::ostringstream oss_;
  std::unique_ptr<BlockPrinter> printer_;
  std::unique_ptr<BlockAnalysis> block_analysis_;

  std::string GetUniqueFuncName(const std::string& func_prefix);
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch