[SelectionDAG] Make binop reduction matcher available to all targets

There is nothing x86-specific about this code, so it'd be nice to make this available for other targets to use in the future (and get it out of X86ISelLowering!).

Differential Revision: https://reviews.llvm.org/D50083

llvm-svn: 338586
This commit is contained in:
Simon Pilgrim 2018-08-01 16:52:28 +00:00
parent bed4babc56
commit a3548c960e
3 changed files with 73 additions and 65 deletions

View File

@ -1503,6 +1503,15 @@ public:
/// allow an 'add' to be transformed into an 'or'. /// allow an 'add' to be transformed into an 'or'.
bool haveNoCommonBitsSet(SDValue A, SDValue B) const; bool haveNoCommonBitsSet(SDValue A, SDValue B) const;
/// Match a binop + shuffle pyramid that represents a horizontal reduction
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
/// Extract. The reduction must use one of the opcodes listed in /p
/// CandidateBinOps and on success /p BinOp will contain the matching opcode.
/// Returns the vector that is being reduced on, or SDValue() if a reduction
/// was not matched.
SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps);
/// Utility function used by legalize and lowering to /// Utility function used by legalize and lowering to
/// "unroll" a vector operation by splitting out the scalars and operating /// "unroll" a vector operation by splitting out the scalars and operating
/// on each element individually. If the ResNE is 0, fully unroll the vector /// on each element individually. If the ResNE is 0, fully unroll the vector

View File

@ -8318,6 +8318,64 @@ void SDNode::intersectFlagsWith(const SDNodeFlags Flags) {
this->Flags.intersectWith(Flags); this->Flags.intersectWith(Flags);
} }
SDValue
SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps) {
// The pattern must end in an extract from index 0.
if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
!isNullConstant(Extract->getOperand(1)))
return SDValue();
SDValue Op = Extract->getOperand(0);
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
// Match against one of the candidate binary ops.
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
return Op.getOpcode() == unsigned(BinOp);
}))
return SDValue();
// At each stage, we're looking for something that looks like:
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
// i32 undef, i32 undef, i32 undef, i32 undef>
// %a = binop <8 x i32> %op, %s
// Where the mask changes according to the stage. E.g. for a 3-stage pyramid,
// we expect something like:
// <4,5,6,7,u,u,u,u>
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
unsigned CandidateBinOp = Op.getOpcode();
for (unsigned i = 0; i < Stages; ++i) {
if (Op.getOpcode() != CandidateBinOp)
return SDValue();
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(Op0);
if (Shuffle) {
Op = Op1;
} else {
Shuffle = dyn_cast<ShuffleVectorSDNode>(Op1);
Op = Op0;
}
// The first operand of the shuffle should be the same as the other operand
// of the binop.
if (!Shuffle || Shuffle->getOperand(0) != Op)
return SDValue();
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index)
if (Shuffle->getMaskElt(Index) != MaskEnd + Index)
return SDValue();
}
BinOp = (ISD::NodeType)CandidateBinOp;
return Op;
}
SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) { SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
assert(N->getNumValues() == 1 && assert(N->getNumValues() == 1 &&
"Can't unroll a vector with multiple results!"); "Can't unroll a vector with multiple results!");

View File

@ -31807,65 +31807,6 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
return SDValue(); return SDValue();
} }
// Match a binop + shuffle pyramid that represents a horizontal reduction over
// the elements of a vector.
// Returns the vector that is being reduced on, or SDValue() if a reduction
// was not matched.
static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps) {
// The pattern must end in an extract from index 0.
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
!isNullConstant(Extract->getOperand(1)))
return SDValue();
SDValue Op = Extract->getOperand(0);
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
// Match against one of the candidate binary ops.
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
return Op.getOpcode() == unsigned(BinOp);
}))
return SDValue();
// At each stage, we're looking for something that looks like:
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
// i32 undef, i32 undef, i32 undef, i32 undef>
// %a = binop <8 x i32> %op, %s
// Where the mask changes according to the stage. E.g. for a 3-stage pyramid,
// we expect something like:
// <4,5,6,7,u,u,u,u>
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
unsigned CandidateBinOp = Op.getOpcode();
for (unsigned i = 0; i < Stages; ++i) {
if (Op.getOpcode() != CandidateBinOp)
return SDValue();
ShuffleVectorSDNode *Shuffle =
dyn_cast<ShuffleVectorSDNode>(Op.getOperand(0).getNode());
if (Shuffle) {
Op = Op.getOperand(1);
} else {
Shuffle = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(1).getNode());
Op = Op.getOperand(0);
}
// The first operand of the shuffle should be the same as the other operand
// of the binop.
if (!Shuffle || Shuffle->getOperand(0) != Op)
return SDValue();
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index)
if (Shuffle->getMaskElt(Index) != MaskEnd + Index)
return SDValue();
}
BinOp = CandidateBinOp;
return Op;
}
// Given a select, detect the following pattern: // Given a select, detect the following pattern:
// 1: %2 = zext <N x i8> %0 to <N x i32> // 1: %2 = zext <N x i8> %0 to <N x i32>
// 2: %3 = zext <N x i8> %1 to <N x i32> // 2: %3 = zext <N x i8> %1 to <N x i32>
@ -31980,8 +31921,8 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,
return SDValue(); return SDValue();
// Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns. // Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns.
unsigned BinOp; ISD::NodeType BinOp;
SDValue Src = matchBinOpReduction( SDValue Src = DAG.matchBinOpReduction(
Extract, BinOp, {ISD::SMAX, ISD::SMIN, ISD::UMAX, ISD::UMIN}); Extract, BinOp, {ISD::SMAX, ISD::SMIN, ISD::UMAX, ISD::UMIN});
if (!Src) if (!Src)
return SDValue(); return SDValue();
@ -32060,8 +32001,8 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
return SDValue(); return SDValue();
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns. // Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
unsigned BinOp = 0; ISD::NodeType BinOp;
SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); SDValue Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
if (!Match) if (!Match)
return SDValue(); return SDValue();
@ -32143,8 +32084,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
return SDValue(); return SDValue();
// Match shuffle + add pyramid. // Match shuffle + add pyramid.
unsigned BinOp = 0; ISD::NodeType BinOp;
SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD}); SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
// The operand is expected to be zero extended from i8 // The operand is expected to be zero extended from i8
// (verified in detectZextAbsDiff). // (verified in detectZextAbsDiff).