Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / jit / tensorexpr / ir_simplifier.h

#pragma once

#include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/hash_provider.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/types.h>

#include <utility>

/* IR Simplification
 *
 * Simplfies expressions in two stages:
 *  1. Recursively traverse the map combining similar operations into Terms
 * (interacted via Multiplication) and Polynomials (interacted via Addition). We
 * reorder the components of each Term or Polynomial into a consistent order to
 * allow combination or cancelling of like terms.
 *  2. Once the format of the tree is minimal, expand each Term into a sequence
 * of Muls, and each Polynomial into a sequence of Ads.
 */

namespace torch {
namespace jit {
namespace tensorexpr {

// A bunch of helpers for determine the Dtype of the output of a multi argument
// Term or Polynomial.
template <class ExprType>
Dtype promoteTypesVec(ExprPtr s, std::vector<ExprType>& v) {
  Dtype t = s->dtype();
  bool first = true;

  for (const auto& e : v) {
    if (first) {
      t = Dtype(t.scalar_type(), e->dtype().lanes());
      first = false;
    }
    t = promoteTypes(t, e->dtype());
  }
  return t;
}

template <class ExprType>
Dtype promoteTypesVec(std::vector<ExprType>& v) {
  if (v.empty()) {
    throw malformed_input("empty list of types");
  }

  Dtype t = v[0]->dtype();
  for (const auto& e : v) {
    t = promoteTypes(t, e->dtype());
  }
  return t;
}

template <class ExprType>
Dtype promoteTypesMap(
    ExprPtr s,
    std::unordered_map<SimplifierHashType, ExprType>& m) {
  Dtype t = s->dtype();
  bool first = true;
  for (auto& e : m) {
    if (first) {
      t = Dtype(t.scalar_type(), e.second->dtype().lanes());
      first = false;
    }
    t = promoteTypes(t, e.second->dtype());
  }
  return t;
}

template <class ExprType>
Dtype promoteTypesVar(ExprType e) {
  return e->dtype();
}

template <class ExprType, class... Args>
Dtype promoteTypesVar(ExprType e, Args... es) {
  Dtype lhs = e->dtype();
  Dtype rhs = promoteTypesVar(es...);
  if (e->isConstant()) {
    lhs = Dtype(lhs.scalar_type(), rhs.lanes());
  }

  return promoteTypes(lhs, rhs);
}

// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
inline ExprPtr evaluateOp(ExprPtr v) {
  ExprHandle handle(v);
  ExprEval<SimpleIREvaluator> eval(handle);

  switch (v->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name)                                 \
  case ScalarType::Name: {                                    \
    Type val = eval.value<Type>();                            \
    return getImmediateByType(v->dtype().scalar_type(), val); \
  }
    AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
    default:
      LOG(FATAL) << "Unsupported datatype: " << v->dtype();
      return nullptr;
  }
  return nullptr;
}

// A Term represents a grouping of Exprs through multiplication.
// E.g. product(scalar, *variables).
class Term : public ExprNode<Term> {
 public:
  template <class... Args>
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Term(HashProvider& hasher, ExprPtr s, Args... ts)
      : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
    CHECK(s->isConstant());
    addComponent(ts...);
    sort();
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Term(HashProvider& hasher, ExprPtr s, std::vector<ExprPtr> v)
      : ExprNodeBase(promoteTypesVec(s, v)),
        variables_(std::move(v)),
        scalar_(s),
        hasher_(hasher) {
    sort();
  }

  // Convenience constructor from a map of hash -> var, used when merging Terms.
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Term(
      HashProvider& hasher,
      ExprPtr s,
      std::unordered_map<SimplifierHashType, ExprPtr> varmap)
      : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
    for (auto& p : varmap) {
      addComponent(p.second);
    }
    sort();
  }

  ExprPtr scalar() const {
    return scalar_;
  }
  const std::vector<ExprPtr>& variables() const {
    return variables_;
  }
  HashProvider& hasher() const {
    return hasher_;
  }

  // Produce a hash of just the variable components of this term, to determine
  // if it can be combined with another term.
  SimplifierHashType hashVars() const;

 private:
  std::vector<ExprPtr> variables_;
  ExprPtr scalar_;
  HashProvider& hasher_;

  void addComponent() {}
  void addComponent(ExprPtr e) {
    variables_.push_back(std::move(e));
  }
  template <class... Es>
  void addComponent(ExprPtr e, Es&&... es) {
    addComponent(std::move(e));
    addComponent(std::forward<Es>(es)...);
  }

  // Sort by hash to normalize order of components.
  void sort();
};

// Polynomial represents a grouping of Exprs by addition.
// E.g. sum(*variables, scalar).
// This would better be called Expression, but, naming conflict...
class Polynomial : public ExprNode<Polynomial> {
 public:
  template <class... Args>
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Polynomial(HashProvider& hasher, ExprPtr s, Args... ts)
      : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
    CHECK(s->isConstant());
    addTerm(ts...);
    sort();
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Polynomial(HashProvider& hasher, ExprPtr s, std::vector<TermPtr> v)
      : ExprNodeBase(promoteTypesVec(s, v)),
        variables_(std::move(v)),
        scalar_(s),
        hasher_(hasher) {
    sort();
  }

  // Helper constructor for list of terms with no scalar component.
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Polynomial(HashProvider& hasher, std::vector<TermPtr> terms)
      : ExprNodeBase(promoteTypesVec(terms)),
        variables_(std::move(terms)),
        scalar_(getImmediateByType(dtype(), 0)),
        hasher_(hasher) {
    sort();
  }

  // Convenience constructor for map of hash -> var, used when merging
  // Polynomials.
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  Polynomial(
      HashProvider& hasher,
      ExprPtr s,
      std::unordered_map<SimplifierHashType, TermPtr> varmap)
      : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
    for (auto& p : varmap) {
      addTerm(p.second);
    }
    sort();
  }

  ExprPtr scalar() const {
    return scalar_;
  }
  const std::vector<TermPtr>& variables() const {
    return variables_;
  }
  HashProvider& hasher() const {
    return hasher_;
  }

  SimplifierHashType hashVars() const;

 private:
  std::vector<TermPtr> variables_;
  ExprPtr scalar_;
  HashProvider& hasher_;

  void addTerm(TermPtr t) {
    variables_.push_back(std::move(t));
  }
  template <class... Ts>
  void addTerm(TermPtr t, Ts&&... ts) {
    addTerm(std::move(t));
    addTerm(std::forward<Ts>(ts)...);
  }

  // Sort by hash to normalize order of terms.
  void sort();
};

class RoundOff : public BinaryOpNode<RoundOff> {
 public:
  RoundOff(ExprPtr lhs, ExprPtr rhs)
      : BinaryOpNode(lhs, rhs, IRNodeType::kOther) {}
};

class MaxTerm : public ExprNode<MaxTerm> {
 public:
  template <class... Args>
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  MaxTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
      : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
        scalar_(s),
        hasher_(hasher),
        propagate_nans_(p) {
    addComponent(ts...);
    uniquefy();
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  MaxTerm(HashProvider& hasher, ExprPtr s, bool p, std::vector<ExprPtr> v)
      : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
        variables_(std::move(v)),
        scalar_(s),
        hasher_(hasher),
        propagate_nans_(p) {
    uniquefy();
  }

  bool propagate_nans() const {
    return propagate_nans_;
  }

  ExprPtr scalar() const {
    return scalar_;
  }
  const std::vector<ExprPtr>& variables() const {
    return variables_;
  }
  HashProvider& hasher() const {
    return hasher_;
  }

 private:
  std::vector<ExprPtr> variables_;
  ExprPtr scalar_;
  HashProvider& hasher_;
  bool propagate_nans_;

  void addComponent() {}
  void addComponent(ExprPtr e) {
    variables_.push_back(std::move(e));
  }
  template <class... Es>
  void addComponent(ExprPtr e, Es&&... es) {
    addComponent(std::move(e));
    addComponent(std::forward<Es>(es)...);
  }

  // Uniquefy the terms using their hash.
  void uniquefy();
};

class MinTerm : public ExprNode<MinTerm> {
 public:
  template <class... Args>
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  MinTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
      : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
        scalar_(s),
        hasher_(hasher),
        propagate_nans_(p) {
    addComponent(ts...);
    uniquefy();
  }

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  MinTerm(HashProvider& hasher, ExprPtr s, bool p, std::vector<ExprPtr> v)
      : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
        variables_(std::move(v)),
        scalar_(s),
        hasher_(hasher),
        propagate_nans_(p) {
    uniquefy();
  }

  bool propagate_nans() const {
    return propagate_nans_;
  }

  ExprPtr scalar() const {
    return scalar_;
  }
  const std::vector<ExprPtr>& variables() const {
    return variables_;
  }
  HashProvider& hasher() const {
    return hasher_;
  }

 private:
  std::vector<ExprPtr> variables_;
  ExprPtr scalar_;
  HashProvider& hasher_;
  bool propagate_nans_;

  void addComponent() {}
  void addComponent(ExprPtr e) {
    variables_.push_back(std::move(e));
  }
  template <class... Es>
  void addComponent(ExprPtr e, Es&&... es) {
    addComponent(std::move(e));
    addComponent(std::forward<Es>(es)...);
  }

  // Uniquefy the terms using their hash.
  void uniquefy();
};

// Context-sensitive IR simplification
using VarBoundInfo = std::unordered_map<VarPtr, analysis::Bound>;

class TORCH_API SimplifierUnderContext : public IRMutator {
 public:
  ~SimplifierUnderContext() override = default;
  // Add boundary info for index variables in for-loops
  StmtPtr mutate(ForPtr v) override;

  ExprPtr mutate(DivPtr v) override;
  ExprPtr mutate(ModPtr v) override;
  ExprPtr mutate(CompareSelectPtr v) override;
  ExprPtr mutate(IfThenElsePtr v) override;

 protected:
  bool getLoopBoundInfo(const ExprPtr& expr, analysis::Bound* loop_bound_info);

 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  HashProvider hasher_;
  VarBoundInfo var_bound_info_;
};

// Stmt simplification should occur in both modes.
class TORCH_API PolynomialBase : public IRMutator {
 public:
  ~PolynomialBase() override = default;

  StmtPtr mutate(BlockPtr v) override;

  StmtPtr mutate(CondPtr v) override;

  StmtPtr mutate(ForPtr v) override;

  // Trivially factorize terms by GCD of scalar components.
  TermPtr factorizePolynomial(PolynomialPtr poly);

  HashProvider& hasher() {
    return hasher_;
  }

 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  HashProvider hasher_;
};

// Simplify the IR by combining arithmetic expressions over common terms.
class TORCH_API PolynomialTransformer : public PolynomialBase {
 public:
  using PolynomialBase::mutate;
  // Inserts term into the provided map, in the case of a hash collision
  // combines the term with the existing and updates the map.
  void addOrUpdateTerm(
      std::unordered_map<SimplifierHashType, TermPtr>& varmap,
      TermPtr term);

  // Add Polynomial expressions, combining Terms representing the same
  // variables.
  ExprPtr addPolynomials(PolynomialPtr lhs, PolynomialPtr rhs);

  // Insert a new Term into the provided polynomial. If the new term has
  // common variables to an existing term it is combined.
  ExprPtr insertTerm(PolynomialPtr poly, TermPtr term);

  // Merge and simplify addition.
  ExprPtr mutate(AddPtr v) override;

  // Subtract one term from another, cancelling if necessary.
  ExprPtr subTerms(TermPtr lhs, TermPtr rhs, bool negated);

  // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
  // possible.
  ExprPtr subPolynomials(PolynomialPtr lhs, PolynomialPtr rhs);

  // Merge and simplify subtraction.
  ExprPtr mutate(SubPtr v) override;

  // Multiply two terms together, usually creating a new term with the variable
  // lists concatenated.
  TermPtr mulTerms(TermPtr lhs, TermPtr rhs);

  // Multiply a Polynomial by a Term.
  ExprPtr polyByTerm(PolynomialPtr poly, TermPtr term);

  // Match a rounding pattern and create a RoundOff if found.
  ExprPtr isRoundOff(ExprPtr lhs, ExprPtr rhs);

  // Inserts a new component into a term, simplifying if possible.
  ExprPtr insertIntoTerm(TermPtr term, ExprPtr expr);

  // Merge and simplify multiplication.
  ExprPtr mutate(MulPtr v) override;

  ExprPtr mutate(DivPtr v) override;

  ExprPtr mutate(ModPtr v) override;

  ExprPtr mutate(AndPtr v) override;

  ExprPtr mutate(XorPtr v) override;

  ExprPtr mutate(LshiftPtr v) override;

  ExprPtr mutate(RshiftPtr v) override;

  ExprPtr mutate(MaxPtr v) override;

  ExprPtr mutate(MinPtr v) override;

  ExprPtr mutate(CompareSelectPtr v) override;

  ExprPtr mutate(IntrinsicsPtr v) override;

  ExprPtr mutate(CastPtr v) override;

  ExprPtr mutate(IfThenElsePtr v) override;

  static ExprPtr simplify(ExprPtr e);
  static ExprHandle simplify(const ExprHandle& e);
  static StmtPtr simplify(StmtPtr e);
};

// Expands Terms and Polynomial expressions into primitive operations.
// Does some simple factorization and reordering.
class TORCH_API TermExpander : public PolynomialBase {
  PolynomialTransformer* simplifier_;
  std::set<VarPtr> eliminated_allocations_;

 public:
  using PolynomialBase::mutate;
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {}
  bool check_safe() {
    return eliminated_allocations_.empty();
  }

  // Expand Terms out to a series of Muls.
  ExprPtr mutate(TermPtr v) override;

  // Expand Polynomials out to a series of Adds.
  ExprPtr mutate(PolynomialPtr v) override;

  // Expand MaxTerms to a series of Max ops.
  ExprPtr mutate(MaxTermPtr v) override;

  // Expand MinTerms to a series of Min ops.
  ExprPtr mutate(MinTermPtr v) override;

  // Expand RoundOff to it's component: Mul(Div(lhs, rhs), rhs).
  ExprPtr mutate(RoundOffPtr v) override;

  // Eliminate zero length allocations.
  StmtPtr mutate(AllocatePtr v) override;
  StmtPtr mutate(FreePtr v) override;

  // Override to enable condition fusing.
  BlockPtr fuseConditions(BlockPtr v);
  StmtPtr fuseSyncThreads(BlockPtr block);
  StmtPtr mutate(BlockPtr v) override;
};

class TORCH_API IRSimplifier {
 public:
  static StmtPtr simplify(StmtPtr s);
  static ExprPtr simplify(ExprPtr e);
  static ExprHandle simplify(const ExprHandle& e) {
    return ExprHandle(simplify(e.node()));
  }
};

// Flattens the buf and performs the simplifier on the flattened dims.
ExprPtr buf_flat_size(BufPtr v);
// Returns true if expressions A and B can be simplified to an equal expression.
TORCH_API bool exprEquals(ExprPtr A, ExprPtr B);

} // namespace tensorexpr
} // namespace jit
} // namespace torch