[SCEV] Use a SmallPtrSet as a temporary union predicate; NFC

Summary:
Instead of creating and destroying SCEVUnionPredicate instances (which
internally creates and destroys a DenseMap), use temporary SmallPtrSet
instances of remember the set of predicates that will get reified into a
SCEVUnionPredicate.

Reviewers: silviu.baranga, sbaranga

Subscribers: sanjoy, mcrosier, llvm-commits, mzolotukhin

Differential Revision: https://reviews.llvm.org/D25000

llvm-svn: 282606
This commit is contained in:
Sanjoy Das 2016-09-28 17:14:58 +00:00
parent 386236509e
commit f0022125e0
2 changed files with 90 additions and 63 deletions

View File

@ -551,19 +551,36 @@ private:
const SCEV *ExactNotTaken; const SCEV *ExactNotTaken;
const SCEV *MaxNotTaken; const SCEV *MaxNotTaken;
/// A predicate union guard for this ExitLimit. The result is only /// A set of predicate guards for this ExitLimit. The result is only valid
/// valid if this predicate evaluates to 'true' at run-time. /// if all of the predicates in \c Predicates evaluate to 'true' at
SCEVUnionPredicate Predicate; /// run-time.
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
void addPredicate(const SCEVPredicate *P) {
assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
Predicates.insert(P);
}
/*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {} /*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {}
ExitLimit(const SCEV *E, const SCEV *M, SCEVUnionPredicate &P) ExitLimit(
: ExactNotTaken(E), MaxNotTaken(M), Predicate(P) { const SCEV *E, const SCEV *M,
ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
: ExactNotTaken(E), MaxNotTaken(M) {
assert((isa<SCEVCouldNotCompute>(ExactNotTaken) || assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
!isa<SCEVCouldNotCompute>(MaxNotTaken)) && !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
"Exact is not allowed to be less precise than Max"); "Exact is not allowed to be less precise than Max");
for (auto *PredSet : PredSetList)
for (auto *P : *PredSet)
addPredicate(P);
} }
ExitLimit(const SCEV *E, const SCEV *M,
const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
: ExitLimit(E, M, {&PredSet}) {}
ExitLimit(const SCEV *E, const SCEV *M) : ExitLimit(E, M, None) {}
/// Test whether this ExitLimit contains any computed information, or /// Test whether this ExitLimit contains any computed information, or
/// whether it's all SCEVCouldNotCompute values. /// whether it's all SCEVCouldNotCompute values.
bool hasAnyInfo() const { bool hasAnyInfo() const {
@ -1581,9 +1598,9 @@ public:
SCEVUnionPredicate &A); SCEVUnionPredicate &A);
/// Tries to convert the \p S expression to an AddRec expression, /// Tries to convert the \p S expression to an AddRec expression,
/// adding additional predicates to \p Preds as required. /// adding additional predicates to \p Preds as required.
const SCEVAddRecExpr * const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates(
convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds); SmallPtrSetImpl<const SCEVPredicate *> &Preds);
private: private:
/// Compute the backedge taken count knowing the interval difference, the /// Compute the backedge taken count knowing the interval difference, the

View File

@ -5656,11 +5656,14 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
[&](const EdgeExitInfo &EEI) { [&](const EdgeExitInfo &EEI) {
BasicBlock *ExitBB = EEI.first; BasicBlock *ExitBB = EEI.first;
const ExitLimit &EL = EEI.second; const ExitLimit &EL = EEI.second;
if (EL.Predicate.isAlwaysTrue()) if (EL.Predicates.empty())
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr); return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr);
return ExitNotTakenInfo(
ExitBB, EL.ExactNotTaken, std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
llvm::make_unique<SCEVUnionPredicate>(std::move(EL.Predicate))); for (auto *Pred : EL.Predicates)
Predicate->add(Pred);
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate));
}); });
} }
@ -5691,7 +5694,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
BasicBlock *ExitBB = ExitingBlocks[i]; BasicBlock *ExitBB = ExitingBlocks[i];
ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
assert((AllowPredicates || EL.Predicate.isAlwaysTrue()) && assert((AllowPredicates || EL.Predicates.empty()) &&
"Predicated exit limit when predicates are not allowed!"); "Predicated exit limit when predicates are not allowed!");
// 1. For each exit that can be computed, add an entry to ExitCounts. // 1. For each exit that can be computed, add an entry to ExitCounts.
@ -5861,9 +5864,6 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.ExactNotTaken; BECount = EL0.ExactNotTaken;
} }
SCEVUnionPredicate NP;
NP.add(&EL0.Predicate);
NP.add(&EL1.Predicate);
// There are cases (e.g. PR26207) where computeExitLimitFromCond is able // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
// to be more aggressive when computing BECount than when computing // to be more aggressive when computing BECount than when computing
// MaxBECount. In these cases it is possible for EL0.ExactNotTaken and // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
@ -5873,7 +5873,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
!isa<SCEVCouldNotCompute>(BECount)) !isa<SCEVCouldNotCompute>(BECount))
MaxBECount = BECount; MaxBECount = BECount;
return ExitLimit(BECount, MaxBECount, NP); return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
} }
if (BO->getOpcode() == Instruction::Or) { if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or. // Recurse on the operands of the or.
@ -5912,10 +5912,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.ExactNotTaken; BECount = EL0.ExactNotTaken;
} }
SCEVUnionPredicate NP; return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
NP.add(&EL0.Predicate);
NP.add(&EL1.Predicate);
return ExitLimit(BECount, MaxBECount, NP);
} }
} }
@ -6300,8 +6297,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
unsigned BitWidth = getTypeSizeInBits(RHS->getType()); unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *UpperBound = const SCEV *UpperBound =
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
SCEVUnionPredicate P; return ExitLimit(getCouldNotCompute(), UpperBound);
return ExitLimit(getCouldNotCompute(), UpperBound, P);
} }
return getCouldNotCompute(); return getCouldNotCompute();
@ -7062,7 +7058,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// effectively V != 0. We know and take advantage of the fact that this // effectively V != 0. We know and take advantage of the fact that this
// expression only being used in a comparison by zero context. // expression only being used in a comparison by zero context.
SCEVUnionPredicate P; SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// If the value is a constant // If the value is a constant
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
// If the value is already zero, the branch will execute zero times. // If the value is already zero, the branch will execute zero times.
@ -7075,7 +7071,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// Try to make this an AddRec using runtime tests, in the first X // Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the // iterations of this loop, where X is the SCEV expression found by the
// algorithm below. // algorithm below.
AddRec = convertSCEVToAddRecWithPredicates(V, L, P); AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
if (!AddRec || AddRec->getLoop() != L) if (!AddRec || AddRec->getLoop() != L)
return getCouldNotCompute(); return getCouldNotCompute();
@ -7097,7 +7093,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// should not accept a root of 2. // should not accept a root of 2.
const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
if (Val->isZero()) if (Val->isZero())
return ExitLimit(R1, R1, P); // We found a quadratic root! return ExitLimit(R1, R1, Predicates); // We found a quadratic root!
} }
} }
return getCouldNotCompute(); return getCouldNotCompute();
@ -7154,7 +7150,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
else else
MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
: -CR.getUnsignedMin()); : -CR.getUnsignedMin());
return ExitLimit(Distance, MaxBECount, P); return ExitLimit(Distance, MaxBECount, Predicates);
} }
// As a special case, handle the instance where Step is a positive power of // As a special case, handle the instance where Step is a positive power of
@ -7209,7 +7205,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
const SCEV *Limit = const SCEV *Limit =
getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
return ExitLimit(Limit, Limit, P); return ExitLimit(Limit, Limit, Predicates);
} }
} }
@ -7222,14 +7218,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
loopHasNoAbnormalExits(AddRec->getLoop())) { loopHasNoAbnormalExits(AddRec->getLoop())) {
const SCEV *Exact = const SCEV *Exact =
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
return ExitLimit(Exact, Exact, P); return ExitLimit(Exact, Exact, Predicates);
} }
// Then, try to solve the above equation provided that Start is constant. // Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
const SCEV *E = SolveLinEquationWithOverflow( const SCEV *E = SolveLinEquationWithOverflow(
StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
return ExitLimit(E, E, P); return ExitLimit(E, E, Predicates);
} }
return getCouldNotCompute(); return getCouldNotCompute();
} }
@ -8634,7 +8630,7 @@ ScalarEvolution::ExitLimit
ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned, const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) { bool ControlsExit, bool AllowPredicates) {
SCEVUnionPredicate P; SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// We handle only IV < Invariant // We handle only IV < Invariant
if (!isLoopInvariant(RHS, L)) if (!isLoopInvariant(RHS, L))
return getCouldNotCompute(); return getCouldNotCompute();
@ -8646,7 +8642,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
// Try to make this an AddRec using runtime tests, in the first X // Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the // iterations of this loop, where X is the SCEV expression found by the
// algorithm below. // algorithm below.
IV = convertSCEVToAddRecWithPredicates(LHS, L, P); IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
PredicatedIV = true; PredicatedIV = true;
} }
@ -8762,14 +8758,14 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount)) if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount; MaxBECount = BECount;
return ExitLimit(BECount, MaxBECount, P); return ExitLimit(BECount, MaxBECount, Predicates);
} }
ScalarEvolution::ExitLimit ScalarEvolution::ExitLimit
ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned, const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) { bool ControlsExit, bool AllowPredicates) {
SCEVUnionPredicate P; SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// We handle only IV > Invariant // We handle only IV > Invariant
if (!isLoopInvariant(RHS, L)) if (!isLoopInvariant(RHS, L))
return getCouldNotCompute(); return getCouldNotCompute();
@ -8779,7 +8775,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
// Try to make this an AddRec using runtime tests, in the first X // Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the // iterations of this loop, where X is the SCEV expression found by the
// algorithm below. // algorithm below.
IV = convertSCEVToAddRecWithPredicates(LHS, L, P); IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
// Avoid weird loops // Avoid weird loops
if (!IV || IV->getLoop() != L || !IV->isAffine()) if (!IV || IV->getLoop() != L || !IV->isAffine())
@ -8839,7 +8835,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount)) if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount; MaxBECount = BECount;
return ExitLimit(BECount, MaxBECount, P); return ExitLimit(BECount, MaxBECount, Predicates);
} }
const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
@ -10161,25 +10157,34 @@ namespace {
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
public: public:
// Rewrites \p S in the context of a loop L and the predicate A. /// Rewrites \p S in the context of a loop L and the SCEV predication
// If Assume is true, rewrite is free to add further predicates to A /// infrastructure.
// such that the result will be an AddRecExpr. ///
/// If \p Pred is non-null, the SCEV expression is rewritten to respect the
/// equivalences present in \p Pred.
///
/// If \p NewPreds is non-null, rewrite is free to add further predicates to
/// \p NewPreds such that the result will be an AddRecExpr.
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
SCEVUnionPredicate &A, bool Assume) { SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
SCEVPredicateRewriter Rewriter(L, SE, A, Assume); SCEVUnionPredicate *Pred) {
SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
return Rewriter.visit(S); return Rewriter.visit(S);
} }
SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
SCEVUnionPredicate &P, bool Assume) SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
: SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} SCEVUnionPredicate *Pred)
: SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
const SCEV *visitUnknown(const SCEVUnknown *Expr) { const SCEV *visitUnknown(const SCEVUnknown *Expr) {
auto ExprPreds = P.getPredicatesForExpr(Expr); if (Pred) {
for (auto *Pred : ExprPreds) auto ExprPreds = Pred->getPredicatesForExpr(Expr);
if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) for (auto *Pred : ExprPreds)
if (IPred->getLHS() == Expr) if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
return IPred->getRHS(); if (IPred->getLHS() == Expr)
return IPred->getRHS();
}
return Expr; return Expr;
} }
@ -10220,32 +10225,31 @@ private:
bool addOverflowAssumption(const SCEVAddRecExpr *AR, bool addOverflowAssumption(const SCEVAddRecExpr *AR,
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
auto *A = SE.getWrapPredicate(AR, AddedFlags); auto *A = SE.getWrapPredicate(AR, AddedFlags);
if (!Assume) { if (!NewPreds) {
// Check if we've already made this assumption. // Check if we've already made this assumption.
if (P.implies(A)) return Pred && Pred->implies(A);
return true;
return false;
} }
P.add(A); NewPreds->insert(A);
return true; return true;
} }
SCEVUnionPredicate &P; SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
SCEVUnionPredicate *Pred;
const Loop *L; const Loop *L;
bool Assume;
}; };
} // end anonymous namespace } // end anonymous namespace
const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds) { SCEVUnionPredicate &Preds) {
return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
} }
const SCEVAddRecExpr * const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds) { SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
SCEVUnionPredicate TransformPreds;
S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
if (!AddRec) if (!AddRec)
@ -10253,7 +10257,9 @@ ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
// Since the transformation was successful, we can now transfer the SCEV // Since the transformation was successful, we can now transfer the SCEV
// predicates. // predicates.
Preds.add(&TransformPreds); for (auto *P : TransformPreds)
Preds.insert(P);
return AddRec; return AddRec;
} }
@ -10480,11 +10486,15 @@ bool PredicatedScalarEvolution::hasNoOverflow(
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
const SCEV *Expr = this->getSCEV(V); const SCEV *Expr = this->getSCEV(V);
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
if (!New) if (!New)
return nullptr; return nullptr;
for (auto *P : NewPreds)
Preds.add(P);
updateGeneration(); updateGeneration();
RewriteMap[SE.getSCEV(V)] = {Generation, New}; RewriteMap[SE.getSCEV(V)] = {Generation, New};
return New; return New;