[InstCombine] Support arbitrary const shift amount for `lshr (sext i1 ...)`

Add lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0) case. This expands
C == N-1 case to arbitrary C.

Fixes PR52078.

Reviewed By: spatel, RKSimon, lebedev.ri

Differential Revision: https://reviews.llvm.org/D111330
This commit is contained in:
Anton Afanasyev 2021-10-09 11:18:31 +03:00
parent e23351cdc9
commit 7b07c01351
3 changed files with 27 additions and 30 deletions

View File

@ -1067,28 +1067,31 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
return new ZExtInst(NewLShr, Ty);
}
if (match(Op0, m_SExt(m_Value(X))) &&
(!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
// Are we moving the sign bit to the low bit and widening with high zeros?
if (match(Op0, m_SExt(m_Value(X)))) {
unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
if (ShAmtC == BitWidth - 1) {
// lshr (sext i1 X to iN), N-1 --> zext X to iN
if (SrcTyBitWidth == 1)
return new ZExtInst(X, Ty);
// lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0)
if (SrcTyBitWidth == 1) {
auto *NewC = ConstantInt::get(
Ty, APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
}
// lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
if (Op0->hasOneUse()) {
if ((!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType())) &&
Op0->hasOneUse()) {
// Are we moving the sign bit to the low bit and widening with high
// zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
if (ShAmtC == BitWidth - 1) {
Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
return new ZExtInst(NewLShr, Ty);
}
}
// lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
if (ShAmtC == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) {
// The new shift amount can't be more than the narrow source type.
unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1);
Value *AShr = Builder.CreateAShr(X, NewShAmt);
return new ZExtInst(AShr, Ty);
// lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
if (ShAmtC == BitWidth - SrcTyBitWidth) {
// The new shift amount can't be more than the narrow source type.
unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1);
Value *AShr = Builder.CreateAShr(X, NewShAmt);
return new ZExtInst(AShr, Ty);
}
}
}

View File

@ -620,8 +620,7 @@ define i12 @trunc_sandwich_big_sum_shift2_use1(i32 %x) {
define i16 @lshr_sext_i1_to_i16(i1 %a) {
; CHECK-LABEL: @lshr_sext_i1_to_i16(
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[A:%.*]] to i16
; CHECK-NEXT: [[LSHR:%.*]] = lshr i16 [[SEXT]], 4
; CHECK-NEXT: [[LSHR:%.*]] = select i1 [[A:%.*]], i16 4095, i16 0
; CHECK-NEXT: ret i16 [[LSHR]]
;
%sext = sext i1 %a to i16
@ -631,8 +630,7 @@ define i16 @lshr_sext_i1_to_i16(i1 %a) {
define i128 @lshr_sext_i1_to_i128(i1 %a) {
; CHECK-LABEL: @lshr_sext_i1_to_i128(
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[A:%.*]] to i128
; CHECK-NEXT: [[LSHR:%.*]] = lshr i128 [[SEXT]], 42
; CHECK-NEXT: [[LSHR:%.*]] = select i1 [[A:%.*]], i128 77371252455336267181195263, i128 0
; CHECK-NEXT: ret i128 [[LSHR]]
;
%sext = sext i1 %a to i128
@ -644,7 +642,7 @@ define i32 @lshr_sext_i1_to_i32_use(i1 %a) {
; CHECK-LABEL: @lshr_sext_i1_to_i32_use(
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[A:%.*]] to i32
; CHECK-NEXT: call void @use(i32 [[SEXT]])
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[SEXT]], 14
; CHECK-NEXT: [[LSHR:%.*]] = select i1 [[A]], i32 262143, i32 0
; CHECK-NEXT: ret i32 [[LSHR]]
;
%sext = sext i1 %a to i32
@ -657,7 +655,7 @@ define <3 x i14> @lshr_sext_i1_to_i14_splat_vec_use1(<3 x i1> %a) {
; CHECK-LABEL: @lshr_sext_i1_to_i14_splat_vec_use1(
; CHECK-NEXT: [[SEXT:%.*]] = sext <3 x i1> [[A:%.*]] to <3 x i14>
; CHECK-NEXT: call void @usevec(<3 x i14> [[SEXT]])
; CHECK-NEXT: [[LSHR:%.*]] = lshr <3 x i14> [[SEXT]], <i14 4, i14 4, i14 4>
; CHECK-NEXT: [[LSHR:%.*]] = select <3 x i1> [[A]], <3 x i14> <i14 1023, i14 1023, i14 1023>, <3 x i14> zeroinitializer
; CHECK-NEXT: ret <3 x i14> [[LSHR]]
;
%sext = sext <3 x i1> %a to <3 x i14>

View File

@ -16,8 +16,7 @@ define i16 @foo(i1 %a) {
; IC-NEXT: ret i16 [[TRUNC]]
;
; AIC_AND_IC-LABEL: @foo(
; AIC_AND_IC-NEXT: [[SEXT:%.*]] = sext i1 [[A:%.*]] to i16
; AIC_AND_IC-NEXT: [[LSHR:%.*]] = lshr i16 [[SEXT]], 1
; AIC_AND_IC-NEXT: [[LSHR:%.*]] = select i1 [[A:%.*]], i16 32767, i16 0
; AIC_AND_IC-NEXT: ret i16 [[LSHR]]
;
%sext = sext i1 %a to i16
@ -29,18 +28,15 @@ define i16 @foo(i1 %a) {
define i16 @foo2(i1 %a) {
; CHECK-LABEL: @foo2(
; CHECK-NEXT: [[S:%.*]] = sext i1 [[A:%.*]] to i16
; CHECK-NEXT: [[LSHR:%.*]] = lshr i16 [[S]], 1
; CHECK-NEXT: [[LSHR:%.*]] = select i1 [[A:%.*]], i16 32767, i16 0
; CHECK-NEXT: ret i16 [[LSHR]]
;
; IC-LABEL: @foo2(
; IC-NEXT: [[S:%.*]] = sext i1 [[A:%.*]] to i16
; IC-NEXT: [[LSHR:%.*]] = lshr i16 [[S]], 1
; IC-NEXT: [[LSHR:%.*]] = select i1 [[A:%.*]], i16 32767, i16 0
; IC-NEXT: ret i16 [[LSHR]]
;
; AIC_AND_IC-LABEL: @foo2(
; AIC_AND_IC-NEXT: [[S:%.*]] = sext i1 [[A:%.*]] to i16
; AIC_AND_IC-NEXT: [[LSHR:%.*]] = lshr i16 [[S]], 1
; AIC_AND_IC-NEXT: [[LSHR:%.*]] = select i1 [[A:%.*]], i16 32767, i16 0
; AIC_AND_IC-NEXT: ret i16 [[LSHR]]
;
%s = sext i1 %a to i16