diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index dc9bfa3296bf..0a0dc5475f56 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, // the elements of a vector. // Returns the vector that is being reduced on, or SDValue() if a reduction // was not matched. -static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { +static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp, + ArrayRef CandidateBinOps) { // The pattern must end in an extract from index 0. if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || !isNullConstant(Extract->getOperand(1))) return SDValue(); - unsigned Stages = - Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements()); - SDValue Op = Extract->getOperand(0); + unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements()); + + // Match against one of the candidate binary ops. + if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) { + return Op.getOpcode() == BinOp; + })) + return SDValue(); + // At each stage, we're looking for something that looks like: // %s = shufflevector <8 x i32> %op, <8 x i32> undef, // <8 x i32> // <2,3,u,u,u,u,u,u> // <1,u,u,u,u,u,u,u> + unsigned CandidateBinOp = Op.getOpcode(); for (unsigned i = 0; i < Stages; ++i) { - if (Op.getOpcode() != BinOp) + if (Op.getOpcode() != CandidateBinOp) return SDValue(); ShuffleVectorSDNode *Shuffle = @@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { } // The first operand of the shuffle should be the same as the other operand - // of the add. - if (!Shuffle || (Shuffle->getOperand(0) != Op)) + // of the binop. + if (!Shuffle || Shuffle->getOperand(0) != Op) return SDValue(); // Verify the shuffle has the expected (at this stage of the pyramid) mask. @@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { return SDValue(); } + BinOp = CandidateBinOp; return Op; } @@ -30250,66 +30258,63 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, return SDValue(); // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. - for (ISD::NodeType Op : {ISD::OR, ISD::AND}) { - SDValue Match = matchBinOpReduction(Extract, Op); - if (!Match) - continue; + unsigned BinOp = 0; + SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); + if (!Match) + return SDValue(); - // EXTRACT_VECTOR_ELT can require implicit extension of the vector element - // which we can't support here for now. - if (Match.getScalarValueSizeInBits() != BitWidth) - continue; + // EXTRACT_VECTOR_ELT can require implicit extension of the vector element + // which we can't support here for now. + if (Match.getScalarValueSizeInBits() != BitWidth) + return SDValue(); - // We require AVX2 for PMOVMSKB for v16i16/v32i8; - unsigned MatchSizeInBits = Match.getValueSizeInBits(); - if (!(MatchSizeInBits == 128 || - (MatchSizeInBits == 256 && - ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) - return SDValue(); + // We require AVX2 for PMOVMSKB for v16i16/v32i8; + unsigned MatchSizeInBits = Match.getValueSizeInBits(); + if (!(MatchSizeInBits == 128 || + (MatchSizeInBits == 256 && + ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) + return SDValue(); - // Don't bother performing this for 2-element vectors. - if (Match.getValueType().getVectorNumElements() <= 2) - return SDValue(); + // Don't bother performing this for 2-element vectors. + if (Match.getValueType().getVectorNumElements() <= 2) + return SDValue(); - // Check that we are extracting a reduction of all sign bits. - if (DAG.ComputeNumSignBits(Match) != BitWidth) - return SDValue(); + // Check that we are extracting a reduction of all sign bits. + if (DAG.ComputeNumSignBits(Match) != BitWidth) + return SDValue(); - // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. - MVT MaskVT; - if (64 == BitWidth || 32 == BitWidth) - MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), - MatchSizeInBits / BitWidth); - else - MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); + // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. + MVT MaskVT; + if (64 == BitWidth || 32 == BitWidth) + MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), + MatchSizeInBits / BitWidth); + else + MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); - APInt CompareBits; - ISD::CondCode CondCode; - if (Op == ISD::OR) { - // any_of -> MOVMSK != 0 - CompareBits = APInt::getNullValue(32); - CondCode = ISD::CondCode::SETNE; - } else { - // all_of -> MOVMSK == ((1 << NumElts) - 1) - CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); - CondCode = ISD::CondCode::SETEQ; - } - - // Perform the select as i32/i64 and then truncate to avoid partial register - // stalls. - unsigned ResWidth = std::max(BitWidth, 32u); - EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); - SDLoc DL(Extract); - SDValue Zero = DAG.getConstant(0, DL, ResVT); - SDValue Ones = DAG.getAllOnesConstant(DL, ResVT); - SDValue Res = DAG.getBitcast(MaskVT, Match); - Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); - Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), - Ones, Zero, CondCode); - return DAG.getSExtOrTrunc(Res, DL, ExtractVT); + APInt CompareBits; + ISD::CondCode CondCode; + if (BinOp == ISD::OR) { + // any_of -> MOVMSK != 0 + CompareBits = APInt::getNullValue(32); + CondCode = ISD::CondCode::SETNE; + } else { + // all_of -> MOVMSK == ((1 << NumElts) - 1) + CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); + CondCode = ISD::CondCode::SETEQ; } - return SDValue(); + // Perform the select as i32/i64 and then truncate to avoid partial register + // stalls. + unsigned ResWidth = std::max(BitWidth, 32u); + EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); + SDLoc DL(Extract); + SDValue Zero = DAG.getConstant(0, DL, ResVT); + SDValue Ones = DAG.getAllOnesConstant(DL, ResVT); + SDValue Res = DAG.getBitcast(MaskVT, Match); + Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); + Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), + Ones, Zero, CondCode); + return DAG.getSExtOrTrunc(Res, DL, ExtractVT); } static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, @@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, return SDValue(); // Match shuffle + add pyramid. - SDValue Root = matchBinOpReduction(Extract, ISD::ADD); + unsigned BinOp = 0; + SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD}); // The operand is expected to be zero extended from i8 // (verified in detectZextAbsDiff).