diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index 4d844093ff13..65ea4734e452 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -62,6 +62,23 @@ static cl::opt cl::desc("Assume that the product of the two iteration " "limits will never overflow")); +struct FlattenInfo { + Loop *OuterLoop = nullptr; + Loop *InnerLoop = nullptr; + PHINode *InnerInductionPHI = nullptr; + PHINode *OuterInductionPHI = nullptr; + Value *InnerLimit = nullptr; + Value *OuterLimit = nullptr; + BinaryOperator *InnerIncrement = nullptr; + BinaryOperator *OuterIncrement = nullptr; + BranchInst *InnerBranch = nullptr; + BranchInst *OuterBranch = nullptr; + SmallPtrSet LinearIVUses; + SmallPtrSet InnerPHIsToTransform; + + FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; +}; + // Finds the induction variable, increment and limit for a simple loop that we // can flatten. static bool findLoopComponents( @@ -161,10 +178,8 @@ static bool findLoopComponents( return true; } -static bool checkPHIs(Loop *OuterLoop, Loop *InnerLoop, - SmallPtrSetImpl &InnerPHIsToTransform, - PHINode *InnerInductionPHI, PHINode *OuterInductionPHI, - TargetTransformInfo *TTI) { +static bool checkPHIs(struct FlattenInfo &FI, + const TargetTransformInfo *TTI) { // All PHIs in the inner and outer headers must either be: // - The induction PHI, which we are going to rewrite as one induction in // the new loop. This is already checked by findLoopComponents. @@ -180,29 +195,29 @@ static bool checkPHIs(Loop *OuterLoop, Loop *InnerLoop, // the exception of the induction variable), but we do need to check that // there are no unsafe PHI nodes. SmallPtrSet SafeOuterPHIs; - SafeOuterPHIs.insert(OuterInductionPHI); + SafeOuterPHIs.insert(FI.OuterInductionPHI); // Check that all PHI nodes in the inner loop header match one of the valid // patterns. - for (PHINode &InnerPHI : InnerLoop->getHeader()->phis()) { + for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) { // The induction PHIs break these rules, and that's OK because we treat // them specially when doing the transformation. - if (&InnerPHI == InnerInductionPHI) + if (&InnerPHI == FI.InnerInductionPHI) continue; // Each inner loop PHI node must have two incoming values/blocks - one // from the pre-header, and one from the latch. assert(InnerPHI.getNumIncomingValues() == 2); Value *PreHeaderValue = - InnerPHI.getIncomingValueForBlock(InnerLoop->getLoopPreheader()); + InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader()); Value *LatchValue = - InnerPHI.getIncomingValueForBlock(InnerLoop->getLoopLatch()); + InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch()); // The incoming value from the outer loop must be the PHI node in the // outer loop header, with no modifications made in the top of the outer // loop. PHINode *OuterPHI = dyn_cast(PreHeaderValue); - if (!OuterPHI || OuterPHI->getParent() != OuterLoop->getHeader()) { + if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) { LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n"); return false; } @@ -212,7 +227,7 @@ static bool checkPHIs(Loop *OuterLoop, Loop *InnerLoop, // so this will actually be a PHI in the inner loop's exit block, which // only uses values from inside the inner loop. PHINode *LCSSAPHI = dyn_cast( - OuterPHI->getIncomingValueForBlock(OuterLoop->getLoopLatch())); + OuterPHI->getIncomingValueForBlock(FI.OuterLoop->getLoopLatch())); if (!LCSSAPHI) { LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n"); return false; @@ -230,10 +245,10 @@ static bool checkPHIs(Loop *OuterLoop, Loop *InnerLoop, LLVM_DEBUG(dbgs() << " Inner: "; InnerPHI.dump()); LLVM_DEBUG(dbgs() << " Outer: "; OuterPHI->dump()); SafeOuterPHIs.insert(OuterPHI); - InnerPHIsToTransform.insert(&InnerPHI); + FI.InnerPHIsToTransform.insert(&InnerPHI); } - for (PHINode &OuterPHI : OuterLoop->getHeader()->phis()) { + for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { if (!SafeOuterPHIs.count(&OuterPHI)) { LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); return false; @@ -244,18 +259,17 @@ static bool checkPHIs(Loop *OuterLoop, Loop *InnerLoop, } static bool -checkOuterLoopInsts(Loop *OuterLoop, Loop *InnerLoop, +checkOuterLoopInsts(struct FlattenInfo &FI, SmallPtrSetImpl &IterationInstructions, - Value *InnerLimit, PHINode *OuterPHI, - TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI) { // Check for instructions in the outer but not inner loop. If any of these // have side-effects then this transformation is not legal, and if there is // a significant amount of code here which can't be optimised out that it's // not profitable (as these instructions would get executed for each // iteration of the inner loop). unsigned RepeatedInstrCost = 0; - for (auto *B : OuterLoop->getBlocks()) { - if (InnerLoop->contains(B)) + for (auto *B : FI.OuterLoop->getBlocks()) { + if (FI.InnerLoop->contains(B)) continue; for (auto &I : *B) { @@ -276,11 +290,12 @@ checkOuterLoopInsts(Loop *OuterLoop, Loop *InnerLoop, // a fall-through, so adds no cost. BranchInst *Br = dyn_cast(&I); if (Br && Br->isUnconditional() && - Br->getSuccessor(0) == InnerLoop->getHeader()) + Br->getSuccessor(0) == FI.InnerLoop->getHeader()) continue; // Multiplies of the outer iteration variable and inner iteration // count will be optimised out. - if (match(&I, m_c_Mul(m_Specific(OuterPHI), m_Specific(InnerLimit)))) + if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI), + m_Specific(FI.InnerLimit)))) continue; int Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump()); @@ -298,10 +313,7 @@ checkOuterLoopInsts(Loop *OuterLoop, Loop *InnerLoop, return true; } -static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI, - BinaryOperator *InnerIncrement, - BinaryOperator *OuterIncrement, Value *InnerLimit, - SmallPtrSetImpl &LinearIVUses) { +static bool checkIVUsers(struct FlattenInfo &FI) { // We require all uses of both induction variables to match this pattern: // // (OuterPHI * InnerLimit) + InnerPHI @@ -313,20 +325,22 @@ static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI, // Check that all uses of the inner loop's induction variable match the // expected pattern, recording the uses of the outer IV. SmallPtrSet ValidOuterPHIUses; - for (User *U : InnerPHI->users()) { - if (U == InnerIncrement) + for (User *U : FI.InnerInductionPHI->users()) { + if (U == FI.InnerIncrement) continue; LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); Value *MatchedMul, *MatchedItCount; - if (match(U, m_c_Add(m_Specific(InnerPHI), m_Value(MatchedMul))) && + if (match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), + m_Value(MatchedMul))) && match(MatchedMul, - m_c_Mul(m_Specific(OuterPHI), m_Value(MatchedItCount))) && - MatchedItCount == InnerLimit) { + m_c_Mul(m_Specific(FI.OuterInductionPHI), + m_Value(MatchedItCount))) && + MatchedItCount == FI.InnerLimit) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); - LinearIVUses.insert(U); + FI.LinearIVUses.insert(U); } else { LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); return false; @@ -335,8 +349,8 @@ static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI, // Check that there are no uses of the outer IV other than the ones found // as part of the pattern above. - for (User *U : OuterPHI->users()) { - if (U == OuterIncrement) + for (User *U : FI.OuterInductionPHI->users()) { + if (U == FI.OuterIncrement) continue; LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); @@ -349,9 +363,9 @@ static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI, } } - LLVM_DEBUG(dbgs() << "Found " << LinearIVUses.size() + LLVM_DEBUG(dbgs() << "Found " << FI.LinearIVUses.size() << " value(s) that can be replaced:\n"; - for (Value *V : LinearIVUses) { + for (Value *V : FI.LinearIVUses) { dbgs() << " "; V->dump(); }); @@ -361,11 +375,9 @@ static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI, // Return an OverflowResult dependant on if overflow of the multiplication of // InnerLimit and OuterLimit can be assumed not to happen. -static OverflowResult checkOverflow(Loop *OuterLoop, Value *InnerLimit, - Value *OuterLimit, - SmallPtrSetImpl &LinearIVUses, +static OverflowResult checkOverflow(struct FlattenInfo &FI, DominatorTree *DT, AssumptionCache *AC) { - Function *F = OuterLoop->getHeader()->getParent(); + Function *F = FI.OuterLoop->getHeader()->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); // For debugging/testing. @@ -375,12 +387,12 @@ static OverflowResult checkOverflow(Loop *OuterLoop, Value *InnerLimit, // Check if the multiply could not overflow due to known ranges of the // input values. OverflowResult OR = computeOverflowForUnsignedMul( - InnerLimit, OuterLimit, DL, AC, - OuterLoop->getLoopPreheader()->getTerminator(), DT); + FI.InnerLimit, FI.OuterLimit, DL, AC, + FI.OuterLoop->getLoopPreheader()->getTerminator(), DT); if (OR != OverflowResult::MayOverflow) return OR; - for (Value *V : LinearIVUses) { + for (Value *V : FI.LinearIVUses) { for (Value *U : V->users()) { if (auto *GEP = dyn_cast(U)) { // The IV is used as the operand of a GEP, and the IV is at least as @@ -402,53 +414,45 @@ static OverflowResult checkOverflow(Loop *OuterLoop, Value *InnerLimit, return OverflowResult::MayOverflow; } -static bool FlattenLoopPair(Loop *OuterLoop, Loop *InnerLoop, DominatorTree *DT, +static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI, + AssumptionCache *AC, const TargetTransformInfo *TTI, std::function markLoopAsDeleted) { - Function *F = OuterLoop->getHeader()->getParent(); + Function *F = FI.OuterLoop->getHeader()->getParent(); LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " - << OuterLoop->getHeader()->getName() << " and inner loop " - << InnerLoop->getHeader()->getName() << " in " + << FI.OuterLoop->getHeader()->getName() << " and inner loop " + << FI.InnerLoop->getHeader()->getName() << " in " << F->getName() << "\n"); SmallPtrSet IterationInstructions; - PHINode *InnerInductionPHI, *OuterInductionPHI; - Value *InnerLimit, *OuterLimit; - BinaryOperator *InnerIncrement, *OuterIncrement; - BranchInst *InnerBranch, *OuterBranch; - - if (!findLoopComponents(InnerLoop, IterationInstructions, InnerInductionPHI, - InnerLimit, InnerIncrement, InnerBranch, SE)) + if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, + FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) return false; - if (!findLoopComponents(OuterLoop, IterationInstructions, OuterInductionPHI, - OuterLimit, OuterIncrement, OuterBranch, SE)) + if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI, + FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE)) return false; // Both of the loop limit values must be invariant in the outer loop // (non-instructions are all inherently invariant). - if (!OuterLoop->isLoopInvariant(InnerLimit)) { + if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) { LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n"); return false; } - if (!OuterLoop->isLoopInvariant(OuterLimit)) { + if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) { LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n"); return false; } - SmallPtrSet InnerPHIsToTransform; - if (!checkPHIs(OuterLoop, InnerLoop, InnerPHIsToTransform, InnerInductionPHI, - OuterInductionPHI, TTI)) + if (!checkPHIs(FI, TTI)) return false; // FIXME: it should be possible to handle different types correctly. - if (InnerInductionPHI->getType() != OuterInductionPHI->getType()) + if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType()) return false; - if (!checkOuterLoopInsts(OuterLoop, InnerLoop, IterationInstructions, - InnerLimit, OuterInductionPHI, TTI)) + if (!checkOuterLoopInsts(FI, IterationInstructions, TTI)) return false; // Find the values in the loop that can be replaced with the linearized @@ -456,9 +460,7 @@ static bool FlattenLoopPair(Loop *OuterLoop, Loop *InnerLoop, DominatorTree *DT, // or outer induction variable. If there were, we could still do this // transformation, but we'd have to insert a div/mod to calculate the // original IVs, so it wouldn't be profitable. - SmallPtrSet LinearIVUses; - if (!checkIVUsers(InnerInductionPHI, OuterInductionPHI, InnerIncrement, - OuterIncrement, InnerLimit, LinearIVUses)) + if (!checkIVUsers(FI)) return false; // Check if the new iteration variable might overflow. In this case, we @@ -468,8 +470,7 @@ static bool FlattenLoopPair(Loop *OuterLoop, Loop *InnerLoop, DominatorTree *DT, // TODO: it might be worth using a wider iteration variable rather than // versioning the loop, if a wide enough type is legal. bool MustVersionLoop = true; - OverflowResult OR = - checkOverflow(OuterLoop, InnerLimit, OuterLimit, LinearIVUses, DT, AC); + OverflowResult OR = checkOverflow(FI, DT, AC); if (OR == OverflowResult::AlwaysOverflowsHigh || OR == OverflowResult::AlwaysOverflowsLow) { LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); @@ -490,47 +491,47 @@ static bool FlattenLoopPair(Loop *OuterLoop, Loop *InnerLoop, DominatorTree *DT, { using namespace ore; - OptimizationRemark Remark(DEBUG_TYPE, "Flattened", InnerLoop->getStartLoc(), - InnerLoop->getHeader()); + OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(), + FI.InnerLoop->getHeader()); OptimizationRemarkEmitter ORE(F); Remark << "Flattened into outer loop"; ORE.emit(Remark); } Value *NewTripCount = - BinaryOperator::CreateMul(InnerLimit, OuterLimit, "flatten.tripcount", - OuterLoop->getLoopPreheader()->getTerminator()); + BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount", + FI.OuterLoop->getLoopPreheader()->getTerminator()); LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; NewTripCount->dump()); // Fix up PHI nodes that take values from the inner loop back-edge, which // we are about to remove. - InnerInductionPHI->removeIncomingValue(InnerLoop->getLoopLatch()); - for (PHINode *PHI : InnerPHIsToTransform) - PHI->removeIncomingValue(InnerLoop->getLoopLatch()); + FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + for (PHINode *PHI : FI.InnerPHIsToTransform) + PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); // Modify the trip count of the outer loop to be the product of the two // trip counts. - cast(OuterBranch->getCondition())->setOperand(1, NewTripCount); + cast(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount); // Replace the inner loop backedge with an unconditional branch to the exit. - BasicBlock *InnerExitBlock = InnerLoop->getExitBlock(); - BasicBlock *InnerExitingBlock = InnerLoop->getExitingBlock(); + BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock(); + BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock(); InnerExitingBlock->getTerminator()->eraseFromParent(); BranchInst::Create(InnerExitBlock, InnerExitingBlock); - DT->deleteEdge(InnerExitingBlock, InnerLoop->getHeader()); + DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); // Replace all uses of the polynomial calculated from the two induction // variables with the one new one. - for (Value *V : LinearIVUses) - V->replaceAllUsesWith(OuterInductionPHI); + for (Value *V : FI.LinearIVUses) + V->replaceAllUsesWith(FI.OuterInductionPHI); // Tell LoopInfo, SCEV and the pass manager that the inner loop has been // deleted, and any information that have about the outer loop invalidated. - markLoopAsDeleted(InnerLoop); - SE->forgetLoop(OuterLoop); - SE->forgetLoop(InnerLoop); - LI->erase(InnerLoop); + markLoopAsDeleted(FI.InnerLoop); + SE->forgetLoop(FI.OuterLoop); + SE->forgetLoop(FI.InnerLoop); + LI->erase(FI.InnerLoop); return true; } @@ -543,8 +544,9 @@ PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM, Loop *InnerLoop = *L.begin(); std::string LoopName(InnerLoop->getName()); + struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop); if (!FlattenLoopPair( - &L, InnerLoop, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, + FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, [&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); })) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); @@ -600,6 +602,7 @@ bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { *L->getHeader()->getParent()); Loop *InnerLoop = *L->begin(); - return FlattenLoopPair(L, InnerLoop, DT, LI, SE, AC, TTI, + struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop); + return FlattenLoopPair(FI, DT, LI, SE, AC, TTI, [&](Loop *L) { LPM.markLoopAsDeleted(*L); }); }