diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index be7527b3efd1..ef5ea30ffeac 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -233,6 +233,8 @@ public: SmallVectorImpl &Checks); }; +struct LoopStructure; + class InductiveRangeCheckElimination { ScalarEvolution &SE; BranchProbabilityInfo *BPI; @@ -243,6 +245,10 @@ class InductiveRangeCheckElimination { llvm::Optional >; GetBFIFunc GetBFI; + // Returns true if it is profitable to do a transform basing on estimation of + // number of iterations. + bool isProfitableToTransform(const Loop &L, LoopStructure &LS); + public: InductiveRangeCheckElimination(ScalarEvolution &SE, BranchProbabilityInfo *BPI, DominatorTree &DT, @@ -505,9 +511,8 @@ struct LoopStructure { return Result; } - static Optional parseLoopStructure(ScalarEvolution &, - BranchProbabilityInfo *BPI, - Loop &, const char *&); + static Optional parseLoopStructure(ScalarEvolution &, Loop &, + const char *&); }; /// This class is used to constrain loops to run within a given iteration space. @@ -751,8 +756,7 @@ static bool isSafeIncreasingBound(const SCEV *Start, } Optional -LoopStructure::parseLoopStructure(ScalarEvolution &SE, - BranchProbabilityInfo *BPI, Loop &L, +LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, const char *&FailureReason) { if (!L.isLoopSimplifyForm()) { FailureReason = "loop not in LoopSimplify form"; @@ -787,16 +791,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; - BranchProbability ExitProbability = - BPI ? BPI->getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx) - : BranchProbability::getZero(); - - if (!SkipProfitabilityChecks && - ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { - FailureReason = "short running loop, not profitable"; - return None; - } - ICmpInst *ICI = dyn_cast(LatchBr->getCondition()); if (!ICI || !isa(ICI->getOperand(0)->getType())) { FailureReason = "latch terminator branch not conditional on integral icmp"; @@ -1855,6 +1849,37 @@ bool IRCELegacyPass::runOnFunction(Function &F) { return Changed; } +bool +InductiveRangeCheckElimination::isProfitableToTransform(const Loop &L, + LoopStructure &LS) { + if (SkipProfitabilityChecks) + return true; + if (GetBFI.hasValue()) { + BlockFrequencyInfo &BFI = (*GetBFI)(); + uint64_t hFreq = BFI.getBlockFreq(LS.Header).getFrequency(); + uint64_t phFreq = BFI.getBlockFreq(L.getLoopPreheader()).getFrequency(); + if (phFreq != 0 && hFreq != 0 && (hFreq / phFreq < MinRuntimeIterations)) { + LLVM_DEBUG(dbgs() << "irce: could not prove profitability: " + << "the estimated number of iterations basing on " + "frequency info is " << (hFreq / phFreq) << "\n";); + return false; + } + return true; + } + + if (!BPI) + return true; + BranchProbability ExitProbability = + BPI->getEdgeProbability(LS.Latch, LS.LatchBrExitIdx); + if (ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { + LLVM_DEBUG(dbgs() << "irce: could not prove profitability: " + << "the exit probability is too big " << ExitProbability + << "\n";); + return false; + } + return true; +} + bool InductiveRangeCheckElimination::run( Loop *L, function_ref LPMAddNewLoop) { if (L->getBlocks().size() >= LoopSizeCutoff) { @@ -1894,25 +1919,15 @@ bool InductiveRangeCheckElimination::run( const char *FailureReason = nullptr; Optional MaybeLoopStructure = - LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); + LoopStructure::parseLoopStructure(SE, *L, FailureReason); if (!MaybeLoopStructure.hasValue()) { LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason << "\n";); return false; } LoopStructure LS = MaybeLoopStructure.getValue(); - // Profitability check. - if (!SkipProfitabilityChecks && GetBFI.hasValue()) { - BlockFrequencyInfo &BFI = (*GetBFI)(); - uint64_t hFreq = BFI.getBlockFreq(LS.Header).getFrequency(); - uint64_t phFreq = BFI.getBlockFreq(Preheader).getFrequency(); - if (phFreq != 0 && hFreq != 0 && (hFreq / phFreq < MinRuntimeIterations)) { - LLVM_DEBUG(dbgs() << "irce: could not prove profitability: " - << "the estimated number of iterations basing on " - "frequency info is " << (hFreq / phFreq) << "\n";); - return false; - } - } + if (!isProfitableToTransform(*L, LS)) + return false; const SCEVAddRecExpr *IndVar = cast(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep)));