diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 2cc6362e870e..d0700399e506 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -10341,11 +10341,26 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( // Check whether swapping the found predicate makes it the same as the // desired predicate. if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { - if (isa(RHS)) + // We can write the implication + // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS + // using one of the following ways: + // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS + // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS + // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS + // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS + // Forms 1. and 2. require swapping the operands of one condition. Don't + // do this if it would break canonical constant/addrec ordering. + if (!isa(RHS) && !isa(LHS)) + return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS, + Context); + if (!isa(FoundRHS) && !isa(FoundLHS)) return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context); - else - return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS, - LHS, FoundLHS, FoundRHS, Context); + + // There's no clear preference between forms 3. and 4., try both. + return isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS), + FoundLHS, FoundRHS, Context) || + isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS), + getNotSCEV(FoundRHS), Context); } // Unsigned comparison is the same as signed comparison when both the operands @@ -10768,11 +10783,7 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, return true; return isImpliedCondOperandsHelper(Pred, LHS, RHS, - FoundLHS, FoundRHS) || - // ~x < ~y --> x > y - isImpliedCondOperandsHelper(Pred, LHS, RHS, - getNotSCEV(FoundRHS), - getNotSCEV(FoundLHS)); + FoundLHS, FoundRHS); } /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values? diff --git a/llvm/test/Analysis/ScalarEvolution/zext-wrap.ll b/llvm/test/Analysis/ScalarEvolution/zext-wrap.ll index 66bedcea7edf..d52589a15dde 100644 --- a/llvm/test/Analysis/ScalarEvolution/zext-wrap.ll +++ b/llvm/test/Analysis/ScalarEvolution/zext-wrap.ll @@ -15,7 +15,7 @@ bb.i: ; preds = %bb1.i, %bb.nph ; This cast shouldn't be folded into the addrec. ; CHECK: %tmp = zext i8 %l_95.0.i1 to i16 -; CHECK: --> (zext i8 {0,+,-1}<%bb.i> to i16){{ U: [^ ]+ S: [^ ]+}}{{ *}}Exits: 2 +; CHECK: --> (zext i8 {0,+,-1}<%bb.i> to i16){{ U: [^ ]+ S: [^ ]+}}{{ *}}Exits: 2 %tmp = zext i8 %l_95.0.i1 to i16 diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index 7fa588566c55..3014fa4cb379 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -68,6 +68,13 @@ protected: const SCEV *&RHS) { return SE.matchURem(Expr, LHS, RHS); } + + static bool isImpliedCond( + ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, + const SCEV *FoundRHS) { + return SE.isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); + } }; TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) { @@ -1368,6 +1375,45 @@ TEST_F(ScalarEvolutionsTest, ProveImplicationViaNarrowing) { }); } +TEST_F(ScalarEvolutionsTest, ImpliedCond) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %len) { " + "entry: " + " br label %loop " + "loop: " + " %iv = phi i32 [ 0, %entry], [%iv.next, %loop] " + " %iv.next = add nsw i32 %iv, 1 " + " %cmp = icmp slt i32 %iv, %len " + " br i1 %cmp, label %loop, label %exit " + "exit:" + " ret void " + "}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + Instruction *IV = getInstructionByName(F, "iv"); + Type *Ty = IV->getType(); + const SCEV *Zero = SE.getZero(Ty); + const SCEV *MinusOne = SE.getMinusOne(Ty); + // {0,+,1} + const SCEV *AddRec_0_1 = SE.getSCEV(IV); + // {0,+,-1} + const SCEV *AddRec_0_N1 = SE.getNegativeSCEV(AddRec_0_1); + + // {0,+,1} > 0 -> {0,+,-1} < 0 + EXPECT_TRUE(isImpliedCond(SE, ICmpInst::ICMP_SLT, AddRec_0_N1, Zero, + ICmpInst::ICMP_SGT, AddRec_0_1, Zero)); + // {0,+,-1} < -1 -> {0,+,1} > 0 + EXPECT_TRUE(isImpliedCond(SE, ICmpInst::ICMP_SGT, AddRec_0_1, Zero, + ICmpInst::ICMP_SLT, AddRec_0_N1, MinusOne)); + }); +} + TEST_F(ScalarEvolutionsTest, MatchURem) { LLVMContext C; SMDiagnostic Err;