[SCEV] Track backedge taken count users (NFCI)

Track which SCEVs are used as ExactNotTaken counts in
BackedgeTakenInfo structures, so we can directly determine which
loops need to be invalidated, rather than iterating over all BECounts.

This gives a small compile-time improvement on average, but the
motivation here is more to ensure there are no degenerate cases,
if the number of backedge taken counts is large.

Differential Revision: https://reviews.llvm.org/D114784
This commit is contained in:
Nikita Popov 2021-11-29 21:02:37 +01:00
parent ec15b7307f
commit 67704801c6
2 changed files with 68 additions and 50 deletions

View File

@ -1378,6 +1378,8 @@ private:
/// includes an exact count and a maximum count.
///
class BackedgeTakenInfo {
friend class ScalarEvolution;
/// A list of computable exits and their not-taken counts. Loops almost
/// never have more than one computable exit.
SmallVector<ExitNotTakenInfo, 1> ExitNotTaken;
@ -1398,9 +1400,6 @@ private:
/// True iff the backedge is taken either exactly Max or zero times.
bool MaxOrZero = false;
/// SCEV expressions used in any of the ExitNotTakenInfo counts.
SmallPtrSet<const SCEV *, 4> Operands;
bool isComplete() const { return IsComplete; }
const SCEV *getConstantMax() const { return ConstantMax; }
@ -1466,10 +1465,6 @@ private:
/// Return true if the number of times this backedge is taken is either the
/// value returned by getConstantMax or zero.
bool isConstantMaxOrZero(ScalarEvolution *SE) const;
/// Return true if any backedge taken count expressions refer to the given
/// subexpression.
bool hasOperand(const SCEV *S) const;
};
/// Cache the backedge-taken count of the loops for this function as they
@ -1480,6 +1475,10 @@ private:
/// function as they are computed.
DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
/// Loops whose backedge taken counts directly use this non-constant SCEV.
DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
BECountUsers;
/// This map contains entries for all of the PHI instructions that we
/// attempt to compute constant evolutions for. This allows us to avoid
/// potentially expensive recomputation of these properties. An instruction
@ -1911,6 +1910,9 @@ private:
bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R,
SCEV::NoWrapFlags &Flags);
/// Forget predicated/non-predicated backedge taken counts for the given loop.
void forgetBackedgeTakenCounts(const Loop *L, bool Predicated);
/// Drop memoized information for all \p SCEVs.
void forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs);

View File

@ -7603,6 +7603,7 @@ void ScalarEvolution::forgetAllLoops() {
// result.
BackedgeTakenCounts.clear();
PredicatedBackedgeTakenCounts.clear();
BECountUsers.clear();
LoopPropertiesCache.clear();
ConstantEvolutionLoopExitValue.clear();
ValueExprMap.clear();
@ -7629,8 +7630,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
auto *CurrL = LoopWorklist.pop_back_val();
// Drop any stored trip count value.
BackedgeTakenCounts.erase(CurrL);
PredicatedBackedgeTakenCounts.erase(CurrL);
forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
// Drop information about predicated SCEV rewrites for this loop.
for (auto I = PredicatedSCEVRewrites.begin();
@ -7804,10 +7805,6 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
}
bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S) const {
return Operands.contains(S);
}
ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
: ExitLimit(E, E, false, None) {
}
@ -7848,19 +7845,6 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
: ExitLimit(E, M, MaxOrZero, None) {
}
class SCEVRecordOperands {
SmallPtrSetImpl<const SCEV *> &Operands;
public:
SCEVRecordOperands(SmallPtrSetImpl<const SCEV *> &Operands)
: Operands(Operands) {}
bool follow(const SCEV *S) {
Operands.insert(S);
return true;
}
bool isDone() { return false; }
};
/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
/// computable exit into a persistent ExitNotTakenInfo array.
ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
@ -7889,14 +7873,6 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
isa<SCEVConstant>(ConstantMax)) &&
"No point in having a non-constant max backedge taken count!");
SCEVRecordOperands RecordOperands(Operands);
SCEVTraversal<SCEVRecordOperands> ST(RecordOperands);
if (!isa<SCEVCouldNotCompute>(ConstantMax))
ST.visitAll(ConstantMax);
for (auto &ENT : ExitNotTaken)
if (!isa<SCEVCouldNotCompute>(ENT.ExactNotTaken))
ST.visitAll(ENT.ExactNotTaken);
}
/// Compute the number of times the backedge of the specified loop will execute.
@ -7978,6 +7954,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
// The loop backedge will be taken the maximum or zero times if there's
// a single exit that must be taken the maximum or zero times.
bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
// Remember which SCEVs are used in exit limits for invalidation purposes.
// We only care about non-constant SCEVs here, so we can ignore EL.MaxNotTaken
// and MaxBECount, which must be SCEVConstant.
for (const auto &Pair : ExitCounts)
if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
MaxBECount, MaxOrZero);
}
@ -12466,6 +12449,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
PredicatedBackedgeTakenCounts(
std::move(Arg.PredicatedBackedgeTakenCounts)),
BECountUsers(std::move(Arg.BECountUsers)),
ConstantEvolutionLoopExitValue(
std::move(Arg.ConstantEvolutionLoopExitValue)),
ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
@ -12882,6 +12866,23 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
}
void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
bool Predicated) {
auto &BECounts =
Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
auto It = BECounts.find(L);
if (It != BECounts.end()) {
for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
assert(UserIt != BECountUsers.end());
UserIt->second.erase({L, Predicated});
}
}
BECounts.erase(It);
}
}
void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
@ -12906,21 +12907,6 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
else
++I;
}
auto RemoveSCEVFromBackedgeMap = [&ToForget](
DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
for (auto I = Map.begin(), E = Map.end(); I != E;) {
BackedgeTakenInfo &BEInfo = I->second;
if (any_of(ToForget,
[&BEInfo](const SCEV *S) { return BEInfo.hasOperand(S); }))
Map.erase(I++);
else
++I;
}
};
RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
}
void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
@ -12958,6 +12944,15 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
ValuesAtScopesUsers.erase(ScopeUserIt);
}
auto BEUsersIt = BECountUsers.find(S);
if (BEUsersIt != BECountUsers.end()) {
// Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
auto Copy = BEUsersIt->second;
for (const auto &Pair : Copy)
forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
BECountUsers.erase(BEUsersIt);
}
}
void
@ -13144,10 +13139,31 @@ void ScalarEvolution::verify() const {
is_contained(It->second, std::make_pair(L, ValueAtScope)))
continue;
dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
<< ValueAtScope << " missing in ValuesAtScopes";
<< ValueAtScope << " missing in ValuesAtScopes\n";
std::abort();
}
}
// Verify integrity of BECountUsers.
auto VerifyBECountUsers = [&](bool Predicated) {
auto &BECounts =
Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
for (const auto &LoopAndBEInfo : BECounts) {
for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
if (UserIt != BECountUsers.end() &&
UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
continue;
dbgs() << "Value " << *ENT.ExactNotTaken << " for loop "
<< *LoopAndBEInfo.first << " missing from BECountUsers\n";
std::abort();
}
}
}
};
VerifyBECountUsers(/* Predicated */ false);
VerifyBECountUsers(/* Predicated */ true);
}
bool ScalarEvolution::invalidate(