[LoopUnroll] Clean up exit folding (NFC)

This does some non-functional cleanup of exit folding during
unrolling. The two main changes are:

 * First rewrite latch->header edges, which is unrelated to exit
   folding.
 * Combine folding for latch and non-latch exits. After the
   previous change, the only difference in their logic is that
   for non-latch exits we currently only fold "known non-exit"
   cases, but not "known exit" cases.

I think this helps a lot to clarify this code and prepare it for
future changes.

Differential Revision: https://reviews.llvm.org/D103333
This commit is contained in:
Nikita Popov 2021-05-28 20:14:32 +02:00
parent a41309966a
commit f765445a69
1 changed files with 52 additions and 101 deletions

View File

@ -528,12 +528,6 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
if (!LatchIsExiting) if (!LatchIsExiting)
++NumUnrolledNotLatch; ++NumUnrolledNotLatch;
Optional<bool> ContinueOnTrue = None;
BasicBlock *LoopExit = nullptr;
if (ExitingBI) {
ContinueOnTrue = L->contains(ExitingBI->getSuccessor(0));
LoopExit = ExitingBI->getSuccessor(*ContinueOnTrue);
}
// For the first iteration of the loop, we should use the precloned values for // For the first iteration of the loop, we should use the precloned values for
// PHI nodes. Insert associations now. // PHI nodes. Insert associations now.
@ -545,14 +539,11 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
std::vector<BasicBlock *> Headers; std::vector<BasicBlock *> Headers;
std::vector<BasicBlock *> ExitingBlocks; std::vector<BasicBlock *> ExitingBlocks;
std::vector<BasicBlock *> ExitingSucc;
std::vector<BasicBlock *> Latches; std::vector<BasicBlock *> Latches;
Headers.push_back(Header); Headers.push_back(Header);
Latches.push_back(LatchBlock); Latches.push_back(LatchBlock);
if (ExitingBI) { if (ExitingBI)
ExitingBlocks.push_back(ExitingBI->getParent()); ExitingBlocks.push_back(ExitingBI->getParent());
ExitingSucc.push_back(ExitingBI->getSuccessor(!(*ContinueOnTrue)));
}
// The current on-the-fly SSA update requires blocks to be processed in // The current on-the-fly SSA update requires blocks to be processed in
// reverse postorder so that LastValueMap contains the correct value at each // reverse postorder so that LastValueMap contains the correct value at each
@ -652,12 +643,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
// Keep track of the exiting block and its successor block contained in // Keep track of the exiting block and its successor block contained in
// the loop for the current iteration. // the loop for the current iteration.
if (ExitingBI) { if (ExitingBI)
if (*BB == ExitingBlocks[0]) if (*BB == ExitingBlocks[0])
ExitingBlocks.push_back(New); ExitingBlocks.push_back(New);
if (*BB == ExitingSucc[0])
ExitingSucc.push_back(New);
}
NewBlocks.push_back(New); NewBlocks.push_back(New);
UnrolledLoopBlocks.push_back(New); UnrolledLoopBlocks.push_back(New);
@ -714,111 +702,74 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
} }
} }
auto setDest = [](BasicBlock *Src, BasicBlock *Dest, BasicBlock *BlockInLoop, // Connect latches of the unrolled iterations to the headers of the next
bool NeedConditional, Optional<bool> ContinueOnTrue, // iteration. Currently they point to the header of the same iteration.
bool IsDestLoopExit) { for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
auto *Term = cast<BranchInst>(Src->getTerminator()); unsigned j = (i + 1) % e;
if (NeedConditional) { Latches[i]->getTerminator()->replaceSuccessorWith(Headers[i], Headers[j]);
// Update the conditional branch's successor for the following }
// iteration.
assert(ContinueOnTrue.hasValue() && if (ExitingBI) {
"Expecting valid ContinueOnTrue when NeedConditional is true"); auto SetDest = [](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
Term->setSuccessor(!(*ContinueOnTrue), Dest); auto *Term = cast<BranchInst>(Src->getTerminator());
} else { BasicBlock *Dest = Term->getSuccessor(ExitOnTrue ^ WillExit);
// Remove phi operands at this loop exit
if (!IsDestLoopExit) { // Remove predecessors from all non-Dest successors.
BasicBlock *BB = Src; for (BasicBlock *Succ : successors(Src)) {
for (BasicBlock *Succ : successors(BB)) { if (Succ == Dest)
// Preserve the incoming value from BB if we are jumping to the block continue;
// in the current loop. Succ->removePredecessor(Src, /* KeepOneInputPHIs */ true);
if (Succ == BlockInLoop)
continue;
for (PHINode &Phi : Succ->phis())
Phi.removeIncomingValue(BB, false);
}
} }
// Replace the conditional branch with an unconditional one. // Replace the conditional branch with an unconditional one.
BranchInst::Create(Dest, Term); BranchInst::Create(Dest, Term);
Term->eraseFromParent(); Term->eraseFromParent();
} };
};
// Connect latches of the unrolled iterations to the headers of the next auto WillExit = [&](unsigned i, unsigned j) -> Optional<bool> {
// iteration. If the latch is also the exiting block, the conditional branch
// may have to be preserved.
for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
// The branch destination.
unsigned j = (i + 1) % e;
BasicBlock *Dest = Headers[j];
bool NeedConditional = LatchIsExiting;
if (LatchIsExiting) {
if (RuntimeTripCount && j != 0)
NeedConditional = false;
// For a complete unroll, make the last iteration end with a branch
// to the exit block.
if (CompletelyUnroll) { if (CompletelyUnroll) {
if (j == 0) if (ULO.PreserveCondBr && j && !(ULO.PreserveOnlyFirst && i != 0))
Dest = LoopExit; return None;
// If using trip count upper bound to completely unroll, we need to return j == 0;
// keep the conditional branch except the last one because the loop }
// may exit after any iteration.
assert(NeedConditional && if (RuntimeTripCount && j != 0)
"NeedCondition cannot be modified by both complete " return false;
"unrolling and runtime unrolling");
NeedConditional = if (j != BreakoutTrip &&
(ULO.PreserveCondBr && j && !(ULO.PreserveOnlyFirst && i != 0)); (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) {
} else if (j != BreakoutTrip &&
(ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) {
// If we know the trip count or a multiple of it, we can safely use an // If we know the trip count or a multiple of it, we can safely use an
// unconditional branch for some iterations. // unconditional branch for some iterations.
NeedConditional = false; return false;
} }
} return None;
};
setDest(Latches[i], Dest, Headers[i], NeedConditional, ContinueOnTrue, // Fold branches for iterations where we know that they will exit or not
Dest == LoopExit); // exit.
} bool ExitOnTrue = !L->contains(ExitingBI->getSuccessor(0));
if (!LatchIsExiting) {
// If the latch is not exiting, we may be able to simplify the conditional
// branches in the unrolled exiting blocks.
for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
// The branch destination. // The branch destination.
unsigned j = (i + 1) % e; unsigned j = (i + 1) % e;
bool NeedConditional = true; Optional<bool> KnownWillExit = WillExit(i, j);
if (!KnownWillExit)
if (RuntimeTripCount && j != 0)
NeedConditional = false;
if (CompletelyUnroll)
// We cannot drop the conditional branch for the last condition, as we
// may have to execute the loop body depending on the condition.
NeedConditional = j == 0 || ULO.PreserveCondBr;
else if (j != BreakoutTrip &&
(ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0))
// If we know the trip count or a multiple of it, we can safely use an
// unconditional branch for some iterations.
NeedConditional = false;
// Conditional branches from non-latch exiting block have successors
// either in the same loop iteration or outside the loop. The branches are
// already correct.
if (NeedConditional)
continue; continue;
setDest(ExitingBlocks[i], ExitingSucc[i], ExitingSucc[i], NeedConditional,
None, false);
}
// When completely unrolling, the last latch becomes unreachable. // TODO: Also fold known-exiting branches for non-latch exits.
if (CompletelyUnroll) { if (*KnownWillExit && !LatchIsExiting)
BranchInst *Term = cast<BranchInst>(Latches.back()->getTerminator()); continue;
new UnreachableInst(Term->getContext(), Term);
Term->eraseFromParent(); SetDest(ExitingBlocks[i], *KnownWillExit, ExitOnTrue);
} }
} }
// When completely unrolling, the last latch becomes unreachable.
if (!LatchIsExiting && CompletelyUnroll) {
BranchInst *Term = cast<BranchInst>(Latches.back()->getTerminator());
new UnreachableInst(Term->getContext(), Term);
Term->eraseFromParent();
}
// Update dominators of blocks we might reach through exits. // Update dominators of blocks we might reach through exits.
// Immediate dominator of such block might change, because we add more // Immediate dominator of such block might change, because we add more
// routes which can lead to the exit: we can now reach it from the copied // routes which can lead to the exit: we can now reach it from the copied