[AggressiveInstCombine] Add arithmetic shift right instr to `TruncInstCombine` DAG

Add `ashr` instruction to the DAG post-dominated by `trunc`, allowing
`TruncInstCombine` to reduce bitwidth of expressions containing
these instructions.

We should be shifting by less than the target bitwidth.
Also it is sufficient to require that all truncated bits
of the value-to-be-shifted are sign bits (all zeros or ones) and
one sign bit is left untruncated: https://alive2.llvm.org/ce/z/Ajo2__

Part of https://reviews.llvm.org/D107766

Differential Revision: https://reviews.llvm.org/D108355
This commit is contained in:
Anton Afanasyev 2021-08-19 19:36:54 +03:00
parent 8614cb9f99
commit bed587631f
2 changed files with 52 additions and 53 deletions

View File

@ -65,6 +65,7 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
case Instruction::Xor:
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr:
Ops.push_back(I->getOperand(0));
Ops.push_back(I->getOperand(1));
break;
@ -133,6 +134,7 @@ bool TruncInstCombine::buildTruncExpressionDag() {
case Instruction::Xor:
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr:
case Instruction::Select: {
SmallVector<Value *, 2> Operands;
getRelevantOperands(I, Operands);
@ -143,8 +145,7 @@ bool TruncInstCombine::buildTruncExpressionDag() {
// TODO: Can handle more cases here:
// 1. shufflevector, extractelement, insertelement
// 2. udiv, urem
// 3. ashr
// 4. phi node(and loop handling)
// 3. phi node(and loop handling)
// ...
return false;
}
@ -277,14 +278,16 @@ Type *TruncInstCombine::getBestTruncatedType() {
CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
// Initialize MinBitWidth for shift instructions with the minimum number
// that is greater than shift amount (i.e. shift amount + 1). For `lshr`
// adjust MinBitWidth so that all potentially truncated bits of
// the value-to-be-shifted are zeros.
// Also normalize MinBitWidth not to be greater than source bitwidth.
// that is greater than shift amount (i.e. shift amount + 1).
// For `lshr` adjust MinBitWidth so that all potentially truncated
// bits of the value-to-be-shifted are zeros.
// For `ashr` adjust MinBitWidth so that all potentially truncated
// bits of the value-to-be-shifted are sign bits (all zeros or ones)
// and even one (first) untruncated bit is sign bit.
// Exit early if MinBitWidth is not less than original bitwidth.
for (auto &Itr : InstInfoMap) {
Instruction *I = Itr.first;
if (I->getOpcode() == Instruction::Shl ||
I->getOpcode() == Instruction::LShr) {
if (I->isShift()) {
KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL);
unsigned MinBitWidth = KnownRHS.getMaxValue()
.uadd_sat(APInt(OrigBitWidth, 1))
@ -295,9 +298,13 @@ Type *TruncInstCombine::getBestTruncatedType() {
KnownBits KnownLHS = computeKnownBits(I->getOperand(0), DL);
MinBitWidth =
std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());
if (MinBitWidth >= OrigBitWidth)
return nullptr;
}
if (I->getOpcode() == Instruction::AShr) {
unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0), DL);
MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
}
if (MinBitWidth >= OrigBitWidth)
return nullptr;
Itr.second.MinBitWidth = MinBitWidth;
}
}
@ -390,14 +397,15 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
case Instruction::Or:
case Instruction::Xor:
case Instruction::Shl:
case Instruction::LShr: {
case Instruction::LShr:
case Instruction::AShr: {
Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
// Preserve `exact` flag since truncation doesn't change exactness
if (Opc == Instruction::LShr)
if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))
if (auto *ResI = dyn_cast<Instruction>(Res))
ResI->setIsExact(I->isExact());
ResI->setIsExact(PEO->isExact());
break;
}
case Instruction::Select: {

View File

@ -19,10 +19,8 @@ define i16 @ashr_15_zext(i16 %x) {
define i16 @ashr_sext_15(i16 %x) {
; CHECK-LABEL: @ashr_sext_15(
; CHECK-NEXT: [[SEXT:%.*]] = sext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[SEXT]], 15
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ASHR]] to i16
; CHECK-NEXT: ret i16 [[TRUNC]]
; CHECK-NEXT: [[ASHR:%.*]] = ashr i16 [[X:%.*]], 15
; CHECK-NEXT: ret i16 [[ASHR]]
;
%sext = sext i16 %x to i32
%ashr = ashr i32 %sext, 15
@ -68,14 +66,13 @@ define i16 @ashr_var_shift_amount(i8 %x, i8 %amt) {
define i16 @ashr_var_bounded_shift_amount(i8 %x, i8 %amt) {
; CHECK-LABEL: @ashr_var_bounded_shift_amount(
; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i32
; CHECK-NEXT: [[ZA:%.*]] = zext i8 [[AMT:%.*]] to i32
; CHECK-NEXT: [[ZA2:%.*]] = and i32 [[ZA]], 15
; CHECK-NEXT: [[S:%.*]] = ashr i32 [[Z]], [[ZA2]]
; CHECK-NEXT: [[A:%.*]] = add i32 [[S]], [[Z]]
; CHECK-NEXT: [[S2:%.*]] = ashr i32 [[A]], 2
; CHECK-NEXT: [[T:%.*]] = trunc i32 [[S2]] to i16
; CHECK-NEXT: ret i16 [[T]]
; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16
; CHECK-NEXT: [[ZA:%.*]] = zext i8 [[AMT:%.*]] to i16
; CHECK-NEXT: [[ZA2:%.*]] = and i16 [[ZA]], 15
; CHECK-NEXT: [[S:%.*]] = ashr i16 [[Z]], [[ZA2]]
; CHECK-NEXT: [[A:%.*]] = add i16 [[S]], [[Z]]
; CHECK-NEXT: [[S2:%.*]] = ashr i16 [[A]], 2
; CHECK-NEXT: ret i16 [[S2]]
;
%z = zext i8 %x to i32
%za = zext i8 %amt to i32
@ -108,16 +105,15 @@ define i32 @ashr_check_no_overflow(i32 %x, i16 %amt) {
define void @ashr_big_dag(i16* %a, i8 %b, i8 %c) {
; CHECK-LABEL: @ashr_big_dag(
; CHECK-NEXT: [[ZEXT1:%.*]] = zext i8 [[B:%.*]] to i32
; CHECK-NEXT: [[ZEXT2:%.*]] = zext i8 [[C:%.*]] to i32
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ZEXT1]], [[ZEXT2]]
; CHECK-NEXT: [[SFT1:%.*]] = and i32 [[ADD1]], 15
; CHECK-NEXT: [[SHR1:%.*]] = ashr i32 [[ADD1]], [[SFT1]]
; CHECK-NEXT: [[ADD2:%.*]] = add i32 [[ADD1]], [[SHR1]]
; CHECK-NEXT: [[SFT2:%.*]] = and i32 [[ADD2]], 7
; CHECK-NEXT: [[SHR2:%.*]] = ashr i32 [[ADD2]], [[SFT2]]
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHR2]] to i16
; CHECK-NEXT: store i16 [[TRUNC]], i16* [[A:%.*]], align 2
; CHECK-NEXT: [[ZEXT1:%.*]] = zext i8 [[B:%.*]] to i16
; CHECK-NEXT: [[ZEXT2:%.*]] = zext i8 [[C:%.*]] to i16
; CHECK-NEXT: [[ADD1:%.*]] = add i16 [[ZEXT1]], [[ZEXT2]]
; CHECK-NEXT: [[SFT1:%.*]] = and i16 [[ADD1]], 15
; CHECK-NEXT: [[SHR1:%.*]] = ashr i16 [[ADD1]], [[SFT1]]
; CHECK-NEXT: [[ADD2:%.*]] = add i16 [[ADD1]], [[SHR1]]
; CHECK-NEXT: [[SFT2:%.*]] = and i16 [[ADD2]], 7
; CHECK-NEXT: [[SHR2:%.*]] = ashr i16 [[ADD2]], [[SFT2]]
; CHECK-NEXT: store i16 [[SHR2]], i16* [[A:%.*]], align 2
; CHECK-NEXT: ret void
;
%zext1 = zext i8 %b to i32
@ -152,13 +148,12 @@ define i8 @ashr_check_not_i8_trunc(i16 %x) {
define <2 x i16> @ashr_vector(<2 x i8> %x) {
; CHECK-LABEL: @ashr_vector(
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
; CHECK-NEXT: [[ZA:%.*]] = and <2 x i32> [[Z]], <i32 7, i32 8>
; CHECK-NEXT: [[S:%.*]] = ashr <2 x i32> [[Z]], [[ZA]]
; CHECK-NEXT: [[A:%.*]] = add <2 x i32> [[S]], [[Z]]
; CHECK-NEXT: [[S2:%.*]] = ashr <2 x i32> [[A]], <i32 4, i32 5>
; CHECK-NEXT: [[T:%.*]] = trunc <2 x i32> [[S2]] to <2 x i16>
; CHECK-NEXT: ret <2 x i16> [[T]]
; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16>
; CHECK-NEXT: [[ZA:%.*]] = and <2 x i16> [[Z]], <i16 7, i16 8>
; CHECK-NEXT: [[S:%.*]] = ashr <2 x i16> [[Z]], [[ZA]]
; CHECK-NEXT: [[A:%.*]] = add <2 x i16> [[S]], [[Z]]
; CHECK-NEXT: [[S2:%.*]] = ashr <2 x i16> [[A]], <i16 4, i16 5>
; CHECK-NEXT: ret <2 x i16> [[S2]]
;
%z = zext <2 x i8> %x to <2 x i32>
%za = and <2 x i32> %z, <i32 7, i32 8>
@ -213,11 +208,9 @@ define <2 x i16> @ashr_vector_large_shift_amount(<2 x i8> %x) {
define i16 @ashr_exact(i16 %x) {
; CHECK-LABEL: @ashr_exact(
; CHECK-NEXT: [[ZEXT:%.*]] = zext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[AND:%.*]] = and i32 [[ZEXT]], 32767
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[AND]], 15
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ASHR]] to i16
; CHECK-NEXT: ret i16 [[TRUNC]]
; CHECK-NEXT: [[AND:%.*]] = and i16 [[X:%.*]], 32767
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i16 [[AND]], 15
; CHECK-NEXT: ret i16 [[ASHR]]
;
%zext = zext i16 %x to i32
%and = and i32 %zext, 32767
@ -245,12 +238,10 @@ define i16 @ashr_negative_operand(i16 %x) {
define i16 @ashr_negative_operand_but_short(i16 %x) {
; CHECK-LABEL: @ashr_negative_operand_but_short(
; CHECK-NEXT: [[ZEXT:%.*]] = zext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[AND:%.*]] = and i32 [[ZEXT]], 32767
; CHECK-NEXT: [[XOR:%.*]] = xor i32 -1, [[AND]]
; CHECK-NEXT: [[LSHR2:%.*]] = ashr i32 [[XOR]], 2
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[LSHR2]] to i16
; CHECK-NEXT: ret i16 [[TRUNC]]
; CHECK-NEXT: [[AND:%.*]] = and i16 [[X:%.*]], 32767
; CHECK-NEXT: [[XOR:%.*]] = xor i16 -1, [[AND]]
; CHECK-NEXT: [[LSHR2:%.*]] = ashr i16 [[XOR]], 2
; CHECK-NEXT: ret i16 [[LSHR2]]
;
%zext = zext i16 %x to i32
%and = and i32 %zext, 32767