forked from OSchip/llvm-project
[X86] matchScalarReduction - add support for partial reductions
Add optional support for opt-in partial reduction cases by providing an optional partial mask to indicate which elements have been extracted for the scalar reduction.
This commit is contained in:
parent
2e77362626
commit
ebb181cf40
|
@ -20964,9 +20964,12 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
|
|||
}
|
||||
|
||||
/// Helper for matching OR(EXTRACTELT(X,0),OR(EXTRACTELT(X,1),...))
|
||||
/// style scalarized (associative) reduction patterns.
|
||||
/// style scalarized (associative) reduction patterns. Partial reductions
|
||||
/// are supported when the pointer SrcMask is non-null.
|
||||
/// TODO - move this to SelectionDAG?
|
||||
static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
|
||||
SmallVectorImpl<SDValue> &SrcOps) {
|
||||
SmallVectorImpl<SDValue> &SrcOps,
|
||||
SmallVectorImpl<APInt> *SrcMask = nullptr) {
|
||||
SmallVector<SDValue, 8> Opnds;
|
||||
DenseMap<SDValue, APInt> SrcOpMap;
|
||||
EVT VT = MVT::Other;
|
||||
|
@ -21018,12 +21021,18 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
|
|||
M->second.setBit(CIdx);
|
||||
}
|
||||
|
||||
// Quit if not all elements are used.
|
||||
for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
|
||||
E = SrcOpMap.end();
|
||||
I != E; ++I) {
|
||||
if (!I->second.isAllOnesValue())
|
||||
return false;
|
||||
if (SrcMask) {
|
||||
// Collect the source partial masks.
|
||||
for (SDValue &SrcOp : SrcOps)
|
||||
SrcMask->push_back(SrcOpMap[SrcOp]);
|
||||
} else {
|
||||
// Quit if not all elements are used.
|
||||
for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
|
||||
E = SrcOpMap.end();
|
||||
I != E; ++I) {
|
||||
if (!I->second.isAllOnesValue())
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -41210,7 +41219,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
|
|||
// TODO: Support multiple SrcOps.
|
||||
if (VT == MVT::i1) {
|
||||
SmallVector<SDValue, 2> SrcOps;
|
||||
if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps) &&
|
||||
SmallVector<APInt, 2> SrcPartials;
|
||||
if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps, &SrcPartials) &&
|
||||
SrcOps.size() == 1) {
|
||||
SDLoc dl(N);
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
|
@ -41220,9 +41230,11 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
|
|||
if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
|
||||
Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
|
||||
if (Mask) {
|
||||
APInt AllBits = APInt::getAllOnesValue(NumElts);
|
||||
return DAG.getSetCC(dl, MVT::i1, Mask,
|
||||
DAG.getConstant(AllBits, dl, MaskVT), ISD::SETEQ);
|
||||
assert(SrcPartials[0].getBitWidth() == NumElts &&
|
||||
"Unexpected partial reduction mask");
|
||||
SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
|
||||
Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
|
||||
return DAG.getSetCC(dl, MVT::i1, Mask, PartialBits, ISD::SETEQ);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -41685,7 +41697,8 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
|
|||
// TODO: Support multiple SrcOps.
|
||||
if (VT == MVT::i1) {
|
||||
SmallVector<SDValue, 2> SrcOps;
|
||||
if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps) &&
|
||||
SmallVector<APInt, 2> SrcPartials;
|
||||
if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps, &SrcPartials) &&
|
||||
SrcOps.size() == 1) {
|
||||
SDLoc dl(N);
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
|
@ -41695,9 +41708,12 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
|
|||
if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
|
||||
Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
|
||||
if (Mask) {
|
||||
APInt AllBits = APInt::getNullValue(NumElts);
|
||||
return DAG.getSetCC(dl, MVT::i1, Mask,
|
||||
DAG.getConstant(AllBits, dl, MaskVT), ISD::SETNE);
|
||||
assert(SrcPartials[0].getBitWidth() == NumElts &&
|
||||
"Unexpected partial reduction mask");
|
||||
SDValue ZeroBits = DAG.getConstant(0, dl, MaskVT);
|
||||
SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
|
||||
Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
|
||||
return DAG.getSetCC(dl, MVT::i1, Mask, ZeroBits, ISD::SETNE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4225,40 +4225,25 @@ define i1 @movmsk_v16i8(<16 x i8> %x, <16 x i8> %y) {
|
|||
ret i1 %u2
|
||||
}
|
||||
|
||||
; TODO: Replace shift+mask chain with NOT+TEST+SETE
|
||||
define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
|
||||
; SSE2-LABEL: movmsk_v8i16:
|
||||
; SSE2: # %bb.0:
|
||||
; SSE2-NEXT: pcmpgtw %xmm1, %xmm0
|
||||
; SSE2-NEXT: packsswb %xmm0, %xmm0
|
||||
; SSE2-NEXT: pmovmskb %xmm0, %ecx
|
||||
; SSE2-NEXT: movl %ecx, %eax
|
||||
; SSE2-NEXT: shrb $7, %al
|
||||
; SSE2-NEXT: movl %ecx, %edx
|
||||
; SSE2-NEXT: andb $16, %dl
|
||||
; SSE2-NEXT: shrb $4, %dl
|
||||
; SSE2-NEXT: andb %al, %dl
|
||||
; SSE2-NEXT: movl %ecx, %eax
|
||||
; SSE2-NEXT: shrb %al
|
||||
; SSE2-NEXT: andb %dl, %al
|
||||
; SSE2-NEXT: andb %cl, %al
|
||||
; SSE2-NEXT: pmovmskb %xmm0, %eax
|
||||
; SSE2-NEXT: andb $-109, %al
|
||||
; SSE2-NEXT: cmpb $-109, %al
|
||||
; SSE2-NEXT: sete %al
|
||||
; SSE2-NEXT: retq
|
||||
;
|
||||
; AVX-LABEL: movmsk_v8i16:
|
||||
; AVX: # %bb.0:
|
||||
; AVX-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
|
||||
; AVX-NEXT: vpacksswb %xmm0, %xmm0, %xmm0
|
||||
; AVX-NEXT: vpmovmskb %xmm0, %ecx
|
||||
; AVX-NEXT: movl %ecx, %eax
|
||||
; AVX-NEXT: shrb $7, %al
|
||||
; AVX-NEXT: movl %ecx, %edx
|
||||
; AVX-NEXT: andb $16, %dl
|
||||
; AVX-NEXT: shrb $4, %dl
|
||||
; AVX-NEXT: andb %al, %dl
|
||||
; AVX-NEXT: movl %ecx, %eax
|
||||
; AVX-NEXT: shrb %al
|
||||
; AVX-NEXT: andb %dl, %al
|
||||
; AVX-NEXT: andb %cl, %al
|
||||
; AVX-NEXT: vpmovmskb %xmm0, %eax
|
||||
; AVX-NEXT: andb $-109, %al
|
||||
; AVX-NEXT: cmpb $-109, %al
|
||||
; AVX-NEXT: sete %al
|
||||
; AVX-NEXT: retq
|
||||
;
|
||||
; KNL-LABEL: movmsk_v8i16:
|
||||
|
@ -4266,34 +4251,20 @@ define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
|
|||
; KNL-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
|
||||
; KNL-NEXT: vpmovsxwq %xmm0, %zmm0
|
||||
; KNL-NEXT: vptestmq %zmm0, %zmm0, %k0
|
||||
; KNL-NEXT: kshiftrw $4, %k0, %k1
|
||||
; KNL-NEXT: kmovw %k1, %ecx
|
||||
; KNL-NEXT: kshiftrw $7, %k0, %k1
|
||||
; KNL-NEXT: kmovw %k1, %eax
|
||||
; KNL-NEXT: kshiftrw $1, %k0, %k1
|
||||
; KNL-NEXT: kmovw %k1, %edx
|
||||
; KNL-NEXT: kmovw %k0, %esi
|
||||
; KNL-NEXT: andb %cl, %al
|
||||
; KNL-NEXT: andb %dl, %al
|
||||
; KNL-NEXT: andb %sil, %al
|
||||
; KNL-NEXT: # kill: def $al killed $al killed $eax
|
||||
; KNL-NEXT: kmovw %k0, %eax
|
||||
; KNL-NEXT: andb $-109, %al
|
||||
; KNL-NEXT: cmpb $-109, %al
|
||||
; KNL-NEXT: sete %al
|
||||
; KNL-NEXT: vzeroupper
|
||||
; KNL-NEXT: retq
|
||||
;
|
||||
; SKX-LABEL: movmsk_v8i16:
|
||||
; SKX: # %bb.0:
|
||||
; SKX-NEXT: vpcmpgtw %xmm1, %xmm0, %k0
|
||||
; SKX-NEXT: kshiftrb $4, %k0, %k1
|
||||
; SKX-NEXT: kmovd %k1, %ecx
|
||||
; SKX-NEXT: kshiftrb $7, %k0, %k1
|
||||
; SKX-NEXT: kmovd %k1, %eax
|
||||
; SKX-NEXT: kshiftrb $1, %k0, %k1
|
||||
; SKX-NEXT: kmovd %k1, %edx
|
||||
; SKX-NEXT: kmovd %k0, %esi
|
||||
; SKX-NEXT: andb %cl, %al
|
||||
; SKX-NEXT: andb %dl, %al
|
||||
; SKX-NEXT: andb %sil, %al
|
||||
; SKX-NEXT: # kill: def $al killed $al killed $eax
|
||||
; SKX-NEXT: kmovd %k0, %eax
|
||||
; SKX-NEXT: andb $-109, %al
|
||||
; SKX-NEXT: cmpb $-109, %al
|
||||
; SKX-NEXT: sete %al
|
||||
; SKX-NEXT: retq
|
||||
%cmp = icmp sgt <8 x i16> %x, %y
|
||||
%e1 = extractelement <8 x i1> %cmp, i32 0
|
||||
|
@ -4478,30 +4449,18 @@ define i1 @movmsk_v4f32(<4 x float> %x, <4 x float> %y) {
|
|||
; KNL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
|
||||
; KNL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
|
||||
; KNL-NEXT: vcmpeq_uqps %zmm1, %zmm0, %k0
|
||||
; KNL-NEXT: kshiftrw $3, %k0, %k1
|
||||
; KNL-NEXT: kmovw %k1, %ecx
|
||||
; KNL-NEXT: kshiftrw $2, %k0, %k1
|
||||
; KNL-NEXT: kmovw %k1, %eax
|
||||
; KNL-NEXT: kshiftrw $1, %k0, %k0
|
||||
; KNL-NEXT: kmovw %k0, %edx
|
||||
; KNL-NEXT: orb %cl, %al
|
||||
; KNL-NEXT: orb %dl, %al
|
||||
; KNL-NEXT: # kill: def $al killed $al killed $eax
|
||||
; KNL-NEXT: kmovw %k0, %eax
|
||||
; KNL-NEXT: testb $14, %al
|
||||
; KNL-NEXT: setne %al
|
||||
; KNL-NEXT: vzeroupper
|
||||
; KNL-NEXT: retq
|
||||
;
|
||||
; SKX-LABEL: movmsk_v4f32:
|
||||
; SKX: # %bb.0:
|
||||
; SKX-NEXT: vcmpeq_uqps %xmm1, %xmm0, %k0
|
||||
; SKX-NEXT: kshiftrb $3, %k0, %k1
|
||||
; SKX-NEXT: kmovd %k1, %ecx
|
||||
; SKX-NEXT: kshiftrb $2, %k0, %k1
|
||||
; SKX-NEXT: kmovd %k1, %eax
|
||||
; SKX-NEXT: kshiftrb $1, %k0, %k0
|
||||
; SKX-NEXT: kmovd %k0, %edx
|
||||
; SKX-NEXT: orb %cl, %al
|
||||
; SKX-NEXT: orb %dl, %al
|
||||
; SKX-NEXT: # kill: def $al killed $al killed $eax
|
||||
; SKX-NEXT: kmovd %k0, %eax
|
||||
; SKX-NEXT: testb $14, %al
|
||||
; SKX-NEXT: setne %al
|
||||
; SKX-NEXT: retq
|
||||
%cmp = fcmp ueq <4 x float> %x, %y
|
||||
%e1 = extractelement <4 x i1> %cmp, i32 1
|
||||
|
|
Loading…
Reference in New Issue