Refactor ForStmt: having it contain a StmtBlock instead of subclassing

StmtBlock.  This is more consistent with IfStmt and also conceptually makes
more sense - a forstmt "isn't" its body, it contains its body.

This is step 1/N towards merging BasicBlock and StmtBlock.  This is required
because in the new regime StmtBlock will have a use list (just like BasicBlock
does) of operands, and ForStmt already has a use list for its induction
variable.

This is a mechanical patch, NFC.

PiperOrigin-RevId: 226684158
This commit is contained in:
Chris Lattner 2018-12-23 08:17:48 -08:00 committed by jpienaar
parent 4eef795a1d
commit 1301f907a1
22 changed files with 109 additions and 67 deletions

View File

@ -345,7 +345,7 @@ public:
/// Returns a builder for the body of a for Stmt.
static MLFuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
return MLFuncBuilder(forStmt, forStmt->end());
return MLFuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
}
/// Returns the current insertion point of the builder.

View File

@ -228,7 +228,7 @@ public:
assert(index < getNumSuccessors());
return getBasicBlockOperands()[index].get();
}
BasicBlock *getSuccessor(unsigned index) const {
const BasicBlock *getSuccessor(unsigned index) const {
return const_cast<Instruction *>(this)->getSuccessor(index);
}
void setSuccessor(BasicBlock *block, unsigned index);

View File

@ -216,8 +216,31 @@ private:
}
};
/// A ForStmtBody represents statements contained within a ForStmt.
class ForStmtBody : public StmtBlock {
public:
explicit ForStmtBody(ForStmt *stmt)
: StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) {
assert(stmt != nullptr && "ForStmtBody must have non-null parent");
}
~ForStmtBody() {}
/// Methods for support type inquiry through isa, cast, and dyn_cast
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::ForBody;
}
/// Returns the 'for' statement that contains this body.
ForStmt *getFor() { return forStmt; }
const ForStmt *getFor() const { return forStmt; }
private:
ForStmt *forStmt;
};
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public MLValue, public StmtBlock {
class ForStmt : public Statement, public MLValue {
public:
static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
@ -228,7 +251,7 @@ public:
// since child statements need to be destroyed before the MLValue that this
// for stmt represents is destroyed. Affine maps are immortal objects and
// don't need to be deleted.
clear();
getBody()->clear();
}
/// Resolve base class ambiguity.
@ -242,6 +265,12 @@ public:
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
/// Get the body of the ForStmt.
ForStmtBody *getBody() { return &body; }
/// Get the body of the ForStmt.
const ForStmtBody *getBody() const { return &body; }
//===--------------------------------------------------------------------===//
// Bounds and step
//===--------------------------------------------------------------------===//
@ -359,10 +388,6 @@ public:
return ptr->getKind() == IROperandOwner::Kind::ForStmt;
}
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::For;
}
// For statement represents implicitly represents induction variable by
// inheriting from MLValue class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself.
@ -371,6 +396,9 @@ public:
}
private:
// The StmtBlock for the body.
ForStmtBody body;
// Affine map for the lower bound.
AffineMap lbMap;
// Affine map for the upper bound. The upper bound is exclusive.
@ -456,7 +484,9 @@ public:
~IfClause() {}
/// Returns the if statement that contains this clause.
IfStmt *getIf() const { return ifStmt; }
const IfStmt *getIf() const { return ifStmt; }
IfStmt *getIf() { return ifStmt; }
private:
IfStmt *ifStmt;

View File

@ -36,7 +36,7 @@ class StmtBlock {
public:
enum class StmtBlockKind {
MLFunc, // MLFunction
For, // ForStmt
ForBody, // ForStmtBody
IfClause // IfClause
};
@ -53,7 +53,11 @@ public:
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
Statement *getContainingStmt() const;
Statement *getContainingStmt();
const Statement *getContainingStmt() const {
return const_cast<StmtBlock *>(this)->getContainingStmt();
}
/// Returns the function that this statement block is part of.
/// The function is determined by traversing the chain of parent statements.

View File

@ -146,12 +146,13 @@ public:
void walkForStmt(ForStmt *forStmt) {
static_cast<SubClass *>(this)->visitForStmt(forStmt);
static_cast<SubClass *>(this)->walk(forStmt->begin(), forStmt->end());
auto *body = forStmt->getBody();
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
}
void walkForStmtPostOrder(ForStmt *forStmt) {
static_cast<SubClass *>(this)->walkPostOrder(forStmt->begin(),
forStmt->end());
auto *body = forStmt->getBody();
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
static_cast<SubClass *>(this)->visitForStmt(forStmt);
}

View File

@ -905,7 +905,7 @@ static StmtBlock *getCommonStmtBlock(const MemRefAccess &srcAccess,
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
assert(isa<ForStmt>(commonForValue));
return dyn_cast<ForStmt>(commonForValue);
return cast<ForStmt>(commonForValue)->getBody();
}
// Returns true if the ancestor operation statement of 'srcAccess' properly

View File

@ -305,9 +305,10 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) {
// violation when we have the support.
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
ArrayRef<uint64_t> shifts) {
assert(shifts.size() == forStmt.getStatements().size());
auto *forBody = forStmt.getBody();
assert(shifts.size() == forBody->getStatements().size());
unsigned s = 0;
for (const auto &stmt : forStmt) {
for (const auto &stmt : *forBody) {
// A for or if stmt does not produce any def/results (that are used
// outside).
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
@ -319,8 +320,8 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
// This is a naive way. If performance becomes an issue, a map can
// be used to store 'shifts' - to look up the shift for a statement in
// constant time.
if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner()))
if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)])
return false;
}
}

View File

@ -362,7 +362,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
if (level == positions.size() - 1)
return &stmt;
if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
return getStmtAtPosition(positions, level + 1, childForStmt);
return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
@ -453,13 +453,13 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// Clone src loop nest and insert it a the beginning of the statement block
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
MLFuncBuilder b(dstForStmt, dstForStmt->begin());
MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
DenseMap<const MLValue *, MLValue *> operandMap;
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
Statement *sliceStmt =
getStmtAtPosition(positions, /*level=*/0, sliceLoopNest);
getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceStmt'.
SmallVector<ForStmt *, 4> sliceSurroundingLoops;
getLoopIVs(*sliceStmt, &sliceSurroundingLoops);

View File

@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() {
HashTable::ScopeTy blockScope(liveValues);
// The induction variable of a for statement is live within its body.
if (auto *forStmt = dyn_cast<ForStmt>(&block))
liveValues.insert(forStmt, true);
if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
liveValues.insert(forStmtBody->getFor(), true);
for (auto &stmt : block) {
// Verify that each of the operands are live.
@ -322,7 +322,7 @@ bool MLFuncVerifier::verifyDominance() {
return true;
}
if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
if (walkBlock(*forStmt))
if (walkBlock(*forStmt->getBody()))
return true;
}

View File

@ -206,7 +206,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) {
if (!hasShorthandForm(ubMap))
recordAffineMapReference(ubMap);
for (auto &childStmt : *forStmt)
for (auto &childStmt : *forStmt->getBody())
visitStatement(&childStmt);
}
@ -1447,7 +1447,7 @@ void MLFunctionPrinter::print(const ForStmt *stmt) {
os << " step " << stmt->getStep();
os << " {\n";
print(static_cast<const StmtBlock *>(stmt));
print(stmt->getBody());
os.indent(numSpaces) << "}";
}

View File

@ -147,7 +147,7 @@ Instruction *Instruction::clone() const {
int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1;
for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE;
--succIt) {
successors[succIt] = getSuccessor(succIt);
successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt));
// Add the successor operands in-place in reverse order.
for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e;

View File

@ -338,7 +338,7 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())),
StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
operands.reserve(numOperands);
}
@ -544,8 +544,8 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
operandMap[forStmt] = newFor;
// Recursively clone the body of the for loop.
for (auto &subStmt : *forStmt)
newFor->push_back(subStmt.clone(operandMap, context));
for (auto &subStmt : *forStmt->getBody())
newFor->getBody()->push_back(subStmt.clone(operandMap, context));
return newFor;
}

View File

@ -24,18 +24,19 @@ using namespace mlir;
// Statement block
//===----------------------------------------------------------------------===//
Statement *StmtBlock::getContainingStmt() const {
Statement *StmtBlock::getContainingStmt() {
switch (kind) {
case StmtBlockKind::MLFunc:
return nullptr;
case StmtBlockKind::For:
return cast<ForStmt>(const_cast<StmtBlock *>(this));
case StmtBlockKind::ForBody:
return cast<ForStmtBody>(this)->getFor();
case StmtBlockKind::IfClause:
return cast<IfClause>(this)->getIf();
}
}
MLFunction *StmtBlock::findFunction() const {
// FIXME: const incorrect.
StmtBlock *block = const_cast<StmtBlock *>(this);
while (block->getContainingStmt()) {

View File

@ -2876,7 +2876,7 @@ ParseResult MLFunctionParser::parseForStmt() {
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been
// successfully parsed.
if (parseStmtBlock(forStmt))
if (parseStmtBlock(forStmt->getBody()))
return ParseFailure;
// Reset insertion point to the current block.

View File

@ -242,7 +242,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// Walking manually because we need custom logic before and after traversing
// the list of children.
builder.setInsertionPoint(loopBodyFirstBlock);
visitStmtBlock(forStmt);
visitStmtBlock(forStmt->getBody());
// Builder point is currently at the last block of the loop body. Append the
// induction variable stepping to this block and branch back to the exit

View File

@ -365,7 +365,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef),
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
/*domStmtFilter=*/&*forStmt->begin());
/*domStmtFilter=*/&*forStmt->getBody()->begin());
return true;
}
@ -391,7 +391,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
if (forStmt->getStep() != 1) {
if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->begin())) {
if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
runOnForStmt(innerFor);
}
return;

View File

@ -59,12 +59,12 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
// destination's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
StmtBlock::iterator loc) {
dest->getStatements().splice(loc, src->getStatements());
dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements());
}
// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
moveLoopBody(src, dest, dest->begin());
moveLoopBody(src, dest, dest->getBody()->begin());
}
/// Constructs and sets new loop bounds after tiling for the case of
@ -167,8 +167,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
MLFuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *pointLoop = b.createFor(loc, 0, 0);
pointLoop->getStatements().splice(
pointLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
pointLoop->getBody()->getStatements().splice(
pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
topLoop);
newLoops[2 * width - 1 - i] = pointLoop;
topLoop = pointLoop;
if (i == 0)
@ -180,8 +181,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
MLFuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *tileSpaceLoop = b.createFor(loc, 0, 0);
tileSpaceLoop->getStatements().splice(
tileSpaceLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
tileSpaceLoop->getBody()->getStatements().splice(
tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
topLoop);
newLoops[2 * width - i - 1] = tileSpaceLoop;
topLoop = tileSpaceLoop;
}
@ -223,8 +225,8 @@ static void getTileableBands(MLFunction *f,
ForStmt *currStmt = root;
do {
band.push_back(currStmt);
} while (currStmt->getStatements().size() == 1 &&
(currStmt = dyn_cast<ForStmt>(&*currStmt->begin())));
} while (currStmt->getBody()->getStatements().size() == 1 &&
(currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
bands->push_back(band);
};

View File

@ -104,7 +104,8 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
}
bool walkForStmtPostOrder(ForStmt *forStmt) {
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
bool hasInnerLoops =
walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
return true;

View File

@ -152,7 +152,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
if (unrollJamFactor == 1 || forStmt->getStatements().empty())
if (unrollJamFactor == 1 || forStmt->getBody()->empty())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);

View File

@ -147,7 +147,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
loops.insert(forStmt);
// Setting the insertion point to the innermost loop achieves nesting.
b.setInsertionPointToStart(loops.back());
b.setInsertionPointToStart(loops.back()->getBody());
if (composed == getAffineConstantExpr(0, b.getContext())) {
transfer->emitWarning(
"Redundant copy can be implemented as a vector broadcast");

View File

@ -81,8 +81,9 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
/// a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
auto *forBody = forStmt->getBody();
MLFuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
@ -127,7 +128,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
ivModTwoOp->getResult(0), AffineMap::Null(), {},
&*forStmt->begin())) {
&*forStmt->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getOperation()->erase();
@ -184,7 +185,7 @@ static void findMatchingStartFinishStmts(
// Collect outgoing DMA statements - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
for (auto &stmt : *forStmt) {
for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
@ -195,7 +196,7 @@ static void findMatchingStartFinishStmts(
}
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
for (auto &stmt : *forStmt) {
for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
@ -228,7 +229,7 @@ static void findMatchingStartFinishStmts(
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
if (!dominates(*forStmt->begin(), *use.getOwner())) {
if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@ -339,16 +340,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
for (const auto &stmt : *forStmt) {
for (const auto &stmt : *forStmt->getBody()) {
if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
stmtShiftMap[&stmt] = 1;
}
}
// Get shifts stored in map.
std::vector<uint64_t> shifts(forStmt->getStatements().size());
std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size());
unsigned s = 0;
for (auto &stmt : *forStmt) {
for (auto &stmt : *forStmt->getBody()) {
assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
shifts[s++] = stmtShiftMap[&stmt];
LLVM_DEBUG(

View File

@ -119,7 +119,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
// Move the loop body statements to the loop's containing block.
auto *block = forStmt->getBlock();
block->getStatements().splice(StmtBlock::iterator(forStmt),
forStmt->getStatements());
forStmt->getBody()->getStatements());
forStmt->erase();
return true;
}
@ -181,7 +181,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
operandMap[srcForStmt] = loopChunk;
}
for (auto *stmt : stmts) {
loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
}
}
if (promoteIfSingleIteration(loopChunk))
@ -206,7 +206,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// method.
UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
if (forStmt->getStatements().empty())
if (forStmt->getBody()->empty())
return UtilResult::Success;
// If the trip counts aren't constant, we would need versioning and
@ -225,7 +225,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
int64_t step = forStmt->getStep();
unsigned numChildStmts = forStmt->getStatements().size();
unsigned numChildStmts = forStmt->getBody()->getStatements().size();
// Do a linear time (counting) sort for the shifts.
uint64_t maxShift = 0;
@ -243,7 +243,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
// body of the 'for' stmt.
std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
unsigned pos = 0;
for (auto &stmt : *forStmt) {
for (auto &stmt : *forStmt->getBody()) {
auto shift = shifts[pos++];
sortedStmtGroups[shift].push_back(&stmt);
}
@ -352,7 +352,7 @@ bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
if (unrollFactor == 1 || forStmt->getStatements().empty())
if (unrollFactor == 1 || forStmt->getBody()->empty())
return false;
auto lbMap = forStmt->getLowerBoundMap();
@ -406,11 +406,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Builder to insert unrolled bodies right after the last statement in the
// body of 'forStmt'.
MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
// Keep a pointer to the last statement in the original block so that we know
// what to clone (since we are doing this in-place).
StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end());
StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
for (unsigned i = 1; i < unrollFactor; i++) {
@ -429,7 +429,8 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
}
// Clone the original body of 'forStmt'.
for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) {
for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
it++) {
builder.clone(*it, operandMap);
}
}