[InstCombine] Try not to demand low order bits for Add

Don't demand low order bits from the LHS of an Add if:
- they are not demanded in the result, and
- they are known to be zero in the RHS, so they can't possibly
  overflow and affect higher bit positions

This is intended to avoid a regression from a future patch to change
the order of canonicalization of ADD and AND.

Differential Revision: https://reviews.llvm.org/D130075
This commit is contained in:
Jay Foad 2022-07-18 19:32:40 +01:00
parent 8b24e64014
commit 2754ff883d
4 changed files with 70 additions and 29 deletions

View File

@ -154,6 +154,20 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 0 && !V->hasOneUse())
DemandedMask.setAllBits();
// Update flags after simplifying an operand based on the fact that some high
// order bits are not demanded.
auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I,
unsigned NLZ) {
if (NLZ > 0) {
// Disable the nsw and nuw flags here: We can no longer guarantee that
// we won't wrap after simplification. Removing the nsw/nuw flags is
// legal here because the top bit is not demanded.
I->setHasNoSignedWrap(false);
I->setHasNoUnsignedWrap(false);
}
return I;
};
// If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
// about the high bits of the operands.
auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
@ -165,13 +179,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
if (NLZ > 0) {
// Disable the nsw and nuw flags here: We can no longer guarantee that
// we won't wrap after simplification. Removing the nsw/nuw flags is
// legal here because the top bit is not demanded.
I->setHasNoSignedWrap(false);
I->setHasNoUnsignedWrap(false);
}
disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
return true;
}
return false;
@ -461,7 +469,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
break;
}
case Instruction::Add:
case Instruction::Add: {
if ((DemandedMask & 1) == 0) {
// If we do not need the low bit, try to convert bool math to logic:
// add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN
@ -498,11 +506,48 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return Builder.CreateSExt(Or, VTy);
}
}
[[fallthrough]];
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
unsigned NLZ = DemandedMask.countLeadingZeros();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
// If low order bits are not demanded and known to be zero in one operand,
// then we don't need to demand them from the other operand, since they
// can't cause overflow into any bits that are demanded in the result.
unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes();
APInt DemandedFromLHS = DemandedFromOps;
DemandedFromLHS.clearLowBits(NTZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1))
return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
// If we are known to be adding/subtracting zeros to every bit below
// the highest demanded bit, we just return the other side.
if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
// Otherwise just compute the known bits of the result.
bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
Known = KnownBits::computeForAddSub(true, NSW, LHSKnown, RHSKnown);
break;
}
case Instruction::Sub: {
APInt DemandedFromOps;
if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
return I;
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
unsigned NLZ = DemandedMask.countLeadingZeros();
APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1))
return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
// If we are known to be adding/subtracting zeros to every bit below
// the highest demanded bit, we just return the other side.
@ -510,14 +555,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I->getOperand(0);
// We can't do this with the LHS for subtraction, unless we are only
// demanding the LSB.
if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) &&
DemandedFromOps.isSubsetOf(LHSKnown.Zero))
if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
// Otherwise just compute the known bits of the result.
bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add,
NSW, LHSKnown, RHSKnown);
Known = KnownBits::computeForAddSub(false, NSW, LHSKnown, RHSKnown);
break;
}
case Instruction::Mul: {

View File

@ -477,8 +477,7 @@ define i32 @add_of_selects(i1 %A, i32 %B) {
define i32 @add_undemanded_low_bits(i32 %x) {
; CHECK-LABEL: @add_undemanded_low_bits(
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], 15
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[OR]], 1616
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X:%.*]], 1616
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[ADD]], 4
; CHECK-NEXT: ret i32 [[SHR]]
;
@ -490,8 +489,7 @@ define i32 @add_undemanded_low_bits(i32 %x) {
define i32 @sub_undemanded_low_bits(i32 %x) {
; CHECK-LABEL: @sub_undemanded_low_bits(
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], 15
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[OR]], -1616
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], -1616
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[SUB]], 4
; CHECK-NEXT: ret i32 [[SHR]]
;

View File

@ -28,8 +28,8 @@ define void @fp_iv_loop1(float* noalias nocapture %A, i32 %N) #0 {
; AUTO_VEC-NEXT: [[CAST_CRD:%.*]] = sitofp i64 [[N_VEC]] to float
; AUTO_VEC-NEXT: [[TMP0:%.*]] = fmul fast float [[CAST_CRD]], 5.000000e-01
; AUTO_VEC-NEXT: [[IND_END:%.*]] = fadd fast float [[TMP0]], 1.000000e+00
; AUTO_VEC-NEXT: [[TMP1:%.*]] = add nsw i64 [[N_VEC]], -32
; AUTO_VEC-NEXT: [[TMP2:%.*]] = lshr exact i64 [[TMP1]], 5
; AUTO_VEC-NEXT: [[TMP1:%.*]] = add nsw i64 [[ZEXT]], -32
; AUTO_VEC-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 5
; AUTO_VEC-NEXT: [[TMP3:%.*]] = add nuw nsw i64 [[TMP2]], 1
; AUTO_VEC-NEXT: [[XTRAITER:%.*]] = and i64 [[TMP3]], 3
; AUTO_VEC-NEXT: [[TMP4:%.*]] = icmp ult i64 [[TMP1]], 96
@ -298,8 +298,8 @@ define double @external_use_with_fast_math(double* %a, i64 %n) {
; AUTO_VEC-NEXT: [[N_VEC:%.*]] = and i64 [[SMAX]], 9223372036854775792
; AUTO_VEC-NEXT: [[CAST_CRD:%.*]] = sitofp i64 [[N_VEC]] to double
; AUTO_VEC-NEXT: [[TMP0:%.*]] = fmul fast double [[CAST_CRD]], 3.000000e+00
; AUTO_VEC-NEXT: [[TMP1:%.*]] = add nsw i64 [[N_VEC]], -16
; AUTO_VEC-NEXT: [[TMP2:%.*]] = lshr exact i64 [[TMP1]], 4
; AUTO_VEC-NEXT: [[TMP1:%.*]] = add nsw i64 [[SMAX]], -16
; AUTO_VEC-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 4
; AUTO_VEC-NEXT: [[TMP3:%.*]] = add nuw nsw i64 [[TMP2]], 1
; AUTO_VEC-NEXT: [[XTRAITER:%.*]] = and i64 [[TMP3]], 3
; AUTO_VEC-NEXT: [[TMP4:%.*]] = icmp ult i64 [[TMP1]], 48
@ -559,11 +559,11 @@ define void @fadd_reassoc_FMF(float* nocapture %p, i32 %N) {
; AUTO_VEC-NEXT: [[CAST_CRD:%.*]] = sitofp i64 [[N_VEC]] to float
; AUTO_VEC-NEXT: [[TMP1:%.*]] = fmul reassoc float [[CAST_CRD]], 4.200000e+01
; AUTO_VEC-NEXT: [[IND_END:%.*]] = fadd reassoc float [[TMP1]], 1.000000e+00
; AUTO_VEC-NEXT: [[TMP2:%.*]] = add nsw i64 [[N_VEC]], -32
; AUTO_VEC-NEXT: [[TMP3:%.*]] = lshr exact i64 [[TMP2]], 5
; AUTO_VEC-NEXT: [[TMP2:%.*]] = add nsw i64 [[TMP0]], -32
; AUTO_VEC-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP2]], 5
; AUTO_VEC-NEXT: [[TMP4:%.*]] = add nuw nsw i64 [[TMP3]], 1
; AUTO_VEC-NEXT: [[XTRAITER:%.*]] = and i64 [[TMP4]], 1
; AUTO_VEC-NEXT: [[TMP5:%.*]] = icmp eq i64 [[TMP2]], 0
; AUTO_VEC-NEXT: [[TMP5:%.*]] = icmp ult i64 [[TMP2]], 32
; AUTO_VEC-NEXT: br i1 [[TMP5]], label [[MIDDLE_BLOCK_UNR_LCSSA:%.*]], label [[VECTOR_PH_NEW:%.*]]
; AUTO_VEC: vector.ph.new:
; AUTO_VEC-NEXT: [[UNROLL_ITER:%.*]] = and i64 [[TMP4]], 1152921504606846974

View File

@ -176,8 +176,8 @@ define void @test_runtime_trip_count(i32 %N) {
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER7:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
; CHECK-NEXT: [[N_VEC:%.*]] = and i64 [[WIDE_TRIP_COUNT]], 4294967292
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i64 [[N_VEC]], -4
; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i64 [[TMP0]], 2
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i64 [[WIDE_TRIP_COUNT]], -4
; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[TMP0]], 2
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i64 [[TMP1]], 1
; CHECK-NEXT: [[XTRAITER:%.*]] = and i64 [[TMP2]], 7
; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i64 [[TMP0]], 28