forked from OSchip/llvm-project
Introduce loop body skewing / loop pipelining / loop shifting utility.
- loopBodySkew shifts statements of a loop body by stmt-wise delays, and is typically meant to be used to: - allow overlap of non-blocking start/wait until completion operations with other computation - allow shifting of statements (for better register reuse/locality/parallelism) - software pipelining (when applied to the innermost loop) - an additional argument specifies whether to unroll the prologue and epilogue. - add method to check SSA dominance preservation. - add a fake loop pipeline pass to test this utility. Sample input/output are below. While on this, fix/add following: - fix minor bug in getAddMulPureAffineExpr - add additional builder methods for common affine map cases - fix const_operand_iterator's for ForStmt, etc. When there is no such thing as 'const MLValue', the iterator shouldn't be returning const MLValue's. Returning MLValue is const correct. Sample input/output examples: 1) Simplest case: shift second statement by one. Input: for %i = 0 to 7 { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint } Output: #map0 = (d0) -> (d0 - 1) mlfunc @loop_nest_simple1() { %c8 = constant 8 : affineint %c0 = constant 0 : affineint %0 = "foo"(%c0) : (affineint) -> affineint for %i0 = 1 to 7 { %1 = "foo"(%i0) : (affineint) -> affineint %2 = affine_apply #map0(%i0) %3 = "bar"(%2) : (affineint) -> affineint } %4 = affine_apply #map0(%c8) %5 = "bar"(%4) : (affineint) -> affineint return } 2) DMA overlap: shift dma.wait and compute by one. Input for %i = 0 to 7 { %pingpong = affine_apply (d0) -> (d0 mod 2) (%i) "dma.enqueue"(%pingpong) : (affineint) -> affineint %pongping = affine_apply (d0) -> (d0 mod 2) (%i) "dma.wait"(%pongping) : (affineint) -> affineint "compute1"(%pongping) : (affineint) -> affineint } Output #map0 = (d0) -> (d0 mod 2) #map1 = (d0) -> (d0 - 1) #map2 = ()[s0] -> (s0 + 7) mlfunc @loop_nest_dma() { %c8 = constant 8 : affineint %c0 = constant 0 : affineint %0 = affine_apply #map0(%c0) %1 = "dma.enqueue"(%0) : (affineint) -> affineint for %i0 = 1 to 7 { %2 = affine_apply #map0(%i0) %3 = "dma.enqueue"(%2) : (affineint) -> affineint %4 = affine_apply #map1(%i0) %5 = affine_apply #map0(%4) %6 = "dma.wait"(%5) : (affineint) -> affineint %7 = "compute1"(%5) : (affineint) -> affineint } %8 = affine_apply #map1(%c8) %9 = affine_apply #map0(%8) %10 = "dma.wait"(%9) : (affineint) -> affineint %11 = "compute1"(%9) : (affineint) -> affineint return } 3) With arbitrary affine bound maps: Shift last two statements by two. Input: for %i = %N to ()[s0] -> (s0 + 7)()[%N] { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint %z = "foo_bar"(%i) : (affineint) -> (affineint) "bar_foo"(%i) : (affineint) -> (affineint) } Output #map0 = ()[s0] -> (s0 + 1) #map1 = ()[s0] -> (s0 + 2) #map2 = ()[s0] -> (s0 + 7) #map3 = (d0) -> (d0 - 2) #map4 = ()[s0] -> (s0 + 8) #map5 = ()[s0] -> (s0 + 9) for %i0 = %arg0 to #map0()[%arg0] { %0 = "foo"(%i0) : (affineint) -> affineint %1 = "bar"(%i0) : (affineint) -> affineint } for %i1 = #map1()[%arg0] to #map2()[%arg0] { %2 = "foo"(%i1) : (affineint) -> affineint %3 = "bar"(%i1) : (affineint) -> affineint %4 = affine_apply #map3(%i1) %5 = "foo_bar"(%4) : (affineint) -> affineint %6 = "bar_foo"(%4) : (affineint) -> affineint } for %i2 = #map4()[%arg0] to #map5()[%arg0] { %7 = affine_apply #map3(%i2) %8 = "foo_bar"(%7) : (affineint) -> affineint %9 = "bar_foo"(%7) : (affineint) -> affineint } 4) Shift one by zero, second by one, third by two for %i = 0 to 7 { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint %z = "foobar"(%i) : (affineint) -> affineint } #map0 = (d0) -> (d0 - 1) #map1 = (d0) -> (d0 - 2) #map2 = ()[s0] -> (s0 + 7) %c9 = constant 9 : affineint %c8 = constant 8 : affineint %c1 = constant 1 : affineint %c0 = constant 0 : affineint %0 = "foo"(%c0) : (affineint) -> affineint %1 = "foo"(%c1) : (affineint) -> affineint %2 = affine_apply #map0(%c1) %3 = "bar"(%2) : (affineint) -> affineint for %i0 = 2 to 7 { %4 = "foo"(%i0) : (affineint) -> affineint %5 = affine_apply #map0(%i0) %6 = "bar"(%5) : (affineint) -> affineint %7 = affine_apply #map1(%i0) %8 = "foobar"(%7) : (affineint) -> affineint } %9 = affine_apply #map0(%c8) %10 = "bar"(%9) : (affineint) -> affineint %11 = affine_apply #map1(%c8) %12 = "foobar"(%11) : (affineint) -> affineint %13 = affine_apply #map1(%c9) %14 = "foobar"(%13) : (affineint) -> affineint 5) SSA dominance violated; no shifting if a shift is specified for the second statement. for %i = 0 to 7 { %x = "foo"(%i) : (affineint) -> affineint "bar"(%x) : (affineint) -> affineint } PiperOrigin-RevId: 214975731
This commit is contained in:
parent
ec35e51f6d
commit
041817a45e
|
@ -140,6 +140,16 @@ public:
|
|||
// One symbol identity map: ()[s] -> (s).
|
||||
AffineMap *getSymbolIdentityMap();
|
||||
|
||||
/// Returns a map that shifts its (single) input dimension by 'shift'.
|
||||
/// (d0) -> (d0 + shift)
|
||||
AffineMap *getSingleDimShiftAffineMap(int64_t shift);
|
||||
|
||||
/// Returns an affine map that is a translation (shift) of all result
|
||||
/// expressions in 'map' by 'shift'.
|
||||
/// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
|
||||
/// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
|
||||
AffineMap *getShiftedAffineMap(AffineMap *map, int64_t shift);
|
||||
|
||||
// Integer set.
|
||||
IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr *> constraints,
|
||||
|
|
|
@ -311,9 +311,15 @@ public:
|
|||
const StmtOperand &getStmtOperand(unsigned idx) const {
|
||||
return getStmtOperands()[idx];
|
||||
}
|
||||
|
||||
// TODO: provide iterators for the lower and upper bound operands
|
||||
// if the current access via getLowerBound(), getUpperBound() is too slow.
|
||||
|
||||
/// Returns operands for the lower bound map.
|
||||
operand_range getLowerBoundOperands();
|
||||
/// Returns operands for the upper bound map.
|
||||
operand_range getUpperBoundOperands();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -33,6 +33,15 @@ class ForStmt;
|
|||
class MLFunction;
|
||||
class MLFuncBuilder;
|
||||
|
||||
// Values that can be used to signal success/failure. This can be implicitly
|
||||
// converted to/from boolean values, with false representing success and true
|
||||
// failure.
|
||||
struct LLVM_NODISCARD UtilResult {
|
||||
enum ResultEnum { Success, Failure } value;
|
||||
UtilResult(ResultEnum v) : value(v) {}
|
||||
operator bool() const { return value == Failure; }
|
||||
};
|
||||
|
||||
/// Unrolls this for statement completely if the trip count is known to be
|
||||
/// constant. Returns false otherwise.
|
||||
bool loopUnrollFull(ForStmt *forStmt);
|
||||
|
@ -72,6 +81,16 @@ AffineMap *getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
|||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
|
||||
/// Skew the statements in the body of a 'for' statement with the specified
|
||||
/// statement-wise delays.
|
||||
UtilResult stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
|
||||
bool unrollPrologueEpilogue = false);
|
||||
|
||||
/// Checks if SSA dominance would be violated if a for stmt's child statements
|
||||
/// are shifted by the specified delays.
|
||||
bool checkDominancePreservationOnShift(const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> delays);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
|
||||
|
|
|
@ -47,6 +47,10 @@ MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
|
|||
/// Creates an affine expression simplification pass.
|
||||
FunctionPass *createSimplifyAffineExprPass();
|
||||
|
||||
/// Creates a pass to pipeline explicit movement of data across levels of the
|
||||
/// memory hierarchy.
|
||||
MLFunctionPass *createPipelineDataTransferPass();
|
||||
|
||||
/// Replaces all ML functions in the module with equivalent CFG functions.
|
||||
/// Function references are appropriately patched to refer to the newly
|
||||
/// generated CFG functions.
|
||||
|
|
|
@ -251,7 +251,7 @@ AffineExpr *Builder::getAddMulPureAffineExpr(unsigned numDims,
|
|||
expr = AffineBinaryOpExpr::getAdd(expr, term, context);
|
||||
}
|
||||
// Constant term.
|
||||
unsigned constTerm = coeffs[coeffs.size() - 1];
|
||||
int64_t constTerm = coeffs[coeffs.size() - 1];
|
||||
if (constTerm != 0)
|
||||
expr = AffineBinaryOpExpr::getAdd(expr, constTerm, context);
|
||||
return expr;
|
||||
|
@ -278,6 +278,22 @@ AffineMap *Builder::getSymbolIdentityMap() {
|
|||
context);
|
||||
}
|
||||
|
||||
AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) {
|
||||
// expr = 1*d0 + shift.
|
||||
auto *expr = getAddMulPureAffineExpr(1, 0, {1, shift});
|
||||
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr, {}, context);
|
||||
}
|
||||
|
||||
AffineMap *Builder::getShiftedAffineMap(AffineMap *map, int64_t shift) {
|
||||
SmallVector<AffineExpr *, 4> shiftedResults;
|
||||
shiftedResults.reserve(map->getNumResults());
|
||||
for (auto *resultExpr : map->getResults()) {
|
||||
shiftedResults.push_back(getAddExpr(resultExpr, shift));
|
||||
}
|
||||
return AffineMap::get(map->getNumDims(), map->getNumSymbols(), shiftedResults,
|
||||
map->getRangeSizes(), context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CFG function elements.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -350,6 +350,15 @@ void ForStmt::setConstantUpperBound(int64_t value) {
|
|||
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
|
||||
}
|
||||
|
||||
ForStmt::operand_range ForStmt::getLowerBoundOperands() {
|
||||
return {operand_begin(),
|
||||
operand_begin() + getLowerBoundMap()->getNumInputs()};
|
||||
}
|
||||
|
||||
ForStmt::operand_range ForStmt::getUpperBoundOperands() {
|
||||
return {operand_begin() + getLowerBoundMap()->getNumInputs(), operand_end()};
|
||||
}
|
||||
|
||||
bool ForStmt::matchingBoundOperandList() const {
|
||||
if (lbMap->getNumDims() != ubMap->getNumDims() ||
|
||||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- LoopUtils.cpp - Misc loop utilities for simplification //-----------===//
|
||||
//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,7 +15,7 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements miscellaneous loop simplification routines.
|
||||
// This file implements miscellaneous loop transformation routines.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -28,6 +28,7 @@
|
|||
#include "mlir/IR/StandardOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -161,3 +162,228 @@ void mlir::promoteSingleIterationLoops(MLFunction *f) {
|
|||
LoopBodyPromoter fsw;
|
||||
fsw.walkPostOrder(f);
|
||||
}
|
||||
|
||||
/// Generates a for 'stmt' with the specified lower and upper bounds while
|
||||
/// generating the right IV remappings for the delayed statements. The
|
||||
/// statement blocks that go into the loop are specified in stmtGroupQueue
|
||||
/// starting from the specified offset, and in that order; the first element of
|
||||
/// the pair specifies the delay applied to that group of statements. Returns
|
||||
/// nullptr if the generated loop simplifies to a single iteration one.
|
||||
static ForStmt *
|
||||
generateLoop(AffineMap *lb, AffineMap *ub,
|
||||
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
|
||||
&stmtGroupQueue,
|
||||
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
|
||||
SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
|
||||
SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
|
||||
|
||||
auto *loopChunk =
|
||||
b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub);
|
||||
OperationStmt::OperandMapTy operandMap;
|
||||
|
||||
for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
|
||||
it != e; ++it) {
|
||||
auto elt = *it;
|
||||
// All 'same delay' statements get added with the operands being remapped
|
||||
// (to results of cloned statements).
|
||||
// Generate the remapping if the delay is not zero: oldIV = newIV - delay.
|
||||
// TODO(bondhugula): check if srcForStmt is actually used in elt.second
|
||||
// instead of just checking if it's used at all.
|
||||
if (!srcForStmt->use_empty() && elt.first != 0) {
|
||||
auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
|
||||
auto *oldIV =
|
||||
b.create<AffineApplyOp>(
|
||||
srcForStmt->getLoc(),
|
||||
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)),
|
||||
loopChunk)
|
||||
->getResult(0);
|
||||
operandMap[srcForStmt] = cast<MLValue>(oldIV);
|
||||
} else {
|
||||
operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk);
|
||||
}
|
||||
for (auto *stmt : elt.second) {
|
||||
loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
|
||||
}
|
||||
}
|
||||
if (promoteIfSingleIteration(loopChunk))
|
||||
return nullptr;
|
||||
return loopChunk;
|
||||
}
|
||||
|
||||
// Returns delay of that child statement of 'forStmt' which either has 'operand'
|
||||
// as one of its operands or has a descendant statement with operand 'operand'.
|
||||
// This is a naive implementation. If performance becomes an issue, a map can
|
||||
// be used to store 'delays' - to look up the delay for a statement in constant
|
||||
// time.
|
||||
static uint64_t getContainingStmtDelay(const StmtOperand &operand,
|
||||
const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> delays) {
|
||||
// Traverse up the statement hierarchy starting from the owner of operand to
|
||||
// find the ancestor statement that resides in the block of 'forStmt'.
|
||||
const Statement *stmt = operand.getOwner();
|
||||
assert(stmt != nullptr);
|
||||
while (stmt->getParentStmt() != &forStmt) {
|
||||
stmt = stmt->getParentStmt();
|
||||
assert(stmt && "traversing parent's should reach forStmt block");
|
||||
}
|
||||
// Look up the delay of 'stmt'.
|
||||
unsigned j = 0;
|
||||
for (const auto &s : forStmt) {
|
||||
if (&s == stmt)
|
||||
break;
|
||||
j++;
|
||||
}
|
||||
assert(j < forStmt.getStatements().size() && "child stmt should be found");
|
||||
return delays[j];
|
||||
}
|
||||
|
||||
/// Checks if SSA dominance would be violated if a for stmt's body statements
|
||||
/// are shifted by the specified delays. This method checks if a 'def' and all
|
||||
/// its uses have the same delay factor.
|
||||
bool mlir::checkDominancePreservationOnShift(const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> delays) {
|
||||
assert(delays.size() == forStmt.getStatements().size());
|
||||
unsigned s = 0;
|
||||
for (const auto &stmt : forStmt) {
|
||||
// A for or if stmt does not produce any def/results (that are used
|
||||
// outside).
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
|
||||
const MLValue *result = opStmt->getResult(i);
|
||||
for (const StmtOperand &use : result->getUses()) {
|
||||
if (delays[s] != getContainingStmtDelay(use, forStmt, delays))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
s++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Skew the statements in the body of a 'for' statement with the specified
|
||||
/// statement-wise delays. The delays are with respect to the original execution
|
||||
/// order. A delay of zero for each statement will lead to no change.
|
||||
// The skewing of statements with respect to one another can be used for example
|
||||
// to allow overlap of asynchronous operations (such as DMA communication) with
|
||||
// computation, or just relative shifting of statements for better register
|
||||
// reuse, locality or parallelism. As such, the delays are typically expected to
|
||||
// be at most of the order of the number of statements. This method should not
|
||||
// be used as a substitute for loop distribution/fission.
|
||||
// This method uses an algorithm// in time linear in the number of statements in
|
||||
// the body of the for loop - (using the 'sweep line' paradigm). This method
|
||||
// asserts preservation of SSA dominance. A check for that as well as that for
|
||||
// memory-based depedence preservation check rests with the users of this
|
||||
// method.
|
||||
UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
|
||||
bool unrollPrologueEpilogue) {
|
||||
if (forStmt->getStatements().empty())
|
||||
return UtilResult::Success;
|
||||
|
||||
// If the trip counts aren't constant, we would need versioning and
|
||||
// conditional guards (or context information to prevent such versioning). The
|
||||
// better way to pipeline for such loops is to first tile them and extract
|
||||
// constant trip count "full tiles" before applying this.
|
||||
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
|
||||
if (!mayBeConstTripCount.hasValue())
|
||||
return UtilResult::Failure;
|
||||
uint64_t tripCount = mayBeConstTripCount.getValue();
|
||||
|
||||
assert(checkDominancePreservationOnShift(*forStmt, delays) &&
|
||||
"dominance preservation failed\n");
|
||||
|
||||
unsigned numChildStmts = forStmt->getStatements().size();
|
||||
|
||||
// Do a linear time (counting) sort for the delays.
|
||||
uint64_t maxDelay = 0;
|
||||
for (unsigned i = 0; i < numChildStmts; i++) {
|
||||
maxDelay = std::max(maxDelay, delays[i]);
|
||||
}
|
||||
// Such large delays are not the typical use case.
|
||||
if (maxDelay >= numChildStmts)
|
||||
return UtilResult::Failure;
|
||||
|
||||
// An array of statement groups sorted by delay amount; each group has all
|
||||
// statements with the same delay in the order in which they appear in the
|
||||
// body of the 'for' stmt.
|
||||
std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1);
|
||||
unsigned pos = 0;
|
||||
for (auto &stmt : *forStmt) {
|
||||
auto delay = delays[pos++];
|
||||
sortedStmtGroups[delay].push_back(&stmt);
|
||||
}
|
||||
|
||||
// Unless the shifts have a specific pattern (which actually would be the
|
||||
// common use case), prologue and epilogue are not meaningfully defined.
|
||||
// Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
|
||||
// loop generated as the prologue and the last as epilogue and unroll these
|
||||
// fully.
|
||||
ForStmt *prologue = nullptr;
|
||||
ForStmt *epilogue = nullptr;
|
||||
|
||||
// Do a sweep over the sorted delays while storing open groups in a
|
||||
// vector, and generating loop portions as necessary during the sweep. A block
|
||||
// of statements is paired with its delay.
|
||||
std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
|
||||
|
||||
auto *origLbMap = forStmt->getLowerBoundMap();
|
||||
uint64_t lbDelay = 0;
|
||||
MLFuncBuilder b(forStmt);
|
||||
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
|
||||
// If nothing is delayed by d, continue.
|
||||
if (sortedStmtGroups[d].empty())
|
||||
continue;
|
||||
if (!stmtGroupQueue.empty()) {
|
||||
assert(d >= 1 &&
|
||||
"Queue expected to be empty when the first block is found");
|
||||
// The interval for which the loop needs to be generated here is:
|
||||
// ( lbDelay, min(lbDelay + tripCount - 1, d - 1) ] and the body of the
|
||||
// loop needs to have all statements in stmtQueue in that order.
|
||||
ForStmt *res;
|
||||
if (lbDelay + tripCount - 1 < d - 1) {
|
||||
res = generateLoop(
|
||||
b.getShiftedAffineMap(origLbMap, lbDelay),
|
||||
b.getShiftedAffineMap(origLbMap, lbDelay + tripCount - 1),
|
||||
stmtGroupQueue, 0, forStmt, &b);
|
||||
// Entire loop for the queued stmt groups generated, empty it.
|
||||
stmtGroupQueue.clear();
|
||||
lbDelay += tripCount;
|
||||
} else {
|
||||
res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
|
||||
b.getShiftedAffineMap(origLbMap, d - 1),
|
||||
stmtGroupQueue, 0, forStmt, &b);
|
||||
lbDelay = d;
|
||||
}
|
||||
if (!prologue && res)
|
||||
prologue = res;
|
||||
epilogue = res;
|
||||
} else {
|
||||
// Start of first interval.
|
||||
lbDelay = d;
|
||||
}
|
||||
// Augment the list of statements that get into the current open interval.
|
||||
stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
|
||||
}
|
||||
|
||||
// Those statements groups left in the queue now need to be processed (FIFO)
|
||||
// and their loops completed.
|
||||
for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
|
||||
uint64_t ubDelay = stmtGroupQueue[i].first + tripCount - 1;
|
||||
epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
|
||||
b.getShiftedAffineMap(origLbMap, ubDelay),
|
||||
stmtGroupQueue, i, forStmt, &b);
|
||||
lbDelay = ubDelay + 1;
|
||||
if (!prologue)
|
||||
prologue = epilogue;
|
||||
}
|
||||
|
||||
// Erase the original for stmt.
|
||||
forStmt->eraseFromBlock();
|
||||
|
||||
if (unrollPrologueEpilogue && prologue)
|
||||
loopUnrollFull(prologue);
|
||||
if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
|
||||
loopUnrollFull(epilogue);
|
||||
|
||||
return UtilResult::Success;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a pass to pipeline data transfers.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct PipelineDataTransfer : public MLFunctionPass {
|
||||
explicit PipelineDataTransfer() {}
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Creates a pass to pipeline explicit movement of data across levels of the
|
||||
/// memory hierarchy.
|
||||
MLFunctionPass *mlir::createPipelineDataTransferPass() {
|
||||
return new PipelineDataTransfer();
|
||||
}
|
||||
|
||||
// For testing purposes, this just runs on the first statement of the MLFunction
|
||||
// if that statement is a for stmt, and shifts the second half of its body by
|
||||
// one.
|
||||
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
|
||||
if (f->empty())
|
||||
return PassResult::Success;
|
||||
auto *forStmt = dyn_cast<ForStmt>(&f->front());
|
||||
if (!forStmt)
|
||||
return PassResult::Failure;
|
||||
|
||||
unsigned numStmts = forStmt->getStatements().size();
|
||||
if (numStmts == 0)
|
||||
return PassResult::Success;
|
||||
|
||||
std::vector<uint64_t> delays(numStmts);
|
||||
for (unsigned i = 0; i < numStmts; i++)
|
||||
delays[i] = (i < numStmts / 2) ? 0 : 1;
|
||||
|
||||
if (!checkDominancePreservationOnShift(*forStmt, delays))
|
||||
// Violates SSA dominance.
|
||||
return PassResult::Failure;
|
||||
|
||||
if (stmtBodySkew(forStmt, delays))
|
||||
return PassResult::Failure;
|
||||
|
||||
return PassResult::Success;
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
// RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: mlfunc @loop_nest_simple() {
|
||||
// CHECK: %c8 = constant 8 : affineint
|
||||
// CHECK-NEXT: %c0 = constant 0 : affineint
|
||||
// CHECK-NEXT: %0 = "foo"(%c0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: for %i0 = 1 to 7 {
|
||||
// CHECK-NEXT: %1 = "foo"(%i0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %2 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: %3 = "bar"(%2) : (affineint) -> affineint
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %4 = affine_apply #map0(%c8)
|
||||
// CHECK-NEXT: %5 = "bar"(%4) : (affineint) -> affineint
|
||||
// CHECK-NEXT: return
|
||||
mlfunc @loop_nest_simple() {
|
||||
for %i = 0 to 7 {
|
||||
%y = "foo"(%i) : (affineint) -> affineint
|
||||
%x = "bar"(%i) : (affineint) -> affineint
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @loop_nest_dma() {
|
||||
// CHECK: %c8 = constant 8 : affineint
|
||||
// CHECK-NEXT: %c0 = constant 0 : affineint
|
||||
// CHECK-NEXT: %0 = affine_apply #map1(%c0)
|
||||
// CHECK-NEXT: %1 = "dma.enqueue"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %2 = "dma.enqueue"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: for %i0 = 1 to 7 {
|
||||
// CHECK-NEXT: %3 = affine_apply #map1(%i0)
|
||||
// CHECK-NEXT: %4 = "dma.enqueue"(%3) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %5 = "dma.enqueue"(%3) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %6 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: %7 = affine_apply #map1(%6)
|
||||
// CHECK-NEXT: %8 = "dma.wait"(%7) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %9 = "compute1"(%7) : (affineint) -> affineint
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %10 = affine_apply #map0(%c8)
|
||||
// CHECK-NEXT: %11 = affine_apply #map1(%10)
|
||||
// CHECK-NEXT: %12 = "dma.wait"(%11) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %13 = "compute1"(%11) : (affineint) -> affineint
|
||||
// CHECK-NEXT: return
|
||||
mlfunc @loop_nest_dma() {
|
||||
for %i = 0 to 7 {
|
||||
%pingpong = affine_apply (d0) -> (d0 mod 2) (%i)
|
||||
"dma.enqueue"(%pingpong) : (affineint) -> affineint
|
||||
"dma.enqueue"(%pingpong) : (affineint) -> affineint
|
||||
%pongping = affine_apply (d0) -> (d0 mod 2) (%i)
|
||||
"dma.wait"(%pongping) : (affineint) -> affineint
|
||||
"compute1"(%pongping) : (affineint) -> affineint
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @loop_nest_bound_map(%arg0 : affineint) {
|
||||
// CHECK: %0 = affine_apply #map2()[%arg0]
|
||||
// CHECK-NEXT: %1 = "foo"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %2 = "bar"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: for %i0 = #map3()[%arg0] to #map4()[%arg0] {
|
||||
// CHECK-NEXT: %3 = "foo"(%i0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %4 = "bar"(%i0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %5 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: %6 = "foo_bar"(%5) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %7 = "bar_foo"(%5) : (affineint) -> affineint
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %8 = affine_apply #map5()[%arg0]
|
||||
// CHECK-NEXT: %9 = affine_apply #map0(%8)
|
||||
// CHECK-NEXT: %10 = "foo_bar"(%9) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %11 = "bar_foo"(%9) : (affineint) -> affineint
|
||||
// CHECK-NEXT: return
|
||||
mlfunc @loop_nest_bound_map(%N : affineint) {
|
||||
for %i = %N to ()[s0] -> (s0 + 7)()[%N] {
|
||||
"foo"(%i) : (affineint) -> affineint
|
||||
"bar"(%i) : (affineint) -> affineint
|
||||
"foo_bar"(%i) : (affineint) -> (affineint)
|
||||
"bar_foo"(%i) : (affineint) -> (affineint)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -70,6 +70,7 @@ enum Passes {
|
|||
ConvertToCFG,
|
||||
LoopUnroll,
|
||||
LoopUnrollAndJam,
|
||||
PipelineDataTransfer,
|
||||
PrintCFGGraph,
|
||||
SimplifyAffineExpr,
|
||||
TFRaiseControlFlow,
|
||||
|
@ -85,6 +86,9 @@ static cl::list<Passes> passList(
|
|||
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
|
||||
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam",
|
||||
"Unroll and jam loops"),
|
||||
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
|
||||
"Pipeline non-blocking data transfers between"
|
||||
"explicitly managed levels of the memory hierarchy"),
|
||||
clEnumValN(PrintCFGGraph, "print-cfg-graph",
|
||||
"Print CFG graph per function"),
|
||||
clEnumValN(SimplifyAffineExpr, "simplify-affine-expr",
|
||||
|
@ -179,6 +183,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
|
|||
case LoopUnrollAndJam:
|
||||
pass = createLoopUnrollAndJamPass();
|
||||
break;
|
||||
case PipelineDataTransfer:
|
||||
pass = createPipelineDataTransferPass();
|
||||
break;
|
||||
case PrintCFGGraph:
|
||||
pass = createPrintCFGGraphPass();
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue