Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / tensorexpr / analysis.h

#pragma once

#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/stmt.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>

#include <utility>

namespace torch {
namespace jit {
namespace tensorexpr {
class HasRand : public IRVisitor {
 public:
  HasRand(StmtPtr stmt) : stmt_(std::move(stmt)) {
    stmt_->accept(this);
  }

  bool has_rand() const {
    return has_rand_;
  }

 private:
  void visit(IntrinsicsPtr v) override {
    if (v->op_type() == IntrinsicsOp::kRand) {
      has_rand_ = true;
    } else {
      IRVisitor::visit(std::move(v));
    }
  }
  StmtPtr stmt_;
  bool has_rand_ = false;
};

template <typename Op>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class NodeFinder : public IRVisitor {
 public:
  void visit(NodePtr<Op> v) override {
    nodes.push_back((NodePtr<Op>)v);
    IRVisitor::visit(v);
  }

  static std::vector<NodePtr<Op>> find(StmtPtr s) {
    NodeFinder<Op> nf;
    s->accept(&nf);
    return nf.nodes;
  }

  static std::vector<NodePtr<Op>> find(ExprPtr e) {
    NodeFinder<Op> nf;
    e->accept(&nf);
    return nf.nodes;
  }

  std::vector<NodePtr<Op>> nodes;
};

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class VarFinder : public IRVisitor {
 public:
  void visit(VarPtr v) override {
    vars_.insert(v);
    IRVisitor::visit(std::move(v));
  }

  static std::unordered_set<VarPtr> find(StmtPtr s) {
    VarFinder nf;
    s->accept(&nf);
    return nf.vars();
  }

  static std::unordered_set<VarPtr> find(ExprPtr e) {
    VarFinder nf;
    e->accept(&nf);
    return nf.vars();
  }

  const std::unordered_set<VarPtr>& vars() {
    return vars_;
  }

 private:
  std::unordered_set<VarPtr> vars_;
};

class BufFinder : public IRVisitor {
 public:
  void visit(BufPtr v) override {
    bufs_.insert(v);
    IRVisitor::visit(std::move(v));
  }

  static std::unordered_set<BufPtr> find(StmtPtr s) {
    BufFinder nf;
    s->accept(&nf);
    return nf.bufs();
  }

  static std::unordered_set<BufPtr> find(ExprPtr e) {
    BufFinder nf;
    e->accept(&nf);
    return nf.bufs();
  }

  const std::unordered_set<BufPtr>& bufs() {
    return bufs_;
  }

 private:
  std::unordered_set<BufPtr> bufs_;
};

// Finds all kinds of write operations to the provided Buf.
class WritesToBuf : public IRVisitor {
 public:
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  WritesToBuf(BufPtr target) : target_(std::move(target)) {}

  std::vector<StmtPtr> writes() {
    return writes_;
  }

  static std::vector<StmtPtr> find(StmtPtr s, BufPtr b) {
    WritesToBuf finder(std::move(b));
    s->accept(&finder);
    return finder.writes();
  }

 private:
  void visit(StorePtr v) override {
    if (v->buf() == target_) {
      writes_.push_back(v);
    }
  }

  void visit(AtomicAddPtr v) override {
    if (v->buf() == target_) {
      writes_.push_back(v);
    }
  }

  BufPtr target_;
  std::vector<StmtPtr> writes_;
};

class StmtsReadingBuf : public IRVisitor {
 public:
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  StmtsReadingBuf(BufPtr target) : target_(std::move(target)) {}

  std::vector<StmtPtr> reads() {
    return reads_;
  }

  static std::vector<StmtPtr> find(StmtPtr s, BufPtr b) {
    StmtsReadingBuf finder(std::move(b));
    s->accept(&finder);
    return finder.reads();
  }

 private:
  bool readsBuffer(StmtPtr s) {
    auto loads = NodeFinder<Load>::find(std::move(s));
    for (const auto& l : loads) {
      if (l->buf() == target_) {
        return true;
      }
    }
    return false;
  }

  void visit(StorePtr v) override {
    if (readsBuffer(v)) {
      reads_.push_back(v);
    }
  }

  void visit(LetPtr v) override {
    if (readsBuffer(v)) {
      reads_.push_back(v);
    }
  }

  void visit(CondPtr v) override {
    if (readsBuffer(v)) {
      reads_.push_back(v);
    }
  }

  void visit(AtomicAddPtr v) override {
    if (readsBuffer(v)) {
      reads_.push_back(v);
    }
  }

  BufPtr target_;
  std::vector<StmtPtr> reads_;
};

class ExternalAllocBufFinder : public IRVisitor {
 public:
  void visit(ExternalCallWithAllocPtr v) override {
    const auto& bufs_out = v->buf_out_args();
    bufs_.insert(bufs_out.begin(), bufs_out.end());
    IRVisitor::visit(std::move(v));
  }

  static std::unordered_set<BufPtr> find(StmtPtr s) {
    ExternalAllocBufFinder f;
    s->accept(&f);
    return f.bufs();
  }

  static std::unordered_set<BufPtr> find(ExprPtr e) {
    ExternalAllocBufFinder f;
    e->accept(&f);
    return f.bufs();
  }

  const std::unordered_set<BufPtr>& bufs() {
    return bufs_;
  }

 private:
  std::unordered_set<BufPtr> bufs_;
};

// Traverses the IR to determine if a particular Var is modified within it.
class ModifiesVarChecker : public IRVisitor {
 public:
  ModifiesVarChecker(VarPtr v) : var_(std::move(v)) {}

  static bool check(StmtPtr s, VarPtr v) {
    ModifiesVarChecker checker(std::move(v));
    s->accept(&checker);
    return checker.found();
  }

  bool found() {
    return found_;
  }

 private:
  void visit(StorePtr v) override {
    if (v->buf()->base_handle() == var_) {
      found_ = true;
      return;
    }
    IRVisitor::visit(std::move(v));
  }

  void visit(AtomicAddPtr v) override {
    if (v->buf()->base_handle() == var_) {
      found_ = true;
      return;
    }
    IRVisitor::visit(std::move(v));
  }

  void visit(LetPtr v) override {
    if (v->var() == var_) {
      found_ = true;
      return;
    }
    IRVisitor::visit(std::move(v));
  }

  void visit(ForPtr v) override {
    if (v->var() == var_) {
      found_ = true;
      return;
    }
    IRVisitor::visit(std::move(v));
  }

  VarPtr var_;
  bool found_{false};
};

// Traverse the Block stmt to identify the live range of the specified buf. The
// live range, indicated by a pair of integers, specifies the first and last
// stmt in block stmts that access to the buf.
class BufLiveRange : public IRVisitor {
 public:
  BufLiveRange(BufPtr b) : buf_(std::move(b)) {}

  static std::tuple<int32_t, int32_t> liveRange(StmtPtr s, BufPtr b) {
    BlockPtr block = to<Block>(std::move(s));
    // We Only analze buffer live ranges for block stmts.
    if (!block) {
      return std::make_tuple(0, 0);
    }

    BufLiveRange analyzer(std::move(b));
    block->accept(&analyzer);
    return analyzer.getLiveRange();
  }

 private:
  std::tuple<int32_t, int32_t> getLiveRange() {
    return std::make_tuple(begin_, end_);
  }

  bool hasBufReads(StmtPtr s) {
    auto loads1 = NodeFinder<Load>::find(s);
    for (const auto& l : loads1) {
      if (l->buf() == buf_) {
        return true;
      }
    }
    auto loads2 = NodeFinder<ExternalCall>::find(s);
    for (const auto& l : loads2) {
      for (const auto& lb : l->buf_args()) {
        if (lb == buf_) {
          return true;
        }
      }
    }
    auto loads3 = NodeFinder<ExternalCallWithAlloc>::find(std::move(s));
    for (const auto& l : loads3) {
      for (const auto& lb : l->buf_args()) {
        if (lb == buf_) {
          return true;
        }
      }
    }
    return false;
  }

  bool hasBufWrites(StmtPtr s) {
    auto writes1 = NodeFinder<Store>::find(s);
    for (const auto& w : writes1) {
      if (w->buf() == buf_) {
        return true;
      }
    }
    auto writes2 = NodeFinder<ExternalCall>::find(s);
    for (const auto& w : writes2) {
      if (w->buf() == buf_) {
        return true;
      }
    }
    auto writes3 = NodeFinder<ExternalCallWithAlloc>::find(std::move(s));
    for (const auto& w : writes3) {
      for (const auto& wb : w->buf_out_args()) {
        if (wb == buf_) {
          return true;
        }
      }
    }
    return false;
  }

  void findAccAndUpdateLiveRange(StmtPtr s) {
    bool has_reads = hasBufReads(s), has_writes = hasBufWrites(std::move(s));
    if (has_reads || has_writes) {
      if (begin_ == -1) {
        begin_ = curr_index_;
      };
      end_ = curr_index_;
    }
  }

  void visit(BlockPtr v) override {
    for (const StmtPtr& s : *v) {
      curr_index_ += 1;
      findAccAndUpdateLiveRange(s);
    }
  }

  BufPtr buf_;
  int32_t begin_ = -1;
  int32_t end_ = -1;
  int32_t curr_index_ = -1;
};

// A class that analyzes the given program relevant for Block backend
// It creates a map of multi dim buffers and their flat verions
class CreateBufferMap : public IRVisitor {
 public:
  const std::unordered_map<std::string, BufPtr>& getBufferMap() const {
    return map_input_to_tensor_bufs_;
  }

 private:
  void visit(StorePtr v) override {
    auto load_node = to<Load>(v->value());
    if (load_node) {
      auto t_buf = load_node->buf();
      map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf());
    } else {
      auto add_node = to<Add>(v->value());
      auto mul_node = to<Mul>(v->value());
      // This means for now, v->value() can be Add or Mul
      TORCH_INTERNAL_ASSERT(add_node || mul_node, buildErrorMessage());
      map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf());
    }
    v->value()->accept(this);
  }
  std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
};

} // namespace tensorexpr
} // namespace jit
} // namespace torch