From ae3315af075975873df7a33e3835f2170f860b46 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Wed, 2 Oct 2019 23:02:12 +0000 Subject: [PATCH] [InstCombine] Bypass high bit extract before variable sign-extension (PR43523) https://rise4fun.com/Alive/8BY - valid for lshr+trunc+variable sext https://rise4fun.com/Alive/7jk - the variable sext can be redundant https://rise4fun.com/Alive/Qslu - 'exact'-ness of first shift can be preserver https://rise4fun.com/Alive/IF63 - without trunc we could view this as more general "drop redundant mask before right-shift", but let's handle it here for now https://rise4fun.com/Alive/iip - likewise, without trunc, variable sext can be redundant. There's more patterns for sure - e.g. we can have 'lshr' as the final shift, but that might be best handled by some more generic transform, e.g. "drop redundant masking before right-shift" (PR42456) I'm singling-out this sext patch because you can only extract high bits with `*shr` (unlike abstract bit masking), and i *know* this fold is wanted by existing code. I don't believe there is much to review here, so i'm gonna opt into post-review mode here. https://bugs.llvm.org/show_bug.cgi?id=43523 llvm-svn: 373542 --- .../InstCombine/InstCombineInternal.h | 2 + .../InstCombine/InstCombineShifts.cpp | 72 +++++++++++++++++++ ...signext-of-variable-high-bit-extraction.ll | 43 +++++------ 3 files changed, 91 insertions(+), 26 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 673099436b79..dcdbee15fe56 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -351,6 +351,8 @@ public: Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); + Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr); Instruction *visitAShr(BinaryOperator &I); Instruction *visitLShr(BinaryOperator &I); Instruction *commonShiftTransforms(BinaryOperator &I); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index bc4affbecdfa..9d96ddc4040d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1039,6 +1039,75 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return nullptr; } +Instruction * +InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr) { + assert(OldAShr.getOpcode() == Instruction::AShr && + "Must be called with arithmetic right-shift instruction only."); + + // Check that constant C is a splat of the element-wise bitwidth of V. + auto BitWidthSplat = [](Constant *C, Value *V) { + return match( + C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + V->getType()->getScalarSizeInBits()))); + }; + + // It should look like variable-length sign-extension on the outside: + // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits) + Value *NBits; + Instruction *MaybeTrunc; + Constant *C1, *C2; + if (!match(&OldAShr, + m_AShr(m_Shl(m_Instruction(MaybeTrunc), + m_ZExtOrSelf(m_Sub(m_Constant(C1), + m_ZExtOrSelf(m_Value(NBits))))), + m_ZExtOrSelf(m_Sub(m_Constant(C2), + m_ZExtOrSelf(m_Deferred(NBits)))))) || + !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr)) + return nullptr; + + // There may or may not be a truncation after outer two shifts. + Instruction *HighBitExtract; + match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract))); + bool HadTrunc = MaybeTrunc != HighBitExtract; + + // And finally, the innermost part of the pattern must be a right-shift. + Value *X, *NumLowBitsToSkip; + if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip)))) + return nullptr; + + // Said right-shift must extract high NBits bits - C0 must be it's bitwidth. + Constant *C0; + if (!match(NumLowBitsToSkip, + m_ZExtOrSelf( + m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) || + !BitWidthSplat(C0, HighBitExtract)) + return nullptr; + + // Since the NBits is identical for all shifts, if the outermost and + // innermost shifts are identical, then outermost shifts are redundant. + // If we had truncation, do keep it though. + if (HighBitExtract->getOpcode() == OldAShr.getOpcode()) + return replaceInstUsesWith(OldAShr, MaybeTrunc); + + // Else, if there was a truncation, then we need to ensure that one + // instruction will go away. + if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Finally, bypass two innermost shifts, and perform the outermost shift on + // the operands of the innermost shift. + Instruction *NewAShr = + BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip); + NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType()); +} + Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) @@ -1113,6 +1182,9 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { } } + if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) + return R; + // See if we can turn a signed shr into an unsigned shr. if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); diff --git a/llvm/test/Transforms/InstCombine/variable-signext-of-variable-high-bit-extraction.ll b/llvm/test/Transforms/InstCombine/variable-signext-of-variable-high-bit-extraction.ll index 61343c7feb8a..a5f38735a373 100644 --- a/llvm/test/Transforms/InstCombine/variable-signext-of-variable-high-bit-extraction.ll +++ b/llvm/test/Transforms/InstCombine/variable-signext-of-variable-high-bit-extraction.ll @@ -17,8 +17,8 @@ define i32 @t0(i64 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[EXTRACTED_NARROW]]) ; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW:%.*]] = sub i32 32, [[NBITS]] ; CHECK-NEXT: call void @use32(i32 [[NUM_HIGH_BITS_TO_SMEAR_NARROW]]) -; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i32 [[EXTRACTED_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i32 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr i64 [[DATA]], [[SKIP_HIGH_WIDE]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = trunc i64 [[TMP1]] to i32 ; CHECK-NEXT: ret i32 [[SIGNEXTENDED]] ; %skip_high = sub i32 64, %nbits @@ -51,8 +51,8 @@ define i32 @t0_zext_of_nbits(i64 %data, i8 %nbits_narrow) { ; CHECK-NEXT: call void @use16(i16 [[NUM_HIGH_BITS_TO_SMEAR_NARROW_NARROW]]) ; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW:%.*]] = zext i16 [[NUM_HIGH_BITS_TO_SMEAR_NARROW_NARROW]] to i32 ; CHECK-NEXT: call void @use32(i32 [[NUM_HIGH_BITS_TO_SMEAR_NARROW]]) -; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i32 [[EXTRACTED_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i32 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr i64 [[DATA]], [[SKIP_HIGH_WIDE]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = trunc i64 [[TMP1]] to i32 ; CHECK-NEXT: ret i32 [[SIGNEXTENDED]] ; %nbits = zext i8 %nbits_narrow to i16 @@ -85,8 +85,8 @@ define i32 @t0_exact(i64 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[EXTRACTED_NARROW]]) ; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW:%.*]] = sub i32 32, [[NBITS]] ; CHECK-NEXT: call void @use32(i32 [[NUM_HIGH_BITS_TO_SMEAR_NARROW]]) -; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i32 [[EXTRACTED_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i32 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i64 [[DATA]], [[SKIP_HIGH_WIDE]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = trunc i64 [[TMP1]] to i32 ; CHECK-NEXT: ret i32 [[SIGNEXTENDED]] ; %skip_high = sub i32 64, %nbits @@ -118,8 +118,7 @@ define i32 @t1_redundant_sext(i64 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[NUM_HIGH_BITS_TO_SMEAR_NARROW]]) ; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i32 [[EXTRACTED_WITH_SIGNEXTENSION_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT_POSITIONED]]) -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i32 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] -; CHECK-NEXT: ret i32 [[SIGNEXTENDED]] +; CHECK-NEXT: ret i32 [[EXTRACTED_WITH_SIGNEXTENSION_NARROW]] ; %skip_high = sub i32 64, %nbits call void @use32(i32 %skip_high) @@ -147,7 +146,7 @@ define i64 @t2_notrunc(i64 %data, i64 %nbits) { ; CHECK-NEXT: call void @use64(i64 [[NUM_HIGH_BITS_TO_SMEAR]]) ; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i64 [[EXTRACTED]], [[NUM_HIGH_BITS_TO_SMEAR]] ; CHECK-NEXT: call void @use64(i64 [[SIGNBIT_POSITIONED]]) -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i64 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i64 [[DATA]], [[SKIP_HIGH]] ; CHECK-NEXT: ret i64 [[SIGNEXTENDED]] ; %skip_high = sub i64 64, %nbits @@ -172,8 +171,7 @@ define i64 @t3_notrunc_redundant_sext(i64 %data, i64 %nbits) { ; CHECK-NEXT: call void @use64(i64 [[NUM_HIGH_BITS_TO_SMEAR]]) ; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i64 [[EXTRACTED]], [[NUM_HIGH_BITS_TO_SMEAR]] ; CHECK-NEXT: call void @use64(i64 [[SIGNBIT_POSITIONED]]) -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i64 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR]] -; CHECK-NEXT: ret i64 [[SIGNEXTENDED]] +; CHECK-NEXT: ret i64 [[EXTRACTED]] ; %skip_high = sub i64 64, %nbits call void @use64(i64 %skip_high) @@ -191,11 +189,8 @@ define <2 x i32> @t4_vec(<2 x i64> %data, <2 x i32> %nbits) { ; CHECK-LABEL: @t4_vec( ; CHECK-NEXT: [[SKIP_HIGH:%.*]] = sub <2 x i32> , [[NBITS:%.*]] ; CHECK-NEXT: [[SKIP_HIGH_WIDE:%.*]] = zext <2 x i32> [[SKIP_HIGH]] to <2 x i64> -; CHECK-NEXT: [[EXTRACTED:%.*]] = lshr <2 x i64> [[DATA:%.*]], [[SKIP_HIGH_WIDE]] -; CHECK-NEXT: [[EXTRACTED_NARROW:%.*]] = trunc <2 x i64> [[EXTRACTED]] to <2 x i32> -; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW:%.*]] = sub <2 x i32> , [[NBITS]] -; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl <2 x i32> [[EXTRACTED_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr <2 x i32> [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr <2 x i64> [[DATA:%.*]], [[SKIP_HIGH_WIDE]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = trunc <2 x i64> [[TMP1]] to <2 x i32> ; CHECK-NEXT: ret <2 x i32> [[SIGNEXTENDED]] ; %skip_high = sub <2 x i32> , %nbits @@ -212,12 +207,8 @@ define <3 x i32> @t5_vec_undef(<3 x i64> %data, <3 x i32> %nbits) { ; CHECK-LABEL: @t5_vec_undef( ; CHECK-NEXT: [[SKIP_HIGH:%.*]] = sub <3 x i32> , [[NBITS:%.*]] ; CHECK-NEXT: [[SKIP_HIGH_WIDE:%.*]] = zext <3 x i32> [[SKIP_HIGH]] to <3 x i64> -; CHECK-NEXT: [[EXTRACTED:%.*]] = lshr <3 x i64> [[DATA:%.*]], [[SKIP_HIGH_WIDE]] -; CHECK-NEXT: [[EXTRACTED_NARROW:%.*]] = trunc <3 x i64> [[EXTRACTED]] to <3 x i32> -; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW0:%.*]] = sub <3 x i32> , [[NBITS]] -; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW1:%.*]] = sub <3 x i32> , [[NBITS]] -; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl <3 x i32> [[EXTRACTED_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW0]] -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr <3 x i32> [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW1]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr <3 x i64> [[DATA:%.*]], [[SKIP_HIGH_WIDE]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = trunc <3 x i64> [[TMP1]] to <3 x i32> ; CHECK-NEXT: ret <3 x i32> [[SIGNEXTENDED]] ; %skip_high = sub <3 x i32> , %nbits @@ -244,8 +235,8 @@ define i32 @t6_extrause_good0(i64 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[EXTRACTED_NARROW]]) ; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW:%.*]] = sub i32 32, [[NBITS]] ; CHECK-NEXT: call void @use32(i32 [[NUM_HIGH_BITS_TO_SMEAR_NARROW]]) -; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i32 [[EXTRACTED_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i32 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr i64 [[DATA]], [[SKIP_HIGH_WIDE]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = trunc i64 [[TMP1]] to i32 ; CHECK-NEXT: ret i32 [[SIGNEXTENDED]] ; %skip_high = sub i32 64, %nbits @@ -274,10 +265,10 @@ define i32 @t7_extrause_good1(i64 %data, i32 %nbits) { ; CHECK-NEXT: call void @use32(i32 [[EXTRACTED_NARROW]]) ; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW0:%.*]] = sub i32 32, [[NBITS]] ; CHECK-NEXT: call void @use32(i32 [[NUM_HIGH_BITS_TO_SMEAR_NARROW0]]) -; CHECK-NEXT: [[NUM_HIGH_BITS_TO_SMEAR_NARROW1:%.*]] = sub i32 32, [[NBITS]] ; CHECK-NEXT: [[SIGNBIT_POSITIONED:%.*]] = shl i32 [[EXTRACTED_NARROW]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW0]] ; CHECK-NEXT: call void @use32(i32 [[SIGNBIT_POSITIONED]]) -; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = ashr i32 [[SIGNBIT_POSITIONED]], [[NUM_HIGH_BITS_TO_SMEAR_NARROW1]] +; CHECK-NEXT: [[TMP1:%.*]] = ashr i64 [[DATA]], [[SKIP_HIGH_WIDE]] +; CHECK-NEXT: [[SIGNEXTENDED:%.*]] = trunc i64 [[TMP1]] to i32 ; CHECK-NEXT: ret i32 [[SIGNEXTENDED]] ; %skip_high = sub i32 64, %nbits