diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index f1766fd4fce7..63f67416f1e9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1021,6 +1021,26 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, CxtI.getName() + ".simplified"); } +/// Reduce a pair of compares that check if a value has exactly 1 bit set. +static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, + InstCombiner::BuilderTy &Builder) { + // Handle 'and' commutation: make the not-equal compare the first operand. + if (Cmp1->getPredicate() == ICmpInst::ICMP_NE) + std::swap(Cmp0, Cmp1); + + // (X != 0) && (ctpop(X) u< 2) --> ctpop(X) == 1 + CmpInst::Predicate Pred0, Pred1; + Value *X; + if (match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) && + match(Cmp1, m_ICmp(Pred1, m_Intrinsic(m_Specific(X)), + m_SpecificInt(2))) && + Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT) { + Value *CtPop = Cmp1->getOperand(0); + return Builder.CreateICmpEQ(CtPop, ConstantInt::get(CtPop->getType(), 1)); + } + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { @@ -1063,6 +1083,9 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder)) return V; + if (Value *V = foldIsPowerOf2(LHS, RHS, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); diff --git a/llvm/test/Transforms/InstCombine/ispow2.ll b/llvm/test/Transforms/InstCombine/ispow2.ll index 6ce825be20b1..513eb85421d9 100644 --- a/llvm/test/Transforms/InstCombine/ispow2.ll +++ b/llvm/test/Transforms/InstCombine/ispow2.ll @@ -187,13 +187,13 @@ define i1 @is_pow2or0_negate_op_extra_use2(i32 %x) { declare i32 @llvm.ctpop.i32(i32) declare <2 x i8> @llvm.ctpop.v2i8(<2 x i8>) +; (X != 0) && (ctpop(X) u< 2) --> ctpop(X) == 1 + define i1 @is_pow2_ctpop(i32 %x) { ; CHECK-LABEL: @is_pow2_ctpop( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range !0 -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[T0]], 2 -; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[TMP1]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ult i32 %t0, 2 @@ -212,8 +212,8 @@ define i1 @is_pow2_ctpop_extra_uses(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[NOTZERO]]) -; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[TMP1]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ult i32 %t0, 2 @@ -229,10 +229,8 @@ define i1 @is_pow2_ctpop_extra_uses(i32 %x) { define <2 x i1> @is_pow2_ctpop_commute_vec(<2 x i8> %x) { ; CHECK-LABEL: @is_pow2_ctpop_commute_vec( ; CHECK-NEXT: [[T0:%.*]] = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i8> [[T0]], -; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne <2 x i8> [[X]], zeroinitializer -; CHECK-NEXT: [[R:%.*]] = and <2 x i1> [[CMP]], [[NOTZERO]] -; CHECK-NEXT: ret <2 x i1> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <2 x i8> [[T0]], +; CHECK-NEXT: ret <2 x i1> [[TMP1]] ; %t0 = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %x) %cmp = icmp ult <2 x i8> %t0, @@ -241,6 +239,8 @@ define <2 x i1> @is_pow2_ctpop_commute_vec(<2 x i8> %x) { ret <2 x i1> %r } +; Negative test - wrong constant. + define i1 @is_pow2_ctpop_wrong_cmp_op1(i32 %x) { ; CHECK-LABEL: @is_pow2_ctpop_wrong_cmp_op1( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range !0 @@ -256,6 +256,8 @@ define i1 @is_pow2_ctpop_wrong_cmp_op1(i32 %x) { ret i1 %r } +; Negative test - wrong constant. + define i1 @is_pow2_ctpop_wrong_cmp_op2(i32 %x) { ; CHECK-LABEL: @is_pow2_ctpop_wrong_cmp_op2( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range !0 @@ -271,6 +273,8 @@ define i1 @is_pow2_ctpop_wrong_cmp_op2(i32 %x) { ret i1 %r } +; Negative test - wrong predicate. + define i1 @is_pow2_ctpop_wrong_pred1(i32 %x) { ; CHECK-LABEL: @is_pow2_ctpop_wrong_pred1( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range !0 @@ -286,6 +290,8 @@ define i1 @is_pow2_ctpop_wrong_pred1(i32 %x) { ret i1 %r } +; Negative test - wrong predicate. + define i1 @is_pow2_ctpop_wrong_pred2(i32 %x) { ; CHECK-LABEL: @is_pow2_ctpop_wrong_pred2( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range !0