diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index 0245569aefe1..21e6374b66df 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -810,35 +810,30 @@ class Type; }; using ValueToValueMap = DenseMap; + using ValueToSCEVMapTy = DenseMap; /// The SCEVParameterRewriter takes a scalar evolution expression and updates - /// the SCEVUnknown components following the Map (Value -> Value). + /// the SCEVUnknown components following the Map (Value -> SCEV). class SCEVParameterRewriter : public SCEVRewriteVisitor { public: static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - ValueToValueMap &Map, - bool InterpretConsts = false) { - SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); + ValueToSCEVMapTy &Map) { + SCEVParameterRewriter Rewriter(SE, Map); return Rewriter.visit(Scev); } - SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C) - : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {} + SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) + : SCEVRewriteVisitor(SE), Map(M) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { - Value *V = Expr->getValue(); - if (Map.count(V)) { - Value *NV = Map[V]; - if (InterpretConsts && isa(NV)) - return SE.getConstant(cast(NV)); - return SE.getUnknown(NV); - } - return Expr; + auto I = Map.find(Expr->getValue()); + if (I == Map.end()) + return Expr; + return I->second; } private: - ValueToValueMap ⤅ - bool InterpretConsts; + ValueToSCEVMapTy ⤅ }; using LoopToScevMapT = DenseMap; diff --git a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp index 19bf5766f448..64e908bdf342 100644 --- a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp @@ -215,16 +215,14 @@ void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { return cannotDivide(Numerator); // The Remainder is obtained by replacing Denominator by 0 in Numerator. - ValueToValueMap RewriteMap; - RewriteMap[cast(Denominator)->getValue()] = - cast(Zero)->getValue(); - Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + ValueToSCEVMapTy RewriteMap; + RewriteMap[cast(Denominator)->getValue()] = Zero; + Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); if (Remainder->isZero()) { // The Quotient is obtained by replacing Denominator by 1 in Numerator. - RewriteMap[cast(Denominator)->getValue()] = - cast(One)->getValue(); - Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + RewriteMap[cast(Denominator)->getValue()] = One; + Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); return; } diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index f78fc396ebc1..156427f56ead 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -149,6 +149,12 @@ static Instruction *getInstructionByName(Function &F, StringRef Name) { llvm_unreachable("Expected to find instruction!"); } +static Value *getArgByName(Function &F, StringRef Name) { + for (auto &Arg : F.args()) + if (Arg.getName() == Name) + return &Arg; + llvm_unreachable("Expected to find instruction!"); +} TEST_F(ScalarEvolutionsTest, CommutativeExprOperandOrder) { LLVMContext C; SMDiagnostic Err; @@ -1120,4 +1126,43 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) { }); } +TEST_F(ScalarEvolutionsTest, SCEVrewriteUnknowns) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %i) { " + "entry: " + " %cmp3 = icmp ult i32 %i, 16 " + " br i1 %cmp3, label %loop.body, label %exit " + "loop.body: " + " %iv = phi i32 [ %iv.next, %loop.body ], [ %i, %entry ] " + " %iv.next = add nsw i32 %iv, 1 " + " %cmp = icmp eq i32 %iv.next, 16 " + " br i1 %cmp, label %exit, label %loop.body " + "exit: " + " ret void " + "} ", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1} + auto *ScevI = SE.getSCEV(getArgByName(F, "i")); // {0,+,1} + + ValueToSCEVMapTy RewriteMap; + RewriteMap[cast(ScevI)->getValue()] = + SE.getUMinExpr(ScevI, SE.getConstant(ScevI->getType(), 17)); + auto *WithUMin = SCEVParameterRewriter::rewrite(ScevIV, SE, RewriteMap); + + EXPECT_NE(WithUMin, ScevIV); + auto *AR = dyn_cast(WithUMin); + EXPECT_TRUE(AR); + EXPECT_EQ(AR->getStart(), + SE.getUMinExpr(ScevI, SE.getConstant(ScevI->getType(), 17))); + EXPECT_EQ(AR->getStepRecurrence(SE), + cast(ScevIV)->getStepRecurrence(SE)); + }); +} } // end namespace llvm