diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 1f132953d023..42c65051c96f 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -4085,37 +4085,49 @@ void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { namespace { +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start +/// expression in case its Loop is L. If it is not L then +/// if IgnoreOtherLoops is true then use AddRec itself +/// otherwise rewrite cannot be done. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. class SCEVInitRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, - ScalarEvolution &SE) { + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + bool IgnoreOtherLoops = false) { SCEVInitRewriter Rewriter(L, SE); const SCEV *Result = Rewriter.visit(S); - return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + if (Rewriter.hasSeenLoopVariantSCEVUnknown()) + return SE.getCouldNotCompute(); + return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops + ? SE.getCouldNotCompute() + : Result; } const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (!SE.isLoopInvariant(Expr, L)) - Valid = false; + SeenLoopVariantSCEVUnknown = true; return Expr; } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - // Only allow AddRecExprs for this loop. + // Only re-write AddRecExprs for this loop. if (Expr->getLoop() == L) return Expr->getStart(); - Valid = false; + SeenOtherLoops = true; return Expr; } - bool isValid() { return Valid; } + bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } + + bool hasSeenOtherLoops() { return SeenOtherLoops; } private: explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) : SCEVRewriteVisitor(SE), L(L) {} const Loop *L; - bool Valid = true; + bool SeenLoopVariantSCEVUnknown = false; + bool SeenOtherLoops = false; }; /// This class evaluates the compare condition by matching it against the