[DAGCombine] matchBinOpReduction - add partial reduction matching

This patch adds support for recognizing cases where a larger vector type is being used to reduce just the elements in the lower subvector:

e.g. <8 x i32> reduction pattern in a <16 x i32> vector:

<4,5,6,7,u,u,u,u,u,u,u,u,u,u,u,u>
<2,3,u,u,u,u,u,u,u,u,u,u,u,u,u,u>
<1,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u>

matchBinOpReduction returns the lower extracted subvector in such cases, assuming isExtractSubvectorCheap accepts the extraction.

I've only enabled it for X86 reduction sums so far. I intend to enable it for the bitop/minmax cases in future patches, and eventually I think its worth turning it on all the time. This is mainly just a case of ensuring calls to matchBinOpReduction don't make assumptions on the vector width based on the original vector extraction.

Fixes the x86 partial reduction sum cases in PR33758 and PR42023.

Differential Revision: https://reviews.llvm.org/D65047

llvm-svn: 366933
This commit is contained in:
Simon Pilgrim 2019-07-24 17:29:56 +00:00
parent e8bffd3ff0
commit 7d318b2bb1
4 changed files with 69 additions and 137 deletions

View File

@ -1588,9 +1588,12 @@ public:
/// Extract. The reduction must use one of the opcodes listed in /p
/// CandidateBinOps and on success /p BinOp will contain the matching opcode.
/// Returns the vector that is being reduced on, or SDValue() if a reduction
/// was not matched.
/// was not matched. If \p AllowPartials is set then in the case of a
/// reduction pattern that only matches the first few stages, the extracted
/// subvector of the start of the reduction is returned.
SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps);
ArrayRef<ISD::NodeType> CandidateBinOps,
bool AllowPartials = false);
/// Utility function used by legalize and lowering to
/// "unroll" a vector operation by splitting out the scalars and operating

View File

@ -9005,7 +9005,8 @@ void SDNode::intersectFlagsWith(const SDNodeFlags Flags) {
SDValue
SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps) {
ArrayRef<ISD::NodeType> CandidateBinOps,
bool AllowPartials) {
// The pattern must end in an extract from index 0.
if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
!isNullConstant(Extract->getOperand(1)))
@ -9019,6 +9020,23 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
return Op.getOpcode() == unsigned(BinOp);
}))
return SDValue();
unsigned CandidateBinOp = Op.getOpcode();
// Matching failed - attempt to see if we did enough stages that a partial
// reduction from a subvector is possible.
auto PartialReduction = [&](SDValue Op, unsigned NumSubElts) {
if (!AllowPartials || !Op)
return SDValue();
EVT OpVT = Op.getValueType();
EVT OpSVT = OpVT.getScalarType();
EVT SubVT = EVT::getVectorVT(*getContext(), OpSVT, NumSubElts);
if (!TLI->isExtractSubvectorCheap(SubVT, OpVT, 0))
return SDValue();
BinOp = (ISD::NodeType)CandidateBinOp;
return getNode(
ISD::EXTRACT_SUBVECTOR, SDLoc(Op), SubVT, Op,
getConstant(0, SDLoc(Op), TLI->getVectorIdxTy(getDataLayout())));
};
// At each stage, we're looking for something that looks like:
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
@ -9030,10 +9048,15 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
// <4,5,6,7,u,u,u,u>
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
unsigned CandidateBinOp = Op.getOpcode();
// While a partial reduction match would be:
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
SDValue PrevOp;
for (unsigned i = 0; i < Stages; ++i) {
unsigned MaskEnd = (1 << i);
if (Op.getOpcode() != CandidateBinOp)
return SDValue();
return PartialReduction(PrevOp, MaskEnd);
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
@ -9049,12 +9072,14 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
// The first operand of the shuffle should be the same as the other operand
// of the binop.
if (!Shuffle || Shuffle->getOperand(0) != Op)
return SDValue();
return PartialReduction(PrevOp, MaskEnd);
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index)
if (Shuffle->getMaskElt(Index) != MaskEnd + Index)
return SDValue();
for (int Index = 0; Index < (int)MaskEnd; ++Index)
if (Shuffle->getMaskElt(Index) != (MaskEnd + Index))
return PartialReduction(PrevOp, MaskEnd);
PrevOp = Op;
}
BinOp = (ISD::NodeType)CandidateBinOp;

View File

@ -35688,7 +35688,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
// TODO: Allow FADD with reduction and/or reassociation and no-signed-zeros.
ISD::NodeType Opc;
SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD});
SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD}, true);
if (!Rdx)
return SDValue();
@ -35697,7 +35697,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
"Reduction doesn't end in an extract from index 0");
EVT VT = ExtElt->getValueType(0);
EVT VecVT = ExtElt->getOperand(0).getValueType();
EVT VecVT = Rdx.getValueType();
if (VecVT.getScalarType() != VT)
return SDValue();
@ -35711,14 +35711,14 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
// vXi8 reduction - sum lo/hi halves then use PSADBW.
if (VT == MVT::i8) {
while (Rdx.getValueSizeInBits() > 128) {
EVT RdxVT = Rdx.getValueType();
unsigned HalfSize = RdxVT.getSizeInBits() / 2;
unsigned HalfElts = RdxVT.getVectorNumElements() / 2;
unsigned HalfSize = VecVT.getSizeInBits() / 2;
unsigned HalfElts = VecVT.getVectorNumElements() / 2;
SDValue Lo = extractSubVector(Rdx, 0, DAG, DL, HalfSize);
SDValue Hi = extractSubVector(Rdx, HalfElts, DAG, DL, HalfSize);
Rdx = DAG.getNode(ISD::ADD, DL, Lo.getValueType(), Lo, Hi);
VecVT = Rdx.getValueType();
}
assert(Rdx.getValueType() == MVT::v16i8 && "v16i8 reduction expected");
assert(VecVT == MVT::v16i8 && "v16i8 reduction expected");
SDValue Hi = DAG.getVectorShuffle(
MVT::v16i8, DL, Rdx, Rdx,
@ -35746,15 +35746,14 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
unsigned NumElts = VecVT.getVectorNumElements();
SDValue Hi = extract128BitVector(Rdx, NumElts / 2, DAG, DL);
SDValue Lo = extract128BitVector(Rdx, 0, DAG, DL);
VecVT = EVT::getVectorVT(*DAG.getContext(), VT, NumElts / 2);
Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Hi, Lo);
Rdx = DAG.getNode(HorizOpcode, DL, Lo.getValueType(), Hi, Lo);
VecVT = Rdx.getValueType();
}
if (!((VecVT == MVT::v8i16 || VecVT == MVT::v4i32) && Subtarget.hasSSSE3()) &&
!((VecVT == MVT::v4f32 || VecVT == MVT::v2f64) && Subtarget.hasSSE3()))
return SDValue();
// extract (add (shuf X), X), 0 --> extract (hadd X, X), 0
assert(Rdx.getValueType() == VecVT && "Unexpected reduction match");
unsigned ReductionSteps = Log2_32(VecVT.getVectorNumElements());
for (unsigned i = 0; i != ReductionSteps; ++i)
Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Rdx, Rdx);

View File

@ -1699,8 +1699,7 @@ define i32 @partial_reduction_add_v8i32(<8 x i32> %x) {
;
; AVX-FAST-LABEL: partial_reduction_add_v8i32:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
@ -1741,34 +1740,13 @@ define i32 @partial_reduction_add_v16i32(<16 x i32> %x) {
; AVX-SLOW-NEXT: vzeroupper
; AVX-SLOW-NEXT: retq
;
; AVX1-FAST-LABEL: partial_reduction_add_v16i32:
; AVX1-FAST: # %bb.0:
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-FAST-NEXT: vmovd %xmm0, %eax
; AVX1-FAST-NEXT: vzeroupper
; AVX1-FAST-NEXT: retq
;
; AVX2-FAST-LABEL: partial_reduction_add_v16i32:
; AVX2-FAST: # %bb.0:
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vmovd %xmm0, %eax
; AVX2-FAST-NEXT: vzeroupper
; AVX2-FAST-NEXT: retq
;
; AVX512-FAST-LABEL: partial_reduction_add_v16i32:
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vmovd %xmm0, %eax
; AVX512-FAST-NEXT: vzeroupper
; AVX512-FAST-NEXT: retq
; AVX-FAST-LABEL: partial_reduction_add_v16i32:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
; AVX-FAST-NEXT: retq
%x23 = shufflevector <16 x i32> %x, <16 x i32> undef, <16 x i32> <i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
%x0213 = add <16 x i32> %x, %x23
%x13 = shufflevector <16 x i32> %x0213, <16 x i32> undef, <16 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
@ -2010,8 +1988,7 @@ define i32 @hadd32_8(<8 x i32> %x225) {
;
; AVX-FAST-LABEL: hadd32_8:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
@ -2052,34 +2029,13 @@ define i32 @hadd32_16(<16 x i32> %x225) {
; AVX-SLOW-NEXT: vzeroupper
; AVX-SLOW-NEXT: retq
;
; AVX1-FAST-LABEL: hadd32_16:
; AVX1-FAST: # %bb.0:
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-FAST-NEXT: vmovd %xmm0, %eax
; AVX1-FAST-NEXT: vzeroupper
; AVX1-FAST-NEXT: retq
;
; AVX2-FAST-LABEL: hadd32_16:
; AVX2-FAST: # %bb.0:
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vmovd %xmm0, %eax
; AVX2-FAST-NEXT: vzeroupper
; AVX2-FAST-NEXT: retq
;
; AVX512-FAST-LABEL: hadd32_16:
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vmovd %xmm0, %eax
; AVX512-FAST-NEXT: vzeroupper
; AVX512-FAST-NEXT: retq
; AVX-FAST-LABEL: hadd32_16:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
; AVX-FAST-NEXT: retq
%x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> <i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
%x227 = add <16 x i32> %x225, %x226
%x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
@ -2149,8 +2105,7 @@ define i32 @hadd32_8_optsize(<8 x i32> %x225) optsize {
;
; AVX-LABEL: hadd32_8_optsize:
; AVX: # %bb.0:
; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vmovd %xmm0, %eax
; AVX-NEXT: vzeroupper
@ -2172,63 +2127,13 @@ define i32 @hadd32_16_optsize(<16 x i32> %x225) optsize {
; SSE3-NEXT: movd %xmm1, %eax
; SSE3-NEXT: retq
;
; AVX1-SLOW-LABEL: hadd32_16_optsize:
; AVX1-SLOW: # %bb.0:
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-SLOW-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-SLOW-NEXT: vmovd %xmm0, %eax
; AVX1-SLOW-NEXT: vzeroupper
; AVX1-SLOW-NEXT: retq
;
; AVX1-FAST-LABEL: hadd32_16_optsize:
; AVX1-FAST: # %bb.0:
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-FAST-NEXT: vmovd %xmm0, %eax
; AVX1-FAST-NEXT: vzeroupper
; AVX1-FAST-NEXT: retq
;
; AVX2-SLOW-LABEL: hadd32_16_optsize:
; AVX2-SLOW: # %bb.0:
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-SLOW-NEXT: vmovd %xmm0, %eax
; AVX2-SLOW-NEXT: vzeroupper
; AVX2-SLOW-NEXT: retq
;
; AVX2-FAST-LABEL: hadd32_16_optsize:
; AVX2-FAST: # %bb.0:
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vmovd %xmm0, %eax
; AVX2-FAST-NEXT: vzeroupper
; AVX2-FAST-NEXT: retq
;
; AVX512-SLOW-LABEL: hadd32_16_optsize:
; AVX512-SLOW: # %bb.0:
; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-SLOW-NEXT: vmovd %xmm0, %eax
; AVX512-SLOW-NEXT: vzeroupper
; AVX512-SLOW-NEXT: retq
;
; AVX512-FAST-LABEL: hadd32_16_optsize:
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vmovd %xmm0, %eax
; AVX512-FAST-NEXT: vzeroupper
; AVX512-FAST-NEXT: retq
; AVX-LABEL: hadd32_16_optsize:
; AVX: # %bb.0:
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vmovd %xmm0, %eax
; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
%x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> <i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
%x227 = add <16 x i32> %x225, %x226
%x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>