[InstCombine] visitTrunc - pass through undefs for trunc(shift(trunc/ext(x),c)) patterns

Based on the recent patches D88475 and D88429 where we are losing undef values due to extension/comparisons.

I've added a Constant::mergeUndefsWith method that merges the undef scalar/elements from another Constant into a specific Constant.

Differential Revision: https://reviews.llvm.org/D88687
This commit is contained in:
Simon Pilgrim 2020-10-13 14:35:02 +01:00
parent 2e604d23b4
commit 9c3138bd6d
5 changed files with 47 additions and 9 deletions

View File

@ -204,6 +204,12 @@ public:
/// Try to replace undefined constant C or undefined elements in C with /// Try to replace undefined constant C or undefined elements in C with
/// Replacement. If no changes are made, the constant C is returned. /// Replacement. If no changes are made, the constant C is returned.
static Constant *replaceUndefsWith(Constant *C, Constant *Replacement); static Constant *replaceUndefsWith(Constant *C, Constant *Replacement);
/// Merges undefs of a Constant with another Constant, along with the
/// undefs already present. Other doesn't have to be the same type as C, but
/// both must either be scalars or vectors with the same element count. If no
/// changes are made, the constant C is returned.
static Constant *mergeUndefsWith(Constant *C, Constant *Other);
}; };
} // end namespace llvm } // end namespace llvm

View File

@ -737,6 +737,40 @@ Constant *Constant::replaceUndefsWith(Constant *C, Constant *Replacement) {
return ConstantVector::get(NewC); return ConstantVector::get(NewC);
} }
Constant *Constant::mergeUndefsWith(Constant *C, Constant *Other) {
assert(C && Other && "Expected non-nullptr constant arguments");
if (match(C, m_Undef()))
return C;
Type *Ty = C->getType();
if (match(Other, m_Undef()))
return UndefValue::get(Ty);
auto *VTy = dyn_cast<FixedVectorType>(Ty);
if (!VTy)
return C;
Type *EltTy = VTy->getElementType();
unsigned NumElts = VTy->getNumElements();
assert(isa<FixedVectorType>(Other->getType()) &&
cast<FixedVectorType>(Other->getType())->getNumElements() == NumElts &&
"Type mismatch");
bool FoundExtraUndef = false;
SmallVector<Constant *, 32> NewC(NumElts);
for (unsigned I = 0; I != NumElts; ++I) {
NewC[I] = C->getAggregateElement(I);
Constant *OtherEltC = Other->getAggregateElement(I);
assert(NewC[I] && OtherEltC && "Unknown vector element");
if (!match(NewC[I], m_Undef()) && match(OtherEltC, m_Undef())) {
NewC[I] = UndefValue::get(EltTy);
FoundExtraUndef = true;
}
}
if (FoundExtraUndef)
return ConstantVector::get(NewC);
return C;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConstantInt // ConstantInt

View File

@ -810,8 +810,6 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
// If the shift is small enough, all zero bits created by the shift are // If the shift is small enough, all zero bits created by the shift are
// removed by the trunc. // removed by the trunc.
// TODO: Support passing through undef shift amounts - these currently get
// clamped to MaxAmt.
if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
APInt(SrcWidth, MaxShiftAmt)))) { APInt(SrcWidth, MaxShiftAmt)))) {
// trunc (lshr (sext A), C) --> ashr A, C // trunc (lshr (sext A), C) --> ashr A, C
@ -819,6 +817,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false);
Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
ShAmt = Constant::mergeUndefsWith(ShAmt, C);
return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt)
: BinaryOperator::CreateAShr(A, ShAmt); : BinaryOperator::CreateAShr(A, ShAmt);
} }
@ -841,13 +840,12 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
// If the shift is small enough, all zero/sign bits created by the shift are // If the shift is small enough, all zero/sign bits created by the shift are
// removed by the trunc. // removed by the trunc.
// TODO: Support passing through undef shift amounts - these currently get
// zero'd by getIntegerCast.
if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
APInt(SrcWidth, MaxShiftAmt)))) { APInt(SrcWidth, MaxShiftAmt)))) {
auto *OldShift = cast<Instruction>(Src); auto *OldShift = cast<Instruction>(Src);
auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true);
bool IsExact = OldShift->isExact(); bool IsExact = OldShift->isExact();
auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true);
ShAmt = Constant::mergeUndefsWith(ShAmt, C);
Value *Shift = Value *Shift =
OldShift->getOpcode() == Instruction::AShr OldShift->getOpcode() == Instruction::AShr
? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact)

View File

@ -1570,7 +1570,7 @@ define <2 x i8> @trunc_lshr_sext_uniform(<2 x i8> %A) {
define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) { define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) {
; ALL-LABEL: @trunc_lshr_sext_uniform_undef( ; ALL-LABEL: @trunc_lshr_sext_uniform_undef(
; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], <i8 6, i8 7> ; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], <i8 6, i8 undef>
; ALL-NEXT: ret <2 x i8> [[D]] ; ALL-NEXT: ret <2 x i8> [[D]]
; ;
%B = sext <2 x i8> %A to <2 x i32> %B = sext <2 x i8> %A to <2 x i32>
@ -1592,7 +1592,7 @@ define <2 x i8> @trunc_lshr_sext_nonuniform(<2 x i8> %A) {
define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) { define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) {
; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef( ; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef(
; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], <i8 6, i8 2, i8 7> ; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], <i8 6, i8 2, i8 undef>
; ALL-NEXT: ret <3 x i8> [[D]] ; ALL-NEXT: ret <3 x i8> [[D]]
; ;
%B = sext <3 x i8> %A to <3 x i32> %B = sext <3 x i8> %A to <3 x i32>

View File

@ -45,7 +45,7 @@ define <2 x i8> @trunc_lshr_trunc_nonuniform(<2 x i64> %a) {
define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) { define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) {
; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef( ; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef(
; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], <i64 24, i64 0> ; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], <i64 24, i64 undef>
; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8>
; CHECK-NEXT: ret <2 x i8> [[D]] ; CHECK-NEXT: ret <2 x i8> [[D]]
; ;
@ -131,7 +131,7 @@ define <2 x i8> @trunc_ashr_trunc_nonuniform(<2 x i64> %a) {
define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) { define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) {
; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef( ; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef(
; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], <i64 8, i64 0> ; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], <i64 8, i64 undef>
; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8>
; CHECK-NEXT: ret <2 x i8> [[D]] ; CHECK-NEXT: ret <2 x i8> [[D]]
; ;