forked from OSchip/llvm-project
Refactor ForStmt: having it contain a StmtBlock instead of subclassing
StmtBlock. This is more consistent with IfStmt and also conceptually makes more sense - a forstmt "isn't" its body, it contains its body. This is step 1/N towards merging BasicBlock and StmtBlock. This is required because in the new regime StmtBlock will have a use list (just like BasicBlock does) of operands, and ForStmt already has a use list for its induction variable. This is a mechanical patch, NFC. PiperOrigin-RevId: 226684158
This commit is contained in:
parent
4eef795a1d
commit
1301f907a1
|
@ -345,7 +345,7 @@ public:
|
|||
|
||||
/// Returns a builder for the body of a for Stmt.
|
||||
static MLFuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
|
||||
return MLFuncBuilder(forStmt, forStmt->end());
|
||||
return MLFuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
}
|
||||
|
||||
/// Returns the current insertion point of the builder.
|
||||
|
|
|
@ -228,7 +228,7 @@ public:
|
|||
assert(index < getNumSuccessors());
|
||||
return getBasicBlockOperands()[index].get();
|
||||
}
|
||||
BasicBlock *getSuccessor(unsigned index) const {
|
||||
const BasicBlock *getSuccessor(unsigned index) const {
|
||||
return const_cast<Instruction *>(this)->getSuccessor(index);
|
||||
}
|
||||
void setSuccessor(BasicBlock *block, unsigned index);
|
||||
|
|
|
@ -216,8 +216,31 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
/// A ForStmtBody represents statements contained within a ForStmt.
|
||||
class ForStmtBody : public StmtBlock {
|
||||
public:
|
||||
explicit ForStmtBody(ForStmt *stmt)
|
||||
: StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) {
|
||||
assert(stmt != nullptr && "ForStmtBody must have non-null parent");
|
||||
}
|
||||
|
||||
~ForStmtBody() {}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast
|
||||
static bool classof(const StmtBlock *block) {
|
||||
return block->getStmtBlockKind() == StmtBlockKind::ForBody;
|
||||
}
|
||||
|
||||
/// Returns the 'for' statement that contains this body.
|
||||
ForStmt *getFor() { return forStmt; }
|
||||
const ForStmt *getFor() const { return forStmt; }
|
||||
|
||||
private:
|
||||
ForStmt *forStmt;
|
||||
};
|
||||
|
||||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public MLValue, public StmtBlock {
|
||||
class ForStmt : public Statement, public MLValue {
|
||||
public:
|
||||
static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
|
@ -228,7 +251,7 @@ public:
|
|||
// since child statements need to be destroyed before the MLValue that this
|
||||
// for stmt represents is destroyed. Affine maps are immortal objects and
|
||||
// don't need to be deleted.
|
||||
clear();
|
||||
getBody()->clear();
|
||||
}
|
||||
|
||||
/// Resolve base class ambiguity.
|
||||
|
@ -242,6 +265,12 @@ public:
|
|||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
ForStmtBody *getBody() { return &body; }
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
const ForStmtBody *getBody() const { return &body; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Bounds and step
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -359,10 +388,6 @@ public:
|
|||
return ptr->getKind() == IROperandOwner::Kind::ForStmt;
|
||||
}
|
||||
|
||||
static bool classof(const StmtBlock *block) {
|
||||
return block->getStmtBlockKind() == StmtBlockKind::For;
|
||||
}
|
||||
|
||||
// For statement represents implicitly represents induction variable by
|
||||
// inheriting from MLValue class. Whenever you need to refer to the loop
|
||||
// induction variable, just use the for statement itself.
|
||||
|
@ -371,6 +396,9 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
// The StmtBlock for the body.
|
||||
ForStmtBody body;
|
||||
|
||||
// Affine map for the lower bound.
|
||||
AffineMap lbMap;
|
||||
// Affine map for the upper bound. The upper bound is exclusive.
|
||||
|
@ -456,7 +484,9 @@ public:
|
|||
~IfClause() {}
|
||||
|
||||
/// Returns the if statement that contains this clause.
|
||||
IfStmt *getIf() const { return ifStmt; }
|
||||
const IfStmt *getIf() const { return ifStmt; }
|
||||
|
||||
IfStmt *getIf() { return ifStmt; }
|
||||
|
||||
private:
|
||||
IfStmt *ifStmt;
|
||||
|
|
|
@ -36,7 +36,7 @@ class StmtBlock {
|
|||
public:
|
||||
enum class StmtBlockKind {
|
||||
MLFunc, // MLFunction
|
||||
For, // ForStmt
|
||||
ForBody, // ForStmtBody
|
||||
IfClause // IfClause
|
||||
};
|
||||
|
||||
|
@ -53,7 +53,11 @@ public:
|
|||
|
||||
/// Returns the closest surrounding statement that contains this block or
|
||||
/// nullptr if this is a top-level statement block.
|
||||
Statement *getContainingStmt() const;
|
||||
Statement *getContainingStmt();
|
||||
|
||||
const Statement *getContainingStmt() const {
|
||||
return const_cast<StmtBlock *>(this)->getContainingStmt();
|
||||
}
|
||||
|
||||
/// Returns the function that this statement block is part of.
|
||||
/// The function is determined by traversing the chain of parent statements.
|
||||
|
|
|
@ -146,12 +146,13 @@ public:
|
|||
|
||||
void walkForStmt(ForStmt *forStmt) {
|
||||
static_cast<SubClass *>(this)->visitForStmt(forStmt);
|
||||
static_cast<SubClass *>(this)->walk(forStmt->begin(), forStmt->end());
|
||||
auto *body = forStmt->getBody();
|
||||
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
|
||||
}
|
||||
|
||||
void walkForStmtPostOrder(ForStmt *forStmt) {
|
||||
static_cast<SubClass *>(this)->walkPostOrder(forStmt->begin(),
|
||||
forStmt->end());
|
||||
auto *body = forStmt->getBody();
|
||||
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
|
||||
static_cast<SubClass *>(this)->visitForStmt(forStmt);
|
||||
}
|
||||
|
||||
|
|
|
@ -905,7 +905,7 @@ static StmtBlock *getCommonStmtBlock(const MemRefAccess &srcAccess,
|
|||
}
|
||||
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
|
||||
assert(isa<ForStmt>(commonForValue));
|
||||
return dyn_cast<ForStmt>(commonForValue);
|
||||
return cast<ForStmt>(commonForValue)->getBody();
|
||||
}
|
||||
|
||||
// Returns true if the ancestor operation statement of 'srcAccess' properly
|
||||
|
|
|
@ -305,9 +305,10 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) {
|
|||
// violation when we have the support.
|
||||
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> shifts) {
|
||||
assert(shifts.size() == forStmt.getStatements().size());
|
||||
auto *forBody = forStmt.getBody();
|
||||
assert(shifts.size() == forBody->getStatements().size());
|
||||
unsigned s = 0;
|
||||
for (const auto &stmt : forStmt) {
|
||||
for (const auto &stmt : *forBody) {
|
||||
// A for or if stmt does not produce any def/results (that are used
|
||||
// outside).
|
||||
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
|
||||
|
@ -319,8 +320,8 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
|
|||
// This is a naive way. If performance becomes an issue, a map can
|
||||
// be used to store 'shifts' - to look up the shift for a statement in
|
||||
// constant time.
|
||||
if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
|
||||
if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
|
||||
if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner()))
|
||||
if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)])
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -362,7 +362,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
|
|||
if (level == positions.size() - 1)
|
||||
return &stmt;
|
||||
if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
|
||||
return getStmtAtPosition(positions, level + 1, childForStmt);
|
||||
return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
|
||||
|
||||
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
|
||||
auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
|
||||
|
@ -453,13 +453,13 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
// Clone src loop nest and insert it a the beginning of the statement block
|
||||
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
|
||||
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
|
||||
MLFuncBuilder b(dstForStmt, dstForStmt->begin());
|
||||
MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
|
||||
|
||||
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
|
||||
Statement *sliceStmt =
|
||||
getStmtAtPosition(positions, /*level=*/0, sliceLoopNest);
|
||||
getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
|
||||
// Get loop nest surrounding 'sliceStmt'.
|
||||
SmallVector<ForStmt *, 4> sliceSurroundingLoops;
|
||||
getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
|
||||
|
|
|
@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() {
|
|||
HashTable::ScopeTy blockScope(liveValues);
|
||||
|
||||
// The induction variable of a for statement is live within its body.
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&block))
|
||||
liveValues.insert(forStmt, true);
|
||||
if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
|
||||
liveValues.insert(forStmtBody->getFor(), true);
|
||||
|
||||
for (auto &stmt : block) {
|
||||
// Verify that each of the operands are live.
|
||||
|
@ -322,7 +322,7 @@ bool MLFuncVerifier::verifyDominance() {
|
|||
return true;
|
||||
}
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
|
||||
if (walkBlock(*forStmt))
|
||||
if (walkBlock(*forStmt->getBody()))
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -206,7 +206,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) {
|
|||
if (!hasShorthandForm(ubMap))
|
||||
recordAffineMapReference(ubMap);
|
||||
|
||||
for (auto &childStmt : *forStmt)
|
||||
for (auto &childStmt : *forStmt->getBody())
|
||||
visitStatement(&childStmt);
|
||||
}
|
||||
|
||||
|
@ -1447,7 +1447,7 @@ void MLFunctionPrinter::print(const ForStmt *stmt) {
|
|||
os << " step " << stmt->getStep();
|
||||
|
||||
os << " {\n";
|
||||
print(static_cast<const StmtBlock *>(stmt));
|
||||
print(stmt->getBody());
|
||||
os.indent(numSpaces) << "}";
|
||||
}
|
||||
|
||||
|
|
|
@ -147,7 +147,7 @@ Instruction *Instruction::clone() const {
|
|||
int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1;
|
||||
for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE;
|
||||
--succIt) {
|
||||
successors[succIt] = getSuccessor(succIt);
|
||||
successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt));
|
||||
|
||||
// Add the successor operands in-place in reverse order.
|
||||
for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e;
|
||||
|
|
|
@ -338,7 +338,7 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
|
|||
: Statement(Kind::For, location),
|
||||
MLValue(MLValueKind::ForStmt,
|
||||
Type::getIndex(lbMap.getResult(0).getContext())),
|
||||
StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
|
||||
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
|
||||
operands.reserve(numOperands);
|
||||
}
|
||||
|
||||
|
@ -544,8 +544,8 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
|||
operandMap[forStmt] = newFor;
|
||||
|
||||
// Recursively clone the body of the for loop.
|
||||
for (auto &subStmt : *forStmt)
|
||||
newFor->push_back(subStmt.clone(operandMap, context));
|
||||
for (auto &subStmt : *forStmt->getBody())
|
||||
newFor->getBody()->push_back(subStmt.clone(operandMap, context));
|
||||
|
||||
return newFor;
|
||||
}
|
||||
|
|
|
@ -24,18 +24,19 @@ using namespace mlir;
|
|||
// Statement block
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Statement *StmtBlock::getContainingStmt() const {
|
||||
Statement *StmtBlock::getContainingStmt() {
|
||||
switch (kind) {
|
||||
case StmtBlockKind::MLFunc:
|
||||
return nullptr;
|
||||
case StmtBlockKind::For:
|
||||
return cast<ForStmt>(const_cast<StmtBlock *>(this));
|
||||
case StmtBlockKind::ForBody:
|
||||
return cast<ForStmtBody>(this)->getFor();
|
||||
case StmtBlockKind::IfClause:
|
||||
return cast<IfClause>(this)->getIf();
|
||||
}
|
||||
}
|
||||
|
||||
MLFunction *StmtBlock::findFunction() const {
|
||||
// FIXME: const incorrect.
|
||||
StmtBlock *block = const_cast<StmtBlock *>(this);
|
||||
|
||||
while (block->getContainingStmt()) {
|
||||
|
|
|
@ -2876,7 +2876,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
// If parsing of the for statement body fails,
|
||||
// MLIR contains for statement with those nested statements that have been
|
||||
// successfully parsed.
|
||||
if (parseStmtBlock(forStmt))
|
||||
if (parseStmtBlock(forStmt->getBody()))
|
||||
return ParseFailure;
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
|
|
|
@ -242,7 +242,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
// Walking manually because we need custom logic before and after traversing
|
||||
// the list of children.
|
||||
builder.setInsertionPoint(loopBodyFirstBlock);
|
||||
visitStmtBlock(forStmt);
|
||||
visitStmtBlock(forStmt->getBody());
|
||||
|
||||
// Builder point is currently at the last block of the loop body. Append the
|
||||
// induction variable stepping to this block and branch back to the exit
|
||||
|
|
|
@ -365,7 +365,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef),
|
||||
/*extraIndices=*/{}, indexRemap,
|
||||
/*extraOperands=*/outerIVs,
|
||||
/*domStmtFilter=*/&*forStmt->begin());
|
||||
/*domStmtFilter=*/&*forStmt->getBody()->begin());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -391,7 +391,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
|
|||
// the pass has to be instantiated with additional information that we aren't
|
||||
// provided with at the moment.
|
||||
if (forStmt->getStep() != 1) {
|
||||
if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->begin())) {
|
||||
if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
|
||||
runOnForStmt(innerFor);
|
||||
}
|
||||
return;
|
||||
|
|
|
@ -59,12 +59,12 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
|
|||
// destination's body.
|
||||
static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
|
||||
StmtBlock::iterator loc) {
|
||||
dest->getStatements().splice(loc, src->getStatements());
|
||||
dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements());
|
||||
}
|
||||
|
||||
// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
|
||||
static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
|
||||
moveLoopBody(src, dest, dest->begin());
|
||||
moveLoopBody(src, dest, dest->getBody()->begin());
|
||||
}
|
||||
|
||||
/// Constructs and sets new loop bounds after tiling for the case of
|
||||
|
@ -167,8 +167,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
MLFuncBuilder b(topLoop);
|
||||
// Loop bounds will be set later.
|
||||
auto *pointLoop = b.createFor(loc, 0, 0);
|
||||
pointLoop->getStatements().splice(
|
||||
pointLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
|
||||
pointLoop->getBody()->getStatements().splice(
|
||||
pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
|
||||
topLoop);
|
||||
newLoops[2 * width - 1 - i] = pointLoop;
|
||||
topLoop = pointLoop;
|
||||
if (i == 0)
|
||||
|
@ -180,8 +181,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
MLFuncBuilder b(topLoop);
|
||||
// Loop bounds will be set later.
|
||||
auto *tileSpaceLoop = b.createFor(loc, 0, 0);
|
||||
tileSpaceLoop->getStatements().splice(
|
||||
tileSpaceLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
|
||||
tileSpaceLoop->getBody()->getStatements().splice(
|
||||
tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
|
||||
topLoop);
|
||||
newLoops[2 * width - i - 1] = tileSpaceLoop;
|
||||
topLoop = tileSpaceLoop;
|
||||
}
|
||||
|
@ -223,8 +225,8 @@ static void getTileableBands(MLFunction *f,
|
|||
ForStmt *currStmt = root;
|
||||
do {
|
||||
band.push_back(currStmt);
|
||||
} while (currStmt->getStatements().size() == 1 &&
|
||||
(currStmt = dyn_cast<ForStmt>(&*currStmt->begin())));
|
||||
} while (currStmt->getBody()->getStatements().size() == 1 &&
|
||||
(currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
|
||||
bands->push_back(band);
|
||||
};
|
||||
|
||||
|
|
|
@ -104,7 +104,8 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
|
|||
}
|
||||
|
||||
bool walkForStmtPostOrder(ForStmt *forStmt) {
|
||||
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
|
||||
bool hasInnerLoops =
|
||||
walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
|
||||
if (!hasInnerLoops)
|
||||
loops.push_back(forStmt);
|
||||
return true;
|
||||
|
|
|
@ -152,7 +152,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
|
||||
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
|
||||
|
||||
if (unrollJamFactor == 1 || forStmt->getStatements().empty())
|
||||
if (unrollJamFactor == 1 || forStmt->getBody()->empty())
|
||||
return false;
|
||||
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
|
|
|
@ -147,7 +147,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
|
||||
loops.insert(forStmt);
|
||||
// Setting the insertion point to the innermost loop achieves nesting.
|
||||
b.setInsertionPointToStart(loops.back());
|
||||
b.setInsertionPointToStart(loops.back()->getBody());
|
||||
if (composed == getAffineConstantExpr(0, b.getContext())) {
|
||||
transfer->emitWarning(
|
||||
"Redundant copy can be implemented as a vector broadcast");
|
||||
|
|
|
@ -81,8 +81,9 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
|
|||
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
|
||||
/// a replacement cannot be performed.
|
||||
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
||||
MLFuncBuilder bInner(forStmt, forStmt->begin());
|
||||
bInner.setInsertionPoint(forStmt, forStmt->begin());
|
||||
auto *forBody = forStmt->getBody();
|
||||
MLFuncBuilder bInner(forBody, forBody->begin());
|
||||
bInner.setInsertionPoint(forBody, forBody->begin());
|
||||
|
||||
// Doubles the shape with a leading dimension extent of 2.
|
||||
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
|
||||
|
@ -127,7 +128,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
|||
// non-deferencing uses of the memref.
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
|
||||
ivModTwoOp->getResult(0), AffineMap::Null(), {},
|
||||
&*forStmt->begin())) {
|
||||
&*forStmt->getBody()->begin())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "memref replacement for double buffering failed\n";);
|
||||
ivModTwoOp->getOperation()->erase();
|
||||
|
@ -184,7 +185,7 @@ static void findMatchingStartFinishStmts(
|
|||
|
||||
// Collect outgoing DMA statements - needed to check for dependences below.
|
||||
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
|
||||
for (auto &stmt : *forStmt) {
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
||||
if (!opStmt)
|
||||
continue;
|
||||
|
@ -195,7 +196,7 @@ static void findMatchingStartFinishStmts(
|
|||
}
|
||||
|
||||
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
|
||||
for (auto &stmt : *forStmt) {
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
||||
if (!opStmt)
|
||||
continue;
|
||||
|
@ -228,7 +229,7 @@ static void findMatchingStartFinishStmts(
|
|||
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
|
||||
bool escapingUses = false;
|
||||
for (const auto &use : memref->getUses()) {
|
||||
if (!dominates(*forStmt->begin(), *use.getOwner())) {
|
||||
if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "can't pipeline: buffer is live out of loop\n";);
|
||||
escapingUses = true;
|
||||
|
@ -339,16 +340,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
}
|
||||
}
|
||||
// Everything else (including compute ops and dma finish) are shifted by one.
|
||||
for (const auto &stmt : *forStmt) {
|
||||
for (const auto &stmt : *forStmt->getBody()) {
|
||||
if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
|
||||
stmtShiftMap[&stmt] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Get shifts stored in map.
|
||||
std::vector<uint64_t> shifts(forStmt->getStatements().size());
|
||||
std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size());
|
||||
unsigned s = 0;
|
||||
for (auto &stmt : *forStmt) {
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
|
||||
shifts[s++] = stmtShiftMap[&stmt];
|
||||
LLVM_DEBUG(
|
||||
|
|
|
@ -119,7 +119,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
|
|||
// Move the loop body statements to the loop's containing block.
|
||||
auto *block = forStmt->getBlock();
|
||||
block->getStatements().splice(StmtBlock::iterator(forStmt),
|
||||
forStmt->getStatements());
|
||||
forStmt->getBody()->getStatements());
|
||||
forStmt->erase();
|
||||
return true;
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
operandMap[srcForStmt] = loopChunk;
|
||||
}
|
||||
for (auto *stmt : stmts) {
|
||||
loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
|
||||
loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
|
||||
}
|
||||
}
|
||||
if (promoteIfSingleIteration(loopChunk))
|
||||
|
@ -206,7 +206,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
// method.
|
||||
UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
||||
bool unrollPrologueEpilogue) {
|
||||
if (forStmt->getStatements().empty())
|
||||
if (forStmt->getBody()->empty())
|
||||
return UtilResult::Success;
|
||||
|
||||
// If the trip counts aren't constant, we would need versioning and
|
||||
|
@ -225,7 +225,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
|||
|
||||
int64_t step = forStmt->getStep();
|
||||
|
||||
unsigned numChildStmts = forStmt->getStatements().size();
|
||||
unsigned numChildStmts = forStmt->getBody()->getStatements().size();
|
||||
|
||||
// Do a linear time (counting) sort for the shifts.
|
||||
uint64_t maxShift = 0;
|
||||
|
@ -243,7 +243,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
|||
// body of the 'for' stmt.
|
||||
std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
|
||||
unsigned pos = 0;
|
||||
for (auto &stmt : *forStmt) {
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
auto shift = shifts[pos++];
|
||||
sortedStmtGroups[shift].push_back(&stmt);
|
||||
}
|
||||
|
@ -352,7 +352,7 @@ bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
||||
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
|
||||
|
||||
if (unrollFactor == 1 || forStmt->getStatements().empty())
|
||||
if (unrollFactor == 1 || forStmt->getBody()->empty())
|
||||
return false;
|
||||
|
||||
auto lbMap = forStmt->getLowerBoundMap();
|
||||
|
@ -406,11 +406,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
|
||||
// Builder to insert unrolled bodies right after the last statement in the
|
||||
// body of 'forStmt'.
|
||||
MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
|
||||
MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
|
||||
// Keep a pointer to the last statement in the original block so that we know
|
||||
// what to clone (since we are doing this in-place).
|
||||
StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end());
|
||||
StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
|
||||
|
||||
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
|
||||
for (unsigned i = 1; i < unrollFactor; i++) {
|
||||
|
@ -429,7 +429,8 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
}
|
||||
|
||||
// Clone the original body of 'forStmt'.
|
||||
for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) {
|
||||
for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
|
||||
it++) {
|
||||
builder.clone(*it, operandMap);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue