Extend loop unroll/unroll-and-jam to affine bounds + refactor related code.

- extend loop unroll-jam similar to loop unroll for affine bounds
- extend both loop unroll/unroll-jam to deal with cleanup loop for non multiple
  of unroll factor.
- extend promotion of single iteration loops to work with affine bounds
- fix typo bugs in loop unroll
- refactor common code b/w loop unroll and loop unroll-jam
- move prototypes of non-pass transforms to LoopUtils.h
- add additional builder methods.
- introduce loopUnrollUpTo(factor) to unroll by either factor or trip count,
  whichever is less.
- remove Statement::isInnermost (not used for now - will come back at the right
  place/in right form later)

PiperOrigin-RevId: 213471227
This commit is contained in:
Uday Bondhugula 2018-09-18 10:22:03 -07:00 committed by jpienaar
parent 7103779fb8
commit ab4797229c
14 changed files with 373 additions and 179 deletions

View File

@ -32,7 +32,7 @@ class ForStmt;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
AffineExpr *getTripCount(const ForStmt &forStmt);
AffineExpr *getTripCountExpr(const ForStmt &forStmt);
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// uses affine expression analysis and is able to determine constant trip count

View File

@ -101,7 +101,9 @@ public:
AffineSymbolExpr *getSymbolExpr(unsigned position);
AffineConstantExpr *getConstantExpr(int64_t constant);
AffineExpr *getAddExpr(AffineExpr *lhs, AffineExpr *rhs);
AffineExpr *getAddExpr(AffineExpr *lhs, int64_t rhs);
AffineExpr *getSubExpr(AffineExpr *lhs, AffineExpr *rhs);
AffineExpr *getSubExpr(AffineExpr *lhs, int64_t rhs);
AffineExpr *getMulExpr(AffineExpr *lhs, AffineExpr *rhs);
AffineExpr *getMulExpr(AffineExpr *lhs, int64_t rhs);
AffineExpr *getModExpr(AffineExpr *lhs, AffineExpr *rhs);

View File

@ -82,9 +82,6 @@ public:
/// Returns nullptr if the statement is unlinked.
MLFunction *findFunction() const;
/// Returns true if there are no more loops nested under this stmt.
bool isInnermost() const;
/// Destroys this statement and its subclass data.
void destroy();

View File

@ -275,6 +275,10 @@ public:
/// Sets the upper bound to the given constant value.
void setConstantUpperBound(int64_t value);
/// Returns true if both the lower and upper bound have the same operand lists
/// (same operands in the same order).
bool matchingBoundOperandList() const;
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
@ -343,7 +347,9 @@ private:
AffineMap *ubMap;
// Constant step.
int64_t step;
// Operands for the lower and upper bounds.
// Operands for the lower and upper bounds, with the former followed by the
// latter. Dimensional operands are followed by symbolic operands for each
// bound.
std::vector<StmtOperand> operands;
explicit ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,

View File

@ -0,0 +1,77 @@
//===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes for various loop transformation utility
// methods: these are not passes by themselves but are used either by passes,
// optimization sequences, or in turn by other transformation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_LOOP_UTILS_H
#define MLIR_TRANSFORMS_LOOP_UTILS_H
#include "mlir/Support/LLVM.h"
namespace mlir {
class AffineMap;
class ForStmt;
class MLFunction;
class MLFuncBuilder;
/// Unrolls this for statement completely if the trip count is known to be
/// constant. Returns false otherwise.
bool loopUnrollFull(ForStmt *forStmt);
/// Unrolls this for statement by the specified unroll factor. Returns false if
/// the loop cannot be unrolled either due to restrictions or due to invalid
/// unroll factors.
bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
/// Unrolls this loop by the specified unroll factor or its trip count,
/// whichever is lower.
bool loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor);
/// Unrolls and jams this loop by the specified factor. Returns true if the loop
/// is successfully unroll-jammed.
bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant), whichever is lower.
bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
/// was known to have a single iteration. Returns false otherwise.
bool promoteIfSingleIteration(ForStmt *forStmt);
/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
/// their body into the containing StmtBlock.
void promoteSingleIterationLoops(MLFunction *f);
/// Returns the lower bound of the cleanup loop when unrolling a loop
/// with the specified unroll factor.
AffineMap *getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
/// Returns the upper bound of an unrolled loop when unrolling with
/// the specified trip count, stride, and unroll factor.
AffineMap *getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOP_UTILS_H

View File

@ -27,9 +27,7 @@
namespace mlir {
class ForStmt;
class FunctionPass;
class MLFunction;
class MLFunctionPass;
class ModulePass;
@ -38,19 +36,11 @@ class ModulePass;
MLFunctionPass *createLoopUnrollPass(int unrollFactor = -1,
int unrollFull = -1);
/// Unrolls this loop completely.
bool loopUnrollFull(ForStmt *forStmt);
/// Unrolls this loop by the specified unroll factor.
bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
/// Creates a loop unroll jam pass to unroll jam by the specified factor. A
/// factor of -1 lets the pass use the default factor or the one on the command
/// line if provided.
MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
/// Unrolls and jams this loop by the specified factor.
bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
/// Creates an affine expression simplification pass.
FunctionPass *createSimplifyAffineExprPass();
@ -59,14 +49,6 @@ FunctionPass *createSimplifyAffineExprPass();
/// generated CFG functions.
ModulePass *createConvertToCFGPass();
/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
/// was known to have a single iteration. Returns false otherwise.
bool promoteIfSingleIteration(ForStmt *forStmt);
/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
/// their body into the containing StmtBlock.
void promoteSingleIterationLoops(MLFunction *f);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOP_H
#endif // MLIR_TRANSFORMS_PASSES_H

View File

@ -31,7 +31,7 @@ using mlir::AffineExpr;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
AffineExpr *mlir::getTripCountExpr(const ForStmt &forStmt) {
// upper_bound - lower_bound + 1
int64_t loopSpan;
@ -43,32 +43,22 @@ AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
int64_t ub = forStmt.getConstantUpperBound();
loopSpan = ub - lb + 1;
} else {
const AffineBound lb = forStmt.getLowerBound();
const AffineBound ub = forStmt.getUpperBound();
auto lbMap = lb.getMap();
auto ubMap = ub.getMap();
auto *lbMap = forStmt.getLowerBoundMap();
auto *ubMap = forStmt.getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1 ||
lbMap->getNumDims() != ubMap->getNumDims() ||
lbMap->getNumSymbols() != ubMap->getNumSymbols()) {
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
return nullptr;
}
// TODO(bondhugula): handle bounds with different operands.
unsigned i, e = lb.getNumOperands();
for (i = 0; i < e; i++) {
if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
break;
}
// Bounds have different operands, unhandled for now.
if (i != e)
if (!forStmt.matchingBoundOperandList())
return nullptr;
// ub_expr - lb_expr + 1
auto *lbExpr = lbMap->getResult(0);
auto *ubExpr = ubMap->getResult(0);
auto *loopSpanExpr = AffineBinaryOpExpr::getAdd(
AffineBinaryOpExpr::getSub(ubMap->getResult(0), lbMap->getResult(0),
context),
1, context);
AffineBinaryOpExpr::getSub(ubExpr, lbExpr, context), 1, context);
if (auto *expr = simplifyAffineExpr(loopSpanExpr, lbMap->getNumDims(),
lbMap->getNumSymbols(), context))
@ -95,7 +85,7 @@ AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
AffineExpr *tripCountExpr = getTripCount(forStmt);
AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
if (auto *constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
return constExpr->getValue();
@ -107,7 +97,7 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
AffineExpr *tripCountExpr = getTripCount(forStmt);
AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
if (!tripCountExpr)
return 1;

View File

@ -157,6 +157,10 @@ AffineExpr *Builder::getAddExpr(AffineExpr *lhs, AffineExpr *rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Add, lhs, rhs, context);
}
AffineExpr *Builder::getAddExpr(AffineExpr *lhs, int64_t rhs) {
return AffineBinaryOpExpr::getAdd(lhs, rhs, context);
}
AffineExpr *Builder::getMulExpr(AffineExpr *lhs, AffineExpr *rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mul, lhs, rhs, context);
}
@ -171,6 +175,10 @@ AffineExpr *Builder::getSubExpr(AffineExpr *lhs, AffineExpr *rhs) {
return getAddExpr(lhs, getMulExpr(rhs, getConstantExpr(-1)));
}
AffineExpr *Builder::getSubExpr(AffineExpr *lhs, int64_t rhs) {
return AffineBinaryOpExpr::getAdd(lhs, -rhs, context);
}
AffineExpr *Builder::getModExpr(AffineExpr *lhs, AffineExpr *rhs) {
return AffineBinaryOpExpr::get(AffineExpr::Kind::Mod, lhs, rhs, context);
}

View File

@ -84,18 +84,6 @@ MLFunction *Statement::findFunction() const {
return block ? block->findFunction() : nullptr;
}
bool Statement::isInnermost() const {
struct NestedLoopCounter : public StmtWalker<NestedLoopCounter> {
unsigned numNestedLoops;
NestedLoopCounter() : numNestedLoops(0) {}
void walkForStmt(const ForStmt *fs) { numNestedLoops++; }
};
NestedLoopCounter nlc;
nlc.walk(const_cast<Statement *>(this));
return nlc.numNestedLoops == 1;
}
MLValue *Statement::getOperand(unsigned idx) {
return getStmtOperand(idx).get();
}
@ -361,6 +349,20 @@ void ForStmt::setConstantUpperBound(int64_t value) {
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
}
bool ForStmt::matchingBoundOperandList() const {
if (lbMap->getNumDims() != ubMap->getNumDims() ||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
return false;
unsigned numOperands = lbMap->getNumInputs();
for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
// Compare MLValue *'s.
if (getOperand(i) != getOperand(numOperands + i))
return false;
}
return true;
}
//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//

View File

@ -27,6 +27,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
@ -176,76 +177,41 @@ bool mlir::loopUnrollFull(ForStmt *forStmt) {
return false;
}
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
/// the specified trip count, stride, and unroll factor.
static AffineMap *getUnrolledLoopUpperBound(AffineMap *lbMap,
uint64_t tripCount,
unsigned unrollFactor, int64_t step,
MLFuncBuilder *builder) {
assert(lbMap->getNumResults() == 1);
auto *lbExpr = lbMap->getResult(0);
// lbExpr + (count - count % unrollFactor - 1) * step).
auto *expr = builder->getAddExpr(
lbExpr, builder->getConstantExpr(
(tripCount - tripCount % unrollFactor - 1) * step));
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
{expr}, {});
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant) whichever is lower.
bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollFactor)
return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
return loopUnrollByFactor(forStmt, unrollFactor);
}
/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
static AffineMap *getCleanupLoopLowerBound(AffineMap *lbMap, uint64_t tripCount,
unsigned unrollFactor, int64_t step,
MLFuncBuilder *builder) {
assert(lbMap->getNumResults() == 1);
auto *lbExpr = lbMap->getResult(0);
// lbExpr + (count - count % unrollFactor) * step);
auto *expr = builder->getAddExpr(
lbExpr,
builder->getConstantExpr((tripCount - tripCount % unrollFactor) * step));
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
{expr}, {});
}
/// Unrolls this loop by the specified unroll factor.
/// Unrolls this loop by the specified factor. Returns true if the loop
/// is successfully unrolled.
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
if (unrollFactor == 1 || forStmt->getStatements().empty())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
if (!mayBeConstantTripCount.hasValue() &&
getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0)
return false;
const AffineBound &lb = forStmt->getLowerBound();
const AffineBound &ub = forStmt->getLowerBound();
auto lbMap = lb.getMap();
auto ubMap = lb.getMap();
auto *lbMap = forStmt->getLowerBoundMap();
auto *ubMap = forStmt->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as an MLFunction in the general case). However, the right way to
// do such unrolling for an MLFunction would be to specialize the loop for the
// 'hotspot' case and unroll that hotspot case.
// 'hotspot' case and unroll that hotspot.
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
return false;
// TODO(bondhugula): handle bounds with different sets of operands.
// Same operand list for now.
if (lbMap->getNumDims() != ubMap->getNumDims() ||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
return false;
unsigned i, e = lb.getNumOperands();
for (i = 0; i < e; i++) {
if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
break;
}
if (i != e)
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different operand lists.
if (!forStmt->matchingBoundOperandList())
return false;
int64_t step = forStmt->getStep();
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
@ -254,43 +220,29 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
// If the trip count is unknown, we currently unroll only when the unknown
// trip count is known to be a multiple of unroll factor - hence, no cleanup
// loop will be necessary in those cases.
// TODO(bondhugula): handle generation of cleanup loop for unknown trip count
// when it's not known to be a multiple of unroll factor (still for single
// result / same operands case).
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollFactor != 0) {
uint64_t tripCount = mayBeConstantTripCount.getValue();
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap;
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
if (forStmt->hasConstantLowerBound()) {
cleanupForStmt->setConstantLowerBound(
forStmt->getConstantLowerBound() +
(tripCount - tripCount % unrollFactor) * step);
} else {
cleanupForStmt->setLowerBoundMap(
getCleanupLoopLowerBound(forStmt->getLowerBoundMap(), tripCount,
unrollFactor, step, &builder));
}
auto *clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "
"always be determined");
cleanupForStmt->setLowerBoundMap(clLbMap);
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(cleanupForStmt);
// The upper bound needs to be adjusted.
if (forStmt->hasConstantUpperBound()) {
forStmt->setConstantUpperBound(
forStmt->getConstantLowerBound() +
(tripCount - tripCount % unrollFactor - 1) * step);
} else {
forStmt->setUpperBoundMap(
getUnrolledLoopUpperBound(forStmt->getLowerBoundMap(), tripCount,
unrollFactor, step, &builder));
}
// Adjust upper bound.
auto *unrolledUbMap =
getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
assert(unrolledUbMap &&
"upper bound map can alwayys be determined for an unrolled loop "
"with single result bounds");
forStmt->setUpperBoundMap(unrolledUbMap);
}
// Scale the step of loop being unrolled by unroll factor.
int64_t step = forStmt->getStep();
forStmt->setStep(step * unrollFactor);
// Builder to insert unrolled bodies right after the last statement in the

View File

@ -41,15 +41,16 @@
//
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
// stmt's, bodies of those loops will not be jammed.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
@ -108,6 +109,15 @@ bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
}
bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
return loopUnrollJamByFactor(forStmt, unrollJamFactor);
}
/// Unrolls and jams this loop by the specified factor.
bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of statements that do not themselves include
@ -140,19 +150,32 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
if (unrollJamFactor == 1 || forStmt->getStatements().empty())
return false;
Optional<uint64_t> mayTripCount = getConstantTripCount(*forStmt).getValue();
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
if (!mayTripCount.hasValue())
if (!mayBeConstantTripCount.hasValue() &&
getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
return false;
uint64_t tripCount = mayTripCount.getValue();
int64_t lb = forStmt->getConstantLowerBound();
int64_t step = forStmt->getStep();
auto *lbMap = forStmt->getLowerBoundMap();
auto *ubMap = forStmt->getUpperBoundMap();
// If the trip count is lower than the unroll jam factor, no unrolled body.
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as an MLFunction in the general case). However, the right way to
// do such unrolling for an MLFunction would be to specialize the loop for the
// 'hotspot' case and unroll that hotspot.
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
return false;
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different sets of operands.
if (!forStmt->matchingBoundOperandList())
return false;
// If the trip count is lower than the unroll jam factor, no unroll jam.
// TODO(bondhugula): option to specify cleanup loop unrolling.
if (tripCount < unrollJamFactor)
return true;
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
return false;
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
@ -161,23 +184,27 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Generate the cleanup loop if trip count isn't a multiple of
// unrollJamFactor.
if (tripCount % unrollJamFactor) {
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap;
// Insert the cleanup loop right after 'forStmt'.
MLFuncBuilder builder(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt)));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
cleanupForStmt->setConstantLowerBound(
lb + (tripCount - tripCount % unrollJamFactor) * step);
cleanupForStmt->setLowerBoundMap(
getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
// The upper bound needs to be adjusted.
forStmt->setUpperBoundMap(
getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(cleanupForStmt);
}
MLFuncBuilder b(forStmt);
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
int64_t step = forStmt->getStep();
forStmt->setStep(step * unrollJamFactor);
forStmt->setConstantUpperBound(
lb + (tripCount - tripCount % unrollJamFactor - 1) * step);
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of

View File

@ -19,32 +19,129 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
using namespace mlir;
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
/// the specified trip count, stride, and unroll factor. Returns nullptr when
/// the trip count can't be expressed as an affine expression.
AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
auto *lbMap = forStmt.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap->getNumResults() != 1)
return nullptr;
// Sometimes, the trip count cannot be expressed as an affine expression.
auto *tripCountExpr = getTripCountExpr(forStmt);
if (!tripCountExpr)
return nullptr;
AffineExpr *newUbExpr;
auto *lbExpr = lbMap->getResult(0);
int64_t step = forStmt.getStep();
// lbExpr + (count - count % unrollFactor - 1) * step).
if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
newUbExpr = builder->getAddExpr(
lbExpr, builder->getConstantExpr(
(tripCount - tripCount % unrollFactor - 1) * step));
} else {
newUbExpr = builder->getAddExpr(
lbExpr, builder->getMulExpr(
builder->getSubExpr(
builder->getSubExpr(
tripCountExpr,
builder->getModExpr(tripCountExpr, unrollFactor)),
1),
step));
}
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
{newUbExpr}, {});
}
/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
/// Returns nullptr when the trip count can't be expressed as an affine
/// expression.
AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
auto *lbMap = forStmt.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap->getNumResults() != 1)
return nullptr;
// Sometimes the trip count cannot be expressed as an affine expression.
auto *tripCountExpr = getTripCountExpr(forStmt);
if (!tripCountExpr)
return nullptr;
AffineExpr *newLbExpr;
auto *lbExpr = lbMap->getResult(0);
int64_t step = forStmt.getStep();
// lbExpr + (count - count % unrollFactor) * step);
if (auto *cTripCountExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
uint64_t tripCount = static_cast<uint64_t>(cTripCountExpr->getValue());
newLbExpr = builder->getAddExpr(
lbExpr, builder->getConstantExpr(
(tripCount - tripCount % unrollFactor) * step));
} else {
newLbExpr = builder->getAddExpr(
lbExpr, builder->getMulExpr(
builder->getSubExpr(
tripCountExpr,
builder->getModExpr(tripCountExpr, unrollFactor)),
step));
}
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
{newLbExpr}, {});
}
/// Promotes the loop body of a forStmt to its containing block if the forStmt
/// was known to have a single iteration. Returns false otherwise.
// TODO(bondhugula): extend this for arbitrary affine bounds.
bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
if (!tripCount.hasValue() || !forStmt->hasConstantLowerBound())
if (!tripCount.hasValue() || tripCount.getValue() != 1)
return false;
if (tripCount.getValue() != 1)
// TODO(mlir-team): there is no builder for a max.
if (forStmt->getLowerBoundMap()->getNumResults() != 1)
return false;
// Replaces all IV uses to its single iteration value.
auto *mlFunc = forStmt->findFunction();
MLFuncBuilder topBuilder(&mlFunc->front());
auto constOp = topBuilder.create<ConstantAffineIntOp>(
forStmt->getLoc(), forStmt->getConstantLowerBound());
forStmt->replaceAllUsesWith(constOp->getResult());
// Move the statements to the containing block.
if (!forStmt->use_empty()) {
if (forStmt->hasConstantLowerBound()) {
auto *mlFunc = forStmt->findFunction();
MLFuncBuilder topBuilder(&mlFunc->front());
auto constOp = topBuilder.create<ConstantAffineIntOp>(
forStmt->getLoc(), forStmt->getConstantLowerBound());
forStmt->replaceAllUsesWith(constOp->getResult());
} else {
const AffineBound lb = forStmt->getLowerBound();
SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
lb.operand_end());
MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
auto affineApplyOp = builder.create<AffineApplyOp>(
forStmt->getLoc(), lb.getMap(), lbOperands);
forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
}
}
// Move the loop body statements to the loop's containing block.
auto *block = forStmt->getBlock();
block->getStatements().splice(StmtBlock::iterator(forStmt),
forStmt->getStatements());

View File

@ -1,6 +1,8 @@
// RUN: mlir-opt %s -o - -loop-unroll-jam -unroll-jam-factor=2 | FileCheck %s
// CHECK: #map0 = (d0) -> (d0 + 1)
// This should be matched to M1, but M1 is defined later.
// CHECK: {{#map[0-9]+}} = ()[s0] -> (s0 + 8)
// CHECK-LABEL: mlfunc @unroll_jam_imperfect_nest() {
mlfunc @unroll_jam_imperfect_nest() {
@ -34,3 +36,55 @@ mlfunc @unroll_jam_imperfect_nest() {
// CHECK-NEXT: %14 = "addi32"(%c100, %c100) : (affineint, affineint) -> i32
return
}
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_1(%arg0 : affineint) {
mlfunc @loop_nest_unknown_count_1(%N : affineint) {
// UNROLL-BY-4-NEXT: for %i0 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
// UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
// UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: }
// UNROLL-BY-4-NEXT: }
// A cleanup loop should be generated here.
// UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
// UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
// UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
// UNROLL-BY-4_NEXT: }
// UNROLL-BY-4_NEXT: }
// Specify the lower bound in a form so that both lb and ub operands match.
for %i = ()[s0] -> (1)()[%N] to %N {
for %j = 1 to 100 {
%x = "foo"() : () -> i32
}
}
return
}
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_unknown_count_2(%arg0 : affineint) {
mlfunc @loop_nest_unknown_count_2(%arg : affineint) {
// UNROLL-BY-4-NEXT: for %i0 = %arg0 to #map{{[0-9]+}}()[%arg0] step 4 {
// UNROLL-BY-4-NEXT: for %i1 = 1 to 100 {
// UNROLL-BY-4-NEXT: %0 = "foo"(%i0) : (affineint) -> i32
// UNROLL-BY-4-NEXT: %1 = affine_apply #map{{[0-9]+}}(%i0)
// UNROLL-BY-4-NEXT: %2 = "foo"(%1) : (affineint) -> i32
// UNROLL-BY-4-NEXT: %3 = affine_apply #map{{[0-9]+}}(%i0)
// UNROLL-BY-4-NEXT: %4 = "foo"(%3) : (affineint) -> i32
// UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}(%i0)
// UNROLL-BY-4-NEXT: %6 = "foo"(%5) : (affineint) -> i32
// UNROLL-BY-4-NEXT: }
// UNROLL-BY-4-NEXT: }
// The cleanup loop is a single iteration one and is promoted.
// UNROLL-BY-4-NEXT: %7 = affine_apply [[M1:#map{{[0-9]+}}]]()[%arg0]
// UNROLL-BY-4-NEXT: for %i3 = 1 to 100 {
// UNROLL-BY-4-NEXT: %8 = "foo"() : () -> i32
// UNROLL-BY-4_NEXT: }
// Specify the lower bound in a form so that both lb and ub operands match.
for %i = ()[s0] -> (s0) ()[%arg] to ()[s0] -> (s0+8) ()[%arg] {
for %j = 1 to 100 {
%x = "foo"(%i) : (affineint) -> i32
}
}
return
}

View File

@ -462,7 +462,7 @@ mlfunc @loop_nest_operand2() {
}
// Difference between loop bounds is constant, but not a multiple of unroll
// factor. A cleanup loop is generated.
// factor. The cleanup loop happens to be a single iteration one and is promoted.
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand3() {
mlfunc @loop_nest_operand3() {
// UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
@ -473,30 +473,30 @@ mlfunc @loop_nest_operand3() {
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: }
// UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) {
// UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: }
for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 4) (%i) {
for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 8) (%i) {
%x = "foo"() : () -> i32
}
} // UNROLL-BY-4: }
return
}
// Will not be unrolled for now. TODO(bondhugula): handle this.
// xUNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
// UNROLL-BY-4-LABEL: mlfunc @loop_nest_operand4(%arg0 : affineint) {
mlfunc @loop_nest_operand4(%N : affineint) {
// UNROLL-BY-4: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// UNROLL-BY-4: for %i1 = 0 to %arg0 {
// xUNROLL-BY-4: for %i1 = 0 to #map{{[0-9]+}}(%N) step 4 {
// xUNROLL-BY-4: %0 = "foo"() : () -> i32
// xUNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// xUNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
// xUNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
// xUNROLL-BY-4-NEXT: }
// a cleanup loop should be generated here.
for %j = (d0) -> (0) (%N) to %N {
// UNROLL-BY-4: for %i0 = 1 to 100 {
for %i = 1 to 100 {
// UNROLL-BY-4: for %i1 = 1 to #map{{[0-9]+}}()[%arg0] step 4 {
// UNROLL-BY-4: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %3 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: }
// A cleanup loop will be be generated here.
// UNROLL-BY-4-NEXT: for %i2 = #map{{[0-9]+}}()[%arg0] to %arg0 {
// UNROLL-BY-4-NEXT: %4 = "foo"() : () -> i32
// UNROLL-BY-4_NEXT: }
// Specify the lower bound so that both lb and ub operands match.
for %j = ()[s0] -> (1)()[%N] to %N {
%x = "foo"() : () -> i32
}
}