[X86] Teach combineLoopSADPattern to handle cases where there is no loop and the add has two absolute difference inputs

Previously we asumed a vector reduction add is part of a loop and one of the input is a phi. But the code in SelectionDAGBuilder that sets vector reduction flag handles more cases than that. It just requires that the use chain ends in a horizontal reduction. And there are no other uses. This means it can handle unrolled reduction loops.

If the initial value of the reduction was 0, an unrolled loop would begin with a vector reduction add that has two sad inputs. Previously we would only transform one side of the add, but for this case we need to transform both sides.

I've created a lambda to reuse some of the code for both sides. And fixed the variables names to remove reference to "phi".

Differential Revision: https://reviews.llvm.org/D50817

llvm-svn: 340478
This commit is contained in:
Craig Topper 2018-08-22 23:19:01 +00:00
parent 903ef6a03f
commit cf9df99d79
2 changed files with 50 additions and 109 deletions

View File

@ -39132,45 +39132,53 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG,
// We know N is a reduction add, which means one of its operands is a phi.
// To match SAD, we need the other operand to be a vector select.
SDValue SelectOp, Phi;
if (Op0.getOpcode() == ISD::VSELECT) {
SelectOp = Op0;
Phi = Op1;
} else if (Op1.getOpcode() == ISD::VSELECT) {
SelectOp = Op1;
Phi = Op0;
} else
if (Op0.getOpcode() != ISD::VSELECT)
std::swap(Op0, Op1);
if (Op0.getOpcode() != ISD::VSELECT)
return SDValue();
auto BuildPSADBW = [&](SDValue Op0, SDValue Op1) {
// SAD pattern detected. Now build a SAD instruction and an addition for
// reduction. Note that the number of elements of the result of SAD is less
// than the number of elements of its input. Therefore, we could only update
// part of elements in the reduction vector.
SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget);
// The output of PSADBW is a vector of i64.
// We need to turn the vector of i64 into a vector of i32.
// If the reduction vector is at least as wide as the psadbw result, just
// bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
// anyway.
MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32);
if (VT.getSizeInBits() >= ResVT.getSizeInBits())
Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
else
Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad);
if (VT.getSizeInBits() > ResVT.getSizeInBits()) {
// Fill the upper elements with zero to match the add width.
SDValue Zero = DAG.getConstant(0, DL, VT);
Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad,
DAG.getIntPtrConstant(0, DL));
}
return Sad;
};
// Check whether we have an abs-diff pattern feeding into the select.
if(!detectZextAbsDiff(SelectOp, Op0, Op1))
SDValue SadOp0, SadOp1;
if (!detectZextAbsDiff(Op0, SadOp0, SadOp1))
return SDValue();
// SAD pattern detected. Now build a SAD instruction and an addition for
// reduction. Note that the number of elements of the result of SAD is less
// than the number of elements of its input. Therefore, we could only update
// part of elements in the reduction vector.
SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget);
Op0 = BuildPSADBW(SadOp0, SadOp1);
// The output of PSADBW is a vector of i64.
// We need to turn the vector of i64 into a vector of i32.
// If the reduction vector is at least as wide as the psadbw result, just
// bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
// anyway.
MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32);
if (VT.getSizeInBits() >= ResVT.getSizeInBits())
Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
else
Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad);
if (VT.getSizeInBits() > ResVT.getSizeInBits()) {
// Fill the upper elements with zero to match the add width.
SDValue Zero = DAG.getConstant(0, DL, VT);
Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad,
DAG.getIntPtrConstant(0, DL));
// It's possible we have a sad on the other side too.
if (Op1.getOpcode() == ISD::VSELECT &&
detectZextAbsDiff(Op1, SadOp0, SadOp1)) {
Op1 = BuildPSADBW(SadOp0, SadOp1);
}
return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi);
return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1);
}
/// Convert vector increment or decrement to sub/add with an all-ones constant:

View File

@ -1511,53 +1511,12 @@ define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* %
; SSE2-LABEL: sad_double_reduction:
; SSE2: # %bb.0: # %bb
; SSE2-NEXT: movdqu (%rdi), %xmm0
; SSE2-NEXT: movdqu (%rsi), %xmm4
; SSE2-NEXT: pxor %xmm5, %xmm5
; SSE2-NEXT: movdqa %xmm0, %xmm3
; SSE2-NEXT: punpckhbw {{.*#+}} xmm3 = xmm3[8],xmm5[8],xmm3[9],xmm5[9],xmm3[10],xmm5[10],xmm3[11],xmm5[11],xmm3[12],xmm5[12],xmm3[13],xmm5[13],xmm3[14],xmm5[14],xmm3[15],xmm5[15]
; SSE2-NEXT: movdqa %xmm3, %xmm1
; SSE2-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm5[0],xmm1[1],xmm5[1],xmm1[2],xmm5[2],xmm1[3],xmm5[3]
; SSE2-NEXT: punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm5[4],xmm3[5],xmm5[5],xmm3[6],xmm5[6],xmm3[7],xmm5[7]
; SSE2-NEXT: punpcklbw {{.*#+}} xmm0 = xmm0[0],xmm5[0],xmm0[1],xmm5[1],xmm0[2],xmm5[2],xmm0[3],xmm5[3],xmm0[4],xmm5[4],xmm0[5],xmm5[5],xmm0[6],xmm5[6],xmm0[7],xmm5[7]
; SSE2-NEXT: movdqa %xmm0, %xmm2
; SSE2-NEXT: punpckhwd {{.*#+}} xmm2 = xmm2[4],xmm5[4],xmm2[5],xmm5[5],xmm2[6],xmm5[6],xmm2[7],xmm5[7]
; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm5[0],xmm0[1],xmm5[1],xmm0[2],xmm5[2],xmm0[3],xmm5[3]
; SSE2-NEXT: movdqa %xmm4, %xmm6
; SSE2-NEXT: punpckhbw {{.*#+}} xmm6 = xmm6[8],xmm5[8],xmm6[9],xmm5[9],xmm6[10],xmm5[10],xmm6[11],xmm5[11],xmm6[12],xmm5[12],xmm6[13],xmm5[13],xmm6[14],xmm5[14],xmm6[15],xmm5[15]
; SSE2-NEXT: movdqa %xmm6, %xmm7
; SSE2-NEXT: punpcklwd {{.*#+}} xmm7 = xmm7[0],xmm5[0],xmm7[1],xmm5[1],xmm7[2],xmm5[2],xmm7[3],xmm5[3]
; SSE2-NEXT: psubd %xmm7, %xmm1
; SSE2-NEXT: punpckhwd {{.*#+}} xmm6 = xmm6[4],xmm5[4],xmm6[5],xmm5[5],xmm6[6],xmm5[6],xmm6[7],xmm5[7]
; SSE2-NEXT: psubd %xmm6, %xmm3
; SSE2-NEXT: punpcklbw {{.*#+}} xmm4 = xmm4[0],xmm5[0],xmm4[1],xmm5[1],xmm4[2],xmm5[2],xmm4[3],xmm5[3],xmm4[4],xmm5[4],xmm4[5],xmm5[5],xmm4[6],xmm5[6],xmm4[7],xmm5[7]
; SSE2-NEXT: movdqa %xmm4, %xmm6
; SSE2-NEXT: punpckhwd {{.*#+}} xmm6 = xmm6[4],xmm5[4],xmm6[5],xmm5[5],xmm6[6],xmm5[6],xmm6[7],xmm5[7]
; SSE2-NEXT: psubd %xmm6, %xmm2
; SSE2-NEXT: punpcklwd {{.*#+}} xmm4 = xmm4[0],xmm5[0],xmm4[1],xmm5[1],xmm4[2],xmm5[2],xmm4[3],xmm5[3]
; SSE2-NEXT: psubd %xmm4, %xmm0
; SSE2-NEXT: movdqa %xmm1, %xmm4
; SSE2-NEXT: psrad $31, %xmm4
; SSE2-NEXT: paddd %xmm4, %xmm1
; SSE2-NEXT: pxor %xmm4, %xmm1
; SSE2-NEXT: movdqa %xmm3, %xmm4
; SSE2-NEXT: psrad $31, %xmm4
; SSE2-NEXT: paddd %xmm4, %xmm3
; SSE2-NEXT: pxor %xmm4, %xmm3
; SSE2-NEXT: movdqa %xmm2, %xmm4
; SSE2-NEXT: psrad $31, %xmm4
; SSE2-NEXT: paddd %xmm4, %xmm2
; SSE2-NEXT: pxor %xmm4, %xmm2
; SSE2-NEXT: paddd %xmm3, %xmm2
; SSE2-NEXT: movdqa %xmm0, %xmm3
; SSE2-NEXT: psrad $31, %xmm3
; SSE2-NEXT: paddd %xmm3, %xmm0
; SSE2-NEXT: pxor %xmm3, %xmm0
; SSE2-NEXT: paddd %xmm1, %xmm0
; SSE2-NEXT: paddd %xmm2, %xmm0
; SSE2-NEXT: movdqu (%rdx), %xmm1
; SSE2-NEXT: movdqu (%rsi), %xmm1
; SSE2-NEXT: psadbw %xmm0, %xmm1
; SSE2-NEXT: movdqu (%rdx), %xmm0
; SSE2-NEXT: movdqu (%rcx), %xmm2
; SSE2-NEXT: psadbw %xmm1, %xmm2
; SSE2-NEXT: paddd %xmm0, %xmm2
; SSE2-NEXT: psadbw %xmm0, %xmm2
; SSE2-NEXT: paddd %xmm1, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[2,3,0,1]
; SSE2-NEXT: paddd %xmm2, %xmm0
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
@ -1567,26 +1526,9 @@ define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* %
;
; AVX1-LABEL: sad_double_reduction:
; AVX1: # %bb.0: # %bb
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm0 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm3 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpsubd %xmm4, %xmm0, %xmm0
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpsubd %xmm4, %xmm1, %xmm1
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpsubd %xmm4, %xmm2, %xmm2
; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
; AVX1-NEXT: vpsubd %xmm4, %xmm3, %xmm3
; AVX1-NEXT: vpabsd %xmm0, %xmm0
; AVX1-NEXT: vpabsd %xmm1, %xmm1
; AVX1-NEXT: vpabsd %xmm2, %xmm2
; AVX1-NEXT: vpaddd %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vpabsd %xmm3, %xmm2
; AVX1-NEXT: vpaddd %xmm0, %xmm2, %xmm0
; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-NEXT: vmovdqu (%rdi), %xmm0
; AVX1-NEXT: vmovdqu (%rdx), %xmm1
; AVX1-NEXT: vpsadbw (%rsi), %xmm0, %xmm0
; AVX1-NEXT: vpsadbw (%rcx), %xmm1, %xmm1
; AVX1-NEXT: vpaddd %xmm0, %xmm1, %xmm0
; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
@ -1597,16 +1539,9 @@ define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* %
;
; AVX2-LABEL: sad_double_reduction:
; AVX2: # %bb.0: # %bb
; AVX2-NEXT: vpmovzxbd {{.*#+}} ymm0 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
; AVX2-NEXT: vpmovzxbd {{.*#+}} ymm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
; AVX2-NEXT: vpmovzxbd {{.*#+}} ymm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
; AVX2-NEXT: vpsubd %ymm2, %ymm0, %ymm0
; AVX2-NEXT: vpmovzxbd {{.*#+}} ymm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
; AVX2-NEXT: vpsubd %ymm2, %ymm1, %ymm1
; AVX2-NEXT: vpabsd %ymm0, %ymm0
; AVX2-NEXT: vpabsd %ymm1, %ymm1
; AVX2-NEXT: vpaddd %ymm0, %ymm1, %ymm0
; AVX2-NEXT: vmovdqu (%rdi), %xmm0
; AVX2-NEXT: vmovdqu (%rdx), %xmm1
; AVX2-NEXT: vpsadbw (%rsi), %xmm0, %xmm0
; AVX2-NEXT: vpsadbw (%rcx), %xmm1, %xmm1
; AVX2-NEXT: vpaddd %ymm0, %ymm1, %ymm0
; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1
@ -1620,11 +1555,9 @@ define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* %
;
; AVX512-LABEL: sad_double_reduction:
; AVX512: # %bb.0: # %bb
; AVX512-NEXT: vpmovzxbd {{.*#+}} zmm0 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
; AVX512-NEXT: vpmovzxbd {{.*#+}} zmm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
; AVX512-NEXT: vpsubd %zmm1, %zmm0, %zmm0
; AVX512-NEXT: vpabsd %zmm0, %zmm0
; AVX512-NEXT: vmovdqu (%rdi), %xmm0
; AVX512-NEXT: vmovdqu (%rdx), %xmm1
; AVX512-NEXT: vpsadbw (%rsi), %xmm0, %xmm0
; AVX512-NEXT: vpsadbw (%rcx), %xmm1, %xmm1
; AVX512-NEXT: vpaddd %zmm0, %zmm1, %zmm0
; AVX512-NEXT: vextracti64x4 $1, %zmm0, %ymm1