#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
namespace torch {
namespace jit {
namespace tensorexpr {
class Expr;
class Var;
class Buf;
class Tensor;
class Function;
class Stmt;
class For;
class Block;
class Store;
class Dtype;
class TORCH_API LoopNest {
public:
// A constructor for building a LoopNest from a list of Tensors
LoopNest(
const std::vector<Tensor>& output_tensors,
const std::vector<Tensor>& tensors_to_compute);
// A convenience constructor for the case when all tensors are output tensors
LoopNest(const std::vector<Tensor>& output_tensors);
// A constructor for building a LoopNest from an Stmt and a list of output
// buffers.
LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs);
// A constructor for building a LoopNest from another loopnest. It clones the
// other loopnest's stmt.
LoopNest(const LoopNest& other);
StmtPtr root_stmt() const {
return root_stmt_;
}
std::vector<ForPtr> getLoopStmtsFor(Tensor) const;
std::vector<ForPtr> getLoopStmtsFor(BufPtr) const;
std::vector<ForPtr> getLoopStmtsFor(StmtPtr) const;
StmtPtr getLoopBodyFor(Tensor) const;
StmtPtr getLoopBodyFor(BufPtr) const;
// Returns the For stmt indexed by 'indices' in the 'root' For stmt.
//'indices' indicates the path to the returned loop from 'root' in AST, e.g.,
//
// root: for(int i...){
// j_loop: for (int j...){
// k1_loop: for (int k1...){
// A[i, j, k1] = ....
// }
// B[i, j] = ...
// k2_loop: for (int k2...){
// A[i, j, k2] = ...
// }
// }
// }
//
// the path from 'root' to 'j_loop' is [0]
// the path from 'root' to 'k1_loop' is [0, 0]
// the path from 'root' to 'k2_loop' is [0, 2]
ForPtr getLoopAt(ForPtr root, const std::vector<int>& indices) const;
// Returns the For stmt that is immediately enclosing the given stmt.
static ForPtr getParentLoop(StmtPtr st);
// Returns the list of For stmts corresponding to the loopnest that is
// enclosing the given stmt.
static std::vector<ForPtr> getEnclosingLoopNest(StmtPtr st);
// Returns a list of all Stmts that write to the given buf.
std::vector<StmtPtr> getAllWritesToBuf(BufPtr) const;
// The following methods return the For loops that contain writes to
// the given buf.
//
// For example, consider the following code:
// for i1
// for j1
// a[i1,j1] =
// for i2
// for j2
// for k2
// a[i2,j2] =
// for j3
// a[i2,j3] =
// Returns a list of For loops which directly contain a Stmt that writes
// to buf.
// For the above example:
// getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3}
std::vector<ForPtr> getAllInnermostLoopsWritingToBuf(BufPtr) const;
// Returns a list of For loopnests which contain a Stmt that writes to
// the given buf. Each loopnest here is a vector For loops.
// For the above example:
// getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}}
std::vector<std::vector<ForPtr>> getAllLoopNestsWritingToBuf(BufPtr) const;
StmtPtr simplify();
// Sanitize variables and buffer names.
// The pass assigns predefined names for loop index variables
// (i,j,k,l,m,n,o,p,i1,j1,k1,...) and ensures these names are not conflicting
// anywhere. It also removes duplicates from other Buf nad Var names as well
// as replaces illegal characters in them with underscores.
//
// Note: since it's currently technically possible to use the same variable
// as index in two different loops, this transformation finds such cases and
// introduces new variables to avoid duplication.
static StmtPtr sanitizeNames(StmtPtr s);
bool computeInline(StmtPtr s);
bool computeInline(BufPtr b);
void inlineIntermediateBufs(bool allow_duplicated_work);
// Optimizes conditionals.
//
// Currently, only the following pattern of conditionals is optimized.
// This corresponds to the conditional format that is generated to handle
// `aten::cat` op.
//
// for (int i = 0; i < 20; i++) {
// A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
// }
//
// Constraints that must be satisfied for this optimization:
// * All conditions should be of the form "var < expr".
// * All conditions should have the same variable, say v.
// * The condition variable found should be the same as the inner-most
// loop variable. TODO: Remove this constraint.
// * If there are multiple stores that contain conditionals using the same
// loop variable, only the first conditional will be optimized.
// TODO: Remove this constraint.
bool optimizeConditionals();
// Splits the given loop into 2 nested loops with the given factor as the
// inner loop bound. If the factor does not evenly divide the loop bound,
// then the remainining iterations are extracted into a tail loop that is
// added after the given loop.
//
// For example, consider the following code:
// for (int i = 0; i < 100; ++i) {
// A[i] =
// }
//
// splitWithTail(i, 8, ...) will result in:
// for (int i_outer = 0; i_outer < 12; ++i_outer) {
// for (int i_inner = 0; i_inner < 8; ++i_inner) {
// A[i_outer * 8 + i_inner] =
// }
// }
// for (int i_tail = 0; i_tail < 4; ++i_tail) {
// A[i_tail + 96] =
// }
//
// The given loop will be transformed to the outer loop after splitting.
// So, the pointer to the input loop should be valid after splitting and
// will point to the outer loop. The `inner` and `tail` parameters will be
// set to point to the inner and tail loops that are generated.
static void splitWithTail(ForPtr f, int factor, ForPtr* inner, ForPtr* tail);
// A convenience wrapper when the caller does not need to access the
// split loops.
static void splitWithTail(ForPtr f, int factor);
// Splits the given loop into 2 nested loops with the given factor as the
// inner loop bound. If the factor does not evenly divide the loop bound,
// then a conditional is inserted into the body to handle the remaining
// iterations appropriately.
//
// For example, consider the following code:
// for (int i = 0; i < 100; ++i) {
// A[i] =
// }
//
// splitWithMask(i, 8, ...) will result in:
// for (int i_outer = 0; i_outer < 13; ++i_outer) {
// for (int i_inner = 0; i_inner < 8; ++i_inner) {
// if (i_outer * 8 + i_inner < 100) {
// A[i_outer * 8 + i_inner] =
// }
// }
// }
//
// The given loop will be transformed to the outer loop after splitting.
// So, the pointer to the input loop should be valid after splitting and
// will point to the outer loop. The `inner` parameter will be set to point
// to the inner loop that is generated.
static void splitWithMask(ForPtr f, int factor, ForPtr* inner);
// A convenience wrapper when the caller does not need to access the
// split loops.
static void splitWithMask(ForPtr f, int factor);
// The following methods support loop distribution.
// For example, consider the following code. This will be used to
// demonstrate the methods below.
//
// S0: for m
// S1: for i
// S2: A[i] = 0
// S3: for j
// S4: A[i] = A[i] +
// S5: B[i] = A[i]
// S6: for k
// S7: B[i] = B[i] +
// This method distributes the given loop over its body by splitting
// after every given pivot stmt.
//
// NOTE: Pivot stmts that are not in the given loop's body will be ignored.
//
// For the above example:
// distributeLoop(S1, {S3, S5})
// will result in:
// S0: for m
// S1: for i
// S2: A[i] = 0
// S3: for j
// S4: A[i] = A[i] +
// : for i
// S5: B[i] = A[i]
// : for i
// S6: for k
// S7: B[i] = B[i] +
static std::vector<ForPtr> distributeLoop(
ForPtr loop,
const std::unordered_set<StmtPtr>& pivots);
// This method distributes the given loop over every stmt in its body.
//
// For the above example:
// distributeLoop(S1)
// will result in:
// S0: for m
// S1: for i
// S2: A[i] = 0
// : for i
// S3: for j
// S4: A[i] = A[i] +
// : for i
// S5: B[i] = A[i]
// : for i
// S6: for k
// S7: B[i] = B[i] +
static std::vector<ForPtr> distributeLoop(ForPtr loop);
// Same as above, but also distribute parent loops.
// Returns the result of distributing the outermost loop.
//
// For the above example:
// distributeLoopAndParents(S1) will result in:
// S0: for m
// S1: for i
// S2: A[i] = 0
// : for m
// : for i
// S3: for j
// S4: A[i] = A[i] +
// : for m
// : for i
// S5: B[i] = A[i]
// : for m
// : for i
// S6: for k
// S7: B[i] = B[i] +
static std::vector<ForPtr> distributeLoopAndParents(ForPtr loop);
// This method distributes the given loop over its body by splitting
// after every For stmt in its body.
//
// For the above example:
// distributeLoopOverInnerLoops(S1)
// will result in:
// S0: for m
// S1: for i
// S2: A[i] = 0
// S3: for j
// S4: A[i] = A[i] +
// : for i
// S5: B[i] = A[i]
// S6: for k
// S7: B[i] = B[i] +
static std::vector<ForPtr> distributeLoopOverInnerLoops(ForPtr loop);
// Same as above, but also distribute parent loops.
// Returns the result of distributing the outermost loop.
//
// For the above example:
// distributeLoopAndParentsOverInnerLoops(S1)
// will result in:
// S0: for m
// S1: for i
// S2: A[i] = 0
// S3: for j
// S4: A[i] = A[i] +
// : for m
// : for i
// S5: B[i] = A[i]
// S6: for k
// S7: B[i] = B[i] +
static std::vector<ForPtr> distributeLoopAndParentsOverInnerLoops(
ForPtr loop);
// This method performs loop fusion.
// For example, consider the following code.
//
// S1: for m
// S2: A[m] = 0
// S3: for j
// S4: A[m] = A[m] +
// S5: for n
// S5: B[n] = A[n]
// S6: for k
// S7: B[n] = B[n] +
//
// fuseLoops({S1, S5}), will return the following loop:
// S1: for m
// S2: A[m] = 0
// S3: for j
// S4: A[m] = A[m] +
// S5: B[m] = A[m]
// S6: for k
// S7: B[m] = B[m] +
//
// This transformation is unsafe as it simply add all loops into the body of
// the first loop for fusion without correctness checks.
//
// Below are the two requirements to apply unsafeFuseLoops:
// * All the loops have the same parent.
// * There are no statements between these loops in their parent body.
static bool unsafeFuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
// Loop fusion is done only when all the conditions below are satisfied.
// * All the loops have the same parent.
// * There are no statements between these loops in their parent body.
// * The start bounds are the same for all loops.
// * The stop bounds are the same for all loops.
// * Fusing the loops does not violate or add any dependencies.
static bool fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
static void reorderAxis(ForPtr a, ForPtr b);
// Reorder the given list of loops according to the permutation specified.
// Here `permutation[i]` represents the position of the loop in the input
// which will end up at position `i` after the reorder.
//
// For example, consider the following code:
// for p
// for q
// for r
// for s
// A[p,q,r,s] =
//
// reorder({p, q, r, s}, {2, 3, 0, 1}) will return the list of loops in the
// following form:
// for r
// for s
// for p
// for q
// A[p,q,r,s] =
static std::vector<ForPtr> reorder(
const std::vector<ForPtr>& loops,
const std::vector<size_t>& permutation);
// Tile takes a 2d domain (x, y) and splits it into small rectangular blocks
// each with shape (x_factor, y_factor). The traversal over the domain turns
// into an outer iteration over the blocks and an inner traversal over all
// points in the block.
// Note that if x dim % x_factor or y dim % y_factor does not equal to 0, the
// loop body will generate corresponding tailing loops.
// The transformation is in-place and returns 'xtail'.
//
// For example, consider the following code:
// for i: [0, 64)
// for j: [0, 64)
// for k: [0, 32)
// A[i, j] = B[i, k] + C[j, k]
//
// tile(i, j, 4, 8) will transform "i" for-stmt into the following nested
// loop:
// for i_outer: [0, 16)
// for j_outer: [0, 8)
// for i_inner: [0, 4)
// for j_inner: [0, 8)
// for k: [0, 32)
// A[i_outer * 4 + i_inner, j_outer * 8 + j_inner] =
// B[i_outer * 4 + i_inner, k] + C[j_outer * 8 + j_inner, k]
//
// tile(i, j, 4, 9) will transform "i" for-stmt into the following nested
// loop:
// for i_outer: [0, 16)
// for j_outer: [0, 7)
// for i_inner: [0, 4)
// for j_inner: [0, 9)
// for k: (0, 32)
// A[i_outer * 4 + i_inner, j_outer * 9 + j_inner] =
// B[i_outer * 4 + i_inner, k] + C[j_outer * 9 + j_inner, k]
// for j_tail: [0, 1)
// for i_inner: [0, 4)
// for k: (0, 32)
// A[i_outer * 4 + i_inner, 7 * 9 + j_tail] =
// B[i_outer * 4 + i_inner, k] + C[7 * 9 + j_tail, k]
ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor);
// Returns true if the given loops are perfectly nested, i.e., every loop
// (except the innermost) should have exactly one statement in its body
// and that statement must be the next inner loop.
static bool areLoopsPerfectlyNested(const std::vector<ForPtr>& loops);
// Returns true if the given loop has a loop-carried dependence.
static bool hasLoopCarriedDependence(ForPtr loop);
// Unrolls all the iterations of the given loop.
// Requires that the loop bounds are constant.
static void fullUnroll(ForPtr f, StmtPtr* unrolled);
static void fullUnroll(ForPtr f);
// Unrolls the given loop for the specified factor.
// This does not require constant bounds for the loop being unrolled.
static void unroll(ForPtr f, int factor, ForPtr* tail);
static void unroll(ForPtr f, int factor);
static bool normalize(ForPtr f);
static bool isNormalized(ForPtr f);
static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened);
static bool flatten(const std::vector<ForPtr>& f);
// Compresses the given buffer based on its use in the given Stmts.
//
// NOTE: This API assumes that there are no accesses to the given buffer
// outside the given statement. So, this should be called with the entire
// kernel statement to avoid incorrect buffer compressions.
//
// For example, given the input:
//
// for (int i = 0; i < 100; ++i) {
// for (int j = 0; j < 200; ++j) {
// A[i,j] = sin(i*j)
// }
// for (int j = 0; j < 199; ++j) {
// B[i,j] = A[i,j] + A[i, j+1]
// }
// }
//
// compressBuffer(A, ...) will compress buffer A from
// [100, 200] to [1, 200] and modify the code as follows:
//
// for (int i = 0; i < 100; ++i) {
// for (int j = 0; j < 200; ++j) {
// A[0,j] = sin(i*j)
// }
// for (int j = 0; j < 199; ++j) {
// B[i,j] = A[0,j] + A[0, j+1]
// }
// }
static void compressBuffer(BufPtr buf, StmtPtr stmt);
// Compresses all buffers in the given statement.
//
// NOTE: This API assumes that there are no accesses to buffers outside
// the given statement. So, this should be called with the entire
// kernel statement to avoid incorrect buffer compressions.
//
// TODO: Add an IR verifier check to detect invalidly compressed buffers.
static void compressAllBuffers(StmtPtr stmt);
// Get 'num' loops from the loopnest starting at 'f'.
static std::vector<ForPtr> getLoopStmtsInLoopNest(ForPtr f, size_t num);
// LoopOptions are propagated to tail.
static void sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
static void sliceHead(ForPtr f, int factor);
// LoopOptions are propagated to head.
static void sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
static void sliceTail(ForPtr f, int factor);
using AccessResult = std::pair<BufPtr, StmtPtr>;
// Insert a cache for the consumer's usages of the buffer produced in
// consumer, and redirect reads and writes in the consumer to that cache.
// Returns a pair of the new cache buffer, and the new rewritten consumer.
static AccessResult cacheAccesses(
BufPtr producer,
const std::string& name,
StmtPtr consumer);
// Insert a temporary computation of statement S in the scope of loop AT.
// S is assumed to be a Store or a Block containing a Store. Along with the
// computation itself, this transformation inserts Alloc/Free statements for
// the temporary buffer used in the computation.
static void computeAt(StmtPtr s, ForPtr at);
// Rfactor a reduction axis into a normal axis.
//
// Requirements:
// * S is the reduction store
// * S is the only statement in the innermost loop
// * There is at least two reduction arguments in S
// * OUTER_REDUCTION_FOR loop corresponds to the outermost reduction variable
// used in the store and all other reduction variables are index variables of
// children loops of OUTER_REDUCTION_FOR
// * OUTER_REDUCTION_FOR is a perfect loop nest, i.e. it has only loops
// corresponding to the other reduction variables and the store, nested into
// each other
//
// What it does:
// * Introduce a new buffer with an extra dimension of a size equal to the
// span of the loop OUTER_REDUCTION_FOR (the new buffer is returned via
// RFAC_BUF_PTR)
// * Insert an initialization store for the new buffer in
// OUTER_REDUCTION_FOR before its nested loop
// * Replace the reduction store to the original buffer with the reduction
// store to the temp buffer, removing the index var of OUTER_REDUCTION_FOR
// from reduction arguments
// * Insert a final reduction store over the extra dimension of the new
// buffer to the original buffer
// * Returns TRUE if the transformation succeeded and FALSE otherwise
//
// Example:
// Original IR:
// S1: for i # normal axis
// S2: X[i] = 0
// S3: for j # reduction axis
// S4: for k # reduction axis
// S5: X[i] = ReduceOp(X[i] + Y[i,j,k], reduce_axis={j,k})
//
// After RFACTOR(S5, S3)
// S1: for i # normal axis
// S2: X[i] = 0
// S3: for j # reduction axis for X, normal axis for X_rfac
// X_rfac[i,j] = 0
// S4: for k # reduction axis
// X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k})
// X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j})
static bool rfactor(StmtPtr s, ForPtr outer_reduction_for);
static bool rfactor(
StmtPtr s,
ForPtr outer_reduction_for,
BufPtr* rfac_buf_ptr);
// Vectorize the given loop. This method requires that the given loop
// does not perform a reduction.
// It returns true if vectorization is successful and false otherwise.
static bool vectorize(ForPtr);
// Find the inner-most loops and vectorize them. Currently, this only works
// for the LLVM backend, when no reductions are involved.
void vectorizeInnerLoops();
void eliminateDeadStores();
void prepareForCodegen();
const std::unordered_set<BufPtr> getInputBufs() const;
const std::unordered_set<BufPtr> getOutputBufs() const {
return output_bufs_;
}
std::vector<BufPtr> getIntermediateBufs() const;
// Finds which is the outer For between a and b for loops. If neither of the 2
// Fors is an ancestor of the other, it returns nullptr.
static ForPtr findOuterFor(ForPtr a, ForPtr b);
private:
void initialize(
const std::vector<Tensor>& output_tensors,
const std::vector<Tensor>& tensors_to_compute);
StmtPtr root_stmt_;
std::unordered_set<BufPtr> output_bufs_;
};
TORCH_API StmtPtr FlattenIndexes(StmtPtr s);
// TODO: Revisit this once we decide on how dependencies analysis should look
// like. Maybe we would choose to use a different API and BufUse would be
// removed, or if we decide to keep it we need to properly document its API.
struct BufLoadOrStoreUse {
StmtPtr s;
bool isStore;
};
/*
* Returns a map ( Buf -> uses of this Buf), uses are represented as vectors of
* BufUse elements, which are StmtPtr and a bool isStore flag. The order of uses
* in the vectors reflects the order in which the uses appear in the given
* statement.
*/
std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
StmtPtr s);
// replaces all invalid characters with underscore
TORCH_API std::string sanitizeName(const std::string& input_name);
} // namespace tensorexpr
} // namespace jit
} // namespace torch