Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / tensorexpr / ir_simplifier.h

#pragma once

#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/hash_provider.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/types.h>

/* IR Simplification
 *
 * Simplfies expressions in two stages:
 *  1. Recursively traverse the map combining similar operations into Terms
 * (interacted via Multiplication) and Polynomials (interacted via Addition). We
 * reorder the components of each Term or Polynomial into a consistent order to
 * allow combination or cancelling of like terms.
 *  2. Once the format of the tree is minimal, expand each Term into a sequence
 * of Muls, and each Polynomial into a sequence of Ads.
 */

namespace torch {
namespace jit {
namespace tensorexpr {

// A bunch of helpers for determine the Dtype of the output of a multi argument
// Term or Polynomial.
template <class ExprType>
Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
  Dtype t = s->dtype();
  bool first = true;

  for (auto* e : v) {
    if (first) {
      t = Dtype(t.scalar_type(), e->dtype().lanes());
      first = false;
    }
    t = promoteTypes(t, e->dtype());
  }
  return t;
}

template <class ExprType>
Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
  if (v.empty()) {
    throw malformed_input("empty list of types");
  }

  Dtype t = v[0]->dtype();
  for (auto* e : v) {
    t = promoteTypes(t, e->dtype());
  }
  return t;
}

template <class ExprType>
Dtype promoteTypesMap(
    const Expr* s,
    std::unordered_map<SimplifierHashType, const ExprType*>& m) {
  Dtype t = s->dtype();
  bool first = true;
  for (auto& e : m) {
    if (first) {
      t = Dtype(t.scalar_type(), e.second->dtype().lanes());
      first = false;
    }
    t = promoteTypes(t, e.second->dtype());
  }
  return t;
}

template <class ExprType>
Dtype promoteTypesVar(const ExprType* e) {
  return e->dtype();
}

template <class ExprType, class... Args>
Dtype promoteTypesVar(const ExprType* e, Args... es) {
  Dtype lhs = e->dtype();
  Dtype rhs = promoteTypesVar(es...);
  if (e->isConstant()) {
    lhs = Dtype(lhs.scalar_type(), rhs.lanes());
  }

  return promoteTypes(lhs, rhs);
}

// Creates a new Expr of the given type with the provided lhs and rhs.
inline const Expr* newBinaryOpOfType(
    IRNodeType expr_type,
    const Expr* lhs,
    const Expr* rhs,
    bool option) {
  switch (expr_type) {
    case IRNodeType::kAdd:
      return new Add(lhs, rhs);
    case IRNodeType::kSub:
      return new Sub(lhs, rhs);
    case IRNodeType::kMul:
      return new Mul(lhs, rhs);
    case IRNodeType::kDiv:
      return new Div(lhs, rhs);
    case IRNodeType::kMod:
      return new Mod(lhs, rhs);
    case IRNodeType::kMax:
      return new Max(lhs, rhs, option);
    case IRNodeType::kMin:
      return new Min(lhs, rhs, option);
    case IRNodeType::kAnd:
      return new And(lhs, rhs);
    case IRNodeType::kXor:
      return new Xor(lhs, rhs);
    case IRNodeType::kLshift:
      return new Lshift(lhs, rhs);
    case IRNodeType::kRshift:
      return new Rshift(lhs, rhs);
    default:
      LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
      return nullptr;
  }
}

// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
inline Expr* evaluateOp(const Expr* v) {
  ExprHandle handle(v);
  ExprEval<SimpleIREvaluator> eval(handle);

  switch (v->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name)                                 \
  case ScalarType::Name: {                                    \
    Type val = eval.value<Type>();                            \
    return getImmediateByType(v->dtype().scalar_type(), val); \
  }
    AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
#undef TYPE_CASE
    default:
      LOG(FATAL) << "Unsupported datatype: " << v->dtype();
      return nullptr;
  }
  return nullptr;
}

// A Term represents a grouping of Exprs through multiplication.
// E.g. product(scalar, *variables).
class Term : public ExprNode<Term> {
 public:
  template <class... Args>
  Term(HashProvider& hasher, const Expr* s, Args... ts)
      : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
    CHECK(s->isConstant());
    addComponent(ts...);
    sort();
  }

  Term(HashProvider& hasher, const Expr* s, std::vector<const Expr*> v)
      : ExprNodeBase(promoteTypesVec(s, v)),
        variables_(std::move(v)),
        scalar_(s),
        hasher_(hasher) {
    sort();
  }

  // Convenience constructor from a map of hash -> var, used when merging Terms.
  Term(
      HashProvider& hasher,
      const Expr* s,
      std::unordered_map<SimplifierHashType, const Expr*> varmap)
      : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
    for (auto& p : varmap) {
      addComponent(p.second);
    }
    sort();
  }

  const Expr* scalar() const {
    return scalar_;
  }
  const std::vector<const Expr*>& variables() const {
    return variables_;
  }
  HashProvider& hasher() const {
    return hasher_;
  }

  // Produce a hash of just the variable components of this term, to determine
  // if it can be combined with another term.
  SimplifierHashType hashVars() const;

 private:
  std::vector<const Expr*> variables_;
  const Expr* scalar_;
  HashProvider& hasher_;

  void addComponent() {}
  void addComponent(const Expr* e) {
    variables_.push_back(e);
  }
  template <class... Es>
  void addComponent(const Expr* e, Es... es) {
    addComponent(e);
    addComponent(es...);
  }

  // Sort by hash to normalize order of components.
  void sort();
};

// Polynomial represents a grouping of Exprs by addition.
// E.g. sum(*variables, scalar).
// This would better be called Expression, but, naming conflict...
class Polynomial : public ExprNode<Polynomial> {
 public:
  template <class... Args>
  Polynomial(HashProvider& hasher, const Expr* s, Args... ts)
      : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
    CHECK(s->isConstant());
    addTerm(ts...);
    sort();
  }

  Polynomial(HashProvider& hasher, const Expr* s, std::vector<const Term*> v)
      : ExprNodeBase(promoteTypesVec(s, v)),
        variables_(std::move(v)),
        scalar_(s),
        hasher_(hasher) {
    sort();
  }

  // Helper constructor for list of terms with no scalar component.
  Polynomial(HashProvider& hasher, std::vector<const Term*> terms)
      : ExprNodeBase(promoteTypesVec(terms)),
        variables_(std::move(terms)),
        scalar_(getImmediateByType(dtype(), 0)),
        hasher_(hasher) {
    sort();
  }

  // Convenience constructor for map of hash -> var, used when merging
  // Polynomials.
  Polynomial(
      HashProvider& hasher,
      const Expr* s,
      std::unordered_map<SimplifierHashType, const Term*> varmap)
      : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
    for (auto& p : varmap) {
      addTerm(p.second);
    }
    sort();
  }

  const Expr* scalar() const {
    return scalar_;
  }
  const std::vector<const Term*>& variables() const {
    return variables_;
  }
  HashProvider& hasher() const {
    return hasher_;
  }

  SimplifierHashType hashVars() const;

 private:
  std::vector<const Term*> variables_;
  const Expr* scalar_;
  HashProvider& hasher_;

  void addTerm(const Term* t) {
    variables_.push_back(t);
  }
  template <class... Ts>
  void addTerm(const Term* t, Ts... ts) {
    addTerm(t);
    addTerm(ts...);
  }

  // Sort by hash to normalize order of terms.
  void sort();
};

class RoundOff : public BinaryOpNode<RoundOff> {
 public:
  RoundOff(const Expr* lhs, const Expr* rhs)
      : BinaryOpNode(lhs, rhs, IRNodeType::kRoundOff) {}
};

class MaxTerm : public ExprNode<MaxTerm> {
 public:
  template <class... Args>
  MaxTerm(HashProvider& hasher, const Expr* s, bool p, Args... ts)
      : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
        scalar_(s),
        hasher_(hasher),
        propagate_nans_(p) {
    addComponent(ts...);
    uniquefy();
  }

  MaxTerm(
      HashProvider& hasher,
      const Expr* s,
      bool p,
      std::vector<const Expr*> v)
      : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
        variables_(std::move(v)),
        scalar_(s),
        hasher_(hasher),
        propagate_nans_(p) {
    uniquefy();
  }

  bool propagate_nans() const {
    return propagate_nans_;
  }

  const Expr* scalar() const {
    return scalar_;
  }
  const std::vector<const Expr*>& variables() const {
    return variables_;
  }
  HashProvider& hasher() const {
    return hasher_;
  }

 private:
  std::vector<const Expr*> variables_;
  const Expr* scalar_;
  HashProvider& hasher_;
  bool propagate_nans_;

  void addComponent() {}
  void addComponent(const Expr* e) {
    variables_.push_back(e);
  }
  template <class... Es>
  void addComponent(const Expr* e, Es... es) {
    addComponent(e);
    addComponent(es...);
  }

  // Uniquefy the terms using their hash.
  void uniquefy();
};
Loading ...