[SCEV] Use a SmallPtrSet as a temporary union predicate; NFC

Summary:
Instead of creating and destroying SCEVUnionPredicate instances (which
internally creates and destroys a DenseMap), use temporary SmallPtrSet
instances of remember the set of predicates that will get reified into a
SCEVUnionPredicate.

Reviewers: silviu.baranga, sbaranga

Subscribers: sanjoy, mcrosier, llvm-commits, mzolotukhin

Differential Revision: https://reviews.llvm.org/D25000

llvm-svn: 282606
This commit is contained in:
Sanjoy Das 2016-09-28 17:14:58 +00:00
parent 386236509e
commit f0022125e0
2 changed files with 90 additions and 63 deletions

View File

@ -551,19 +551,36 @@ private:
const SCEV *ExactNotTaken;
const SCEV *MaxNotTaken;
/// A predicate union guard for this ExitLimit. The result is only
/// valid if this predicate evaluates to 'true' at run-time.
SCEVUnionPredicate Predicate;
/// A set of predicate guards for this ExitLimit. The result is only valid
/// if all of the predicates in \c Predicates evaluate to 'true' at
/// run-time.
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
void addPredicate(const SCEVPredicate *P) {
assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
Predicates.insert(P);
}
/*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {}
ExitLimit(const SCEV *E, const SCEV *M, SCEVUnionPredicate &P)
: ExactNotTaken(E), MaxNotTaken(M), Predicate(P) {
ExitLimit(
const SCEV *E, const SCEV *M,
ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
: ExactNotTaken(E), MaxNotTaken(M) {
assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
!isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
"Exact is not allowed to be less precise than Max");
for (auto *PredSet : PredSetList)
for (auto *P : *PredSet)
addPredicate(P);
}
ExitLimit(const SCEV *E, const SCEV *M,
const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
: ExitLimit(E, M, {&PredSet}) {}
ExitLimit(const SCEV *E, const SCEV *M) : ExitLimit(E, M, None) {}
/// Test whether this ExitLimit contains any computed information, or
/// whether it's all SCEVCouldNotCompute values.
bool hasAnyInfo() const {
@ -1581,9 +1598,9 @@ public:
SCEVUnionPredicate &A);
/// Tries to convert the \p S expression to an AddRec expression,
/// adding additional predicates to \p Preds as required.
const SCEVAddRecExpr *
convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds);
const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates(
const SCEV *S, const Loop *L,
SmallPtrSetImpl<const SCEVPredicate *> &Preds);
private:
/// Compute the backedge taken count knowing the interval difference, the

View File

@ -5656,11 +5656,14 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
[&](const EdgeExitInfo &EEI) {
BasicBlock *ExitBB = EEI.first;
const ExitLimit &EL = EEI.second;
if (EL.Predicate.isAlwaysTrue())
if (EL.Predicates.empty())
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr);
return ExitNotTakenInfo(
ExitBB, EL.ExactNotTaken,
llvm::make_unique<SCEVUnionPredicate>(std::move(EL.Predicate)));
std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
for (auto *Pred : EL.Predicates)
Predicate->add(Pred);
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate));
});
}
@ -5691,7 +5694,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
BasicBlock *ExitBB = ExitingBlocks[i];
ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
assert((AllowPredicates || EL.Predicate.isAlwaysTrue()) &&
assert((AllowPredicates || EL.Predicates.empty()) &&
"Predicated exit limit when predicates are not allowed!");
// 1. For each exit that can be computed, add an entry to ExitCounts.
@ -5861,9 +5864,6 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.ExactNotTaken;
}
SCEVUnionPredicate NP;
NP.add(&EL0.Predicate);
NP.add(&EL1.Predicate);
// There are cases (e.g. PR26207) where computeExitLimitFromCond is able
// to be more aggressive when computing BECount than when computing
// MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
@ -5873,7 +5873,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
!isa<SCEVCouldNotCompute>(BECount))
MaxBECount = BECount;
return ExitLimit(BECount, MaxBECount, NP);
return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
}
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
@ -5912,10 +5912,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.ExactNotTaken;
}
SCEVUnionPredicate NP;
NP.add(&EL0.Predicate);
NP.add(&EL1.Predicate);
return ExitLimit(BECount, MaxBECount, NP);
return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
}
}
@ -6300,8 +6297,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *UpperBound =
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
SCEVUnionPredicate P;
return ExitLimit(getCouldNotCompute(), UpperBound, P);
return ExitLimit(getCouldNotCompute(), UpperBound);
}
return getCouldNotCompute();
@ -7062,7 +7058,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// effectively V != 0. We know and take advantage of the fact that this
// expression only being used in a comparison by zero context.
SCEVUnionPredicate P;
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// If the value is a constant
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
// If the value is already zero, the branch will execute zero times.
@ -7075,7 +7071,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the
// algorithm below.
AddRec = convertSCEVToAddRecWithPredicates(V, L, P);
AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
if (!AddRec || AddRec->getLoop() != L)
return getCouldNotCompute();
@ -7097,7 +7093,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// should not accept a root of 2.
const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
if (Val->isZero())
return ExitLimit(R1, R1, P); // We found a quadratic root!
return ExitLimit(R1, R1, Predicates); // We found a quadratic root!
}
}
return getCouldNotCompute();
@ -7154,7 +7150,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
else
MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
: -CR.getUnsignedMin());
return ExitLimit(Distance, MaxBECount, P);
return ExitLimit(Distance, MaxBECount, Predicates);
}
// As a special case, handle the instance where Step is a positive power of
@ -7209,7 +7205,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
const SCEV *Limit =
getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
return ExitLimit(Limit, Limit, P);
return ExitLimit(Limit, Limit, Predicates);
}
}
@ -7222,14 +7218,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
loopHasNoAbnormalExits(AddRec->getLoop())) {
const SCEV *Exact =
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
return ExitLimit(Exact, Exact, P);
return ExitLimit(Exact, Exact, Predicates);
}
// Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
const SCEV *E = SolveLinEquationWithOverflow(
StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
return ExitLimit(E, E, P);
return ExitLimit(E, E, Predicates);
}
return getCouldNotCompute();
}
@ -8634,7 +8630,7 @@ ScalarEvolution::ExitLimit
ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) {
SCEVUnionPredicate P;
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// We handle only IV < Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
@ -8646,7 +8642,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
// Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the
// algorithm below.
IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
PredicatedIV = true;
}
@ -8762,14 +8758,14 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
return ExitLimit(BECount, MaxBECount, P);
return ExitLimit(BECount, MaxBECount, Predicates);
}
ScalarEvolution::ExitLimit
ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) {
SCEVUnionPredicate P;
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
// We handle only IV > Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
@ -8779,7 +8775,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
// Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the
// algorithm below.
IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
// Avoid weird loops
if (!IV || IV->getLoop() != L || !IV->isAffine())
@ -8839,7 +8835,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
return ExitLimit(BECount, MaxBECount, P);
return ExitLimit(BECount, MaxBECount, Predicates);
}
const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
@ -10161,25 +10157,34 @@ namespace {
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
public:
// Rewrites \p S in the context of a loop L and the predicate A.
// If Assume is true, rewrite is free to add further predicates to A
// such that the result will be an AddRecExpr.
/// Rewrites \p S in the context of a loop L and the SCEV predication
/// infrastructure.
///
/// If \p Pred is non-null, the SCEV expression is rewritten to respect the
/// equivalences present in \p Pred.
///
/// If \p NewPreds is non-null, rewrite is free to add further predicates to
/// \p NewPreds such that the result will be an AddRecExpr.
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
SCEVUnionPredicate &A, bool Assume) {
SCEVPredicateRewriter Rewriter(L, SE, A, Assume);
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
SCEVUnionPredicate *Pred) {
SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
return Rewriter.visit(S);
}
SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
SCEVUnionPredicate &P, bool Assume)
: SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {}
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
SCEVUnionPredicate *Pred)
: SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
auto ExprPreds = P.getPredicatesForExpr(Expr);
for (auto *Pred : ExprPreds)
if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
if (IPred->getLHS() == Expr)
return IPred->getRHS();
if (Pred) {
auto ExprPreds = Pred->getPredicatesForExpr(Expr);
for (auto *Pred : ExprPreds)
if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
if (IPred->getLHS() == Expr)
return IPred->getRHS();
}
return Expr;
}
@ -10220,32 +10225,31 @@ private:
bool addOverflowAssumption(const SCEVAddRecExpr *AR,
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
auto *A = SE.getWrapPredicate(AR, AddedFlags);
if (!Assume) {
if (!NewPreds) {
// Check if we've already made this assumption.
if (P.implies(A))
return true;
return false;
return Pred && Pred->implies(A);
}
P.add(A);
NewPreds->insert(A);
return true;
}
SCEVUnionPredicate &P;
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
SCEVUnionPredicate *Pred;
const Loop *L;
bool Assume;
};
} // end anonymous namespace
const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds) {
return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false);
return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
}
const SCEVAddRecExpr *
ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
SCEVUnionPredicate &Preds) {
SCEVUnionPredicate TransformPreds;
S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true);
const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
const SCEV *S, const Loop *L,
SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
if (!AddRec)
@ -10253,7 +10257,9 @@ ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
// Since the transformation was successful, we can now transfer the SCEV
// predicates.
Preds.add(&TransformPreds);
for (auto *P : TransformPreds)
Preds.insert(P);
return AddRec;
}
@ -10480,11 +10486,15 @@ bool PredicatedScalarEvolution::hasNoOverflow(
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
const SCEV *Expr = this->getSCEV(V);
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds);
SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
if (!New)
return nullptr;
for (auto *P : NewPreds)
Preds.add(P);
updateGeneration();
RewriteMap[SE.getSCEV(V)] = {Generation, New};
return New;