diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 69d09199a43d..6fd4afffbb5f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1571,11 +1571,22 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateSub(Y, X); if (Constant *C = dyn_cast(Op0)) { + bool IsNegate = match(C, m_ZeroInt()); Value *X; - // C - zext(bool) -> bool ? C - 1 : C - if (match(Op1, m_ZExt(m_Value(X))) && - X->getType()->getScalarSizeInBits() == 1) + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + // 0 - (zext bool) --> sext bool + // C - (zext bool) --> bool ? C - 1 : C + if (IsNegate) + return CastInst::CreateSExtOrBitCast(X, I.getType()); return SelectInst::Create(X, SubOne(C), C); + } + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + // 0 - (sext bool) --> zext bool + // C - (sext bool) --> bool ? C + 1 : C + if (IsNegate) + return CastInst::CreateZExtOrBitCast(X, I.getType()); + return SelectInst::Create(X, AddOne(C), C); + } // C - ~X == X + (1+C) if (match(Op1, m_Not(m_Value(X)))) @@ -1595,16 +1606,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Constant *C2; if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); - - // Fold (sub 0, (zext bool to B)) --> (sext bool to B) - if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X)))) - if (X->getType()->isIntOrIntVectorTy(1)) - return CastInst::CreateSExtOrBitCast(X, Op1->getType()); - - // Fold (sub 0, (sext bool to B)) --> (zext bool to B) - if (C->isNullValue() && match(Op1, m_SExt(m_Value(X)))) - if (X->getType()->isIntOrIntVectorTy(1)) - return CastInst::CreateZExtOrBitCast(X, Op1->getType()); } const APInt *Op0C; diff --git a/llvm/test/Transforms/InstCombine/zext-bool-add-sub.ll b/llvm/test/Transforms/InstCombine/zext-bool-add-sub.ll index 68ee35a438c5..3fe5bddd1178 100644 --- a/llvm/test/Transforms/InstCombine/zext-bool-add-sub.ll +++ b/llvm/test/Transforms/InstCombine/zext-bool-add-sub.ll @@ -214,8 +214,8 @@ define <2 x i64> @sext_negate_vec(<2 x i1> %A) { define <2 x i64> @sext_negate_vec_undef_elt(<2 x i1> %A) { ; CHECK-LABEL: @sext_negate_vec_undef_elt( -; CHECK-NEXT: [[TMP1:%.*]] = zext <2 x i1> [[A:%.*]] to <2 x i64> -; CHECK-NEXT: ret <2 x i64> [[TMP1]] +; CHECK-NEXT: [[SUB:%.*]] = zext <2 x i1> [[A:%.*]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[SUB]] ; %ext = sext <2 x i1> %A to <2 x i64> %sub = sub <2 x i64> , %ext @@ -232,12 +232,10 @@ define i64 @sext_sub_const(i1 %A) { ret i64 %sub } -; FIXME: This doesn't correspond to the zext pattern above. We should have a select. - define i64 @sext_sub_const_extra_use(i1 %A) { ; CHECK-LABEL: @sext_sub_const_extra_use( ; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[A:%.*]] to i64 -; CHECK-NEXT: [[SUB:%.*]] = sub nsw i64 42, [[EXT]] +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[A]], i64 43, i64 42 ; CHECK-NEXT: call void @use(i64 [[EXT]]) ; CHECK-NEXT: ret i64 [[SUB]] ;