[X86] matchScalarReduction - add support for partial reductions

Add optional support for opt-in partial reduction cases by providing an optional partial mask to indicate which elements have been extracted for the scalar reduction.
This commit is contained in:
Simon Pilgrim 2020-03-16 17:58:37 +00:00
parent 2e77362626
commit ebb181cf40
2 changed files with 54 additions and 79 deletions

View File

@ -20964,9 +20964,12 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
}
/// Helper for matching OR(EXTRACTELT(X,0),OR(EXTRACTELT(X,1),...))
/// style scalarized (associative) reduction patterns.
/// style scalarized (associative) reduction patterns. Partial reductions
/// are supported when the pointer SrcMask is non-null.
/// TODO - move this to SelectionDAG?
static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
SmallVectorImpl<SDValue> &SrcOps) {
SmallVectorImpl<SDValue> &SrcOps,
SmallVectorImpl<APInt> *SrcMask = nullptr) {
SmallVector<SDValue, 8> Opnds;
DenseMap<SDValue, APInt> SrcOpMap;
EVT VT = MVT::Other;
@ -21018,12 +21021,18 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
M->second.setBit(CIdx);
}
// Quit if not all elements are used.
for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
E = SrcOpMap.end();
I != E; ++I) {
if (!I->second.isAllOnesValue())
return false;
if (SrcMask) {
// Collect the source partial masks.
for (SDValue &SrcOp : SrcOps)
SrcMask->push_back(SrcOpMap[SrcOp]);
} else {
// Quit if not all elements are used.
for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
E = SrcOpMap.end();
I != E; ++I) {
if (!I->second.isAllOnesValue())
return false;
}
}
return true;
@ -41210,7 +41219,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
// TODO: Support multiple SrcOps.
if (VT == MVT::i1) {
SmallVector<SDValue, 2> SrcOps;
if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps) &&
SmallVector<APInt, 2> SrcPartials;
if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps, &SrcPartials) &&
SrcOps.size() == 1) {
SDLoc dl(N);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@ -41220,9 +41230,11 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
if (Mask) {
APInt AllBits = APInt::getAllOnesValue(NumElts);
return DAG.getSetCC(dl, MVT::i1, Mask,
DAG.getConstant(AllBits, dl, MaskVT), ISD::SETEQ);
assert(SrcPartials[0].getBitWidth() == NumElts &&
"Unexpected partial reduction mask");
SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
return DAG.getSetCC(dl, MVT::i1, Mask, PartialBits, ISD::SETEQ);
}
}
}
@ -41685,7 +41697,8 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
// TODO: Support multiple SrcOps.
if (VT == MVT::i1) {
SmallVector<SDValue, 2> SrcOps;
if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps) &&
SmallVector<APInt, 2> SrcPartials;
if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps, &SrcPartials) &&
SrcOps.size() == 1) {
SDLoc dl(N);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@ -41695,9 +41708,12 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
if (Mask) {
APInt AllBits = APInt::getNullValue(NumElts);
return DAG.getSetCC(dl, MVT::i1, Mask,
DAG.getConstant(AllBits, dl, MaskVT), ISD::SETNE);
assert(SrcPartials[0].getBitWidth() == NumElts &&
"Unexpected partial reduction mask");
SDValue ZeroBits = DAG.getConstant(0, dl, MaskVT);
SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
return DAG.getSetCC(dl, MVT::i1, Mask, ZeroBits, ISD::SETNE);
}
}
}

View File

@ -4225,40 +4225,25 @@ define i1 @movmsk_v16i8(<16 x i8> %x, <16 x i8> %y) {
ret i1 %u2
}
; TODO: Replace shift+mask chain with NOT+TEST+SETE
define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
; SSE2-LABEL: movmsk_v8i16:
; SSE2: # %bb.0:
; SSE2-NEXT: pcmpgtw %xmm1, %xmm0
; SSE2-NEXT: packsswb %xmm0, %xmm0
; SSE2-NEXT: pmovmskb %xmm0, %ecx
; SSE2-NEXT: movl %ecx, %eax
; SSE2-NEXT: shrb $7, %al
; SSE2-NEXT: movl %ecx, %edx
; SSE2-NEXT: andb $16, %dl
; SSE2-NEXT: shrb $4, %dl
; SSE2-NEXT: andb %al, %dl
; SSE2-NEXT: movl %ecx, %eax
; SSE2-NEXT: shrb %al
; SSE2-NEXT: andb %dl, %al
; SSE2-NEXT: andb %cl, %al
; SSE2-NEXT: pmovmskb %xmm0, %eax
; SSE2-NEXT: andb $-109, %al
; SSE2-NEXT: cmpb $-109, %al
; SSE2-NEXT: sete %al
; SSE2-NEXT: retq
;
; AVX-LABEL: movmsk_v8i16:
; AVX: # %bb.0:
; AVX-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpacksswb %xmm0, %xmm0, %xmm0
; AVX-NEXT: vpmovmskb %xmm0, %ecx
; AVX-NEXT: movl %ecx, %eax
; AVX-NEXT: shrb $7, %al
; AVX-NEXT: movl %ecx, %edx
; AVX-NEXT: andb $16, %dl
; AVX-NEXT: shrb $4, %dl
; AVX-NEXT: andb %al, %dl
; AVX-NEXT: movl %ecx, %eax
; AVX-NEXT: shrb %al
; AVX-NEXT: andb %dl, %al
; AVX-NEXT: andb %cl, %al
; AVX-NEXT: vpmovmskb %xmm0, %eax
; AVX-NEXT: andb $-109, %al
; AVX-NEXT: cmpb $-109, %al
; AVX-NEXT: sete %al
; AVX-NEXT: retq
;
; KNL-LABEL: movmsk_v8i16:
@ -4266,34 +4251,20 @@ define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
; KNL-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
; KNL-NEXT: vpmovsxwq %xmm0, %zmm0
; KNL-NEXT: vptestmq %zmm0, %zmm0, %k0
; KNL-NEXT: kshiftrw $4, %k0, %k1
; KNL-NEXT: kmovw %k1, %ecx
; KNL-NEXT: kshiftrw $7, %k0, %k1
; KNL-NEXT: kmovw %k1, %eax
; KNL-NEXT: kshiftrw $1, %k0, %k1
; KNL-NEXT: kmovw %k1, %edx
; KNL-NEXT: kmovw %k0, %esi
; KNL-NEXT: andb %cl, %al
; KNL-NEXT: andb %dl, %al
; KNL-NEXT: andb %sil, %al
; KNL-NEXT: # kill: def $al killed $al killed $eax
; KNL-NEXT: kmovw %k0, %eax
; KNL-NEXT: andb $-109, %al
; KNL-NEXT: cmpb $-109, %al
; KNL-NEXT: sete %al
; KNL-NEXT: vzeroupper
; KNL-NEXT: retq
;
; SKX-LABEL: movmsk_v8i16:
; SKX: # %bb.0:
; SKX-NEXT: vpcmpgtw %xmm1, %xmm0, %k0
; SKX-NEXT: kshiftrb $4, %k0, %k1
; SKX-NEXT: kmovd %k1, %ecx
; SKX-NEXT: kshiftrb $7, %k0, %k1
; SKX-NEXT: kmovd %k1, %eax
; SKX-NEXT: kshiftrb $1, %k0, %k1
; SKX-NEXT: kmovd %k1, %edx
; SKX-NEXT: kmovd %k0, %esi
; SKX-NEXT: andb %cl, %al
; SKX-NEXT: andb %dl, %al
; SKX-NEXT: andb %sil, %al
; SKX-NEXT: # kill: def $al killed $al killed $eax
; SKX-NEXT: kmovd %k0, %eax
; SKX-NEXT: andb $-109, %al
; SKX-NEXT: cmpb $-109, %al
; SKX-NEXT: sete %al
; SKX-NEXT: retq
%cmp = icmp sgt <8 x i16> %x, %y
%e1 = extractelement <8 x i1> %cmp, i32 0
@ -4478,30 +4449,18 @@ define i1 @movmsk_v4f32(<4 x float> %x, <4 x float> %y) {
; KNL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; KNL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; KNL-NEXT: vcmpeq_uqps %zmm1, %zmm0, %k0
; KNL-NEXT: kshiftrw $3, %k0, %k1
; KNL-NEXT: kmovw %k1, %ecx
; KNL-NEXT: kshiftrw $2, %k0, %k1
; KNL-NEXT: kmovw %k1, %eax
; KNL-NEXT: kshiftrw $1, %k0, %k0
; KNL-NEXT: kmovw %k0, %edx
; KNL-NEXT: orb %cl, %al
; KNL-NEXT: orb %dl, %al
; KNL-NEXT: # kill: def $al killed $al killed $eax
; KNL-NEXT: kmovw %k0, %eax
; KNL-NEXT: testb $14, %al
; KNL-NEXT: setne %al
; KNL-NEXT: vzeroupper
; KNL-NEXT: retq
;
; SKX-LABEL: movmsk_v4f32:
; SKX: # %bb.0:
; SKX-NEXT: vcmpeq_uqps %xmm1, %xmm0, %k0
; SKX-NEXT: kshiftrb $3, %k0, %k1
; SKX-NEXT: kmovd %k1, %ecx
; SKX-NEXT: kshiftrb $2, %k0, %k1
; SKX-NEXT: kmovd %k1, %eax
; SKX-NEXT: kshiftrb $1, %k0, %k0
; SKX-NEXT: kmovd %k0, %edx
; SKX-NEXT: orb %cl, %al
; SKX-NEXT: orb %dl, %al
; SKX-NEXT: # kill: def $al killed $al killed $eax
; SKX-NEXT: kmovd %k0, %eax
; SKX-NEXT: testb $14, %al
; SKX-NEXT: setne %al
; SKX-NEXT: retq
%cmp = fcmp ueq <4 x float> %x, %y
%e1 = extractelement <4 x i1> %cmp, i32 1