forked from OSchip/llvm-project
[DAG] Add BuildVectorSDNode::getRepeatedSequence helper to recognise multi-element splat patterns
Replace the X86 specific isSplatZeroExtended helper with a generic BuildVectorSDNode method. I've just used this to simplify the X86ISD::BROADCASTM lowering so far (and remove isSplatZeroExtended), but we should be able to use this in more places to lower to complex broadcast patterns. Differential Revision: https://reviews.llvm.org/D87930
This commit is contained in:
parent
62b17a7697
commit
ce356e1546
|
@ -1908,6 +1908,33 @@ public:
|
|||
/// the vector width and set the bits where elements are undef.
|
||||
SDValue getSplatValue(BitVector *UndefElements = nullptr) const;
|
||||
|
||||
/// Find the shortest repeating sequence of values in the build vector.
|
||||
///
|
||||
/// e.g. { u, X, u, X, u, u, X, u } -> { X }
|
||||
/// { X, Y, u, Y, u, u, X, u } -> { X, Y }
|
||||
///
|
||||
/// Currently this must be a power-of-2 build vector.
|
||||
/// The DemandedElts mask indicates the elements that must be present,
|
||||
/// undemanded elements in Sequence may be null (SDValue()). If passed a
|
||||
/// non-null UndefElements bitvector, it will resize it to match the original
|
||||
/// vector width and set the bits where elements are undef. If result is
|
||||
/// false, Sequence will be empty.
|
||||
bool getRepeatedSequence(const APInt &DemandedElts,
|
||||
SmallVectorImpl<SDValue> &Sequence,
|
||||
BitVector *UndefElements = nullptr) const;
|
||||
|
||||
/// Find the shortest repeating sequence of values in the build vector.
|
||||
///
|
||||
/// e.g. { u, X, u, X, u, u, X, u } -> { X }
|
||||
/// { X, Y, u, Y, u, u, X, u } -> { X, Y }
|
||||
///
|
||||
/// Currently this must be a power-of-2 build vector.
|
||||
/// If passed a non-null UndefElements bitvector, it will resize it to match
|
||||
/// the original vector width and set the bits where elements are undef.
|
||||
/// If result is false, Sequence will be empty.
|
||||
bool getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
|
||||
BitVector *UndefElements = nullptr) const;
|
||||
|
||||
/// Returns the demanded splatted constant or null if this is not a constant
|
||||
/// splat.
|
||||
///
|
||||
|
|
|
@ -9870,6 +9870,58 @@ SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
|
|||
return getSplatValue(DemandedElts, UndefElements);
|
||||
}
|
||||
|
||||
bool BuildVectorSDNode::getRepeatedSequence(const APInt &DemandedElts,
|
||||
SmallVectorImpl<SDValue> &Sequence,
|
||||
BitVector *UndefElements) const {
|
||||
unsigned NumOps = getNumOperands();
|
||||
Sequence.clear();
|
||||
if (UndefElements) {
|
||||
UndefElements->clear();
|
||||
UndefElements->resize(NumOps);
|
||||
}
|
||||
assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size");
|
||||
if (!DemandedElts || NumOps < 2 || !isPowerOf2_32(NumOps))
|
||||
return false;
|
||||
|
||||
// Set the undefs even if we don't find a sequence (like getSplatValue).
|
||||
if (UndefElements)
|
||||
for (unsigned I = 0; I != NumOps; ++I)
|
||||
if (DemandedElts[I] && getOperand(I).isUndef())
|
||||
(*UndefElements)[I] = true;
|
||||
|
||||
// Iteratively widen the sequence length looking for repetitions.
|
||||
for (unsigned SeqLen = 1; SeqLen < NumOps; SeqLen *= 2) {
|
||||
Sequence.append(SeqLen, SDValue());
|
||||
for (unsigned I = 0; I != NumOps; ++I) {
|
||||
if (!DemandedElts[I])
|
||||
continue;
|
||||
SDValue &SeqOp = Sequence[I % SeqLen];
|
||||
SDValue Op = getOperand(I);
|
||||
if (Op.isUndef()) {
|
||||
if (!SeqOp)
|
||||
SeqOp = Op;
|
||||
continue;
|
||||
}
|
||||
if (SeqOp && !SeqOp.isUndef() && SeqOp != Op) {
|
||||
Sequence.clear();
|
||||
break;
|
||||
}
|
||||
SeqOp = Op;
|
||||
}
|
||||
if (!Sequence.empty())
|
||||
return true;
|
||||
}
|
||||
|
||||
assert(Sequence.empty() && "Failed to empty non-repeating sequence pattern");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BuildVectorSDNode::getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
|
||||
BitVector *UndefElements) const {
|
||||
APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
|
||||
return getRepeatedSequence(DemandedElts, Sequence, UndefElements);
|
||||
}
|
||||
|
||||
ConstantSDNode *
|
||||
BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
|
||||
BitVector *UndefElements) const {
|
||||
|
|
|
@ -8659,43 +8659,6 @@ static bool isFoldableUseOfShuffle(SDNode *N) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Check if the current node of build vector is a zero extended vector.
|
||||
// // If so, return the value extended.
|
||||
// // For example: (0,0,0,a,0,0,0,a,0,0,0,a,0,0,0,a) returns a.
|
||||
// // NumElt - return the number of zero extended identical values.
|
||||
// // EltType - return the type of the value include the zero extend.
|
||||
static SDValue isSplatZeroExtended(const BuildVectorSDNode *Op,
|
||||
unsigned &NumElt, MVT &EltType) {
|
||||
SDValue ExtValue = Op->getOperand(0);
|
||||
unsigned NumElts = Op->getNumOperands();
|
||||
unsigned Delta = NumElts;
|
||||
|
||||
for (unsigned i = 1; i < NumElts; i++) {
|
||||
if (Op->getOperand(i) == ExtValue) {
|
||||
Delta = i;
|
||||
break;
|
||||
}
|
||||
if (!(Op->getOperand(i).isUndef() || isNullConstant(Op->getOperand(i))))
|
||||
return SDValue();
|
||||
}
|
||||
if (!isPowerOf2_32(Delta) || Delta == 1)
|
||||
return SDValue();
|
||||
|
||||
for (unsigned i = Delta; i < NumElts; i++) {
|
||||
if (i % Delta == 0) {
|
||||
if (Op->getOperand(i) != ExtValue)
|
||||
return SDValue();
|
||||
} else if (!(isNullConstant(Op->getOperand(i)) ||
|
||||
Op->getOperand(i).isUndef()))
|
||||
return SDValue();
|
||||
}
|
||||
unsigned EltSize = Op->getSimpleValueType(0).getScalarSizeInBits();
|
||||
unsigned ExtVTSize = EltSize * Delta;
|
||||
EltType = MVT::getIntegerVT(ExtVTSize);
|
||||
NumElt = NumElts / Delta;
|
||||
return ExtValue;
|
||||
}
|
||||
|
||||
/// Attempt to use the vbroadcast instruction to generate a splat value
|
||||
/// from a splat BUILD_VECTOR which uses:
|
||||
/// a. A single scalar load, or a constant.
|
||||
|
@ -8713,13 +8676,21 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
|
|||
return SDValue();
|
||||
|
||||
MVT VT = BVOp->getSimpleValueType(0);
|
||||
unsigned NumElts = VT.getVectorNumElements();
|
||||
SDLoc dl(BVOp);
|
||||
|
||||
assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
|
||||
"Unsupported vector type for broadcast.");
|
||||
|
||||
// See if the build vector is a repeating sequence of scalars (inc. splat).
|
||||
SDValue Ld;
|
||||
BitVector UndefElements;
|
||||
SDValue Ld = BVOp->getSplatValue(&UndefElements);
|
||||
SmallVector<SDValue, 16> Sequence;
|
||||
if (BVOp->getRepeatedSequence(Sequence, &UndefElements)) {
|
||||
assert((NumElts % Sequence.size()) == 0 && "Sequence doesn't fit.");
|
||||
if (Sequence.size() == 1)
|
||||
Ld = Sequence[0];
|
||||
}
|
||||
|
||||
// Attempt to use VBROADCASTM
|
||||
// From this pattern:
|
||||
|
@ -8727,29 +8698,29 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
|
|||
// b. t1 = (build_vector t0 t0)
|
||||
//
|
||||
// Create (VBROADCASTM v2i1 X)
|
||||
if (Subtarget.hasCDI()) {
|
||||
MVT EltType = VT.getScalarType();
|
||||
unsigned NumElts = VT.getVectorNumElements();
|
||||
SDValue BOperand;
|
||||
SDValue ZeroExtended = isSplatZeroExtended(BVOp, NumElts, EltType);
|
||||
if ((ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) ||
|
||||
(ZeroExtended && ZeroExtended.getOpcode() == ISD::ZERO_EXTEND &&
|
||||
ZeroExtended.getOperand(0).getOpcode() == ISD::BITCAST) ||
|
||||
(Ld && Ld.getOpcode() == ISD::ZERO_EXTEND &&
|
||||
Ld.getOperand(0).getOpcode() == ISD::BITCAST)) {
|
||||
if (ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST)
|
||||
BOperand = ZeroExtended.getOperand(0);
|
||||
else if (ZeroExtended)
|
||||
BOperand = ZeroExtended.getOperand(0).getOperand(0);
|
||||
else
|
||||
BOperand = Ld.getOperand(0).getOperand(0);
|
||||
if (!Sequence.empty() && Subtarget.hasCDI()) {
|
||||
// If not a splat, are the upper sequence values zeroable?
|
||||
unsigned SeqLen = Sequence.size();
|
||||
bool UpperZeroOrUndef =
|
||||
SeqLen == 1 ||
|
||||
llvm::all_of(makeArrayRef(Sequence).drop_front(), [](SDValue V) {
|
||||
return !V || V.isUndef() || isNullConstant(V);
|
||||
});
|
||||
SDValue Op0 = Sequence[0];
|
||||
if (UpperZeroOrUndef && ((Op0.getOpcode() == ISD::BITCAST) ||
|
||||
(Op0.getOpcode() == ISD::ZERO_EXTEND &&
|
||||
Op0.getOperand(0).getOpcode() == ISD::BITCAST))) {
|
||||
SDValue BOperand = Op0.getOpcode() == ISD::BITCAST
|
||||
? Op0.getOperand(0)
|
||||
: Op0.getOperand(0).getOperand(0);
|
||||
MVT MaskVT = BOperand.getSimpleValueType();
|
||||
MVT EltType = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen);
|
||||
if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) || // for broadcastmb2q
|
||||
(EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d
|
||||
MVT BcstVT = MVT::getVectorVT(EltType, NumElts);
|
||||
MVT BcstVT = MVT::getVectorVT(EltType, NumElts / SeqLen);
|
||||
if (!VT.is512BitVector() && !Subtarget.hasVLX()) {
|
||||
unsigned Scale = 512 / VT.getSizeInBits();
|
||||
BcstVT = MVT::getVectorVT(EltType, NumElts * Scale);
|
||||
BcstVT = MVT::getVectorVT(EltType, Scale * (NumElts / SeqLen));
|
||||
}
|
||||
SDValue Bcst = DAG.getNode(X86ISD::VBROADCASTM, dl, BcstVT, BOperand);
|
||||
if (BcstVT.getSizeInBits() != VT.getSizeInBits())
|
||||
|
@ -8759,7 +8730,6 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
|
|||
}
|
||||
}
|
||||
|
||||
unsigned NumElts = VT.getVectorNumElements();
|
||||
unsigned NumUndefElts = UndefElements.count();
|
||||
if (!Ld || (NumElts - NumUndefElts) <= 1) {
|
||||
APInt SplatValue, Undef;
|
||||
|
@ -8833,6 +8803,8 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
|
|||
(Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP);
|
||||
bool IsLoad = ISD::isNormalLoad(Ld.getNode());
|
||||
|
||||
// TODO: Handle broadcasts of non-constant sequences.
|
||||
|
||||
// Make sure that all of the users of a non-constant load are from the
|
||||
// BUILD_VECTOR node.
|
||||
// FIXME: Is the use count needed for non-constant, non-load case?
|
||||
|
|
|
@ -472,6 +472,127 @@ TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTO
|
|||
EXPECT_EQ(SplatIdx, 0);
|
||||
}
|
||||
|
||||
TEST_F(AArch64SelectionDAGTest, getRepeatedSequence_Patterns) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
||||
TargetLowering TL(*TM);
|
||||
|
||||
SDLoc Loc;
|
||||
unsigned NumElts = 16;
|
||||
MVT IntVT = MVT::i8;
|
||||
MVT VecVT = MVT::getVectorVT(IntVT, NumElts);
|
||||
|
||||
// Base scalar constants.
|
||||
SDValue Val0 = DAG->getConstant(0, Loc, IntVT);
|
||||
SDValue Val1 = DAG->getConstant(1, Loc, IntVT);
|
||||
SDValue Val2 = DAG->getConstant(2, Loc, IntVT);
|
||||
SDValue Val3 = DAG->getConstant(3, Loc, IntVT);
|
||||
SDValue UndefVal = DAG->getUNDEF(IntVT);
|
||||
|
||||
// Build some repeating sequences.
|
||||
SmallVector<SDValue, 16> Pattern1111, Pattern1133, Pattern0123;
|
||||
for(int I = 0; I != 4; ++I) {
|
||||
Pattern1111.append(4, Val1);
|
||||
Pattern1133.append(2, Val1);
|
||||
Pattern1133.append(2, Val3);
|
||||
Pattern0123.push_back(Val0);
|
||||
Pattern0123.push_back(Val1);
|
||||
Pattern0123.push_back(Val2);
|
||||
Pattern0123.push_back(Val3);
|
||||
}
|
||||
|
||||
// Build a non-pow2 repeating sequence.
|
||||
SmallVector<SDValue, 16> Pattern022;
|
||||
Pattern022.push_back(Val0);
|
||||
Pattern022.append(2, Val2);
|
||||
Pattern022.push_back(Val0);
|
||||
Pattern022.append(2, Val2);
|
||||
Pattern022.push_back(Val0);
|
||||
Pattern022.append(2, Val2);
|
||||
Pattern022.push_back(Val0);
|
||||
Pattern022.append(2, Val2);
|
||||
Pattern022.push_back(Val0);
|
||||
Pattern022.append(2, Val2);
|
||||
Pattern022.push_back(Val0);
|
||||
|
||||
// Build a non-repeating sequence.
|
||||
SmallVector<SDValue, 16> Pattern1_3;
|
||||
Pattern1_3.append(8, Val1);
|
||||
Pattern1_3.append(8, Val3);
|
||||
|
||||
// Add some undefs to make it trickier.
|
||||
Pattern1111[1] = Pattern1111[2] = Pattern1111[15] = UndefVal;
|
||||
Pattern1133[0] = Pattern1133[2] = UndefVal;
|
||||
|
||||
auto *BV1111 =
|
||||
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1111));
|
||||
auto *BV1133 =
|
||||
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1133));
|
||||
auto *BV0123=
|
||||
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern0123));
|
||||
auto *BV022 =
|
||||
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern022));
|
||||
auto *BV1_3 =
|
||||
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1_3));
|
||||
|
||||
// Check for sequences.
|
||||
SmallVector<SDValue, 16> Seq1111, Seq1133, Seq0123, Seq022, Seq1_3;
|
||||
BitVector Undefs1111, Undefs1133, Undefs0123, Undefs022, Undefs1_3;
|
||||
|
||||
EXPECT_TRUE(BV1111->getRepeatedSequence(Seq1111, &Undefs1111));
|
||||
EXPECT_EQ(Undefs1111.count(), 3);
|
||||
EXPECT_EQ(Seq1111.size(), 1);
|
||||
EXPECT_EQ(Seq1111[0], Val1);
|
||||
|
||||
EXPECT_TRUE(BV1133->getRepeatedSequence(Seq1133, &Undefs1133));
|
||||
EXPECT_EQ(Undefs1133.count(), 2);
|
||||
EXPECT_EQ(Seq1133.size(), 4);
|
||||
EXPECT_EQ(Seq1133[0], Val1);
|
||||
EXPECT_EQ(Seq1133[1], Val1);
|
||||
EXPECT_EQ(Seq1133[2], Val3);
|
||||
EXPECT_EQ(Seq1133[3], Val3);
|
||||
|
||||
EXPECT_TRUE(BV0123->getRepeatedSequence(Seq0123, &Undefs0123));
|
||||
EXPECT_EQ(Undefs0123.count(), 0);
|
||||
EXPECT_EQ(Seq0123.size(), 4);
|
||||
EXPECT_EQ(Seq0123[0], Val0);
|
||||
EXPECT_EQ(Seq0123[1], Val1);
|
||||
EXPECT_EQ(Seq0123[2], Val2);
|
||||
EXPECT_EQ(Seq0123[3], Val3);
|
||||
|
||||
EXPECT_FALSE(BV022->getRepeatedSequence(Seq022, &Undefs022));
|
||||
EXPECT_FALSE(BV1_3->getRepeatedSequence(Seq1_3, &Undefs1_3));
|
||||
|
||||
// Try again with DemandedElts masks.
|
||||
APInt Mask1111_0 = APInt::getOneBitSet(NumElts, 0);
|
||||
EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_0, Seq1111, &Undefs1111));
|
||||
EXPECT_EQ(Undefs1111.count(), 0);
|
||||
EXPECT_EQ(Seq1111.size(), 1);
|
||||
EXPECT_EQ(Seq1111[0], Val1);
|
||||
|
||||
APInt Mask1111_1 = APInt::getOneBitSet(NumElts, 2);
|
||||
EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_1, Seq1111, &Undefs1111));
|
||||
EXPECT_EQ(Undefs1111.count(), 1);
|
||||
EXPECT_EQ(Seq1111.size(), 1);
|
||||
EXPECT_EQ(Seq1111[0], UndefVal);
|
||||
|
||||
APInt Mask0123 = APInt(NumElts, 0x7777);
|
||||
EXPECT_TRUE(BV0123->getRepeatedSequence(Mask0123, Seq0123, &Undefs0123));
|
||||
EXPECT_EQ(Undefs0123.count(), 0);
|
||||
EXPECT_EQ(Seq0123.size(), 4);
|
||||
EXPECT_EQ(Seq0123[0], Val0);
|
||||
EXPECT_EQ(Seq0123[1], Val1);
|
||||
EXPECT_EQ(Seq0123[2], Val2);
|
||||
EXPECT_EQ(Seq0123[3], SDValue());
|
||||
|
||||
APInt Mask1_3 = APInt::getHighBitsSet(16, 8);
|
||||
EXPECT_TRUE(BV1_3->getRepeatedSequence(Mask1_3, Seq1_3, &Undefs1_3));
|
||||
EXPECT_EQ(Undefs1_3.count(), 0);
|
||||
EXPECT_EQ(Seq1_3.size(), 1);
|
||||
EXPECT_EQ(Seq1_3[0], Val3);
|
||||
}
|
||||
|
||||
TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) {
|
||||
if (!TM)
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue