[X86] Attempt to match multiple binary reduction ops at once. NFCI

matchBinOpReduction currently matches against a single opcode, but we already have a case where we repeat calls to try to match against AND/OR and I'll be shortly adding another case for SMAX/SMIN/UMAX/UMIN (D39729).

This NFCI patch alters matchBinOpReduction to try and pattern match against any of the provided list of candidate bin ops at once to save time.

Differential Revision: https://reviews.llvm.org/D39726

llvm-svn: 317985
This commit is contained in:
Simon Pilgrim 2017-11-11 18:16:55 +00:00
parent 1f3a2af902
commit 294b87b432
1 changed files with 66 additions and 60 deletions

View File

@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
// the elements of a vector. // the elements of a vector.
// Returns the vector that is being reduced on, or SDValue() if a reduction // Returns the vector that is being reduced on, or SDValue() if a reduction
// was not matched. // was not matched.
static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps) {
// The pattern must end in an extract from index 0. // The pattern must end in an extract from index 0.
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
!isNullConstant(Extract->getOperand(1))) !isNullConstant(Extract->getOperand(1)))
return SDValue(); return SDValue();
unsigned Stages =
Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements());
SDValue Op = Extract->getOperand(0); SDValue Op = Extract->getOperand(0);
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
// Match against one of the candidate binary ops.
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
return Op.getOpcode() == BinOp;
}))
return SDValue();
// At each stage, we're looking for something that looks like: // At each stage, we're looking for something that looks like:
// %s = shufflevector <8 x i32> %op, <8 x i32> undef, // %s = shufflevector <8 x i32> %op, <8 x i32> undef,
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef, // <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
@ -30113,8 +30119,9 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
// <4,5,6,7,u,u,u,u> // <4,5,6,7,u,u,u,u>
// <2,3,u,u,u,u,u,u> // <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u> // <1,u,u,u,u,u,u,u>
unsigned CandidateBinOp = Op.getOpcode();
for (unsigned i = 0; i < Stages; ++i) { for (unsigned i = 0; i < Stages; ++i) {
if (Op.getOpcode() != BinOp) if (Op.getOpcode() != CandidateBinOp)
return SDValue(); return SDValue();
ShuffleVectorSDNode *Shuffle = ShuffleVectorSDNode *Shuffle =
@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
} }
// The first operand of the shuffle should be the same as the other operand // The first operand of the shuffle should be the same as the other operand
// of the add. // of the binop.
if (!Shuffle || (Shuffle->getOperand(0) != Op)) if (!Shuffle || Shuffle->getOperand(0) != Op)
return SDValue(); return SDValue();
// Verify the shuffle has the expected (at this stage of the pyramid) mask. // Verify the shuffle has the expected (at this stage of the pyramid) mask.
@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
return SDValue(); return SDValue();
} }
BinOp = CandidateBinOp;
return Op; return Op;
} }
@ -30250,15 +30258,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
return SDValue(); return SDValue();
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns. // Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
for (ISD::NodeType Op : {ISD::OR, ISD::AND}) { unsigned BinOp = 0;
SDValue Match = matchBinOpReduction(Extract, Op); SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
if (!Match) if (!Match)
continue; return SDValue();
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element // EXTRACT_VECTOR_ELT can require implicit extension of the vector element
// which we can't support here for now. // which we can't support here for now.
if (Match.getScalarValueSizeInBits() != BitWidth) if (Match.getScalarValueSizeInBits() != BitWidth)
continue; return SDValue();
// We require AVX2 for PMOVMSKB for v16i16/v32i8; // We require AVX2 for PMOVMSKB for v16i16/v32i8;
unsigned MatchSizeInBits = Match.getValueSizeInBits(); unsigned MatchSizeInBits = Match.getValueSizeInBits();
@ -30285,7 +30293,7 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
APInt CompareBits; APInt CompareBits;
ISD::CondCode CondCode; ISD::CondCode CondCode;
if (Op == ISD::OR) { if (BinOp == ISD::OR) {
// any_of -> MOVMSK != 0 // any_of -> MOVMSK != 0
CompareBits = APInt::getNullValue(32); CompareBits = APInt::getNullValue(32);
CondCode = ISD::CondCode::SETNE; CondCode = ISD::CondCode::SETNE;
@ -30307,9 +30315,6 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
Ones, Zero, CondCode); Ones, Zero, CondCode);
return DAG.getSExtOrTrunc(Res, DL, ExtractVT); return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
}
return SDValue();
} }
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
return SDValue(); return SDValue();
// Match shuffle + add pyramid. // Match shuffle + add pyramid.
SDValue Root = matchBinOpReduction(Extract, ISD::ADD); unsigned BinOp = 0;
SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
// The operand is expected to be zero extended from i8 // The operand is expected to be zero extended from i8
// (verified in detectZextAbsDiff). // (verified in detectZextAbsDiff).