forked from OSchip/llvm-project
Extend loop unroll/unroll-and-jam to affine bounds + refactor related code.
- extend loop unroll-jam similar to loop unroll for affine bounds - extend both loop unroll/unroll-jam to deal with cleanup loop for non multiple of unroll factor. - extend promotion of single iteration loops to work with affine bounds - fix typo bugs in loop unroll - refactor common code b/w loop unroll and loop unroll-jam - move prototypes of non-pass transforms to LoopUtils.h - add additional builder methods. - introduce loopUnrollUpTo(factor) to unroll by either factor or trip count, whichever is less. - remove Statement::isInnermost (not used for now - will come back at the right place/in right form later) PiperOrigin-RevId: 213471227
This commit is contained in:
parent
7103779fb8
commit
ab4797229c
|
@ -32,7 +32,7 @@ class ForStmt;
|
|||
/// Returns the trip count of the loop as an affine expression if the latter is
|
||||
/// expressible as an affine expression, and nullptr otherwise. The trip count
|
||||
/// expression is simplified before returning.
|
||||
AffineExpr *getTripCount(const ForStmt &forStmt);
|
||||
AffineExpr *getTripCountExpr(const ForStmt &forStmt);
|
||||
|
||||
/// Returns the trip count of the loop if it's a constant, None otherwise. This
|
||||
/// uses affine expression analysis and is able to determine constant trip count
|
||||
|
|
|
@ -101,7 +101,9 @@ public:
|
|||
AffineSymbolExpr *getSymbolExpr(unsigned position);
|
||||
AffineConstantExpr *getConstantExpr(int64_t constant);
|
||||
AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
|
||||
AffineExpr *getAddExpr(AffineExpr *lhs, int64_t rhs);
|
||||
AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
|
||||
AffineExpr *getSubExpr(AffineExpr *lhs, int64_t rhs);
|
||||
AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
|
||||
AffineExpr *getMulExpr(AffineExpr *lhs, int64_t rhs);
|
||||
AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);
|
||||
|
|
|
@ -82,9 +82,6 @@ public:
|
|||
/// Returns nullptr if the statement is unlinked.
|
||||
MLFunction *findFunction() const;
|
||||
|
||||
/// Returns true if there are no more loops nested under this stmt.
|
||||
bool isInnermost() const;
|
||||
|
||||
/// Destroys this statement and its subclass data.
|
||||
void destroy();
|
||||
|
||||
|
|
|
@ -275,6 +275,10 @@ public:
|
|||
/// Sets the upper bound to the given constant value.
|
||||
void setConstantUpperBound(int64_t value);
|
||||
|
||||
/// Returns true if both the lower and upper bound have the same operand lists
|
||||
/// (same operands in the same order).
|
||||
bool matchingBoundOperandList() const;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operands
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -343,7 +347,9 @@ private:
|
|||
AffineMap *ubMap;
|
||||
// Constant step.
|
||||
int64_t step;
|
||||
// Operands for the lower and upper bounds.
|
||||
// Operands for the lower and upper bounds, with the former followed by the
|
||||
// latter. Dimensional operands are followed by symbolic operands for each
|
||||
// bound.
|
||||
std::vector<StmtOperand> operands;
|
||||
|
||||
explicit ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
//===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This header file defines prototypes for various loop transformation utility
|
||||
// methods: these are not passes by themselves but are used either by passes,
|
||||
// optimization sequences, or in turn by other transformation utilities.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TRANSFORMS_LOOP_UTILS_H
|
||||
#define MLIR_TRANSFORMS_LOOP_UTILS_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineMap;
|
||||
class ForStmt;
|
||||
class MLFunction;
|
||||
class MLFuncBuilder;
|
||||
|
||||
/// Unrolls this for statement completely if the trip count is known to be
|
||||
/// constant. Returns false otherwise.
|
||||
bool loopUnrollFull(ForStmt *forStmt);
|
||||
/// Unrolls this for statement by the specified unroll factor. Returns false if
|
||||
/// the loop cannot be unrolled either due to restrictions or due to invalid
|
||||
/// unroll factors.
|
||||
bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
|
||||
/// Unrolls this loop by the specified unroll factor or its trip count,
|
||||
/// whichever is lower.
|
||||
bool loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor);
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor. Returns true if the loop
|
||||
/// is successfully unroll-jammed.
|
||||
bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor or by the trip count (if
|
||||
/// constant), whichever is lower.
|
||||
bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
|
||||
|
||||
/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
|
||||
/// was known to have a single iteration. Returns false otherwise.
|
||||
bool promoteIfSingleIteration(ForStmt *forStmt);
|
||||
|
||||
/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
|
||||
/// their body into the containing StmtBlock.
|
||||
void promoteSingleIterationLoops(MLFunction *f);
|
||||
|
||||
/// Returns the lower bound of the cleanup loop when unrolling a loop
|
||||
/// with the specified unroll factor.
|
||||
AffineMap *getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
|
||||
/// Returns the upper bound of an unrolled loop when unrolling with
|
||||
/// the specified trip count, stride, and unroll factor.
|
||||
AffineMap *getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
|
|
@ -27,9 +27,7 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class ForStmt;
|
||||
class FunctionPass;
|
||||
class MLFunction;
|
||||
class MLFunctionPass;
|
||||
class ModulePass;
|
||||
|
||||
|
@ -38,19 +36,11 @@ class ModulePass;
|
|||
MLFunctionPass *createLoopUnrollPass(int unrollFactor = -1,
|
||||
int unrollFull = -1);
|
||||
|
||||
/// Unrolls this loop completely.
|
||||
bool loopUnrollFull(ForStmt *forStmt);
|
||||
/// Unrolls this loop by the specified unroll factor.
|
||||
bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
|
||||
|
||||
/// Creates a loop unroll jam pass to unroll jam by the specified factor. A
|
||||
/// factor of -1 lets the pass use the default factor or the one on the command
|
||||
/// line if provided.
|
||||
MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor.
|
||||
bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
|
||||
|
||||
/// Creates an affine expression simplification pass.
|
||||
FunctionPass *createSimplifyAffineExprPass();
|
||||
|
||||
|
@ -59,14 +49,6 @@ FunctionPass *createSimplifyAffineExprPass();
|
|||
/// generated CFG functions.
|
||||
ModulePass *createConvertToCFGPass();
|
||||
|
||||
/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
|
||||
/// was known to have a single iteration. Returns false otherwise.
|
||||
bool promoteIfSingleIteration(ForStmt *forStmt);
|
||||
|
||||
/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
|
||||
/// their body into the containing StmtBlock.
|
||||
void promoteSingleIterationLoops(MLFunction *f);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_LOOP_H
|
||||
#endif // MLIR_TRANSFORMS_PASSES_H
|
||||
|
|
|
@ -31,7 +31,7 @@ using mlir::AffineExpr;
|
|||
/// Returns the trip count of the loop as an affine expression if the latter is
|
||||
/// expressible as an affine expression, and nullptr otherwise. The trip count
|
||||
/// expression is simplified before returning.
|
||||
AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
|
||||
AffineExpr *mlir::getTripCountExpr(const ForStmt &forStmt) {
|
||||
// upper_bound - lower_bound + 1
|
||||
int64_t loopSpan;
|
||||
|
||||
|
@ -43,32 +43,22 @@ AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
|
|||
int64_t ub = forStmt.getConstantUpperBound();
|
||||
loopSpan = ub - lb + 1;
|
||||
} else {
|
||||
const AffineBound lb = forStmt.getLowerBound();
|
||||
const AffineBound ub = forStmt.getUpperBound();
|
||||
auto lbMap = lb.getMap();
|
||||
auto ubMap = ub.getMap();
|
||||
auto *lbMap = forStmt.getLowerBoundMap();
|
||||
auto *ubMap = forStmt.getUpperBoundMap();
|
||||
// TODO(bondhugula): handle max/min of multiple expressions.
|
||||
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1 ||
|
||||
lbMap->getNumDims() != ubMap->getNumDims() ||
|
||||
lbMap->getNumSymbols() != ubMap->getNumSymbols()) {
|
||||
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO(bondhugula): handle bounds with different operands.
|
||||
unsigned i, e = lb.getNumOperands();
|
||||
for (i = 0; i < e; i++) {
|
||||
if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
|
||||
break;
|
||||
}
|
||||
// Bounds have different operands, unhandled for now.
|
||||
if (i != e)
|
||||
if (!forStmt.matchingBoundOperandList())
|
||||
return nullptr;
|
||||
|
||||
// ub_expr - lb_expr + 1
|
||||
auto *lbExpr = lbMap->getResult(0);
|
||||
auto *ubExpr = ubMap->getResult(0);
|
||||
auto *loopSpanExpr = AffineBinaryOpExpr::getAdd(
|
||||
AffineBinaryOpExpr::getSub(ubMap->getResult(0), lbMap->getResult(0),
|
||||
context),
|
||||
1, context);
|
||||
AffineBinaryOpExpr::getSub(ubExpr, lbExpr, context), 1, context);
|
||||
|
||||
if (auto *expr = simplifyAffineExpr(loopSpanExpr, lbMap->getNumDims(),
|
||||
lbMap->getNumSymbols(), context))
|
||||
|
@ -95,7 +85,7 @@ AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
|
|||
/// method uses affine expression analysis (in turn using getTripCount) and is
|
||||
/// able to determine constant trip count in non-trivial cases.
|
||||
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
|
||||
AffineExpr *tripCountExpr = getTripCount(forStmt);
|
||||
AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
|
||||
|
||||
if (auto *constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
|
||||
return constExpr->getValue();
|
||||
|
@ -107,7 +97,7 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
|
|||
/// expression analysis is used (indirectly through getTripCount), and
|
||||
/// this method is thus able to determine non-trivial divisors.
|
||||
uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
|
||||
AffineExpr *tripCountExpr = getTripCount(forStmt);
|
||||
AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
|
||||
|
||||
if (!tripCountExpr)
|
||||
return 1;
|
||||
|
|
|
@ -157,6 +157,10 @@ AffineExpr *Builder::getAddExpr(AffineExpr *lhs, AffineExpr *rhs) {
|
|||
return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context);
|
||||
}
|
||||
|
||||
AffineExpr *Builder::getAddExpr(AffineExpr *lhs, int64_t rhs) {
|
||||
return AffineBinaryOpExpr::getAdd(lhs, rhs, context);
|
||||
}
|
||||
|
||||
AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) {
|
||||
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context);
|
||||
}
|
||||
|
@ -171,6 +175,10 @@ AffineExpr *Builder::getSubExpr(AffineExpr *lhs, AffineExpr *rhs) {
|
|||
return getAddExpr(lhs, getMulExpr(rhs, getConstantExpr(-1)));
|
||||
}
|
||||
|
||||
AffineExpr *Builder::getSubExpr(AffineExpr *lhs, int64_t rhs) {
|
||||
return AffineBinaryOpExpr::getAdd(lhs, -rhs, context);
|
||||
}
|
||||
|
||||
AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) {
|
||||
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context);
|
||||
}
|
||||
|
|
|
@ -84,18 +84,6 @@ MLFunction *Statement::findFunction() const {
|
|||
return block ? block->findFunction() : nullptr;
|
||||
}
|
||||
|
||||
bool Statement::isInnermost() const {
|
||||
struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
|
||||
unsigned numNestedLoops;
|
||||
NestedLoopCounter() : numNestedLoops(0) {}
|
||||
void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
|
||||
};
|
||||
|
||||
NestedLoopCounter nlc;
|
||||
nlc.walk(const_cast<Statement *>(this));
|
||||
return nlc.numNestedLoops == 1;
|
||||
}
|
||||
|
||||
MLValue *Statement::getOperand(unsigned idx) {
|
||||
return getStmtOperand(idx).get();
|
||||
}
|
||||
|
@ -361,6 +349,20 @@ void ForStmt::setConstantUpperBound(int64_t value) {
|
|||
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
|
||||
}
|
||||
|
||||
bool ForStmt::matchingBoundOperandList() const {
|
||||
if (lbMap->getNumDims() != ubMap->getNumDims() ||
|
||||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
|
||||
return false;
|
||||
|
||||
unsigned numOperands = lbMap->getNumInputs();
|
||||
for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
|
||||
// Compare MLValue *'s.
|
||||
if (getOperand(i) != getOperand(numOperands + i))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
@ -176,76 +177,41 @@ bool mlir::loopUnrollFull(ForStmt *forStmt) {
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
|
||||
/// the specified trip count, stride, and unroll factor.
|
||||
static AffineMap *getUnrolledLoopUpperBound(AffineMap *lbMap,
|
||||
uint64_t tripCount,
|
||||
unsigned unrollFactor, int64_t step,
|
||||
MLFuncBuilder *builder) {
|
||||
assert(lbMap->getNumResults() == 1);
|
||||
auto *lbExpr = lbMap->getResult(0);
|
||||
// lbExpr + (count - count % unrollFactor - 1) * step).
|
||||
auto *expr = builder->getAddExpr(
|
||||
lbExpr, builder->getConstantExpr(
|
||||
(tripCount - tripCount % unrollFactor - 1) * step));
|
||||
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
|
||||
{expr}, {});
|
||||
/// Unrolls and jams this loop by the specified factor or by the trip count (if
|
||||
/// constant) whichever is lower.
|
||||
bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() < unrollFactor)
|
||||
return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
|
||||
return loopUnrollByFactor(forStmt, unrollFactor);
|
||||
}
|
||||
|
||||
/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
|
||||
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
|
||||
static AffineMap *getCleanupLoopLowerBound(AffineMap *lbMap, uint64_t tripCount,
|
||||
unsigned unrollFactor, int64_t step,
|
||||
MLFuncBuilder *builder) {
|
||||
assert(lbMap->getNumResults() == 1);
|
||||
auto *lbExpr = lbMap->getResult(0);
|
||||
// lbExpr + (count - count % unrollFactor) * step);
|
||||
auto *expr = builder->getAddExpr(
|
||||
lbExpr,
|
||||
builder->getConstantExpr((tripCount - tripCount % unrollFactor) * step));
|
||||
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
|
||||
{expr}, {});
|
||||
}
|
||||
|
||||
/// Unrolls this loop by the specified unroll factor.
|
||||
/// Unrolls this loop by the specified factor. Returns true if the loop
|
||||
/// is successfully unrolled.
|
||||
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
||||
assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
|
||||
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
|
||||
|
||||
if (unrollFactor == 1 || forStmt->getStatements().empty())
|
||||
return false;
|
||||
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
|
||||
if (!mayBeConstantTripCount.hasValue() &&
|
||||
getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0)
|
||||
return false;
|
||||
|
||||
const AffineBound &lb = forStmt->getLowerBound();
|
||||
const AffineBound &ub = forStmt->getLowerBound();
|
||||
auto lbMap = lb.getMap();
|
||||
auto ubMap = lb.getMap();
|
||||
auto *lbMap = forStmt->getLowerBoundMap();
|
||||
auto *ubMap = forStmt->getUpperBoundMap();
|
||||
|
||||
// Loops with max/min expressions won't be unrolled here (the output can't be
|
||||
// expressed as an MLFunction in the general case). However, the right way to
|
||||
// do such unrolling for an MLFunction would be to specialize the loop for the
|
||||
// 'hotspot' case and unroll that hotspot case.
|
||||
// 'hotspot' case and unroll that hotspot.
|
||||
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
// TODO(bondhugula): handle bounds with different sets of operands.
|
||||
// Same operand list for now.
|
||||
if (lbMap->getNumDims() != ubMap->getNumDims() ||
|
||||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
|
||||
return false;
|
||||
unsigned i, e = lb.getNumOperands();
|
||||
for (i = 0; i < e; i++) {
|
||||
if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
|
||||
break;
|
||||
}
|
||||
if (i != e)
|
||||
// Same operand list for lower and upper bound for now.
|
||||
// TODO(bondhugula): handle bounds with different operand lists.
|
||||
if (!forStmt->matchingBoundOperandList())
|
||||
return false;
|
||||
|
||||
int64_t step = forStmt->getStep();
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
|
||||
// If the trip count is lower than the unroll factor, no unrolled body.
|
||||
// TODO(bondhugula): option to specify cleanup loop unrolling.
|
||||
|
@ -254,43 +220,29 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
return false;
|
||||
|
||||
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
|
||||
// If the trip count is unknown, we currently unroll only when the unknown
|
||||
// trip count is known to be a multiple of unroll factor - hence, no cleanup
|
||||
// loop will be necessary in those cases.
|
||||
// TODO(bondhugula): handle generation of cleanup loop for unknown trip count
|
||||
// when it's not known to be a multiple of unroll factor (still for single
|
||||
// result / same operands case).
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() % unrollFactor != 0) {
|
||||
uint64_t tripCount = mayBeConstantTripCount.getValue();
|
||||
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
if (forStmt->hasConstantLowerBound()) {
|
||||
cleanupForStmt->setConstantLowerBound(
|
||||
forStmt->getConstantLowerBound() +
|
||||
(tripCount - tripCount % unrollFactor) * step);
|
||||
} else {
|
||||
cleanupForStmt->setLowerBoundMap(
|
||||
getCleanupLoopLowerBound(forStmt->getLowerBoundMap(), tripCount,
|
||||
unrollFactor, step, &builder));
|
||||
}
|
||||
auto *clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
|
||||
assert(clLbMap &&
|
||||
"cleanup loop lower bound map for single result bound maps can "
|
||||
"always be determined");
|
||||
cleanupForStmt->setLowerBoundMap(clLbMap);
|
||||
// Promote the loop body up if this has turned into a single iteration loop.
|
||||
promoteIfSingleIteration(cleanupForStmt);
|
||||
|
||||
// The upper bound needs to be adjusted.
|
||||
if (forStmt->hasConstantUpperBound()) {
|
||||
forStmt->setConstantUpperBound(
|
||||
forStmt->getConstantLowerBound() +
|
||||
(tripCount - tripCount % unrollFactor - 1) * step);
|
||||
} else {
|
||||
forStmt->setUpperBoundMap(
|
||||
getUnrolledLoopUpperBound(forStmt->getLowerBoundMap(), tripCount,
|
||||
unrollFactor, step, &builder));
|
||||
}
|
||||
// Adjust upper bound.
|
||||
auto *unrolledUbMap =
|
||||
getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
|
||||
assert(unrolledUbMap &&
|
||||
"upper bound map can alwayys be determined for an unrolled loop "
|
||||
"with single result bounds");
|
||||
forStmt->setUpperBoundMap(unrolledUbMap);
|
||||
}
|
||||
|
||||
// Scale the step of loop being unrolled by unroll factor.
|
||||
int64_t step = forStmt->getStep();
|
||||
forStmt->setStep(step * unrollFactor);
|
||||
|
||||
// Builder to insert unrolled bodies right after the last statement in the
|
||||
|
|
|
@ -41,15 +41,16 @@
|
|||
//
|
||||
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
|
||||
// stmt's, bodies of those loops will not be jammed.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
@ -108,6 +109,15 @@ bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
|
|||
return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
|
||||
}
|
||||
|
||||
bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() < unrollJamFactor)
|
||||
return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
|
||||
return loopUnrollJamByFactor(forStmt, unrollJamFactor);
|
||||
}
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor.
|
||||
bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
||||
// Gathers all maximal sub-blocks of statements that do not themselves include
|
||||
|
@ -140,19 +150,32 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
if (unrollJamFactor == 1 || forStmt->getStatements().empty())
|
||||
return false;
|
||||
|
||||
Optional<uint64_t> mayTripCount = getConstantTripCount(*forStmt).getValue();
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
|
||||
if (!mayTripCount.hasValue())
|
||||
if (!mayBeConstantTripCount.hasValue() &&
|
||||
getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
|
||||
return false;
|
||||
|
||||
uint64_t tripCount = mayTripCount.getValue();
|
||||
int64_t lb = forStmt->getConstantLowerBound();
|
||||
int64_t step = forStmt->getStep();
|
||||
auto *lbMap = forStmt->getLowerBoundMap();
|
||||
auto *ubMap = forStmt->getUpperBoundMap();
|
||||
|
||||
// If the trip count is lower than the unroll jam factor, no unrolled body.
|
||||
// Loops with max/min expressions won't be unrolled here (the output can't be
|
||||
// expressed as an MLFunction in the general case). However, the right way to
|
||||
// do such unrolling for an MLFunction would be to specialize the loop for the
|
||||
// 'hotspot' case and unroll that hotspot.
|
||||
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
// Same operand list for lower and upper bound for now.
|
||||
// TODO(bondhugula): handle bounds with different sets of operands.
|
||||
if (!forStmt->matchingBoundOperandList())
|
||||
return false;
|
||||
|
||||
// If the trip count is lower than the unroll jam factor, no unroll jam.
|
||||
// TODO(bondhugula): option to specify cleanup loop unrolling.
|
||||
if (tripCount < unrollJamFactor)
|
||||
return true;
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() < unrollJamFactor)
|
||||
return false;
|
||||
|
||||
// Gather all sub-blocks to jam upon the loop being unrolled.
|
||||
JamBlockGatherer jbg;
|
||||
|
@ -161,23 +184,27 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
|
||||
// Generate the cleanup loop if trip count isn't a multiple of
|
||||
// unrollJamFactor.
|
||||
if (tripCount % unrollJamFactor) {
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
// Insert the cleanup loop right after 'forStmt'.
|
||||
MLFuncBuilder builder(forStmt->getBlock(),
|
||||
std::next(StmtBlock::iterator(forStmt)));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
cleanupForStmt->setConstantLowerBound(
|
||||
lb + (tripCount - tripCount % unrollJamFactor) * step);
|
||||
cleanupForStmt->setLowerBoundMap(
|
||||
getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
|
||||
|
||||
// The upper bound needs to be adjusted.
|
||||
forStmt->setUpperBoundMap(
|
||||
getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
|
||||
|
||||
// Promote the loop body up if this has turned into a single iteration loop.
|
||||
promoteIfSingleIteration(cleanupForStmt);
|
||||
}
|
||||
|
||||
MLFuncBuilder b(forStmt);
|
||||
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
|
||||
int64_t step = forStmt->getStep();
|
||||
forStmt->setStep(step * unrollJamFactor);
|
||||
forStmt->setConstantUpperBound(
|
||||
lb + (tripCount - tripCount % unrollJamFactor - 1) * step);
|
||||
|
||||
for (auto &subBlock : subBlocks) {
|
||||
// Builder to insert unroll-jammed bodies. Insert right at the end of
|
||||
|
|
|
@ -19,32 +19,129 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
|
||||
/// the specified trip count, stride, and unroll factor. Returns nullptr when
|
||||
/// the trip count can't be expressed as an affine expression.
|
||||
AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
auto *lbMap = forStmt.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
if (lbMap->getNumResults() != 1)
|
||||
return nullptr;
|
||||
|
||||
// Sometimes, the trip count cannot be expressed as an affine expression.
|
||||
auto *tripCountExpr = getTripCountExpr(forStmt);
|
||||
if (!tripCountExpr)
|
||||
return nullptr;
|
||||
|
||||
AffineExpr *newUbExpr;
|
||||
auto *lbExpr = lbMap->getResult(0);
|
||||
int64_t step = forStmt.getStep();
|
||||
// lbExpr + (count - count % unrollFactor - 1) * step).
|
||||
if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
|
||||
uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
|
||||
newUbExpr = builder->getAddExpr(
|
||||
lbExpr, builder->getConstantExpr(
|
||||
(tripCount - tripCount % unrollFactor - 1) * step));
|
||||
} else {
|
||||
newUbExpr = builder->getAddExpr(
|
||||
lbExpr, builder->getMulExpr(
|
||||
builder->getSubExpr(
|
||||
builder->getSubExpr(
|
||||
tripCountExpr,
|
||||
builder->getModExpr(tripCountExpr, unrollFactor)),
|
||||
1),
|
||||
step));
|
||||
}
|
||||
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
|
||||
{newUbExpr}, {});
|
||||
}
|
||||
|
||||
/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
|
||||
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
|
||||
/// Returns nullptr when the trip count can't be expressed as an affine
|
||||
/// expression.
|
||||
AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
auto *lbMap = forStmt.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
if (lbMap->getNumResults() != 1)
|
||||
return nullptr;
|
||||
|
||||
// Sometimes the trip count cannot be expressed as an affine expression.
|
||||
auto *tripCountExpr = getTripCountExpr(forStmt);
|
||||
if (!tripCountExpr)
|
||||
return nullptr;
|
||||
|
||||
AffineExpr *newLbExpr;
|
||||
auto *lbExpr = lbMap->getResult(0);
|
||||
int64_t step = forStmt.getStep();
|
||||
|
||||
// lbExpr + (count - count % unrollFactor) * step);
|
||||
if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
|
||||
uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
|
||||
newLbExpr = builder->getAddExpr(
|
||||
lbExpr, builder->getConstantExpr(
|
||||
(tripCount - tripCount % unrollFactor) * step));
|
||||
} else {
|
||||
newLbExpr = builder->getAddExpr(
|
||||
lbExpr, builder->getMulExpr(
|
||||
builder->getSubExpr(
|
||||
tripCountExpr,
|
||||
builder->getModExpr(tripCountExpr, unrollFactor)),
|
||||
step));
|
||||
}
|
||||
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
|
||||
{newLbExpr}, {});
|
||||
}
|
||||
|
||||
/// Promotes the loop body of a forStmt to its containing block if the forStmt
|
||||
/// was known to have a single iteration. Returns false otherwise.
|
||||
// TODO(bondhugula): extend this for arbitrary affine bounds.
|
||||
bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
|
||||
Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
|
||||
if (!tripCount.hasValue() || !forStmt->hasConstantLowerBound())
|
||||
if (!tripCount.hasValue() || tripCount.getValue() != 1)
|
||||
return false;
|
||||
|
||||
if (tripCount.getValue() != 1)
|
||||
// TODO(mlir-team): there is no builder for a max.
|
||||
if (forStmt->getLowerBoundMap()->getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
// Replaces all IV uses to its single iteration value.
|
||||
if (!forStmt->use_empty()) {
|
||||
if (forStmt->hasConstantLowerBound()) {
|
||||
auto *mlFunc = forStmt->findFunction();
|
||||
MLFuncBuilder topBuilder(&mlFunc->front());
|
||||
auto constOp = topBuilder.create<ConstantAffineIntOp>(
|
||||
forStmt->getLoc(), forStmt->getConstantLowerBound());
|
||||
forStmt->replaceAllUsesWith(constOp->getResult());
|
||||
// Move the statements to the containing block.
|
||||
} else {
|
||||
const AffineBound lb = forStmt->getLowerBound();
|
||||
SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
|
||||
lb.operand_end());
|
||||
MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
|
||||
auto affineApplyOp = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), lb.getMap(), lbOperands);
|
||||
forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
|
||||
}
|
||||
}
|
||||
// Move the loop body statements to the loop's containing block.
|
||||
auto *block = forStmt->getBlock();
|
||||
block->getStatements().splice(StmtBlock::iterator(forStmt),
|
||||
forStmt->getStatements());
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
// RUN: mlir-opt %s -o - -loop-unroll-jam -unroll-jam-factor=2 | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = (d0) -> (d0 + 1)
|
||||
// This should be matched to M1, but M1 is defined later.
|
||||
// CHECK: {{#map[0-9]+}} = ()[s0] -> (s0 + 8)
|
||||
|
||||
// CHECK-LABEL: mlfunc @unroll_jam_imperfect_nest() {
|
||||
mlfunc @unroll_jam_imperfect_nest() {
|
||||
|
@ -34,3 +36,55 @@ mlfunc @unroll_jam_imperfect_nest() {
|
|||
// CHECK-NEXT: %14 = "addi32"(%c100, %c100) : (affineint, affineint) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_1(%arg0 : affineint) {
|
||||
mlfunc @loop_nest_unknown_count_1(%N : affineint) {
|
||||
// UNROLL-BY-4-NEXT: for %i0 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
|
||||
// UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
|
||||
// UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
// A cleanup loop should be generated here.
|
||||
// UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
|
||||
// UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
|
||||
// UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4_NEXT: }
|
||||
// UNROLL-BY-4_NEXT: }
|
||||
// Specify the lower bound in a form so that both lb and ub operands match.
|
||||
for %i = ()[s0] -> (1)()[%N] to %N {
|
||||
for %j = 1 to 100 {
|
||||
%x = "foo"() : () -> i32
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_2(%arg0 : affineint) {
|
||||
mlfunc @loop_nest_unknown_count_2(%arg : affineint) {
|
||||
// UNROLL-BY-4-NEXT: for %i0 = %arg0 to #map{{[0-9]+}}()[%arg0] step 4 {
|
||||
// UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
|
||||
// UNROLL-BY-4-NEXT: %0 = "foo"(%i0) : (affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %1 = affine_apply #map{{[0-9]+}}(%i0)
|
||||
// UNROLL-BY-4-NEXT: %2 = "foo"(%1) : (affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %3 = affine_apply #map{{[0-9]+}}(%i0)
|
||||
// UNROLL-BY-4-NEXT: %4 = "foo"(%3) : (affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}(%i0)
|
||||
// UNROLL-BY-4-NEXT: %6 = "foo"(%5) : (affineint) -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
// The cleanup loop is a single iteration one and is promoted.
|
||||
// UNROLL-BY-4-NEXT: %7 = affine_apply [[M1:#map{{[0-9]+}}]]()[%arg0]
|
||||
// UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
|
||||
// UNROLL-BY-4-NEXT: %8 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4_NEXT: }
|
||||
// Specify the lower bound in a form so that both lb and ub operands match.
|
||||
for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] {
|
||||
for %j = 1 to 100 {
|
||||
%x = "foo"(%i) : (affineint) -> i32
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -462,7 +462,7 @@ mlfunc @loop_nest_operand2() {
|
|||
}
|
||||
|
||||
// Difference between loop bounds is constant, but not a multiple of unroll
|
||||
// factor. A cleanup loop is generated.
|
||||
// factor. The cleanup loop happens to be a single iteration one and is promoted.
|
||||
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand3() {
|
||||
mlfunc @loop_nest_operand3() {
|
||||
// UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
|
||||
|
@ -473,30 +473,30 @@ mlfunc @loop_nest_operand3() {
|
|||
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
// UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) {
|
||||
// UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 4) (%i) {
|
||||
for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 8) (%i) {
|
||||
%x = "foo"() : () -> i32
|
||||
}
|
||||
} // UNROLL-BY-4: }
|
||||
return
|
||||
}
|
||||
|
||||
// Will not be unrolled for now. TODO(bondhugula): handle this.
|
||||
// xUNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
|
||||
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
|
||||
mlfunc @loop_nest_operand4(%N : affineint) {
|
||||
// UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
|
||||
for %i = 1 to 100 step 2 {
|
||||
// UNROLL-BY-4: for %i1 = 0 to %arg0 {
|
||||
// xUNROLL-BY-4: for %i1 = 0 to #map{{[0-9]+}}(%N) step 4 {
|
||||
// xUNROLL-BY-4: %0 = "foo"() : () -> i32
|
||||
// xUNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
|
||||
// xUNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
|
||||
// xUNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
|
||||
// xUNROLL-BY-4-NEXT: }
|
||||
// a cleanup loop should be generated here.
|
||||
for %j = (d0) -> (0) (%N) to %N {
|
||||
// UNROLL-BY-4: for %i0 = 1 to 100 {
|
||||
for %i = 1 to 100 {
|
||||
// UNROLL-BY-4: for %i1 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
|
||||
// UNROLL-BY-4: %0 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4-NEXT: }
|
||||
// A cleanup loop will be be generated here.
|
||||
// UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
|
||||
// UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
|
||||
// UNROLL-BY-4_NEXT: }
|
||||
// Specify the lower bound so that both lb and ub operands match.
|
||||
for %j = ()[s0] -> (1)()[%N] to %N {
|
||||
%x = "foo"() : () -> i32
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue