#pragma once
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/hash_provider.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/types.h>
/* IR Simplification
*
* Simplfies expressions in two stages:
* 1. Recursively traverse the map combining similar operations into Terms
* (interacted via Multiplication) and Polynomials (interacted via Addition). We
* reorder the components of each Term or Polynomial into a consistent order to
* allow combination or cancelling of like terms.
* 2. Once the format of the tree is minimal, expand each Term into a sequence
* of Muls, and each Polynomial into a sequence of Ads.
*/
namespace torch {
namespace jit {
namespace tensorexpr {
// A bunch of helpers for determine the Dtype of the output of a multi argument
// Term or Polynomial.
template <class ExprType>
Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
Dtype t = s->dtype();
bool first = true;
for (auto* e : v) {
if (first) {
t = Dtype(t.scalar_type(), e->dtype().lanes());
first = false;
}
t = promoteTypes(t, e->dtype());
}
return t;
}
template <class ExprType>
Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
if (v.empty()) {
throw malformed_input("empty list of types");
}
Dtype t = v[0]->dtype();
for (auto* e : v) {
t = promoteTypes(t, e->dtype());
}
return t;
}
template <class ExprType>
Dtype promoteTypesMap(
const Expr* s,
std::unordered_map<SimplifierHashType, const ExprType*>& m) {
Dtype t = s->dtype();
bool first = true;
for (auto& e : m) {
if (first) {
t = Dtype(t.scalar_type(), e.second->dtype().lanes());
first = false;
}
t = promoteTypes(t, e.second->dtype());
}
return t;
}
template <class ExprType>
Dtype promoteTypesVar(const ExprType* e) {
return e->dtype();
}
template <class ExprType, class... Args>
Dtype promoteTypesVar(const ExprType* e, Args... es) {
Dtype lhs = e->dtype();
Dtype rhs = promoteTypesVar(es...);
if (e->isConstant()) {
lhs = Dtype(lhs.scalar_type(), rhs.lanes());
}
return promoteTypes(lhs, rhs);
}
// Creates a new Expr of the given type with the provided lhs and rhs.
inline const Expr* newBinaryOpOfType(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs,
bool option) {
switch (expr_type) {
case IRNodeType::kAdd:
return new Add(lhs, rhs);
case IRNodeType::kSub:
return new Sub(lhs, rhs);
case IRNodeType::kMul:
return new Mul(lhs, rhs);
case IRNodeType::kDiv:
return new Div(lhs, rhs);
case IRNodeType::kMod:
return new Mod(lhs, rhs);
case IRNodeType::kMax:
return new Max(lhs, rhs, option);
case IRNodeType::kMin:
return new Min(lhs, rhs, option);
case IRNodeType::kAnd:
return new And(lhs, rhs);
case IRNodeType::kXor:
return new Xor(lhs, rhs);
case IRNodeType::kLshift:
return new Lshift(lhs, rhs);
case IRNodeType::kRshift:
return new Rshift(lhs, rhs);
default:
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
return nullptr;
}
}
// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
inline Expr* evaluateOp(const Expr* v) {
ExprHandle handle(v);
ExprEval<SimpleIREvaluator> eval(handle);
switch (v->dtype().scalar_type()) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: { \
Type val = eval.value<Type>(); \
return getImmediateByType(v->dtype().scalar_type(), val); \
}
AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
#undef TYPE_CASE
default:
LOG(FATAL) << "Unsupported datatype: " << v->dtype();
return nullptr;
}
return nullptr;
}
// A Term represents a grouping of Exprs through multiplication.
// E.g. product(scalar, *variables).
class Term : public ExprNode<Term> {
public:
template <class... Args>
Term(HashProvider& hasher, const Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addComponent(ts...);
sort();
}
Term(HashProvider& hasher, const Expr* s, std::vector<const Expr*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
hasher_(hasher) {
sort();
}
// Convenience constructor from a map of hash -> var, used when merging Terms.
Term(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Expr*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addComponent(p.second);
}
sort();
}
const Expr* scalar() const {
return scalar_;
}
const std::vector<const Expr*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
return hasher_;
}
// Produce a hash of just the variable components of this term, to determine
// if it can be combined with another term.
SimplifierHashType hashVars() const;
private:
std::vector<const Expr*> variables_;
const Expr* scalar_;
HashProvider& hasher_;
void addComponent() {}
void addComponent(const Expr* e) {
variables_.push_back(e);
}
template <class... Es>
void addComponent(const Expr* e, Es... es) {
addComponent(e);
addComponent(es...);
}
// Sort by hash to normalize order of components.
void sort();
};
// Polynomial represents a grouping of Exprs by addition.
// E.g. sum(*variables, scalar).
// This would better be called Expression, but, naming conflict...
class Polynomial : public ExprNode<Polynomial> {
public:
template <class... Args>
Polynomial(HashProvider& hasher, const Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addTerm(ts...);
sort();
}
Polynomial(HashProvider& hasher, const Expr* s, std::vector<const Term*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
hasher_(hasher) {
sort();
}
// Helper constructor for list of terms with no scalar component.
Polynomial(HashProvider& hasher, std::vector<const Term*> terms)
: ExprNodeBase(promoteTypesVec(terms)),
variables_(std::move(terms)),
scalar_(getImmediateByType(dtype(), 0)),
hasher_(hasher) {
sort();
}
// Convenience constructor for map of hash -> var, used when merging
// Polynomials.
Polynomial(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Term*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addTerm(p.second);
}
sort();
}
const Expr* scalar() const {
return scalar_;
}
const std::vector<const Term*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
return hasher_;
}
SimplifierHashType hashVars() const;
private:
std::vector<const Term*> variables_;
const Expr* scalar_;
HashProvider& hasher_;
void addTerm(const Term* t) {
variables_.push_back(t);
}
template <class... Ts>
void addTerm(const Term* t, Ts... ts) {
addTerm(t);
addTerm(ts...);
}
// Sort by hash to normalize order of terms.
void sort();
};
class RoundOff : public BinaryOpNode<RoundOff> {
public:
RoundOff(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kRoundOff) {}
};
class MaxTerm : public ExprNode<MaxTerm> {
public:
template <class... Args>
MaxTerm(HashProvider& hasher, const Expr* s, bool p, Args... ts)
: ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
scalar_(s),
hasher_(hasher),
propagate_nans_(p) {
addComponent(ts...);
uniquefy();
}
MaxTerm(
HashProvider& hasher,
const Expr* s,
bool p,
std::vector<const Expr*> v)
: ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
variables_(std::move(v)),
scalar_(s),
hasher_(hasher),
propagate_nans_(p) {
uniquefy();
}
bool propagate_nans() const {
return propagate_nans_;
}
const Expr* scalar() const {
return scalar_;
}
const std::vector<const Expr*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
return hasher_;
}
private:
std::vector<const Expr*> variables_;
const Expr* scalar_;
HashProvider& hasher_;
bool propagate_nans_;
void addComponent() {}
void addComponent(const Expr* e) {
variables_.push_back(e);
}
template <class... Es>
void addComponent(const Expr* e, Es... es) {
addComponent(e);
addComponent(es...);
}
// Uniquefy the terms using their hash.
void uniquefy();
};
Loading ...