From 2c2568f39ec641aa8f1dcc011f2ce642c2d3423f Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 14 Apr 2022 14:02:56 -0400 Subject: [PATCH] [InstCombine] canonicalize select with signbit test This is part of solving issue #54750 - in that example we have both forms of the compare and do not recognize the equivalence. --- .../InstCombine/InstCombineSelect.cpp | 17 ++++ llvm/test/Transforms/InstCombine/ashr-lshr.ll | 80 +++++++++---------- .../Transforms/InstCombine/logical-select.ll | 4 +- .../InstCombine/truncating-saturate.ll | 4 +- 4 files changed, 61 insertions(+), 44 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 7fc3f5cfe69a..7dea7433fd55 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1602,6 +1602,23 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, } } + // Canonicalize a signbit condition to use zero constant by swapping: + // (CmpLHS > -1) ? TV : FV --> (CmpLHS < 0) ? FV : TV + // To avoid conflicts (infinite loops) with other canonicalizations, this is + // not applied with any constant select arm. + if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes()) && + !match(TrueVal, m_Constant()) && !match(FalseVal, m_Constant()) && + ICI->hasOneUse()) { + InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(&SI); + Value *IsNeg = Builder.CreateICmpSLT( + CmpLHS, ConstantInt::getNullValue(CmpLHS->getType()), ICI->getName()); + replaceOperand(SI, 0, IsNeg); + SI.swapValues(); + SI.swapProfMetadata(); + return &SI; + } + // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring // decomposeBitTestICmp() might help. { diff --git a/llvm/test/Transforms/InstCombine/ashr-lshr.ll b/llvm/test/Transforms/InstCombine/ashr-lshr.ll index b1afb975435f..d14258fa0482 100644 --- a/llvm/test/Transforms/InstCombine/ashr-lshr.ll +++ b/llvm/test/Transforms/InstCombine/ashr-lshr.ll @@ -3,8 +3,8 @@ define i32 @ashr_lshr_exact_ashr_only(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_exact_ashr_only( -; CHECK-NEXT: [[CMP1:%.*]] = ashr i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret i32 [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[CMP12]] ; %cmp = icmp sgt i32 %x, -1 %l = lshr i32 %x, %y @@ -15,8 +15,8 @@ define i32 @ashr_lshr_exact_ashr_only(i32 %x, i32 %y) { define i32 @ashr_lshr_no_exact(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_no_exact( -; CHECK-NEXT: [[CMP1:%.*]] = ashr i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret i32 [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[CMP12]] ; %cmp = icmp sgt i32 %x, -1 %l = lshr i32 %x, %y @@ -27,8 +27,8 @@ define i32 @ashr_lshr_no_exact(i32 %x, i32 %y) { define i32 @ashr_lshr_exact_both(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_exact_both( -; CHECK-NEXT: [[CMP1:%.*]] = ashr exact i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret i32 [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr exact i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[CMP12]] ; %cmp = icmp sgt i32 %x, -1 %l = lshr exact i32 %x, %y @@ -39,8 +39,8 @@ define i32 @ashr_lshr_exact_both(i32 %x, i32 %y) { define i32 @ashr_lshr_exact_lshr_only(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_exact_lshr_only( -; CHECK-NEXT: [[CMP1:%.*]] = ashr i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret i32 [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[CMP12]] ; %cmp = icmp sgt i32 %x, -1 %l = lshr exact i32 %x, %y @@ -63,8 +63,8 @@ define i32 @ashr_lshr2(i32 %x, i32 %y) { define <2 x i32> @ashr_lshr_splat_vec(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @ashr_lshr_splat_vec( -; CHECK-NEXT: [[CMP1:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret <2 x i32> [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i32> [[CMP12]] ; %cmp = icmp sgt <2 x i32> %x, %l = lshr <2 x i32> %x, %y @@ -75,8 +75,8 @@ define <2 x i32> @ashr_lshr_splat_vec(<2 x i32> %x, <2 x i32> %y) { define <2 x i32> @ashr_lshr_splat_vec2(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @ashr_lshr_splat_vec2( -; CHECK-NEXT: [[CMP1:%.*]] = ashr exact <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret <2 x i32> [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr exact <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i32> [[CMP12]] ; %cmp = icmp sgt <2 x i32> %x, %l = lshr exact <2 x i32> %x, %y @@ -87,8 +87,8 @@ define <2 x i32> @ashr_lshr_splat_vec2(<2 x i32> %x, <2 x i32> %y) { define <2 x i32> @ashr_lshr_splat_vec3(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @ashr_lshr_splat_vec3( -; CHECK-NEXT: [[CMP1:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret <2 x i32> [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i32> [[CMP12]] ; %cmp = icmp sgt <2 x i32> %x, %l = lshr exact <2 x i32> %x, %y @@ -99,8 +99,8 @@ define <2 x i32> @ashr_lshr_splat_vec3(<2 x i32> %x, <2 x i32> %y) { define <2 x i32> @ashr_lshr_splat_vec4(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @ashr_lshr_splat_vec4( -; CHECK-NEXT: [[CMP1:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret <2 x i32> [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i32> [[CMP12]] ; %cmp = icmp sgt <2 x i32> %x, %l = lshr <2 x i32> %x, %y @@ -171,8 +171,8 @@ define i32 @ashr_lshr_cst(i32 %x, i32 %y) { define i32 @ashr_lshr_cst2(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_cst2( -; CHECK-NEXT: [[CMP1:%.*]] = ashr i32 [[X:%.*]], 8 -; CHECK-NEXT: ret i32 [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr i32 [[X:%.*]], 8 +; CHECK-NEXT: ret i32 [[CMP12]] ; %cmp = icmp sgt i32 %x, -1 %l = lshr i32 %x, 8 @@ -231,8 +231,8 @@ define <2 x i32> @ashr_lshr_inv_nonsplat_vec(<2 x i32> %x, <2 x i32> %y) { define <2 x i32> @ashr_lshr_vec_undef(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @ashr_lshr_vec_undef( -; CHECK-NEXT: [[CMP1:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: ret <2 x i32> [[CMP1]] +; CHECK-NEXT: [[CMP12:%.*]] = ashr <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i32> [[CMP12]] ; %cmp = icmp sgt <2 x i32> %x, %l = lshr <2 x i32> %x, %y @@ -317,10 +317,10 @@ define i32 @ashr_lshr_shift_wrong_pred(i32 %x, i32 %y, i32 %z) { define i32 @ashr_lshr_shift_wrong_pred2(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @ashr_lshr_shift_wrong_pred2( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[Z:%.*]], -1 ; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[R:%.*]] = ashr i32 [[X]], [[Y]] -; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP]], i32 [[L]], i32 [[R]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[Z:%.*]], 0 +; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP1]], i32 [[R]], i32 [[L]] ; CHECK-NEXT: ret i32 [[RET]] ; %cmp = icmp sge i32 %z, 0 @@ -332,10 +332,10 @@ define i32 @ashr_lshr_shift_wrong_pred2(i32 %x, i32 %y, i32 %z) { define i32 @ashr_lshr_wrong_operands(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_wrong_operands( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], -1 -; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X]], [[Y:%.*]] +; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[R:%.*]] = ashr i32 [[X]], [[Y]] -; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP]], i32 [[R]], i32 [[L]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[X]], 0 +; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP1]], i32 [[L]], i32 [[R]] ; CHECK-NEXT: ret i32 [[RET]] ; %cmp = icmp sge i32 %x, 0 @@ -347,10 +347,10 @@ define i32 @ashr_lshr_wrong_operands(i32 %x, i32 %y) { define i32 @ashr_lshr_no_ashr(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_no_ashr( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], -1 -; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X]], [[Y:%.*]] +; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[R:%.*]] = xor i32 [[X]], [[Y]] -; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP]], i32 [[L]], i32 [[R]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[X]], 0 +; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP1]], i32 [[R]], i32 [[L]] ; CHECK-NEXT: ret i32 [[RET]] ; %cmp = icmp sge i32 %x, 0 @@ -362,10 +362,10 @@ define i32 @ashr_lshr_no_ashr(i32 %x, i32 %y) { define i32 @ashr_lshr_shift_amt_mismatch(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @ashr_lshr_shift_amt_mismatch( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], -1 -; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X]], [[Y:%.*]] +; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[R:%.*]] = ashr i32 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP]], i32 [[L]], i32 [[R]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[X]], 0 +; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP1]], i32 [[R]], i32 [[L]] ; CHECK-NEXT: ret i32 [[RET]] ; %cmp = icmp sge i32 %x, 0 @@ -377,10 +377,10 @@ define i32 @ashr_lshr_shift_amt_mismatch(i32 %x, i32 %y, i32 %z) { define i32 @ashr_lshr_shift_base_mismatch(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @ashr_lshr_shift_base_mismatch( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], -1 -; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X]], [[Y:%.*]] +; CHECK-NEXT: [[L:%.*]] = lshr i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[R:%.*]] = ashr i32 [[Z:%.*]], [[Y]] -; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP]], i32 [[L]], i32 [[R]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[X]], 0 +; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP1]], i32 [[R]], i32 [[L]] ; CHECK-NEXT: ret i32 [[RET]] ; %cmp = icmp sge i32 %x, 0 @@ -392,10 +392,10 @@ define i32 @ashr_lshr_shift_base_mismatch(i32 %x, i32 %y, i32 %z) { define i32 @ashr_lshr_no_lshr(i32 %x, i32 %y) { ; CHECK-LABEL: @ashr_lshr_no_lshr( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], -1 -; CHECK-NEXT: [[L:%.*]] = add i32 [[X]], [[Y:%.*]] +; CHECK-NEXT: [[L:%.*]] = add i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[R:%.*]] = ashr i32 [[X]], [[Y]] -; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP]], i32 [[L]], i32 [[R]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[X]], 0 +; CHECK-NEXT: [[RET:%.*]] = select i1 [[CMP1]], i32 [[R]], i32 [[L]] ; CHECK-NEXT: ret i32 [[RET]] ; %cmp = icmp sge i32 %x, 0 @@ -422,10 +422,10 @@ define <2 x i32> @ashr_lshr_vec_wrong_pred(<2 x i32> %x, <2 x i32> %y) { define <2 x i32> @ashr_lshr_inv_vec_wrong_pred(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @ashr_lshr_inv_vec_wrong_pred( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[L:%.*]] = lshr <2 x i32> [[X]], [[Y:%.*]] +; CHECK-NEXT: [[L:%.*]] = lshr <2 x i32> [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: [[R:%.*]] = ashr <2 x i32> [[X]], [[Y]] -; CHECK-NEXT: [[RET:%.*]] = select <2 x i1> [[CMP]], <2 x i32> [[R]], <2 x i32> [[L]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt <2 x i32> [[X]], zeroinitializer +; CHECK-NEXT: [[RET:%.*]] = select <2 x i1> [[CMP1]], <2 x i32> [[L]], <2 x i32> [[R]] ; CHECK-NEXT: ret <2 x i32> [[RET]] ; %cmp = icmp sge <2 x i32> %x, zeroinitializer diff --git a/llvm/test/Transforms/InstCombine/logical-select.ll b/llvm/test/Transforms/InstCombine/logical-select.ll index 826ae9062df1..921c5223dc92 100644 --- a/llvm/test/Transforms/InstCombine/logical-select.ll +++ b/llvm/test/Transforms/InstCombine/logical-select.ll @@ -758,10 +758,10 @@ define <2 x i16> @bitcast_vec_cond_commute3(<4 x i8> %cond, <2 x i16> %pc, <2 x ; CHECK-LABEL: @bitcast_vec_cond_commute3( ; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]] ; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]] -; CHECK-NEXT: [[DOTNOT:%.*]] = icmp sgt <4 x i8> [[COND:%.*]], ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8> ; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8> -; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> [[TMP2]], <4 x i8> [[TMP1]] +; CHECK-NEXT: [[DOTNOT2:%.*]] = icmp slt <4 x i8> [[COND:%.*]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[DOTNOT2]], <4 x i8> [[TMP1]], <4 x i8> [[TMP2]] ; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16> ; CHECK-NEXT: ret <2 x i16> [[TMP4]] ; diff --git a/llvm/test/Transforms/InstCombine/truncating-saturate.ll b/llvm/test/Transforms/InstCombine/truncating-saturate.ll index 7759224ee1f5..0f1697906d7f 100644 --- a/llvm/test/Transforms/InstCombine/truncating-saturate.ll +++ b/llvm/test/Transforms/InstCombine/truncating-saturate.ll @@ -111,8 +111,8 @@ define i16 @testtrunclowhigh(i32 %add, i16 %low, i16 %high) { ; CHECK-NEXT: [[A:%.*]] = add i32 [[ADD:%.*]], 128 ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[A]], 256 ; CHECK-NEXT: [[T:%.*]] = trunc i32 [[ADD]] to i16 -; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[ADD]], -1 -; CHECK-NEXT: [[F:%.*]] = select i1 [[C]], i16 [[HIGH:%.*]], i16 [[LOW:%.*]] +; CHECK-NEXT: [[C1:%.*]] = icmp slt i32 [[ADD]], 0 +; CHECK-NEXT: [[F:%.*]] = select i1 [[C1]], i16 [[LOW:%.*]], i16 [[HIGH:%.*]] ; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i16 [[T]], i16 [[F]] ; CHECK-NEXT: ret i16 [[R]] ;