diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 339e3c37ee25..8e8a7cce9fb1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -20964,9 +20964,12 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl, } /// Helper for matching OR(EXTRACTELT(X,0),OR(EXTRACTELT(X,1),...)) -/// style scalarized (associative) reduction patterns. +/// style scalarized (associative) reduction patterns. Partial reductions +/// are supported when the pointer SrcMask is non-null. +/// TODO - move this to SelectionDAG? static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp, - SmallVectorImpl &SrcOps) { + SmallVectorImpl &SrcOps, + SmallVectorImpl *SrcMask = nullptr) { SmallVector Opnds; DenseMap SrcOpMap; EVT VT = MVT::Other; @@ -21018,12 +21021,18 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp, M->second.setBit(CIdx); } - // Quit if not all elements are used. - for (DenseMap::const_iterator I = SrcOpMap.begin(), - E = SrcOpMap.end(); - I != E; ++I) { - if (!I->second.isAllOnesValue()) - return false; + if (SrcMask) { + // Collect the source partial masks. + for (SDValue &SrcOp : SrcOps) + SrcMask->push_back(SrcOpMap[SrcOp]); + } else { + // Quit if not all elements are used. + for (DenseMap::const_iterator I = SrcOpMap.begin(), + E = SrcOpMap.end(); + I != E; ++I) { + if (!I->second.isAllOnesValue()) + return false; + } } return true; @@ -41210,7 +41219,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, // TODO: Support multiple SrcOps. if (VT == MVT::i1) { SmallVector SrcOps; - if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps) && + SmallVector SrcPartials; + if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps, &SrcPartials) && SrcOps.size() == 1) { SDLoc dl(N); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -41220,9 +41230,11 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType())) Mask = DAG.getBitcast(MaskVT, SrcOps[0]); if (Mask) { - APInt AllBits = APInt::getAllOnesValue(NumElts); - return DAG.getSetCC(dl, MVT::i1, Mask, - DAG.getConstant(AllBits, dl, MaskVT), ISD::SETEQ); + assert(SrcPartials[0].getBitWidth() == NumElts && + "Unexpected partial reduction mask"); + SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT); + Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits); + return DAG.getSetCC(dl, MVT::i1, Mask, PartialBits, ISD::SETEQ); } } } @@ -41685,7 +41697,8 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, // TODO: Support multiple SrcOps. if (VT == MVT::i1) { SmallVector SrcOps; - if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps) && + SmallVector SrcPartials; + if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps, &SrcPartials) && SrcOps.size() == 1) { SDLoc dl(N); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -41695,9 +41708,12 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType())) Mask = DAG.getBitcast(MaskVT, SrcOps[0]); if (Mask) { - APInt AllBits = APInt::getNullValue(NumElts); - return DAG.getSetCC(dl, MVT::i1, Mask, - DAG.getConstant(AllBits, dl, MaskVT), ISD::SETNE); + assert(SrcPartials[0].getBitWidth() == NumElts && + "Unexpected partial reduction mask"); + SDValue ZeroBits = DAG.getConstant(0, dl, MaskVT); + SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT); + Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits); + return DAG.getSetCC(dl, MVT::i1, Mask, ZeroBits, ISD::SETNE); } } } diff --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll index 7f0a1418a719..4fdde8c06641 100644 --- a/llvm/test/CodeGen/X86/movmsk-cmp.ll +++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll @@ -4225,40 +4225,25 @@ define i1 @movmsk_v16i8(<16 x i8> %x, <16 x i8> %y) { ret i1 %u2 } -; TODO: Replace shift+mask chain with NOT+TEST+SETE define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) { ; SSE2-LABEL: movmsk_v8i16: ; SSE2: # %bb.0: ; SSE2-NEXT: pcmpgtw %xmm1, %xmm0 ; SSE2-NEXT: packsswb %xmm0, %xmm0 -; SSE2-NEXT: pmovmskb %xmm0, %ecx -; SSE2-NEXT: movl %ecx, %eax -; SSE2-NEXT: shrb $7, %al -; SSE2-NEXT: movl %ecx, %edx -; SSE2-NEXT: andb $16, %dl -; SSE2-NEXT: shrb $4, %dl -; SSE2-NEXT: andb %al, %dl -; SSE2-NEXT: movl %ecx, %eax -; SSE2-NEXT: shrb %al -; SSE2-NEXT: andb %dl, %al -; SSE2-NEXT: andb %cl, %al +; SSE2-NEXT: pmovmskb %xmm0, %eax +; SSE2-NEXT: andb $-109, %al +; SSE2-NEXT: cmpb $-109, %al +; SSE2-NEXT: sete %al ; SSE2-NEXT: retq ; ; AVX-LABEL: movmsk_v8i16: ; AVX: # %bb.0: ; AVX-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vpacksswb %xmm0, %xmm0, %xmm0 -; AVX-NEXT: vpmovmskb %xmm0, %ecx -; AVX-NEXT: movl %ecx, %eax -; AVX-NEXT: shrb $7, %al -; AVX-NEXT: movl %ecx, %edx -; AVX-NEXT: andb $16, %dl -; AVX-NEXT: shrb $4, %dl -; AVX-NEXT: andb %al, %dl -; AVX-NEXT: movl %ecx, %eax -; AVX-NEXT: shrb %al -; AVX-NEXT: andb %dl, %al -; AVX-NEXT: andb %cl, %al +; AVX-NEXT: vpmovmskb %xmm0, %eax +; AVX-NEXT: andb $-109, %al +; AVX-NEXT: cmpb $-109, %al +; AVX-NEXT: sete %al ; AVX-NEXT: retq ; ; KNL-LABEL: movmsk_v8i16: @@ -4266,34 +4251,20 @@ define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) { ; KNL-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0 ; KNL-NEXT: vpmovsxwq %xmm0, %zmm0 ; KNL-NEXT: vptestmq %zmm0, %zmm0, %k0 -; KNL-NEXT: kshiftrw $4, %k0, %k1 -; KNL-NEXT: kmovw %k1, %ecx -; KNL-NEXT: kshiftrw $7, %k0, %k1 -; KNL-NEXT: kmovw %k1, %eax -; KNL-NEXT: kshiftrw $1, %k0, %k1 -; KNL-NEXT: kmovw %k1, %edx -; KNL-NEXT: kmovw %k0, %esi -; KNL-NEXT: andb %cl, %al -; KNL-NEXT: andb %dl, %al -; KNL-NEXT: andb %sil, %al -; KNL-NEXT: # kill: def $al killed $al killed $eax +; KNL-NEXT: kmovw %k0, %eax +; KNL-NEXT: andb $-109, %al +; KNL-NEXT: cmpb $-109, %al +; KNL-NEXT: sete %al ; KNL-NEXT: vzeroupper ; KNL-NEXT: retq ; ; SKX-LABEL: movmsk_v8i16: ; SKX: # %bb.0: ; SKX-NEXT: vpcmpgtw %xmm1, %xmm0, %k0 -; SKX-NEXT: kshiftrb $4, %k0, %k1 -; SKX-NEXT: kmovd %k1, %ecx -; SKX-NEXT: kshiftrb $7, %k0, %k1 -; SKX-NEXT: kmovd %k1, %eax -; SKX-NEXT: kshiftrb $1, %k0, %k1 -; SKX-NEXT: kmovd %k1, %edx -; SKX-NEXT: kmovd %k0, %esi -; SKX-NEXT: andb %cl, %al -; SKX-NEXT: andb %dl, %al -; SKX-NEXT: andb %sil, %al -; SKX-NEXT: # kill: def $al killed $al killed $eax +; SKX-NEXT: kmovd %k0, %eax +; SKX-NEXT: andb $-109, %al +; SKX-NEXT: cmpb $-109, %al +; SKX-NEXT: sete %al ; SKX-NEXT: retq %cmp = icmp sgt <8 x i16> %x, %y %e1 = extractelement <8 x i1> %cmp, i32 0 @@ -4478,30 +4449,18 @@ define i1 @movmsk_v4f32(<4 x float> %x, <4 x float> %y) { ; KNL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1 ; KNL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0 ; KNL-NEXT: vcmpeq_uqps %zmm1, %zmm0, %k0 -; KNL-NEXT: kshiftrw $3, %k0, %k1 -; KNL-NEXT: kmovw %k1, %ecx -; KNL-NEXT: kshiftrw $2, %k0, %k1 -; KNL-NEXT: kmovw %k1, %eax -; KNL-NEXT: kshiftrw $1, %k0, %k0 -; KNL-NEXT: kmovw %k0, %edx -; KNL-NEXT: orb %cl, %al -; KNL-NEXT: orb %dl, %al -; KNL-NEXT: # kill: def $al killed $al killed $eax +; KNL-NEXT: kmovw %k0, %eax +; KNL-NEXT: testb $14, %al +; KNL-NEXT: setne %al ; KNL-NEXT: vzeroupper ; KNL-NEXT: retq ; ; SKX-LABEL: movmsk_v4f32: ; SKX: # %bb.0: ; SKX-NEXT: vcmpeq_uqps %xmm1, %xmm0, %k0 -; SKX-NEXT: kshiftrb $3, %k0, %k1 -; SKX-NEXT: kmovd %k1, %ecx -; SKX-NEXT: kshiftrb $2, %k0, %k1 -; SKX-NEXT: kmovd %k1, %eax -; SKX-NEXT: kshiftrb $1, %k0, %k0 -; SKX-NEXT: kmovd %k0, %edx -; SKX-NEXT: orb %cl, %al -; SKX-NEXT: orb %dl, %al -; SKX-NEXT: # kill: def $al killed $al killed $eax +; SKX-NEXT: kmovd %k0, %eax +; SKX-NEXT: testb $14, %al +; SKX-NEXT: setne %al ; SKX-NEXT: retq %cmp = fcmp ueq <4 x float> %x, %y %e1 = extractelement <4 x i1> %cmp, i32 1