From 6b46a0d1e89aeef8af8e11d0def22e1a5ad15f20 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 9 Nov 2016 18:22:43 +0000 Subject: [PATCH] [SCEV] Refactor out a useful pattern; NFC llvm-svn: 286386 --- .../Analysis/ScalarEvolutionExpressions.h | 25 +++ llvm/lib/Analysis/ScalarEvolution.cpp | 154 +++--------------- 2 files changed, 45 insertions(+), 134 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index 9113880ef25e..fdcd8be00dde 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -537,6 +537,31 @@ namespace llvm { T.visitAll(Root); } + /// Return true if any node in \p Root satisfies the predicate \p Pred. + template + bool SCEVExprContains(const SCEV *Root, PredTy Pred) { + struct FindClosure { + bool Found = false; + PredTy Pred; + + FindClosure(PredTy Pred) : Pred(Pred) {} + + bool follow(const SCEV *S) { + if (!Pred(S)) + return true; + + Found = true; + return false; + } + + bool isDone() const { return Found; } + }; + + FindClosure FC(Pred); + visitAll(Root, FC); + return FC.Found; + } + /// This visitor recursively visits a SCEV expression and re-writes it. /// The result from each visit is cached, so it will return the same /// SCEV for the same input. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index ce9fade782f7..c93301148178 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -3356,69 +3356,24 @@ const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } - bool ScalarEvolution::checkValidity(const SCEV *S) const { - // Helper class working with SCEVTraversal to figure out if a SCEV contains - // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne - // is set iff if find such SCEVUnknown. - // - struct FindInvalidSCEVUnknown { - bool FindOne; - FindInvalidSCEVUnknown() { FindOne = false; } - bool follow(const SCEV *S) { - switch (static_cast(S->getSCEVType())) { - case scConstant: - return false; - case scUnknown: - if (!cast(S)->getValue()) - FindOne = true; - return false; - default: - return true; - } - } - bool isDone() const { return FindOne; } - }; + bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { + auto *SU = dyn_cast(S); + return SU && SU->getValue() == nullptr; + }); - FindInvalidSCEVUnknown F; - SCEVTraversal ST(F); - ST.visitAll(S); - - return !F.FindOne; + return !ContainsNulls; } bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { - // Helper class working with SCEVTraversal to figure out if a SCEV contains a - // sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set iff - // if such sub scAddRecExpr type SCEV is found. - struct FindAddRecurrence { - bool FoundOne; - FindAddRecurrence() : FoundOne(false) {} - - bool follow(const SCEV *S) { - switch (static_cast(S->getSCEVType())) { - case scAddRecExpr: - FoundOne = true; - case scConstant: - case scUnknown: - case scCouldNotCompute: - return false; - default: - return true; - } - } - bool isDone() const { return FoundOne; } - }; - HasRecMapType::iterator I = HasRecMap.find(S); if (I != HasRecMap.end()) return I->second; - FindAddRecurrence F; - SCEVTraversal ST(F); - ST.visitAll(S); - HasRecMap.insert({S, F.FoundOne}); - return F.FoundOne; + bool FoundAddRec = + SCEVExprContains(S, [](const SCEV *S) { return isa(S); }); + HasRecMap.insert({S, FoundAddRec}); + return FoundAddRec; } /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. @@ -8993,38 +8948,15 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, return SE.getCouldNotCompute(); } -namespace { -struct FindUndefs { - bool Found; - FindUndefs() : Found(false) {} - - bool follow(const SCEV *S) { - if (const SCEVUnknown *C = dyn_cast(S)) { - if (isa(C->getValue())) - Found = true; - } else if (const SCEVConstant *C = dyn_cast(S)) { - if (isa(C->getValue())) - Found = true; - } - - // Keep looking if we haven't found it yet. - return !Found; - } - bool isDone() const { - // Stop recursion if we have found an undef. - return Found; - } -}; -} - // Return true when S contains at least an undef value. -static inline bool -containsUndefs(const SCEV *S) { - FindUndefs F; - SCEVTraversal ST(F); - ST.visitAll(S); - - return F.Found; +static inline bool containsUndefs(const SCEV *S) { + return SCEVExprContains(S, [](const SCEV *S) { + if (const auto *SU = dyn_cast(S)) + return isa(SU->getValue()); + else if (const auto *SC = dyn_cast(S)) + return isa(SC->getValue()); + return false; + }); } namespace { @@ -9217,40 +9149,11 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, return true; } -// Returns true when S contains at least a SCEVUnknown parameter. -static inline bool -containsParameters(const SCEV *S) { - struct FindParameter { - bool FoundParameter; - FindParameter() : FoundParameter(false) {} - - bool follow(const SCEV *S) { - if (isa(S)) { - FoundParameter = true; - // Stop recursion: we found a parameter. - return false; - } - // Keep looking. - return true; - } - bool isDone() const { - // Stop recursion if we have found a parameter. - return FoundParameter; - } - }; - - FindParameter F; - SCEVTraversal ST(F); - ST.visitAll(S); - - return F.FoundParameter; -} // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. -static inline bool -containsParameters(SmallVectorImpl &Terms) { +static inline bool containsParameters(SmallVectorImpl &Terms) { for (const SCEV *T : Terms) - if (containsParameters(T)) + if (SCEVExprContains(T, [](const SCEV *S) { return isa(S); })) return true; return false; } @@ -9977,24 +9880,7 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { } bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { - // Search for a SCEV expression node within an expression tree. - // Implements SCEVTraversal::Visitor. - struct SCEVSearch { - const SCEV *Node; - bool IsFound; - - SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} - - bool follow(const SCEV *S) { - IsFound |= (S == Node); - return !IsFound; - } - bool isDone() const { return IsFound; } - }; - - SCEVSearch Search(Op); - visitAll(S, Search); - return Search.IsFound; + return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); } void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {