From 9c3138bd6d8b3e303f0f711753506b330ffa8df0 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Tue, 13 Oct 2020 14:35:02 +0100 Subject: [PATCH] [InstCombine] visitTrunc - pass through undefs for trunc(shift(trunc/ext(x),c)) patterns Based on the recent patches D88475 and D88429 where we are losing undef values due to extension/comparisons. I've added a Constant::mergeUndefsWith method that merges the undef scalar/elements from another Constant into a specific Constant. Differential Revision: https://reviews.llvm.org/D88687 --- llvm/include/llvm/IR/Constant.h | 6 ++++ llvm/lib/IR/Constants.cpp | 34 +++++++++++++++++++ .../InstCombine/InstCombineCasts.cpp | 8 ++--- llvm/test/Transforms/InstCombine/cast.ll | 4 +-- .../InstCombine/trunc-shift-trunc.ll | 4 +-- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h index f4cdef2af774..97650c2051ca 100644 --- a/llvm/include/llvm/IR/Constant.h +++ b/llvm/include/llvm/IR/Constant.h @@ -204,6 +204,12 @@ public: /// Try to replace undefined constant C or undefined elements in C with /// Replacement. If no changes are made, the constant C is returned. static Constant *replaceUndefsWith(Constant *C, Constant *Replacement); + + /// Merges undefs of a Constant with another Constant, along with the + /// undefs already present. Other doesn't have to be the same type as C, but + /// both must either be scalars or vectors with the same element count. If no + /// changes are made, the constant C is returned. + static Constant *mergeUndefsWith(Constant *C, Constant *Other); }; } // end namespace llvm diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index 9f83861f2aa6..7eca7dddf4a4 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -737,6 +737,40 @@ Constant *Constant::replaceUndefsWith(Constant *C, Constant *Replacement) { return ConstantVector::get(NewC); } +Constant *Constant::mergeUndefsWith(Constant *C, Constant *Other) { + assert(C && Other && "Expected non-nullptr constant arguments"); + if (match(C, m_Undef())) + return C; + + Type *Ty = C->getType(); + if (match(Other, m_Undef())) + return UndefValue::get(Ty); + + auto *VTy = dyn_cast(Ty); + if (!VTy) + return C; + + Type *EltTy = VTy->getElementType(); + unsigned NumElts = VTy->getNumElements(); + assert(isa(Other->getType()) && + cast(Other->getType())->getNumElements() == NumElts && + "Type mismatch"); + + bool FoundExtraUndef = false; + SmallVector NewC(NumElts); + for (unsigned I = 0; I != NumElts; ++I) { + NewC[I] = C->getAggregateElement(I); + Constant *OtherEltC = Other->getAggregateElement(I); + assert(NewC[I] && OtherEltC && "Unknown vector element"); + if (!match(NewC[I], m_Undef()) && match(OtherEltC, m_Undef())) { + NewC[I] = UndefValue::get(EltTy); + FoundExtraUndef = true; + } + } + if (FoundExtraUndef) + return ConstantVector::get(NewC); + return C; +} //===----------------------------------------------------------------------===// // ConstantInt diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index e259b898351d..478032f56bdf 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -810,8 +810,6 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // If the shift is small enough, all zero bits created by the shift are // removed by the trunc. - // TODO: Support passing through undef shift amounts - these currently get - // clamped to MaxAmt. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { // trunc (lshr (sext A), C) --> ashr A, C @@ -819,6 +817,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) : BinaryOperator::CreateAShr(A, ShAmt); } @@ -841,13 +840,12 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // If the shift is small enough, all zero/sign bits created by the shift are // removed by the trunc. - // TODO: Support passing through undef shift amounts - these currently get - // zero'd by getIntegerCast. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { auto *OldShift = cast(Src); - auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); bool IsExact = OldShift->isExact(); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); Value *Shift = OldShift->getOpcode() == Instruction::AShr ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll index c5f18b4c625e..c7217b9f4dd3 100644 --- a/llvm/test/Transforms/InstCombine/cast.ll +++ b/llvm/test/Transforms/InstCombine/cast.ll @@ -1570,7 +1570,7 @@ define <2 x i8> @trunc_lshr_sext_uniform(<2 x i8> %A) { define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_uniform_undef( -; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], +; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], ; ALL-NEXT: ret <2 x i8> [[D]] ; %B = sext <2 x i8> %A to <2 x i32> @@ -1592,7 +1592,7 @@ define <2 x i8> @trunc_lshr_sext_nonuniform(<2 x i8> %A) { define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef( -; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], +; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], ; ALL-NEXT: ret <3 x i8> [[D]] ; %B = sext <3 x i8> %A to <3 x i32> diff --git a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll index 7a4a9c189727..269b4619974b 100644 --- a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll +++ b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll @@ -45,7 +45,7 @@ define <2 x i8> @trunc_lshr_trunc_nonuniform(<2 x i64> %a) { define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef( -; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; @@ -131,7 +131,7 @@ define <2 x i8> @trunc_ashr_trunc_nonuniform(<2 x i64> %a) { define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef( -; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ;