[InstCombine] Generalize sadd.sat combine to compute sign bits.

There is a combine in instcombine to transform a saturated add/sub into
a saddsat/ssubsat, currently handling inputs which are both sign
extended (https://alive2.llvm.org/ce/z/68qpTn). This can generalize to,
for example ashr of at least the bitwidth (https://alive2.llvm.org/ce/z/4TFyX-
and https://alive2.llvm.org/ce/z/qDWzFs for example). Which means it
generalizes further to "the number of sign bits", needing to be enough
to truncate to the size of the saturate. (An example using `or` for
instance: https://alive2.llvm.org/ce/z/EI_h_A).

So this patch makes use of ComputeNumSignBits (with the newly added
ComputeMinSignedBits) in matchSAddSubSat to generalize the fold to any
inputs with enough sign bits known, truncating the inputs to the new
size of the saturate.

Differential Revision: https://reviews.llvm.org/D112298
This commit is contained in:
David Green 2021-11-05 15:05:09 +00:00
parent ea55503d7c
commit 08056e1888
2 changed files with 33 additions and 61 deletions

View File

@ -2304,16 +2304,6 @@ Instruction *InstCombinerImpl::matchSAddSubSat(Instruction &MinMax1) {
// Create the new type (which can be a vector type) // Create the new type (which can be a vector type)
Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth);
// Match the two extends from the add/sub
Value *A, *B;
if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B)))))
return nullptr;
// And check the incoming values are of a type smaller than or equal to the
// size of the saturation. Otherwise the higher bits can cause different
// results.
if (A->getType()->getScalarSizeInBits() > NewBitWidth ||
B->getType()->getScalarSizeInBits() > NewBitWidth)
return nullptr;
Intrinsic::ID IntrinsicID; Intrinsic::ID IntrinsicID;
if (AddSub->getOpcode() == Instruction::Add) if (AddSub->getOpcode() == Instruction::Add)
@ -2323,10 +2313,16 @@ Instruction *InstCombinerImpl::matchSAddSubSat(Instruction &MinMax1) {
else else
return nullptr; return nullptr;
// The two operands of the add/sub must be nsw-truncatable to the NewTy. This
// is usually achieved via a sext from a smaller type.
if (ComputeMinSignedBits(AddSub->getOperand(0), 0, AddSub) > NewBitWidth ||
ComputeMinSignedBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth)
return nullptr;
// Finally create and return the sat intrinsic, truncated to the new type // Finally create and return the sat intrinsic, truncated to the new type
Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy);
Value *AT = Builder.CreateSExt(A, NewTy); Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy);
Value *BT = Builder.CreateSExt(B, NewTy); Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy);
Value *Sat = Builder.CreateCall(F, {AT, BT}); Value *Sat = Builder.CreateCall(F, {AT, BT});
return CastInst::Create(Instruction::SExt, Sat, Ty); return CastInst::Create(Instruction::SExt, Sat, Ty);
} }

View File

@ -698,13 +698,10 @@ entry:
define i32 @ashrA(i64 %a, i32 %b) { define i32 @ashrA(i64 %a, i32 %b) {
; CHECK-LABEL: @ashrA( ; CHECK-LABEL: @ashrA(
; CHECK-NEXT: entry: ; CHECK-NEXT: entry:
; CHECK-NEXT: [[CONV:%.*]] = ashr i64 [[A:%.*]], 32 ; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[A:%.*]], 32
; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64 ; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV]], [[CONV1]] ; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[B:%.*]])
; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = call i64 @llvm.smin.i64(i64 [[ADD]], i64 2147483647) ; CHECK-NEXT: ret i32 [[TMP2]]
; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = call i64 @llvm.smax.i64(i64 [[SPEC_STORE_SELECT]], i64 -2147483648)
; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
; CHECK-NEXT: ret i32 [[CONV7]]
; ;
entry: entry:
%conv = ashr i64 %a, 32 %conv = ashr i64 %a, 32
@ -719,15 +716,10 @@ entry:
define i32 @ashrB(i32 %a, i64 %b) { define i32 @ashrB(i32 %a, i64 %b) {
; CHECK-LABEL: @ashrB( ; CHECK-LABEL: @ashrB(
; CHECK-NEXT: entry: ; CHECK-NEXT: entry:
; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[A:%.*]] to i64 ; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[B:%.*]], 32
; CHECK-NEXT: [[CONV1:%.*]] = ashr i64 [[B:%.*]], 32 ; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV1]], [[CONV]] ; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[A:%.*]])
; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648 ; CHECK-NEXT: ret i32 [[TMP2]]
; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647
; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647
; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
; CHECK-NEXT: ret i32 [[CONV7]]
; ;
entry: entry:
%conv = sext i32 %a to i64 %conv = sext i32 %a to i64
@ -744,15 +736,12 @@ entry:
define i32 @ashrAB(i64 %a, i64 %b) { define i32 @ashrAB(i64 %a, i64 %b) {
; CHECK-LABEL: @ashrAB( ; CHECK-LABEL: @ashrAB(
; CHECK-NEXT: entry: ; CHECK-NEXT: entry:
; CHECK-NEXT: [[CONV:%.*]] = ashr i64 [[A:%.*]], 32 ; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[A:%.*]], 32
; CHECK-NEXT: [[CONV1:%.*]] = ashr i64 [[B:%.*]], 32 ; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[B:%.*]], 32
; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV1]], [[CONV]] ; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648 ; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[TMP0]] to i32
; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648 ; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP2]], i32 [[TMP3]])
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647 ; CHECK-NEXT: ret i32 [[TMP4]]
; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647
; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
; CHECK-NEXT: ret i32 [[CONV7]]
; ;
entry: entry:
%conv = ashr i64 %a, 32 %conv = ashr i64 %a, 32
@ -795,14 +784,9 @@ define i32 @ashrA33(i64 %a, i32 %b) {
; CHECK-LABEL: @ashrA33( ; CHECK-LABEL: @ashrA33(
; CHECK-NEXT: entry: ; CHECK-NEXT: entry:
; CHECK-NEXT: [[CONV:%.*]] = ashr i64 [[A:%.*]], 33 ; CHECK-NEXT: [[CONV:%.*]] = ashr i64 [[A:%.*]], 33
; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64 ; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[CONV]] to i32
; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV]], [[CONV1]] ; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP0]], i32 [[B:%.*]])
; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648 ; CHECK-NEXT: ret i32 [[TMP1]]
; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647
; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647
; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
; CHECK-NEXT: ret i32 [[CONV7]]
; ;
entry: entry:
%conv = ashr i64 %a, 33 %conv = ashr i64 %a, 33
@ -844,15 +828,10 @@ entry:
define <2 x i8> @ashrv2i8_s(<2 x i16> %a, <2 x i8> %b) { define <2 x i8> @ashrv2i8_s(<2 x i16> %a, <2 x i8> %b) {
; CHECK-LABEL: @ashrv2i8_s( ; CHECK-LABEL: @ashrv2i8_s(
; CHECK-NEXT: entry: ; CHECK-NEXT: entry:
; CHECK-NEXT: [[CONV:%.*]] = ashr <2 x i16> [[A:%.*]], <i16 8, i16 8> ; CHECK-NEXT: [[TMP0:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 8, i16 8>
; CHECK-NEXT: [[CONV1:%.*]] = sext <2 x i8> [[B:%.*]] to <2 x i16> ; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i16> [[TMP0]] to <2 x i8>
; CHECK-NEXT: [[ADD:%.*]] = add nsw <2 x i16> [[CONV]], [[CONV1]] ; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i8> @llvm.sadd.sat.v2i8(<2 x i8> [[TMP1]], <2 x i8> [[B:%.*]])
; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt <2 x i16> [[ADD]], <i16 -128, i16 -128> ; CHECK-NEXT: ret <2 x i8> [[TMP2]]
; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select <2 x i1> [[TMP0]], <2 x i16> [[ADD]], <2 x i16> <i16 -128, i16 -128>
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <2 x i16> [[SPEC_STORE_SELECT]], <i16 127, i16 127>
; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select <2 x i1> [[TMP1]], <2 x i16> [[SPEC_STORE_SELECT]], <2 x i16> <i16 127, i16 127>
; CHECK-NEXT: [[CONV7:%.*]] = trunc <2 x i16> [[SPEC_STORE_SELECT8]] to <2 x i8>
; CHECK-NEXT: ret <2 x i8> [[CONV7]]
; ;
entry: entry:
%conv = ashr <2 x i16> %a, <i16 8, i16 8> %conv = ashr <2 x i16> %a, <i16 8, i16 8>
@ -868,13 +847,10 @@ entry:
define i16 @or(i8 %X, i16 %Y) { define i16 @or(i8 %X, i16 %Y) {
; CHECK-LABEL: @or( ; CHECK-LABEL: @or(
; CHECK-NEXT: [[CONV10:%.*]] = sext i8 [[X:%.*]] to i16 ; CHECK-NEXT: [[TMP1:%.*]] = trunc i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[CONV14:%.*]] = or i16 [[Y:%.*]], -16 ; CHECK-NEXT: [[TMP2:%.*]] = or i8 [[TMP1]], -16
; CHECK-NEXT: [[SUB:%.*]] = sub nsw i16 [[CONV10]], [[CONV14]] ; CHECK-NEXT: [[TMP3:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[X:%.*]], i8 [[TMP2]])
; CHECK-NEXT: [[L9:%.*]] = icmp sgt i16 [[SUB]], -128 ; CHECK-NEXT: [[L12:%.*]] = sext i8 [[TMP3]] to i16
; CHECK-NEXT: [[L10:%.*]] = select i1 [[L9]], i16 [[SUB]], i16 -128
; CHECK-NEXT: [[L11:%.*]] = icmp slt i16 [[L10]], 127
; CHECK-NEXT: [[L12:%.*]] = select i1 [[L11]], i16 [[L10]], i16 127
; CHECK-NEXT: ret i16 [[L12]] ; CHECK-NEXT: ret i16 [[L12]]
; ;
%conv10 = sext i8 %X to i16 %conv10 = sext i8 %X to i16