[DAG] Fold shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))

Fold shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))) -> bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))

Attempt to fold from a shuffle of a pair of binops to a binop of shuffles, as long as one/both of the binop sources are also shuffles that can be merged with the outer shuffle. This should guarantee that we remove one binop without introducing any additional shuffles.

Technically there's potential for a merged shuffle's lowering to be poorer than the original shuffle, but it could also be better, and I'm not seeing any regressions as long as we keep the 'don't merge splats' rule already present in MergeInnerShuffle.

This expands and generalizes an existing X86 combine and attempts to merge either of each binop's sources (with an on-the-fly commutation of the shuffle mask) - we couldn't do that in the x86 version as it had to stay in a form that DAGCombine's MergeInnerShuffle would still recognise.

Differential Revision: https://reviews.llvm.org/D96345
This commit is contained in:
Simon Pilgrim 2021-02-16 15:24:23 +00:00
parent c320e8196a
commit 5dfba562dd
4 changed files with 136 additions and 96 deletions

View File

@ -20919,11 +20919,13 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
// Compute the combined shuffle mask for a shuffle with SV0 as the first
// operand, and SV1 as the second operand.
// i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask).
// i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
// Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
auto MergeInnerShuffle =
[NumElts, &VT](ShuffleVectorSDNode *SVN, ShuffleVectorSDNode *OtherSVN,
SDValue N1, const TargetLowering &TLI, SDValue &SV0,
SDValue &SV1, SmallVectorImpl<int> &Mask) -> bool {
[NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
ShuffleVectorSDNode *OtherSVN, SDValue N1,
const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
SmallVectorImpl<int> &Mask) -> bool {
// Don't try to fold splats; they're likely to simplify somehow, or they
// might be free.
if (OtherSVN->isSplat())
@ -20940,6 +20942,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
continue;
}
if (Commute)
Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
SDValue CurrentVec;
if (Idx < (int)NumElts) {
// This shuffle index refers to the inner shuffle N0. Lookup the inner
@ -21045,7 +21050,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
SDValue SV0, SV1;
SmallVector<int, 4> Mask;
if (MergeInnerShuffle(SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) {
if (MergeInnerShuffle(false, SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) {
// Check if all indices in Mask are Undef. In case, propagate Undef.
if (llvm::all_of(Mask, [](int M) { return M < 0; }))
return DAG.getUNDEF(VT);
@ -21055,6 +21060,77 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
}
}
// Merge shuffles through binops if we are able to merge it with at least one
// other shuffles.
// shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
unsigned SrcOpcode = N0.getOpcode();
if (SrcOpcode == N1.getOpcode() && TLI.isBinOp(SrcOpcode) &&
N->isOnlyUserOf(N0.getNode()) && N->isOnlyUserOf(N1.getNode())) {
SDValue Op00 = N0.getOperand(0);
SDValue Op10 = N1.getOperand(0);
SDValue Op01 = N0.getOperand(1);
SDValue Op11 = N1.getOperand(1);
// TODO: We might be able to relax the VT check but we don't currently
// have any isBinOp() that has different result/ops VTs so play safe until
// we have test coverage.
if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
Op01.getValueType() == VT && Op11.getValueType() == VT &&
(Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
SmallVectorImpl<int> &Mask, bool LeftOp,
bool Commute) {
SDValue InnerN = Commute ? N1 : N0;
SDValue Op0 = LeftOp ? Op00 : Op01;
SDValue Op1 = LeftOp ? Op10 : Op11;
if (Commute)
std::swap(Op0, Op1);
return Op0.getOpcode() == ISD::VECTOR_SHUFFLE &&
InnerN->isOnlyUserOf(Op0.getNode()) &&
MergeInnerShuffle(Commute, SVN, cast<ShuffleVectorSDNode>(Op0),
Op1, TLI, SV0, SV1, Mask) &&
llvm::none_of(Mask, [](int M) { return M < 0; });
};
// Ensure we don't increase the number of shuffles - we must merge a
// shuffle from at least one of the LHS and RHS ops.
bool MergedLeft = false;
SDValue LeftSV0, LeftSV1;
SmallVector<int, 4> LeftMask;
if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
MergedLeft = true;
} else {
LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
LeftSV0 = Op00, LeftSV1 = Op10;
}
bool MergedRight = false;
SDValue RightSV0, RightSV1;
SmallVector<int, 4> RightMask;
if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
MergedRight = true;
} else {
RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
RightSV0 = Op01, RightSV1 = Op11;
}
if (MergedLeft || MergedRight) {
SDLoc DL(N);
SDValue LHS =
DAG.getVectorShuffle(VT, DL, LeftSV0, LeftSV1, LeftMask);
SDValue RHS =
DAG.getVectorShuffle(VT, DL, RightSV0, RightSV1, RightMask);
return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
}
}
}
}
if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
return V;

View File

@ -38029,34 +38029,6 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
if (SDValue HAddSub = foldShuffleOfHorizOp(N, DAG))
return HAddSub;
// Merge shuffles through binops if its likely we'll be able to merge it
// with other shuffles (as long as they aren't splats).
// shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
// TODO: We might be able to move this to DAGCombiner::visitVECTOR_SHUFFLE.
if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N)) {
unsigned SrcOpcode = N->getOperand(0).getOpcode();
if (SrcOpcode == N->getOperand(1).getOpcode() && TLI.isBinOp(SrcOpcode) &&
N->isOnlyUserOf(N->getOperand(0).getNode()) &&
N->isOnlyUserOf(N->getOperand(1).getNode())) {
SDValue Op00 = N->getOperand(0).getOperand(0);
SDValue Op10 = N->getOperand(1).getOperand(0);
SDValue Op01 = N->getOperand(0).getOperand(1);
SDValue Op11 = N->getOperand(1).getOperand(1);
auto *SVN00 = dyn_cast<ShuffleVectorSDNode>(Op00);
auto *SVN10 = dyn_cast<ShuffleVectorSDNode>(Op10);
auto *SVN01 = dyn_cast<ShuffleVectorSDNode>(Op01);
auto *SVN11 = dyn_cast<ShuffleVectorSDNode>(Op11);
if (((SVN00 && !SVN00->isSplat()) || (SVN10 && !SVN10->isSplat())) &&
((SVN01 && !SVN01->isSplat()) || (SVN11 && !SVN11->isSplat()))) {
SDLoc DL(N);
ArrayRef<int> Mask = SVN->getMask();
SDValue LHS = DAG.getVectorShuffle(VT, DL, Op00, Op10, Mask);
SDValue RHS = DAG.getVectorShuffle(VT, DL, Op01, Op11, Mask);
return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
}
}
}
}
// Attempt to combine into a vector load/broadcast.

View File

@ -9,47 +9,41 @@
define i4 @v4i64(<4 x i64> %a, <4 x i64> %b, <4 x i64> %c, <4 x i64> %d) {
; SSE2-SSSE3-LABEL: v4i64:
; SSE2-SSSE3: # %bb.0:
; SSE2-SSSE3-NEXT: movdqa {{.*#+}} xmm8 = [2147483648,2147483648]
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm3
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm1
; SSE2-SSSE3-NEXT: movdqa %xmm1, %xmm9
; SSE2-SSSE3-NEXT: pcmpgtd %xmm3, %xmm9
; SSE2-SSSE3-NEXT: movdqa {{.*#+}} xmm9 = [2147483648,2147483648]
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm3
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm1
; SSE2-SSSE3-NEXT: movdqa %xmm1, %xmm10
; SSE2-SSSE3-NEXT: pcmpgtd %xmm3, %xmm10
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm2
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm0
; SSE2-SSSE3-NEXT: movdqa %xmm0, %xmm8
; SSE2-SSSE3-NEXT: pcmpgtd %xmm2, %xmm8
; SSE2-SSSE3-NEXT: movdqa %xmm8, %xmm11
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm11 = xmm11[0,2],xmm10[0,2]
; SSE2-SSSE3-NEXT: pcmpeqd %xmm3, %xmm1
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
; SSE2-SSSE3-NEXT: pand %xmm9, %xmm1
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm3 = xmm9[1,1,3,3]
; SSE2-SSSE3-NEXT: por %xmm1, %xmm3
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm2
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm0
; SSE2-SSSE3-NEXT: movdqa %xmm0, %xmm1
; SSE2-SSSE3-NEXT: pcmpgtd %xmm2, %xmm1
; SSE2-SSSE3-NEXT: pcmpeqd %xmm2, %xmm0
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm2 = xmm0[1,1,3,3]
; SSE2-SSSE3-NEXT: pand %xmm1, %xmm2
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,1,3,3]
; SSE2-SSSE3-NEXT: por %xmm2, %xmm0
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm0 = xmm0[0,2],xmm3[0,2]
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm7
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm5
; SSE2-SSSE3-NEXT: movdqa %xmm5, %xmm1
; SSE2-SSSE3-NEXT: pcmpgtd %xmm7, %xmm1
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,3],xmm1[1,3]
; SSE2-SSSE3-NEXT: andps %xmm11, %xmm0
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm8 = xmm8[1,3],xmm10[1,3]
; SSE2-SSSE3-NEXT: orps %xmm0, %xmm8
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm7
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm5
; SSE2-SSSE3-NEXT: movdqa %xmm5, %xmm0
; SSE2-SSSE3-NEXT: pcmpgtd %xmm7, %xmm0
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm6
; SSE2-SSSE3-NEXT: pxor %xmm9, %xmm4
; SSE2-SSSE3-NEXT: movdqa %xmm4, %xmm1
; SSE2-SSSE3-NEXT: pcmpgtd %xmm6, %xmm1
; SSE2-SSSE3-NEXT: movdqa %xmm1, %xmm2
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm2 = xmm2[0,2],xmm0[0,2]
; SSE2-SSSE3-NEXT: pcmpeqd %xmm7, %xmm5
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm2 = xmm5[1,1,3,3]
; SSE2-SSSE3-NEXT: pand %xmm1, %xmm2
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
; SSE2-SSSE3-NEXT: por %xmm2, %xmm1
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm6
; SSE2-SSSE3-NEXT: pxor %xmm8, %xmm4
; SSE2-SSSE3-NEXT: movdqa %xmm4, %xmm2
; SSE2-SSSE3-NEXT: pcmpgtd %xmm6, %xmm2
; SSE2-SSSE3-NEXT: pcmpeqd %xmm6, %xmm4
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm3 = xmm4[1,1,3,3]
; SSE2-SSSE3-NEXT: pand %xmm2, %xmm3
; SSE2-SSSE3-NEXT: pshufd {{.*#+}} xmm2 = xmm2[1,1,3,3]
; SSE2-SSSE3-NEXT: por %xmm3, %xmm2
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm2 = xmm2[0,2],xmm1[0,2]
; SSE2-SSSE3-NEXT: andps %xmm0, %xmm2
; SSE2-SSSE3-NEXT: movmskps %xmm2, %eax
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm4 = xmm4[1,3],xmm5[1,3]
; SSE2-SSSE3-NEXT: andps %xmm2, %xmm4
; SSE2-SSSE3-NEXT: shufps {{.*#+}} xmm1 = xmm1[1,3],xmm0[1,3]
; SSE2-SSSE3-NEXT: orps %xmm4, %xmm1
; SSE2-SSSE3-NEXT: andps %xmm8, %xmm1
; SSE2-SSSE3-NEXT: movmskps %xmm1, %eax
; SSE2-SSSE3-NEXT: # kill: def $al killed $al killed $eax
; SSE2-SSSE3-NEXT: retq
;

View File

@ -8,36 +8,34 @@ define <4 x i64> @PR45808(<4 x i64> %0, <4 x i64> %1) {
; SSE2-LABEL: PR45808:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa {{.*#+}} xmm4 = [2147483648,2147483648]
; SSE2-NEXT: movdqa %xmm3, %xmm5
; SSE2-NEXT: pxor %xmm4, %xmm5
; SSE2-NEXT: movdqa %xmm3, %xmm9
; SSE2-NEXT: pxor %xmm4, %xmm9
; SSE2-NEXT: movdqa %xmm1, %xmm6
; SSE2-NEXT: pxor %xmm4, %xmm6
; SSE2-NEXT: movdqa %xmm6, %xmm7
; SSE2-NEXT: pcmpgtd %xmm5, %xmm7
; SSE2-NEXT: pcmpeqd %xmm5, %xmm6
; SSE2-NEXT: pshufd {{.*#+}} xmm5 = xmm6[1,1,3,3]
; SSE2-NEXT: pand %xmm7, %xmm5
; SSE2-NEXT: pshufd {{.*#+}} xmm6 = xmm7[1,1,3,3]
; SSE2-NEXT: por %xmm5, %xmm6
; SSE2-NEXT: movdqa %xmm2, %xmm5
; SSE2-NEXT: pxor %xmm4, %xmm5
; SSE2-NEXT: movdqa %xmm6, %xmm8
; SSE2-NEXT: pcmpgtd %xmm9, %xmm8
; SSE2-NEXT: movdqa %xmm2, %xmm7
; SSE2-NEXT: pxor %xmm4, %xmm7
; SSE2-NEXT: pxor %xmm0, %xmm4
; SSE2-NEXT: movdqa %xmm4, %xmm7
; SSE2-NEXT: pcmpgtd %xmm5, %xmm7
; SSE2-NEXT: pcmpeqd %xmm5, %xmm4
; SSE2-NEXT: movdqa %xmm4, %xmm5
; SSE2-NEXT: pcmpgtd %xmm7, %xmm5
; SSE2-NEXT: movdqa %xmm5, %xmm10
; SSE2-NEXT: shufps {{.*#+}} xmm10 = xmm10[0,2],xmm8[0,2]
; SSE2-NEXT: pcmpeqd %xmm9, %xmm6
; SSE2-NEXT: pcmpeqd %xmm7, %xmm4
; SSE2-NEXT: shufps {{.*#+}} xmm4 = xmm4[1,3],xmm6[1,3]
; SSE2-NEXT: andps %xmm10, %xmm4
; SSE2-NEXT: shufps {{.*#+}} xmm5 = xmm5[1,3],xmm8[1,3]
; SSE2-NEXT: orps %xmm4, %xmm5
; SSE2-NEXT: pshufd {{.*#+}} xmm4 = xmm5[2,1,3,3]
; SSE2-NEXT: pxor {{.*}}(%rip), %xmm5
; SSE2-NEXT: psllq $63, %xmm4
; SSE2-NEXT: psrad $31, %xmm4
; SSE2-NEXT: pshufd {{.*#+}} xmm4 = xmm4[1,1,3,3]
; SSE2-NEXT: pand %xmm7, %xmm4
; SSE2-NEXT: pshufd {{.*#+}} xmm5 = xmm7[1,1,3,3]
; SSE2-NEXT: por %xmm4, %xmm5
; SSE2-NEXT: pshufd {{.*#+}} xmm4 = xmm5[0,2,2,3]
; SSE2-NEXT: pxor {{.*}}(%rip), %xmm4
; SSE2-NEXT: psllq $63, %xmm6
; SSE2-NEXT: psrad $31, %xmm6
; SSE2-NEXT: pshufd {{.*#+}} xmm5 = xmm6[1,1,3,3]
; SSE2-NEXT: pand %xmm5, %xmm1
; SSE2-NEXT: pandn %xmm3, %xmm5
; SSE2-NEXT: por %xmm5, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm4[0,1,1,3]
; SSE2-NEXT: pand %xmm4, %xmm1
; SSE2-NEXT: pandn %xmm3, %xmm4
; SSE2-NEXT: por %xmm4, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm5[0,1,1,3]
; SSE2-NEXT: psllq $63, %xmm3
; SSE2-NEXT: psrad $31, %xmm3
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm3[1,1,3,3]