forked from OSchip/llvm-project
[X86] Attempt to match multiple binary reduction ops at once. NFCI
matchBinOpReduction currently matches against a single opcode, but we already have a case where we repeat calls to try to match against AND/OR and I'll be shortly adding another case for SMAX/SMIN/UMAX/UMIN (D39729). This NFCI patch alters matchBinOpReduction to try and pattern match against any of the provided list of candidate bin ops at once to save time. Differential Revision: https://reviews.llvm.org/D39726 llvm-svn: 317985
This commit is contained in:
parent
1f3a2af902
commit
294b87b432
|
@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
|
|||
// the elements of a vector.
|
||||
// Returns the vector that is being reduced on, or SDValue() if a reduction
|
||||
// was not matched.
|
||||
static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
||||
static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
|
||||
ArrayRef<ISD::NodeType> CandidateBinOps) {
|
||||
// The pattern must end in an extract from index 0.
|
||||
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
|
||||
!isNullConstant(Extract->getOperand(1)))
|
||||
return SDValue();
|
||||
|
||||
unsigned Stages =
|
||||
Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements());
|
||||
|
||||
SDValue Op = Extract->getOperand(0);
|
||||
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
|
||||
|
||||
// Match against one of the candidate binary ops.
|
||||
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
|
||||
return Op.getOpcode() == BinOp;
|
||||
}))
|
||||
return SDValue();
|
||||
|
||||
// At each stage, we're looking for something that looks like:
|
||||
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
|
||||
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
|
||||
|
@ -30113,8 +30119,9 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
|||
// <4,5,6,7,u,u,u,u>
|
||||
// <2,3,u,u,u,u,u,u>
|
||||
// <1,u,u,u,u,u,u,u>
|
||||
unsigned CandidateBinOp = Op.getOpcode();
|
||||
for (unsigned i = 0; i < Stages; ++i) {
|
||||
if (Op.getOpcode() != BinOp)
|
||||
if (Op.getOpcode() != CandidateBinOp)
|
||||
return SDValue();
|
||||
|
||||
ShuffleVectorSDNode *Shuffle =
|
||||
|
@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
|||
}
|
||||
|
||||
// The first operand of the shuffle should be the same as the other operand
|
||||
// of the add.
|
||||
if (!Shuffle || (Shuffle->getOperand(0) != Op))
|
||||
// of the binop.
|
||||
if (!Shuffle || Shuffle->getOperand(0) != Op)
|
||||
return SDValue();
|
||||
|
||||
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
|
||||
|
@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
|||
return SDValue();
|
||||
}
|
||||
|
||||
BinOp = CandidateBinOp;
|
||||
return Op;
|
||||
}
|
||||
|
||||
|
@ -30250,15 +30258,15 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
|||
return SDValue();
|
||||
|
||||
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
|
||||
for (ISD::NodeType Op : {ISD::OR, ISD::AND}) {
|
||||
SDValue Match = matchBinOpReduction(Extract, Op);
|
||||
unsigned BinOp = 0;
|
||||
SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
|
||||
if (!Match)
|
||||
continue;
|
||||
return SDValue();
|
||||
|
||||
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element
|
||||
// which we can't support here for now.
|
||||
if (Match.getScalarValueSizeInBits() != BitWidth)
|
||||
continue;
|
||||
return SDValue();
|
||||
|
||||
// We require AVX2 for PMOVMSKB for v16i16/v32i8;
|
||||
unsigned MatchSizeInBits = Match.getValueSizeInBits();
|
||||
|
@ -30285,7 +30293,7 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
|||
|
||||
APInt CompareBits;
|
||||
ISD::CondCode CondCode;
|
||||
if (Op == ISD::OR) {
|
||||
if (BinOp == ISD::OR) {
|
||||
// any_of -> MOVMSK != 0
|
||||
CompareBits = APInt::getNullValue(32);
|
||||
CondCode = ISD::CondCode::SETNE;
|
||||
|
@ -30307,9 +30315,6 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
|||
Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
|
||||
Ones, Zero, CondCode);
|
||||
return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
|
||||
}
|
||||
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
||||
|
@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
|||
return SDValue();
|
||||
|
||||
// Match shuffle + add pyramid.
|
||||
SDValue Root = matchBinOpReduction(Extract, ISD::ADD);
|
||||
unsigned BinOp = 0;
|
||||
SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
|
||||
|
||||
// The operand is expected to be zero extended from i8
|
||||
// (verified in detectZextAbsDiff).
|
||||
|
|
Loading…
Reference in New Issue