diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 73faa0afcdc8..df50611832ce 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1378,6 +1378,8 @@ private: /// includes an exact count and a maximum count. /// class BackedgeTakenInfo { + friend class ScalarEvolution; + /// A list of computable exits and their not-taken counts. Loops almost /// never have more than one computable exit. SmallVector ExitNotTaken; @@ -1398,9 +1400,6 @@ private: /// True iff the backedge is taken either exactly Max or zero times. bool MaxOrZero = false; - /// SCEV expressions used in any of the ExitNotTakenInfo counts. - SmallPtrSet Operands; - bool isComplete() const { return IsComplete; } const SCEV *getConstantMax() const { return ConstantMax; } @@ -1466,10 +1465,6 @@ private: /// Return true if the number of times this backedge is taken is either the /// value returned by getConstantMax or zero. bool isConstantMaxOrZero(ScalarEvolution *SE) const; - - /// Return true if any backedge taken count expressions refer to the given - /// subexpression. - bool hasOperand(const SCEV *S) const; }; /// Cache the backedge-taken count of the loops for this function as they @@ -1480,6 +1475,10 @@ private: /// function as they are computed. DenseMap PredicatedBackedgeTakenCounts; + /// Loops whose backedge taken counts directly use this non-constant SCEV. + DenseMap, 4>> + BECountUsers; + /// This map contains entries for all of the PHI instructions that we /// attempt to compute constant evolutions for. This allows us to avoid /// potentially expensive recomputation of these properties. An instruction @@ -1911,6 +1910,9 @@ private: bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags); + /// Forget predicated/non-predicated backedge taken counts for the given loop. + void forgetBackedgeTakenCounts(const Loop *L, bool Predicated); + /// Drop memoized information for all \p SCEVs. void forgetMemoizedResults(ArrayRef SCEVs); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index ece19099d65e..7dc7f9904c70 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7603,6 +7603,7 @@ void ScalarEvolution::forgetAllLoops() { // result. BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); + BECountUsers.clear(); LoopPropertiesCache.clear(); ConstantEvolutionLoopExitValue.clear(); ValueExprMap.clear(); @@ -7629,8 +7630,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) { auto *CurrL = LoopWorklist.pop_back_val(); // Drop any stored trip count value. - BackedgeTakenCounts.erase(CurrL); - PredicatedBackedgeTakenCounts.erase(CurrL); + forgetBackedgeTakenCounts(CurrL, /* Predicated */ false); + forgetBackedgeTakenCounts(CurrL, /* Predicated */ true); // Drop information about predicated SCEV rewrites for this loop. for (auto I = PredicatedSCEVRewrites.begin(); @@ -7804,10 +7805,6 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); } -bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S) const { - return Operands.contains(S); -} - ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) : ExitLimit(E, E, false, None) { } @@ -7848,19 +7845,6 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M, : ExitLimit(E, M, MaxOrZero, None) { } -class SCEVRecordOperands { - SmallPtrSetImpl &Operands; - -public: - SCEVRecordOperands(SmallPtrSetImpl &Operands) - : Operands(Operands) {} - bool follow(const SCEV *S) { - Operands.insert(S); - return true; - } - bool isDone() { return false; } -}; - /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( @@ -7889,14 +7873,6 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( assert((isa(ConstantMax) || isa(ConstantMax)) && "No point in having a non-constant max backedge taken count!"); - - SCEVRecordOperands RecordOperands(Operands); - SCEVTraversal ST(RecordOperands); - if (!isa(ConstantMax)) - ST.visitAll(ConstantMax); - for (auto &ENT : ExitNotTaken) - if (!isa(ENT.ExactNotTaken)) - ST.visitAll(ENT.ExactNotTaken); } /// Compute the number of times the backedge of the specified loop will execute. @@ -7978,6 +7954,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, // The loop backedge will be taken the maximum or zero times if there's // a single exit that must be taken the maximum or zero times. bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); + + // Remember which SCEVs are used in exit limits for invalidation purposes. + // We only care about non-constant SCEVs here, so we can ignore EL.MaxNotTaken + // and MaxBECount, which must be SCEVConstant. + for (const auto &Pair : ExitCounts) + if (!isa(Pair.second.ExactNotTaken)) + BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates}); return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount, MaxBECount, MaxOrZero); } @@ -12466,6 +12449,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), + BECountUsers(std::move(Arg.BECountUsers)), ConstantEvolutionLoopExitValue( std::move(Arg.ConstantEvolutionLoopExitValue)), ValuesAtScopes(std::move(Arg.ValuesAtScopes)), @@ -12882,6 +12866,23 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); } +void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, + bool Predicated) { + auto &BECounts = + Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts; + auto It = BECounts.find(L); + if (It != BECounts.end()) { + for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) { + if (!isa(ENT.ExactNotTaken)) { + auto UserIt = BECountUsers.find(ENT.ExactNotTaken); + assert(UserIt != BECountUsers.end()); + UserIt->second.erase({L, Predicated}); + } + } + BECounts.erase(It); + } +} + void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { SmallPtrSet ToForget(SCEVs.begin(), SCEVs.end()); SmallVector Worklist(ToForget.begin(), ToForget.end()); @@ -12906,21 +12907,6 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { else ++I; } - - auto RemoveSCEVFromBackedgeMap = [&ToForget]( - DenseMap &Map) { - for (auto I = Map.begin(), E = Map.end(); I != E;) { - BackedgeTakenInfo &BEInfo = I->second; - if (any_of(ToForget, - [&BEInfo](const SCEV *S) { return BEInfo.hasOperand(S); })) - Map.erase(I++); - else - ++I; - } - }; - - RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); - RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { @@ -12958,6 +12944,15 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S)); ValuesAtScopesUsers.erase(ScopeUserIt); } + + auto BEUsersIt = BECountUsers.find(S); + if (BEUsersIt != BECountUsers.end()) { + // Work on a copy, as forgetBackedgeTakenCounts() will modify the original. + auto Copy = BEUsersIt->second; + for (const auto &Pair : Copy) + forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt()); + BECountUsers.erase(BEUsersIt); + } } void @@ -13144,10 +13139,31 @@ void ScalarEvolution::verify() const { is_contained(It->second, std::make_pair(L, ValueAtScope))) continue; dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: " - << ValueAtScope << " missing in ValuesAtScopes"; + << ValueAtScope << " missing in ValuesAtScopes\n"; std::abort(); } } + + // Verify integrity of BECountUsers. + auto VerifyBECountUsers = [&](bool Predicated) { + auto &BECounts = + Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts; + for (const auto &LoopAndBEInfo : BECounts) { + for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) { + if (!isa(ENT.ExactNotTaken)) { + auto UserIt = BECountUsers.find(ENT.ExactNotTaken); + if (UserIt != BECountUsers.end() && + UserIt->second.contains({ LoopAndBEInfo.first, Predicated })) + continue; + dbgs() << "Value " << *ENT.ExactNotTaken << " for loop " + << *LoopAndBEInfo.first << " missing from BECountUsers\n"; + std::abort(); + } + } + } + }; + VerifyBECountUsers(/* Predicated */ false); + VerifyBECountUsers(/* Predicated */ true); } bool ScalarEvolution::invalidate(