diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index fb90814b6ecc..edd51baf2a89 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -373,18 +373,6 @@ foldShiftByConstOfShiftByConst(BinaryOperator &I, const APInt *COp1, if (ShiftAmt2 < ShiftAmt1) { uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2; - // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShr = - BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst); - NewShr->setIsExact(true); - return NewShr; - } - // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) if (I.getOpcode() == Instruction::LShr && ShiftOp->getOpcode() == Instruction::Shl) { @@ -670,18 +658,28 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - const APInt *ShrAmt; - if (match(Op0, m_CombineOr(m_Exact(m_LShr(m_Value(X), m_APInt(ShrAmt))), - m_Exact(m_AShr(m_Value(X), m_APInt(ShrAmt))))) && - ShrAmt->ult(*ShAmtAPInt)) { - // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) - // The inexact version is deferred to DAGCombine, so we don't hide shl - // behind a bit mask. - Constant *ShiftDiffCst = ConstantInt::get(Ty, *ShAmtAPInt - *ShrAmt); - auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); - return NewShl; + // The inexact versions are deferred to DAGCombine, so we don't hide shl + // behind a bit mask. + const APInt *ShrOp1; + if (match(Op0, m_CombineOr(m_Exact(m_LShr(m_Value(X), m_APInt(ShrOp1))), + m_Exact(m_AShr(m_Value(X), m_APInt(ShrOp1)))))) { + unsigned ShrAmt = ShrOp1->getZExtValue(); + if (ShrAmt < ShAmt) { + // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + return NewShl; + } + if (ShrAmt > ShAmt) { + // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + auto *NewShr = BinaryOperator::Create( + cast(Op0)->getOpcode(), X, ShiftDiff); + NewShr->setIsExact(true); + return NewShr; + } } // If the shifted-out value is known-zero, then this is a NUW shift. diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll index 0c6655c800de..ea4936e37374 100644 --- a/llvm/test/Transforms/InstCombine/shift.ll +++ b/llvm/test/Transforms/InstCombine/shift.ll @@ -806,8 +806,7 @@ define i32 @test46(i32 %a) { define <2 x i32> @test46_splat_vec(<2 x i32> %a) { ; CHECK-LABEL: @test46_splat_vec( -; CHECK-NEXT: [[Y:%.*]] = ashr exact <2 x i32> %a, -; CHECK-NEXT: [[Z:%.*]] = shl nsw <2 x i32> [[Y]], +; CHECK-NEXT: [[Z:%.*]] = ashr exact <2 x i32> %a, ; CHECK-NEXT: ret <2 x i32> [[Z]] ; %y = ashr exact <2 x i32> %a, @@ -831,8 +830,7 @@ define i8 @test47(i8 %a) { define <2 x i8> @test47_splat_vec(<2 x i8> %a) { ; CHECK-LABEL: @test47_splat_vec( -; CHECK-NEXT: [[Y:%.*]] = lshr exact <2 x i8> %a, -; CHECK-NEXT: [[Z:%.*]] = shl nuw nsw <2 x i8> [[Y]], +; CHECK-NEXT: [[Z:%.*]] = lshr exact <2 x i8> %a, ; CHECK-NEXT: ret <2 x i8> [[Z]] ; %y = lshr exact <2 x i8> %a,