forked from OSchip/llvm-project
[X86] Attempt to match multiple binary reduction ops at once. NFCI
matchBinOpReduction currently matches against a single opcode, but we already have a case where we repeat calls to try to match against AND/OR and I'll be shortly adding another case for SMAX/SMIN/UMAX/UMIN (D39729). This NFCI patch alters matchBinOpReduction to try and pattern match against any of the provided list of candidate bin ops at once to save time. Differential Revision: https://reviews.llvm.org/D39726 llvm-svn: 317985
This commit is contained in:
parent
1f3a2af902
commit
294b87b432
|
@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
|
|||
// the elements of a vector.
|
||||
// Returns the vector that is being reduced on, or SDValue() if a reduction
|
||||
// was not matched.
|
||||
static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
||||
static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
|
||||
ArrayRef<ISD::NodeType> CandidateBinOps) {
|
||||
// The pattern must end in an extract from index 0.
|
||||
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
|
||||
!isNullConstant(Extract->getOperand(1)))
|
||||
return SDValue();
|
||||
|
||||
unsigned Stages =
|
||||
Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements());
|
||||
|
||||
SDValue Op = Extract->getOperand(0);
|
||||
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
|
||||
|
||||
// Match against one of the candidate binary ops.
|
||||
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
|
||||
return Op.getOpcode() == BinOp;
|
||||
}))
|
||||
return SDValue();
|
||||
|
||||
// At each stage, we're looking for something that looks like:
|
||||
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
|
||||
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
|
||||
|
@ -30113,8 +30119,9 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
|||
// <4,5,6,7,u,u,u,u>
|
||||
// <2,3,u,u,u,u,u,u>
|
||||
// <1,u,u,u,u,u,u,u>
|
||||
unsigned CandidateBinOp = Op.getOpcode();
|
||||
for (unsigned i = 0; i < Stages; ++i) {
|
||||
if (Op.getOpcode() != BinOp)
|
||||
if (Op.getOpcode() != CandidateBinOp)
|
||||
return SDValue();
|
||||
|
||||
ShuffleVectorSDNode *Shuffle =
|
||||
|
@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
|||
}
|
||||
|
||||
// The first operand of the shuffle should be the same as the other operand
|
||||
// of the add.
|
||||
if (!Shuffle || (Shuffle->getOperand(0) != Op))
|
||||
// of the binop.
|
||||
if (!Shuffle || Shuffle->getOperand(0) != Op)
|
||||
return SDValue();
|
||||
|
||||
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
|
||||
|
@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
|
|||
return SDValue();
|
||||
}
|
||||
|
||||
BinOp = CandidateBinOp;
|
||||
return Op;
|
||||
}
|
||||
|
||||
|
@ -30250,66 +30258,63 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
|||
return SDValue();
|
||||
|
||||
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
|
||||
for (ISD::NodeType Op : {ISD::OR, ISD::AND}) {
|
||||
SDValue Match = matchBinOpReduction(Extract, Op);
|
||||
if (!Match)
|
||||
continue;
|
||||
unsigned BinOp = 0;
|
||||
SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
|
||||
if (!Match)
|
||||
return SDValue();
|
||||
|
||||
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element
|
||||
// which we can't support here for now.
|
||||
if (Match.getScalarValueSizeInBits() != BitWidth)
|
||||
continue;
|
||||
// EXTRACT_VECTOR_ELT can require implicit extension of the vector element
|
||||
// which we can't support here for now.
|
||||
if (Match.getScalarValueSizeInBits() != BitWidth)
|
||||
return SDValue();
|
||||
|
||||
// We require AVX2 for PMOVMSKB for v16i16/v32i8;
|
||||
unsigned MatchSizeInBits = Match.getValueSizeInBits();
|
||||
if (!(MatchSizeInBits == 128 ||
|
||||
(MatchSizeInBits == 256 &&
|
||||
((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2()))))
|
||||
return SDValue();
|
||||
// We require AVX2 for PMOVMSKB for v16i16/v32i8;
|
||||
unsigned MatchSizeInBits = Match.getValueSizeInBits();
|
||||
if (!(MatchSizeInBits == 128 ||
|
||||
(MatchSizeInBits == 256 &&
|
||||
((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2()))))
|
||||
return SDValue();
|
||||
|
||||
// Don't bother performing this for 2-element vectors.
|
||||
if (Match.getValueType().getVectorNumElements() <= 2)
|
||||
return SDValue();
|
||||
// Don't bother performing this for 2-element vectors.
|
||||
if (Match.getValueType().getVectorNumElements() <= 2)
|
||||
return SDValue();
|
||||
|
||||
// Check that we are extracting a reduction of all sign bits.
|
||||
if (DAG.ComputeNumSignBits(Match) != BitWidth)
|
||||
return SDValue();
|
||||
// Check that we are extracting a reduction of all sign bits.
|
||||
if (DAG.ComputeNumSignBits(Match) != BitWidth)
|
||||
return SDValue();
|
||||
|
||||
// For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
|
||||
MVT MaskVT;
|
||||
if (64 == BitWidth || 32 == BitWidth)
|
||||
MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
|
||||
MatchSizeInBits / BitWidth);
|
||||
else
|
||||
MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);
|
||||
// For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
|
||||
MVT MaskVT;
|
||||
if (64 == BitWidth || 32 == BitWidth)
|
||||
MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
|
||||
MatchSizeInBits / BitWidth);
|
||||
else
|
||||
MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);
|
||||
|
||||
APInt CompareBits;
|
||||
ISD::CondCode CondCode;
|
||||
if (Op == ISD::OR) {
|
||||
// any_of -> MOVMSK != 0
|
||||
CompareBits = APInt::getNullValue(32);
|
||||
CondCode = ISD::CondCode::SETNE;
|
||||
} else {
|
||||
// all_of -> MOVMSK == ((1 << NumElts) - 1)
|
||||
CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements());
|
||||
CondCode = ISD::CondCode::SETEQ;
|
||||
}
|
||||
|
||||
// Perform the select as i32/i64 and then truncate to avoid partial register
|
||||
// stalls.
|
||||
unsigned ResWidth = std::max(BitWidth, 32u);
|
||||
EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth);
|
||||
SDLoc DL(Extract);
|
||||
SDValue Zero = DAG.getConstant(0, DL, ResVT);
|
||||
SDValue Ones = DAG.getAllOnesConstant(DL, ResVT);
|
||||
SDValue Res = DAG.getBitcast(MaskVT, Match);
|
||||
Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res);
|
||||
Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
|
||||
Ones, Zero, CondCode);
|
||||
return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
|
||||
APInt CompareBits;
|
||||
ISD::CondCode CondCode;
|
||||
if (BinOp == ISD::OR) {
|
||||
// any_of -> MOVMSK != 0
|
||||
CompareBits = APInt::getNullValue(32);
|
||||
CondCode = ISD::CondCode::SETNE;
|
||||
} else {
|
||||
// all_of -> MOVMSK == ((1 << NumElts) - 1)
|
||||
CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements());
|
||||
CondCode = ISD::CondCode::SETEQ;
|
||||
}
|
||||
|
||||
return SDValue();
|
||||
// Perform the select as i32/i64 and then truncate to avoid partial register
|
||||
// stalls.
|
||||
unsigned ResWidth = std::max(BitWidth, 32u);
|
||||
EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth);
|
||||
SDLoc DL(Extract);
|
||||
SDValue Zero = DAG.getConstant(0, DL, ResVT);
|
||||
SDValue Ones = DAG.getAllOnesConstant(DL, ResVT);
|
||||
SDValue Res = DAG.getBitcast(MaskVT, Match);
|
||||
Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res);
|
||||
Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
|
||||
Ones, Zero, CondCode);
|
||||
return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
|
||||
}
|
||||
|
||||
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
||||
|
@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
|||
return SDValue();
|
||||
|
||||
// Match shuffle + add pyramid.
|
||||
SDValue Root = matchBinOpReduction(Extract, ISD::ADD);
|
||||
unsigned BinOp = 0;
|
||||
SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
|
||||
|
||||
// The operand is expected to be zero extended from i8
|
||||
// (verified in detectZextAbsDiff).
|
||||
|
|
Loading…
Reference in New Issue