[SCEV] Refactor out a useful pattern; NFC

llvm-svn: 286386
This commit is contained in:
Sanjoy Das 2016-11-09 18:22:43 +00:00
parent a9cadeddd4
commit 6b46a0d1e8
2 changed files with 45 additions and 134 deletions

View File

@ -537,6 +537,31 @@ namespace llvm {
T.visitAll(Root);
}
/// Return true if any node in \p Root satisfies the predicate \p Pred.
template <typename PredTy>
bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
struct FindClosure {
bool Found = false;
PredTy Pred;
FindClosure(PredTy Pred) : Pred(Pred) {}
bool follow(const SCEV *S) {
if (!Pred(S))
return true;
Found = true;
return false;
}
bool isDone() const { return Found; }
};
FindClosure FC(Pred);
visitAll(Root, FC);
return FC.Found;
}
/// This visitor recursively visits a SCEV expression and re-writes it.
/// The result from each visit is cached, so it will return the same
/// SCEV for the same input.

View File

@ -3356,69 +3356,24 @@ const SCEV *ScalarEvolution::getCouldNotCompute() {
return CouldNotCompute.get();
}
bool ScalarEvolution::checkValidity(const SCEV *S) const {
// Helper class working with SCEVTraversal to figure out if a SCEV contains
// a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
// is set iff if find such SCEVUnknown.
//
struct FindInvalidSCEVUnknown {
bool FindOne;
FindInvalidSCEVUnknown() { FindOne = false; }
bool follow(const SCEV *S) {
switch (static_cast<SCEVTypes>(S->getSCEVType())) {
case scConstant:
return false;
case scUnknown:
if (!cast<SCEVUnknown>(S)->getValue())
FindOne = true;
return false;
default:
return true;
}
}
bool isDone() const { return FindOne; }
};
bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
auto *SU = dyn_cast<SCEVUnknown>(S);
return SU && SU->getValue() == nullptr;
});
FindInvalidSCEVUnknown F;
SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
ST.visitAll(S);
return !F.FindOne;
return !ContainsNulls;
}
bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
// Helper class working with SCEVTraversal to figure out if a SCEV contains a
// sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set iff
// if such sub scAddRecExpr type SCEV is found.
struct FindAddRecurrence {
bool FoundOne;
FindAddRecurrence() : FoundOne(false) {}
bool follow(const SCEV *S) {
switch (static_cast<SCEVTypes>(S->getSCEVType())) {
case scAddRecExpr:
FoundOne = true;
case scConstant:
case scUnknown:
case scCouldNotCompute:
return false;
default:
return true;
}
}
bool isDone() const { return FoundOne; }
};
HasRecMapType::iterator I = HasRecMap.find(S);
if (I != HasRecMap.end())
return I->second;
FindAddRecurrence F;
SCEVTraversal<FindAddRecurrence> ST(F);
ST.visitAll(S);
HasRecMap.insert({S, F.FoundOne});
return F.FoundOne;
bool FoundAddRec =
SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
HasRecMap.insert({S, FoundAddRec});
return FoundAddRec;
}
/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
@ -8993,38 +8948,15 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
return SE.getCouldNotCompute();
}
namespace {
struct FindUndefs {
bool Found;
FindUndefs() : Found(false) {}
bool follow(const SCEV *S) {
if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) {
if (isa<UndefValue>(C->getValue()))
Found = true;
} else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
if (isa<UndefValue>(C->getValue()))
Found = true;
}
// Keep looking if we haven't found it yet.
return !Found;
}
bool isDone() const {
// Stop recursion if we have found an undef.
return Found;
}
};
}
// Return true when S contains at least an undef value.
static inline bool
containsUndefs(const SCEV *S) {
FindUndefs F;
SCEVTraversal<FindUndefs> ST(F);
ST.visitAll(S);
return F.Found;
static inline bool containsUndefs(const SCEV *S) {
return SCEVExprContains(S, [](const SCEV *S) {
if (const auto *SU = dyn_cast<SCEVUnknown>(S))
return isa<UndefValue>(SU->getValue());
else if (const auto *SC = dyn_cast<SCEVConstant>(S))
return isa<UndefValue>(SC->getValue());
return false;
});
}
namespace {
@ -9217,40 +9149,11 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
return true;
}
// Returns true when S contains at least a SCEVUnknown parameter.
static inline bool
containsParameters(const SCEV *S) {
struct FindParameter {
bool FoundParameter;
FindParameter() : FoundParameter(false) {}
bool follow(const SCEV *S) {
if (isa<SCEVUnknown>(S)) {
FoundParameter = true;
// Stop recursion: we found a parameter.
return false;
}
// Keep looking.
return true;
}
bool isDone() const {
// Stop recursion if we have found a parameter.
return FoundParameter;
}
};
FindParameter F;
SCEVTraversal<FindParameter> ST(F);
ST.visitAll(S);
return F.FoundParameter;
}
// Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
static inline bool
containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
for (const SCEV *T : Terms)
if (containsParameters(T))
if (SCEVExprContains(T, [](const SCEV *S) { return isa<SCEVUnknown>(S); }))
return true;
return false;
}
@ -9977,24 +9880,7 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
}
bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
// Search for a SCEV expression node within an expression tree.
// Implements SCEVTraversal::Visitor.
struct SCEVSearch {
const SCEV *Node;
bool IsFound;
SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
bool follow(const SCEV *S) {
IsFound |= (S == Node);
return !IsFound;
}
bool isDone() const { return IsFound; }
};
SCEVSearch Search(Op);
visitAll(S, Search);
return Search.IsFound;
return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
}
void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {