[DAG] DAGCombiner::visitVECTOR_SHUFFLE - recognise INSERT_SUBVECTOR patterns

IR typically creates INSERT_SUBVECTOR patterns as a widening of the subvector with undefs to pad to the destination size, followed by a shuffle for the actual insertion - SelectionDAGBuilder has to do something similar for shuffles when source/destination vectors are different sizes.

This combine attempts to recognize these patterns by looking for a shuffle of a subvector (from a CONCAT_VECTORS) that starts at a modulo of its size into an otherwise identity shuffle of the base vector.

This uncovered a couple of target-specific issues as we haven't often created INSERT_SUBVECTOR nodes in generic code - aarch64 could only handle insertions into the bottom of undefs (i.e. a vector widening), and x86-avx512 vXi1 insertion wasn't keeping track of undef elements in the base vector.

Fixes PR50053

Differential Revision: https://reviews.llvm.org/D107068
This commit is contained in:
Simon Pilgrim 2021-08-05 15:04:06 +01:00
parent 38b098be66
commit 2cbf9fd402
9 changed files with 147 additions and 36 deletions

View File

@ -21299,6 +21299,70 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
}
}
// See if we can replace a shuffle with an insert_subvector.
// e.g. v2i32 into v8i32:
// shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
// --> insert_subvector(lhs,rhs1,4).
if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
// Ensure RHS subvectors are legal.
assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
EVT SubVT = RHS.getOperand(0).getValueType();
int NumSubVecs = RHS.getNumOperands();
int NumSubElts = SubVT.getVectorNumElements();
assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
if (!TLI.isTypeLegal(SubVT))
return SDValue();
// Don't bother if we have an unary shuffle (matches undef + LHS elts).
if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
return SDValue();
// Search [NumSubElts] spans for RHS sequence.
// TODO: Can we avoid nested loops to increase performance?
SmallVector<int> InsertionMask(NumElts);
for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
// Reset mask to identity.
std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
// Add subvector insertion.
std::iota(InsertionMask.begin() + SubIdx,
InsertionMask.begin() + SubIdx + NumSubElts,
NumElts + (SubVec * NumSubElts));
// See if the shuffle mask matches the reference insertion mask.
bool MatchingShuffle = true;
for (int i = 0; i != (int)NumElts; ++i) {
int ExpectIdx = InsertionMask[i];
int ActualIdx = Mask[i];
if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
MatchingShuffle = false;
break;
}
}
if (MatchingShuffle)
return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, LHS,
RHS.getOperand(SubVec),
DAG.getVectorIdxConstant(SubIdx, SDLoc(N)));
}
}
return SDValue();
};
ArrayRef<int> Mask = SVN->getMask();
if (N1.getOpcode() == ISD::CONCAT_VECTORS)
if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
return InsertN1;
if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
SmallVector<int> CommuteMask(Mask.begin(), Mask.end());
ShuffleVectorSDNode::commuteMask(CommuteMask);
if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
return InsertN0;
}
}
// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))

View File

@ -905,6 +905,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::CONCAT_VECTORS);
setTargetDAGCombine(ISD::INSERT_SUBVECTOR);
setTargetDAGCombine(ISD::STORE);
if (Subtarget->supportsAddressTopByteIgnored())
setTargetDAGCombine(ISD::LOAD);
@ -13617,6 +13618,48 @@ static SDValue performConcatVectorsCombine(SDNode *N,
RHS));
}
static SDValue
performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
SDValue Vec = N->getOperand(0);
SDValue SubVec = N->getOperand(1);
uint64_t IdxVal = N->getConstantOperandVal(2);
EVT VecVT = Vec.getValueType();
EVT SubVT = SubVec.getValueType();
// Only do this for legal fixed vector types.
if (!VecVT.isFixedLengthVector() ||
!DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
return SDValue();
// Ignore widening patterns.
if (IdxVal == 0 && Vec.isUndef())
return SDValue();
// Subvector must be half the width and an "aligned" insertion.
unsigned NumSubElts = SubVT.getVectorNumElements();
if ((SubVT.getSizeInBits() * 2) != VecVT.getSizeInBits() ||
(IdxVal != 0 && IdxVal != NumSubElts))
return SDValue();
// Fold insert_subvector -> concat_vectors
// insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi))
// insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub)
SDLoc DL(N);
SDValue Lo, Hi;
if (IdxVal == 0) {
Lo = SubVec;
Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec,
DAG.getVectorIdxConstant(NumSubElts, DL));
} else {
Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec,
DAG.getVectorIdxConstant(0, DL));
Hi = SubVec;
}
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo, Hi);
}
static SDValue tryCombineFixedPointConvert(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
@ -16673,6 +16716,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncateCombine(N, DCI, DAG);
case ISD::CONCAT_VECTORS:
return performConcatVectorsCombine(N, DCI, DAG);
case ISD::INSERT_SUBVECTOR:
return performInsertSubvectorCombine(N, DCI, DAG);
case ISD::SELECT:
return performSelectCombine(N, DCI);
case ISD::VSELECT:

View File

@ -6206,14 +6206,21 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
if (ISD::isBuildVectorAllZeros(Vec.getNode())) {
assert(IdxVal != 0 && "Unexpected index");
NumElems = WideOpVT.getVectorNumElements();
unsigned ShiftLeft = NumElems - SubVecNumElems;
unsigned ShiftRight = NumElems - SubVecNumElems - IdxVal;
SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
DAG.getTargetConstant(ShiftLeft, dl, MVT::i8));
if (ShiftRight != 0)
SubVec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, SubVec,
DAG.getTargetConstant(ShiftRight, dl, MVT::i8));
// If upper elements of Vec are known undef, then just shift into place.
if (llvm::all_of(Vec->ops().slice(IdxVal + SubVecNumElems),
[](SDValue V) { return V.isUndef(); })) {
SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
DAG.getTargetConstant(IdxVal, dl, MVT::i8));
} else {
NumElems = WideOpVT.getVectorNumElements();
unsigned ShiftLeft = NumElems - SubVecNumElems;
unsigned ShiftRight = NumElems - SubVecNumElems - IdxVal;
SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
DAG.getTargetConstant(ShiftLeft, dl, MVT::i8));
if (ShiftRight != 0)
SubVec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, SubVec,
DAG.getTargetConstant(ShiftRight, dl, MVT::i8));
}
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, SubVec, ZeroIdx);
}

View File

@ -1794,7 +1794,7 @@ define <2 x i64> @test_concat_v2i64_v2i64_v1i64(<2 x i64> %x, <1 x i64> %y) #0 {
; CHECK-LABEL: test_concat_v2i64_v2i64_v1i64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
; CHECK-NEXT: zip1 v0.2d, v0.2d, v1.2d
; CHECK-NEXT: mov v0.d[1], v1.d[0]
; CHECK-NEXT: ret
entry:
%vecext = extractelement <2 x i64> %x, i32 0

View File

@ -14,8 +14,8 @@ define void @func(<4 x float> %a, <16 x i8> %b, <16 x i8> %c, <8 x float> %d, <8
; CHECK-NEXT: vaddps %xmm1, %xmm0, %xmm0
; CHECK-NEXT: vaddps %xmm0, %xmm0, %xmm0
; CHECK-NEXT: vmulps %xmm0, %xmm0, %xmm0
; CHECK-NEXT: vperm2f128 {{.*#+}} ymm0 = zero,zero,ymm0[0,1]
; CHECK-NEXT: vxorps %xmm1, %xmm1, %xmm1
; CHECK-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0
; CHECK-NEXT: vaddps %ymm0, %ymm0, %ymm0
; CHECK-NEXT: vhaddps %ymm4, %ymm0, %ymm0
; CHECK-NEXT: vsubps %ymm0, %ymm0, %ymm0

View File

@ -87,7 +87,7 @@ define <8 x i32> @test_x86_avx_vinsertf128_si_256_2(<8 x i32> %a0, <4 x i32> %a1
; CHECK-LABEL: test_x86_avx_vinsertf128_si_256_2:
; CHECK: # %bb.0:
; CHECK-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
; CHECK-NEXT: vblendps $240, %ymm0, %ymm1, %ymm0 # encoding: [0xc4,0xe3,0x75,0x0c,0xc0,0xf0]
; CHECK-NEXT: vblendps $15, %ymm1, %ymm0, %ymm0 # encoding: [0xc4,0xe3,0x7d,0x0c,0xc1,0x0f]
; CHECK-NEXT: # ymm0 = ymm1[0,1,2,3],ymm0[4,5,6,7]
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
%res = call <8 x i32> @llvm.x86.avx.vinsertf128.si.256(<8 x i32> %a0, <4 x i32> %a1, i8 2)

View File

@ -695,11 +695,9 @@ define void @PR50053(<4 x i64>* nocapture %0, <4 x i64>* nocapture readonly %1)
; ALL-LABEL: PR50053:
; ALL: # %bb.0:
; ALL-NEXT: vmovaps (%rsi), %ymm0
; ALL-NEXT: vmovaps 32(%rsi), %xmm1
; ALL-NEXT: vmovaps 48(%rsi), %xmm2
; ALL-NEXT: vperm2f128 {{.*#+}} ymm1 = ymm0[0,1],ymm1[0,1]
; ALL-NEXT: vinsertf128 $1, 32(%rsi), %ymm0, %ymm1
; ALL-NEXT: vinsertf128 $0, 48(%rsi), %ymm0, %ymm0
; ALL-NEXT: vmovaps %ymm1, (%rdi)
; ALL-NEXT: vblendps {{.*#+}} ymm0 = ymm2[0,1,2,3],ymm0[4,5,6,7]
; ALL-NEXT: vmovaps %ymm0, 32(%rdi)
; ALL-NEXT: vzeroupper
; ALL-NEXT: retq

View File

@ -14,35 +14,35 @@ define <16 x i64> @pluto(<16 x i64> %arg, <16 x i64> %arg1, <16 x i64> %arg2, <1
; CHECK-NEXT: vmovaps %ymm4, %ymm10
; CHECK-NEXT: vmovaps %ymm3, %ymm9
; CHECK-NEXT: vmovaps %ymm1, %ymm8
; CHECK-NEXT: vmovaps %ymm0, %ymm3
; CHECK-NEXT: vmovaps %ymm0, %ymm4
; CHECK-NEXT: vmovaps 240(%rbp), %ymm1
; CHECK-NEXT: vmovaps 208(%rbp), %ymm4
; CHECK-NEXT: vmovaps 208(%rbp), %ymm3
; CHECK-NEXT: vmovaps 176(%rbp), %ymm0
; CHECK-NEXT: vmovaps 144(%rbp), %ymm0
; CHECK-NEXT: vmovaps 112(%rbp), %ymm11
; CHECK-NEXT: vmovaps 80(%rbp), %ymm11
; CHECK-NEXT: vmovaps 48(%rbp), %ymm11
; CHECK-NEXT: vmovaps 16(%rbp), %ymm11
; CHECK-NEXT: vpblendd {{.*#+}} ymm3 = ymm6[0,1,2,3,4,5],ymm2[6,7]
; CHECK-NEXT: vmovaps %xmm4, %xmm6
; CHECK-NEXT: vpblendd {{.*#+}} ymm4 = ymm6[0,1,2,3,4,5],ymm2[6,7]
; CHECK-NEXT: vmovaps %xmm3, %xmm6
; CHECK-NEXT: # implicit-def: $ymm2
; CHECK-NEXT: vinserti128 $1, %xmm6, %ymm2, %ymm2
; CHECK-NEXT: vpalignr {{.*#+}} ymm0 = ymm3[8,9,10,11,12,13,14,15],ymm0[0,1,2,3,4,5,6,7],ymm3[24,25,26,27,28,29,30,31],ymm0[16,17,18,19,20,21,22,23]
; CHECK-NEXT: vpalignr {{.*#+}} ymm0 = ymm4[8,9,10,11,12,13,14,15],ymm0[0,1,2,3,4,5,6,7],ymm4[24,25,26,27,28,29,30,31],ymm0[16,17,18,19,20,21,22,23]
; CHECK-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,0]
; CHECK-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm2[4,5],ymm0[6,7]
; CHECK-NEXT: vextracti128 $1, %ymm7, %xmm2
; CHECK-NEXT: vmovq {{.*#+}} xmm6 = xmm2[0],zero
; CHECK-NEXT: # implicit-def: $ymm2
; CHECK-NEXT: vmovaps %xmm6, %xmm2
; CHECK-NEXT: # kill: def $xmm3 killed $xmm3 killed $ymm3
; CHECK-NEXT: vinserti128 $1, %xmm3, %ymm2, %ymm2
; CHECK-NEXT: vmovaps %xmm7, %xmm3
; CHECK-NEXT: vpslldq {{.*#+}} xmm6 = zero,zero,zero,zero,zero,zero,zero,zero,xmm3[0,1,2,3,4,5,6,7]
; CHECK-NEXT: # implicit-def: $ymm3
; CHECK-NEXT: vmovaps %xmm6, %xmm3
; CHECK-NEXT: vpalignr {{.*#+}} ymm4 = ymm4[8,9,10,11,12,13,14,15],ymm5[0,1,2,3,4,5,6,7],ymm4[24,25,26,27,28,29,30,31],ymm5[16,17,18,19,20,21,22,23]
; CHECK-NEXT: vpermq {{.*#+}} ymm4 = ymm4[0,1,0,3]
; CHECK-NEXT: vpblendd {{.*#+}} ymm3 = ymm3[0,1,2,3],ymm4[4,5,6,7]
; CHECK-NEXT: # kill: def $xmm4 killed $xmm4 killed $ymm4
; CHECK-NEXT: vinserti128 $1, %xmm4, %ymm2, %ymm2
; CHECK-NEXT: vmovaps %xmm7, %xmm4
; CHECK-NEXT: vpslldq {{.*#+}} xmm6 = zero,zero,zero,zero,zero,zero,zero,zero,xmm4[0,1,2,3,4,5,6,7]
; CHECK-NEXT: # implicit-def: $ymm4
; CHECK-NEXT: vmovaps %xmm6, %xmm4
; CHECK-NEXT: vpalignr {{.*#+}} ymm3 = ymm3[8,9,10,11,12,13,14,15],ymm5[0,1,2,3,4,5,6,7],ymm3[24,25,26,27,28,29,30,31],ymm5[16,17,18,19,20,21,22,23]
; CHECK-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,1,0,3]
; CHECK-NEXT: vblendps {{.*#+}} ymm3 = ymm4[0,1,2,3],ymm3[4,5,6,7]
; CHECK-NEXT: vpblendd {{.*#+}} ymm1 = ymm7[0,1],ymm1[2,3],ymm7[4,5,6,7]
; CHECK-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,1,1,3]
; CHECK-NEXT: vpshufd {{.*#+}} ymm4 = ymm5[0,1,0,1,4,5,4,5]

View File

@ -563,9 +563,7 @@ define <16 x float> @insert_sub0_0(<16 x float> %base, <4 x float> %sub1, <4 x f
define <16 x float> @insert_sub1_12(<16 x float> %base, <4 x float> %sub1, <4 x float> %sub2, <4 x float> %sub3, <4 x float> %sub4) {
; ALL-LABEL: insert_sub1_12:
; ALL: # %bb.0:
; ALL-NEXT: vinsertf32x4 $1, %xmm2, %zmm0, %zmm1
; ALL-NEXT: vmovapd {{.*#+}} zmm2 = [0,1,2,3,4,5,10,11]
; ALL-NEXT: vpermt2pd %zmm1, %zmm2, %zmm0
; ALL-NEXT: vinsertf32x4 $3, %xmm2, %zmm0, %zmm0
; ALL-NEXT: retq
%sub12 = shufflevector <4 x float> %sub1, <4 x float> %sub2, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
%sub34 = shufflevector <4 x float> %sub3, <4 x float> %sub4, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
@ -591,8 +589,8 @@ define <16 x float> @insert_sub2_4(<16 x float> %base, <4 x float> %sub1, <4 x f
define <16 x float> @insert_sub01_8(<16 x float> %base, <4 x float> %sub1, <4 x float> %sub2, <4 x float> %sub3, <4 x float> %sub4) {
; ALL-LABEL: insert_sub01_8:
; ALL: # %bb.0:
; ALL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; ALL-NEXT: vinsertf32x4 $1, %xmm2, %zmm1, %zmm1
; ALL-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
; ALL-NEXT: vinsertf128 $1, %xmm2, %ymm1, %ymm1
; ALL-NEXT: vinsertf64x4 $1, %ymm1, %zmm0, %zmm0
; ALL-NEXT: retq
%sub12 = shufflevector <4 x float> %sub1, <4 x float> %sub2, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
@ -607,8 +605,7 @@ define <16 x float> @insert_sub23_0(<16 x float> %base, <4 x float> %sub1, <4 x
; ALL: # %bb.0:
; ALL-NEXT: # kill: def $xmm3 killed $xmm3 def $ymm3
; ALL-NEXT: vinsertf128 $1, %xmm4, %ymm3, %ymm1
; ALL-NEXT: vinsertf64x4 $1, %ymm1, %zmm0, %zmm1
; ALL-NEXT: vshuff64x2 {{.*#+}} zmm0 = zmm1[4,5,6,7],zmm0[4,5,6,7]
; ALL-NEXT: vinsertf64x4 $0, %ymm1, %zmm0, %zmm0
; ALL-NEXT: retq
%sub12 = shufflevector <4 x float> %sub1, <4 x float> %sub2, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
%sub34 = shufflevector <4 x float> %sub3, <4 x float> %sub4, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>