forked from OSchip/llvm-project
[SelectionDAG] Make binop reduction matcher available to all targets
There is nothing x86-specific about this code, so it'd be nice to make this available for other targets to use in the future (and get it out of X86ISelLowering!). Differential Revision: https://reviews.llvm.org/D50083 llvm-svn: 338586
This commit is contained in:
parent
bed4babc56
commit
a3548c960e
|
@ -1503,6 +1503,15 @@ public:
|
|||
/// allow an 'add' to be transformed into an 'or'.
|
||||
bool haveNoCommonBitsSet(SDValue A, SDValue B) const;
|
||||
|
||||
/// Match a binop + shuffle pyramid that represents a horizontal reduction
|
||||
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
|
||||
/// Extract. The reduction must use one of the opcodes listed in /p
|
||||
/// CandidateBinOps and on success /p BinOp will contain the matching opcode.
|
||||
/// Returns the vector that is being reduced on, or SDValue() if a reduction
|
||||
/// was not matched.
|
||||
SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
|
||||
ArrayRef<ISD::NodeType> CandidateBinOps);
|
||||
|
||||
/// Utility function used by legalize and lowering to
|
||||
/// "unroll" a vector operation by splitting out the scalars and operating
|
||||
/// on each element individually. If the ResNE is 0, fully unroll the vector
|
||||
|
|
|
@ -8318,6 +8318,64 @@ void SDNode::intersectFlagsWith(const SDNodeFlags Flags) {
|
|||
this->Flags.intersectWith(Flags);
|
||||
}
|
||||
|
||||
SDValue
|
||||
SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
|
||||
ArrayRef<ISD::NodeType> CandidateBinOps) {
|
||||
// The pattern must end in an extract from index 0.
|
||||
if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
|
||||
!isNullConstant(Extract->getOperand(1)))
|
||||
return SDValue();
|
||||
|
||||
SDValue Op = Extract->getOperand(0);
|
||||
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
|
||||
|
||||
// Match against one of the candidate binary ops.
|
||||
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
|
||||
return Op.getOpcode() == unsigned(BinOp);
|
||||
}))
|
||||
return SDValue();
|
||||
|
||||
// At each stage, we're looking for something that looks like:
|
||||
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
|
||||
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
|
||||
// i32 undef, i32 undef, i32 undef, i32 undef>
|
||||
// %a = binop <8 x i32> %op, %s
|
||||
// Where the mask changes according to the stage. E.g. for a 3-stage pyramid,
|
||||
// we expect something like:
|
||||
// <4,5,6,7,u,u,u,u>
|
||||
// <2,3,u,u,u,u,u,u>
|
||||
// <1,u,u,u,u,u,u,u>
|
||||
unsigned CandidateBinOp = Op.getOpcode();
|
||||
for (unsigned i = 0; i < Stages; ++i) {
|
||||
if (Op.getOpcode() != CandidateBinOp)
|
||||
return SDValue();
|
||||
|
||||
SDValue Op0 = Op.getOperand(0);
|
||||
SDValue Op1 = Op.getOperand(1);
|
||||
|
||||
ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(Op0);
|
||||
if (Shuffle) {
|
||||
Op = Op1;
|
||||
} else {
|
||||
Shuffle = dyn_cast<ShuffleVectorSDNode>(Op1);
|
||||
Op = Op0;
|
||||
}
|
||||
|
||||
// The first operand of the shuffle should be the same as the other operand
|
||||
// of the binop.
|
||||
if (!Shuffle || Shuffle->getOperand(0) != Op)
|
||||
return SDValue();
|
||||
|
||||
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
|
||||
for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index)
|
||||
if (Shuffle->getMaskElt(Index) != MaskEnd + Index)
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
BinOp = (ISD::NodeType)CandidateBinOp;
|
||||
return Op;
|
||||
}
|
||||
|
||||
SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
|
||||
assert(N->getNumValues() == 1 &&
|
||||
"Can't unroll a vector with multiple results!");
|
||||
|
|
|
@ -31807,65 +31807,6 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
|
|||
return SDValue();
|
||||
}
|
||||
|
||||
// Match a binop + shuffle pyramid that represents a horizontal reduction over
|
||||
// the elements of a vector.
|
||||
// Returns the vector that is being reduced on, or SDValue() if a reduction
|
||||
// was not matched.
|
||||
static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
|
||||
ArrayRef<ISD::NodeType> CandidateBinOps) {
|
||||
// The pattern must end in an extract from index 0.
|
||||
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
|
||||
!isNullConstant(Extract->getOperand(1)))
|
||||
return SDValue();
|
||||
|
||||
SDValue Op = Extract->getOperand(0);
|
||||
unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
|
||||
|
||||
// Match against one of the candidate binary ops.
|
||||
if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
|
||||
return Op.getOpcode() == unsigned(BinOp);
|
||||
}))
|
||||
return SDValue();
|
||||
|
||||
// At each stage, we're looking for something that looks like:
|
||||
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
|
||||
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
|
||||
// i32 undef, i32 undef, i32 undef, i32 undef>
|
||||
// %a = binop <8 x i32> %op, %s
|
||||
// Where the mask changes according to the stage. E.g. for a 3-stage pyramid,
|
||||
// we expect something like:
|
||||
// <4,5,6,7,u,u,u,u>
|
||||
// <2,3,u,u,u,u,u,u>
|
||||
// <1,u,u,u,u,u,u,u>
|
||||
unsigned CandidateBinOp = Op.getOpcode();
|
||||
for (unsigned i = 0; i < Stages; ++i) {
|
||||
if (Op.getOpcode() != CandidateBinOp)
|
||||
return SDValue();
|
||||
|
||||
ShuffleVectorSDNode *Shuffle =
|
||||
dyn_cast<ShuffleVectorSDNode>(Op.getOperand(0).getNode());
|
||||
if (Shuffle) {
|
||||
Op = Op.getOperand(1);
|
||||
} else {
|
||||
Shuffle = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(1).getNode());
|
||||
Op = Op.getOperand(0);
|
||||
}
|
||||
|
||||
// The first operand of the shuffle should be the same as the other operand
|
||||
// of the binop.
|
||||
if (!Shuffle || Shuffle->getOperand(0) != Op)
|
||||
return SDValue();
|
||||
|
||||
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
|
||||
for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index)
|
||||
if (Shuffle->getMaskElt(Index) != MaskEnd + Index)
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
BinOp = CandidateBinOp;
|
||||
return Op;
|
||||
}
|
||||
|
||||
// Given a select, detect the following pattern:
|
||||
// 1: %2 = zext <N x i8> %0 to <N x i32>
|
||||
// 2: %3 = zext <N x i8> %1 to <N x i32>
|
||||
|
@ -31980,8 +31921,8 @@ static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG,
|
|||
return SDValue();
|
||||
|
||||
// Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns.
|
||||
unsigned BinOp;
|
||||
SDValue Src = matchBinOpReduction(
|
||||
ISD::NodeType BinOp;
|
||||
SDValue Src = DAG.matchBinOpReduction(
|
||||
Extract, BinOp, {ISD::SMAX, ISD::SMIN, ISD::UMAX, ISD::UMIN});
|
||||
if (!Src)
|
||||
return SDValue();
|
||||
|
@ -32060,8 +32001,8 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract,
|
|||
return SDValue();
|
||||
|
||||
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
|
||||
unsigned BinOp = 0;
|
||||
SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
|
||||
ISD::NodeType BinOp;
|
||||
SDValue Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
|
||||
if (!Match)
|
||||
return SDValue();
|
||||
|
||||
|
@ -32143,8 +32084,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
|
|||
return SDValue();
|
||||
|
||||
// Match shuffle + add pyramid.
|
||||
unsigned BinOp = 0;
|
||||
SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
|
||||
ISD::NodeType BinOp;
|
||||
SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
|
||||
|
||||
// The operand is expected to be zero extended from i8
|
||||
// (verified in detectZextAbsDiff).
|
||||
|
|
Loading…
Reference in New Issue