diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index e3a424ebdeea..2ba052b7e02d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -503,9 +503,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownZero.setLowBits(ShiftAmt); } break; - case Instruction::LShr: - // For a logical shift right - if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + case Instruction::LShr: { + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Unsigned shift right. @@ -526,7 +526,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownZero.setHighBits(ShiftAmt); // high bits known zero. } break; - case Instruction::AShr: + } + case Instruction::AShr: { // If this is an arithmetic shift right and only the low-bit is set, we can // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless @@ -543,12 +544,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask.isSignMask()) return I->getOperand(0); - if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Signed shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); - // If any of the "high bits" are demanded, we should set the sign bit as + // If any of the high bits are demanded, we should set the sign bit as // demanded. if (DemandedMask.countLeadingZeros() <= ShiftAmt) DemandedMaskIn.setSignBit(); @@ -561,6 +563,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne, Depth + 1)) return I; + assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now. APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); @@ -576,16 +579,16 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // are demanded, turn this into an unsigned shift right. if (BitWidth <= ShiftAmt || KnownZero[BitWidth-ShiftAmt-1] || (HighBits & ~DemandedMask) == HighBits) { - // Perform the logical shift right. - BinaryOperator *NewVal = BinaryOperator::CreateLShr(I->getOperand(0), - SA, I->getName()); - NewVal->setIsExact(cast(I)->isExact()); - return InsertNewInstWith(NewVal, *I); + BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), + I->getOperand(1)); + LShr->setIsExact(cast(I)->isExact()); + return InsertNewInstWith(LShr, *I); } else if ((KnownOne & SignMask) != 0) { // New bits are known one. KnownOne |= HighBits; } } break; + } case Instruction::SRem: if (ConstantInt *Rem = dyn_cast(I->getOperand(1))) { // X % -1 demands all the bits because we don't want to introduce diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll index 72d361e8b6db..d5f489280a03 100644 --- a/llvm/test/Transforms/InstCombine/shift.ll +++ b/llvm/test/Transforms/InstCombine/shift.ll @@ -1270,8 +1270,7 @@ define <2 x i64> @test_64_splat_vec(<2 x i32> %t) { define <2 x i8> @ashr_demanded_bits_splat(<2 x i8> %x) { ; CHECK-LABEL: @ashr_demanded_bits_splat( -; CHECK-NEXT: [[AND:%.*]] = and <2 x i8> %x, -; CHECK-NEXT: [[SHR:%.*]] = ashr exact <2 x i8> [[AND]], +; CHECK-NEXT: [[SHR:%.*]] = ashr <2 x i8> %x, ; CHECK-NEXT: ret <2 x i8> [[SHR]] ; %and = and <2 x i8> %x, @@ -1281,8 +1280,7 @@ define <2 x i8> @ashr_demanded_bits_splat(<2 x i8> %x) { define <2 x i8> @lshr_demanded_bits_splat(<2 x i8> %x) { ; CHECK-LABEL: @lshr_demanded_bits_splat( -; CHECK-NEXT: [[AND:%.*]] = and <2 x i8> %x, -; CHECK-NEXT: [[SHR:%.*]] = lshr exact <2 x i8> [[AND]], +; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i8> %x, ; CHECK-NEXT: ret <2 x i8> [[SHR]] ; %and = and <2 x i8> %x, diff --git a/llvm/test/Transforms/InstCombine/vector-casts.ll b/llvm/test/Transforms/InstCombine/vector-casts.ll index 643ab6c5348f..2197c250ace2 100644 --- a/llvm/test/Transforms/InstCombine/vector-casts.ll +++ b/llvm/test/Transforms/InstCombine/vector-casts.ll @@ -15,9 +15,9 @@ define <2 x i1> @test1(<2 x i64> %a) { ; The ashr turns into an lshr. define <2 x i64> @test2(<2 x i64> %a) { ; CHECK-LABEL: @test2( -; CHECK-NEXT: [[B:%.*]] = and <2 x i64> %a, -; CHECK-NEXT: [[T:%.*]] = lshr <2 x i64> [[B]], -; CHECK-NEXT: ret <2 x i64> [[T]] +; CHECK-NEXT: [[B:%.*]] = and <2 x i64> %a, +; CHECK-NEXT: [[TMP1:%.*]] = lshr exact <2 x i64> [[B]], +; CHECK-NEXT: ret <2 x i64> [[TMP1]] ; %b = and <2 x i64> %a, %t = ashr <2 x i64> %b,