[OpenMPIRBuilder] Implicitly defined control blocks. NFC.

Do not explicitly store the BasicBlocks for Preheader, Body and After inside CanonicalLoopInfo, but look the up when needed using their position relative to the other loop control blocks. By definition, instructions inside these are not managed by CanonicalLoopInfo (except terminator for Preheader) hence it makes sense to think of them as connections to the CanonicalLoopInfo instead of part of the CanonicalLoopInfo itself.

In particular for Preheader, it makes using SplitBasicBlock easier since inserting control flow at an InsertPoint may otherwise require updating the CanonicalLoopInfo's Preheader because the branch that jumps to the header is moved to another BasicBlock.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D114368
This commit is contained in:
Michael Kruse 2021-12-06 14:08:07 -06:00
parent c5fef77bc3
commit fed966f2a4
2 changed files with 33 additions and 25 deletions

View File

@ -1408,13 +1408,10 @@ class CanonicalLoopInfo {
friend class OpenMPIRBuilder; friend class OpenMPIRBuilder;
private: private:
BasicBlock *Preheader = nullptr;
BasicBlock *Header = nullptr; BasicBlock *Header = nullptr;
BasicBlock *Cond = nullptr; BasicBlock *Cond = nullptr;
BasicBlock *Body = nullptr;
BasicBlock *Latch = nullptr; BasicBlock *Latch = nullptr;
BasicBlock *Exit = nullptr; BasicBlock *Exit = nullptr;
BasicBlock *After = nullptr;
/// Add the control blocks of this loop to \p BBs. /// Add the control blocks of this loop to \p BBs.
/// ///
@ -1436,10 +1433,7 @@ public:
/// Code that must be execute before any loop iteration can be emitted here, /// Code that must be execute before any loop iteration can be emitted here,
/// such as computing the loop trip count and begin lifetime markers. Code in /// such as computing the loop trip count and begin lifetime markers. Code in
/// the preheader is not considered part of the canonical loop. /// the preheader is not considered part of the canonical loop.
BasicBlock *getPreheader() const { BasicBlock *getPreheader() const;
assert(isValid() && "Requires a valid canonical loop");
return Preheader;
}
/// The header is the entry for each iteration. In the canonical control flow, /// The header is the entry for each iteration. In the canonical control flow,
/// it only contains the PHINode for the induction variable. /// it only contains the PHINode for the induction variable.
@ -1460,7 +1454,7 @@ public:
/// eventually branch to the \p Latch block. /// eventually branch to the \p Latch block.
BasicBlock *getBody() const { BasicBlock *getBody() const {
assert(isValid() && "Requires a valid canonical loop"); assert(isValid() && "Requires a valid canonical loop");
return Body; return cast<BranchInst>(Cond->getTerminator())->getSuccessor(0);
} }
/// Reaching the latch indicates the end of the loop body code. In the /// Reaching the latch indicates the end of the loop body code. In the
@ -1484,7 +1478,7 @@ public:
/// statements/cancellations). /// statements/cancellations).
BasicBlock *getAfter() const { BasicBlock *getAfter() const {
assert(isValid() && "Requires a valid canonical loop"); assert(isValid() && "Requires a valid canonical loop");
return After; return Exit->getSingleSuccessor();
} }
/// Returns the llvm::Value containing the number of loop iterations. It must /// Returns the llvm::Value containing the number of loop iterations. It must
@ -1515,18 +1509,21 @@ public:
/// Return the insertion point for user code before the loop. /// Return the insertion point for user code before the loop.
OpenMPIRBuilder::InsertPointTy getPreheaderIP() const { OpenMPIRBuilder::InsertPointTy getPreheaderIP() const {
assert(isValid() && "Requires a valid canonical loop"); assert(isValid() && "Requires a valid canonical loop");
BasicBlock *Preheader = getPreheader();
return {Preheader, std::prev(Preheader->end())}; return {Preheader, std::prev(Preheader->end())};
}; };
/// Return the insertion point for user code in the body. /// Return the insertion point for user code in the body.
OpenMPIRBuilder::InsertPointTy getBodyIP() const { OpenMPIRBuilder::InsertPointTy getBodyIP() const {
assert(isValid() && "Requires a valid canonical loop"); assert(isValid() && "Requires a valid canonical loop");
BasicBlock *Body = getBody();
return {Body, Body->begin()}; return {Body, Body->begin()};
}; };
/// Return the insertion point for user code after the loop. /// Return the insertion point for user code after the loop.
OpenMPIRBuilder::InsertPointTy getAfterIP() const { OpenMPIRBuilder::InsertPointTy getAfterIP() const {
assert(isValid() && "Requires a valid canonical loop"); assert(isValid() && "Requires a valid canonical loop");
BasicBlock *After = getAfter();
return {After, After->begin()}; return {After, After->begin()};
}; };

View File

@ -1329,13 +1329,10 @@ CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
LoopInfos.emplace_front(); LoopInfos.emplace_front();
CanonicalLoopInfo *CL = &LoopInfos.front(); CanonicalLoopInfo *CL = &LoopInfos.front();
CL->Preheader = Preheader;
CL->Header = Header; CL->Header = Header;
CL->Cond = Cond; CL->Cond = Cond;
CL->Body = Body;
CL->Latch = Latch; CL->Latch = Latch;
CL->Exit = Exit; CL->Exit = Exit;
CL->After = After;
#ifndef NDEBUG #ifndef NDEBUG
CL->assertOK(); CL->assertOK();
@ -1359,7 +1356,7 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
// Split the loop at the insertion point: Branch to the preheader and move // Split the loop at the insertion point: Branch to the preheader and move
// every following instruction to after the loop (the After BB). Also, the // every following instruction to after the loop (the After BB). Also, the
// new successor is the loop's after block. // new successor is the loop's after block.
Builder.CreateBr(CL->Preheader); Builder.CreateBr(CL->getPreheader());
After->getInstList().splice(After->begin(), BB->getInstList(), After->getInstList().splice(After->begin(), BB->getInstList(),
Builder.GetInsertPoint(), BB->end()); Builder.GetInsertPoint(), BB->end());
After->replaceSuccessorsPhiUsesWith(BB, After); After->replaceSuccessorsPhiUsesWith(BB, After);
@ -1791,6 +1788,12 @@ OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
BasicBlock *OrigAfter = Outermost->getAfter(); BasicBlock *OrigAfter = Outermost->getAfter();
Function *F = OrigPreheader->getParent(); Function *F = OrigPreheader->getParent();
// Loop control blocks that may become orphaned later.
SmallVector<BasicBlock *, 12> OldControlBBs;
OldControlBBs.reserve(6 * Loops.size());
for (CanonicalLoopInfo *Loop : Loops)
Loop->collectControlBlocks(OldControlBBs);
// Setup the IRBuilder for inserting the trip count computation. // Setup the IRBuilder for inserting the trip count computation.
Builder.SetCurrentDebugLocation(DL); Builder.SetCurrentDebugLocation(DL);
if (ComputeIP.isSet()) if (ComputeIP.isSet())
@ -1886,10 +1889,6 @@ OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]); Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
// Remove unused parts of the input loops. // Remove unused parts of the input loops.
SmallVector<BasicBlock *, 12> OldControlBBs;
OldControlBBs.reserve(6 * Loops.size());
for (CanonicalLoopInfo *Loop : Loops)
Loop->collectControlBlocks(OldControlBBs);
removeUnusedBlocksFromParent(OldControlBBs); removeUnusedBlocksFromParent(OldControlBBs);
for (CanonicalLoopInfo *L : Loops) for (CanonicalLoopInfo *L : Loops)
@ -1915,6 +1914,12 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
BasicBlock *InnerEnter = InnermostLoop->getBody(); BasicBlock *InnerEnter = InnermostLoop->getBody();
BasicBlock *InnerLatch = InnermostLoop->getLatch(); BasicBlock *InnerLatch = InnermostLoop->getLatch();
// Loop control blocks that may become orphaned later.
SmallVector<BasicBlock *, 12> OldControlBBs;
OldControlBBs.reserve(6 * Loops.size());
for (CanonicalLoopInfo *Loop : Loops)
Loop->collectControlBlocks(OldControlBBs);
// Collect original trip counts and induction variable to be accessible by // Collect original trip counts and induction variable to be accessible by
// index. Also, the structure of the original loops is not preserved during // index. Also, the structure of the original loops is not preserved during
// the construction of the tiled loops, so do it before we scavenge the BBs of // the construction of the tiled loops, so do it before we scavenge the BBs of
@ -2074,10 +2079,6 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
} }
// Remove unused parts of the original loops. // Remove unused parts of the original loops.
SmallVector<BasicBlock *, 12> OldControlBBs;
OldControlBBs.reserve(6 * Loops.size());
for (CanonicalLoopInfo *Loop : Loops)
Loop->collectControlBlocks(OldControlBBs);
removeUnusedBlocksFromParent(OldControlBBs); removeUnusedBlocksFromParent(OldControlBBs);
for (CanonicalLoopInfo *L : Loops) for (CanonicalLoopInfo *L : Loops)
@ -3321,7 +3322,16 @@ void CanonicalLoopInfo::collectControlBlocks(
// flow. For consistency, this also means we do not add the Body block, which // flow. For consistency, this also means we do not add the Body block, which
// is just the entry to the body code. // is just the entry to the body code.
BBs.reserve(BBs.size() + 6); BBs.reserve(BBs.size() + 6);
BBs.append({Preheader, Header, Cond, Latch, Exit, After}); BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
}
BasicBlock *CanonicalLoopInfo::getPreheader() const {
assert(isValid() && "Requires a valid canonical loop");
for (BasicBlock *Pred : predecessors(Header)) {
if (Pred != Latch)
return Pred;
}
llvm_unreachable("Missing preheader");
} }
void CanonicalLoopInfo::assertOK() const { void CanonicalLoopInfo::assertOK() const {
@ -3330,6 +3340,10 @@ void CanonicalLoopInfo::assertOK() const {
if (!isValid()) if (!isValid())
return; return;
BasicBlock *Preheader = getPreheader();
BasicBlock *Body = getBody();
BasicBlock *After = getAfter();
// Verify standard control-flow we use for OpenMP loops. // Verify standard control-flow we use for OpenMP loops.
assert(Preheader); assert(Preheader);
assert(isa<BranchInst>(Preheader->getTerminator()) && assert(isa<BranchInst>(Preheader->getTerminator()) &&
@ -3415,11 +3429,8 @@ void CanonicalLoopInfo::assertOK() const {
} }
void CanonicalLoopInfo::invalidate() { void CanonicalLoopInfo::invalidate() {
Preheader = nullptr;
Header = nullptr; Header = nullptr;
Cond = nullptr; Cond = nullptr;
Body = nullptr;
Latch = nullptr; Latch = nullptr;
Exit = nullptr; Exit = nullptr;
After = nullptr;
} }