forked from OSchip/llvm-project
[X86] Attempt to match multiple binary reduction ops at once. NFCI
matchBinOpReduction currently matches against a single opcode, but we already have a case where we repeat calls to try to match against AND/OR and I'll be shortly adding another case for SMAX/SMIN/UMAX/UMIN (D39729). This NFCI patch alters matchBinOpReduction to try and pattern match against any of the provided list of candidate bin ops at once to save time. Differential Revision: https://reviews.llvm.org/D39726 llvm-svn: 317985
This commit is contained in:
parent
1f3a2af902
commit
294b87b432
|
@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
|
||||||
// the elements of a vector.
|
// the elements of a vector.
|
||||||
// Returns the vector that is being reduced on, or SDValue() if a reduction
|
// Returns the vector that is being reduced on, or SDValue() if a reduction
|
||||||
// was not matched.
|
// was not matched.
|
||||||
static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
|
||||||
|
ArrayRef<ISD::NodeType> CandidateBinOps) {
|
||||||
// The pattern must end in an extract from index 0.
|
// The pattern must end in an extract from index 0.
|
||||||
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
|
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
|
||||||
!isNullConstant(Extract->getOperand(1)))
|
!isNullConstant(Extract->getOperand(1)))
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
unsigned Stages =
|
|
||||||
Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements());
|
|
||||||
|
|
||||||
SDValue Op = Extract->getOperand(0);
|
SDValue Op = Extract->getOperand(0);
|
||||||
|
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
|
||||||
|
|
||||||
|
// Match against one of the candidate binary ops.
|
||||||
|
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
|
||||||
|
return Op.getOpcode() == BinOp;
|
||||||
|
}))
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
// At each stage, we're looking for something that looks like:
|
// At each stage, we're looking for something that looks like:
|
||||||
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
|
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
|
||||||
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
|
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
|
||||||
|
@ -30113,8 +30119,9 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
||||||
// <4,5,6,7,u,u,u,u>
|
// <4,5,6,7,u,u,u,u>
|
||||||
// <2,3,u,u,u,u,u,u>
|
// <2,3,u,u,u,u,u,u>
|
||||||
// <1,u,u,u,u,u,u,u>
|
// <1,u,u,u,u,u,u,u>
|
||||||
|
unsigned CandidateBinOp = Op.getOpcode();
|
||||||
for (unsigned i = 0; i < Stages; ++i) {
|
for (unsigned i = 0; i < Stages; ++i) {
|
||||||
if (Op.getOpcode() != BinOp)
|
if (Op.getOpcode() != CandidateBinOp)
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
ShuffleVectorSDNode *Shuffle =
|
ShuffleVectorSDNode *Shuffle =
|
||||||
|
@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The first operand of the shuffle should be the same as the other operand
|
// The first operand of the shuffle should be the same as the other operand
|
||||||
// of the add.
|
// of the binop.
|
||||||
if (!Shuffle || (Shuffle->getOperand(0) != Op))
|
if (!Shuffle || Shuffle->getOperand(0) != Op)
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
|
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
|
||||||
|
@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
||||||
return SDValue();
|
return SDValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BinOp = CandidateBinOp;
|
||||||
return Op;
|
return Op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30250,15 +30258,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
|
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
|
||||||
for (ISD::NodeType Op : {ISD::OR, ISD::AND}) {
|
unsigned BinOp = 0;
|
||||||
SDValue Match = matchBinOpReduction(Extract, Op);
|
SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
|
||||||
if (!Match)
|
if (!Match)
|
||||||
continue;
|
return SDValue();
|
||||||
|
|
||||||
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element
|
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element
|
||||||
// which we can't support here for now.
|
// which we can't support here for now.
|
||||||
if (Match.getScalarValueSizeInBits() != BitWidth)
|
if (Match.getScalarValueSizeInBits() != BitWidth)
|
||||||
continue;
|
return SDValue();
|
||||||
|
|
||||||
// We require AVX2 for PMOVMSKB for v16i16/v32i8;
|
// We require AVX2 for PMOVMSKB for v16i16/v32i8;
|
||||||
unsigned MatchSizeInBits = Match.getValueSizeInBits();
|
unsigned MatchSizeInBits = Match.getValueSizeInBits();
|
||||||
|
@ -30285,7 +30293,7 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
||||||
|
|
||||||
APInt CompareBits;
|
APInt CompareBits;
|
||||||
ISD::CondCode CondCode;
|
ISD::CondCode CondCode;
|
||||||
if (Op == ISD::OR) {
|
if (BinOp == ISD::OR) {
|
||||||
// any_of -> MOVMSK != 0
|
// any_of -> MOVMSK != 0
|
||||||
CompareBits = APInt::getNullValue(32);
|
CompareBits = APInt::getNullValue(32);
|
||||||
CondCode = ISD::CondCode::SETNE;
|
CondCode = ISD::CondCode::SETNE;
|
||||||
|
@ -30307,9 +30315,6 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
||||||
Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
|
Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
|
||||||
Ones, Zero, CondCode);
|
Ones, Zero, CondCode);
|
||||||
return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
|
return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
|
||||||
}
|
|
||||||
|
|
||||||
return SDValue();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
||||||
|
@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
||||||
// Match shuffle + add pyramid.
|
// Match shuffle + add pyramid.
|
||||||
SDValue Root = matchBinOpReduction(Extract, ISD::ADD);
|
unsigned BinOp = 0;
|
||||||
|
SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
|
||||||
|
|
||||||
// The operand is expected to be zero extended from i8
|
// The operand is expected to be zero extended from i8
|
||||||
// (verified in detectZextAbsDiff).
|
// (verified in detectZextAbsDiff).
|
||||||
|
|
Loading…
Reference in New Issue