[DAG] Add BuildVectorSDNode::getRepeatedSequence helper to recognise multi-element splat patterns

Replace the X86 specific isSplatZeroExtended helper with a generic BuildVectorSDNode method.

I've just used this to simplify the X86ISD::BROADCASTM lowering so far (and remove isSplatZeroExtended), but we should be able to use this in more places to lower to complex broadcast patterns.

Differential Revision: https://reviews.llvm.org/D87930
This commit is contained in:
Simon Pilgrim 2020-10-24 12:23:09 +01:00
parent 62b17a7697
commit ce356e1546
4 changed files with 229 additions and 57 deletions

View File

@ -1908,6 +1908,33 @@ public:
/// the vector width and set the bits where elements are undef.
SDValue getSplatValue(BitVector *UndefElements = nullptr) const;
/// Find the shortest repeating sequence of values in the build vector.
///
/// e.g. { u, X, u, X, u, u, X, u } -> { X }
/// { X, Y, u, Y, u, u, X, u } -> { X, Y }
///
/// Currently this must be a power-of-2 build vector.
/// The DemandedElts mask indicates the elements that must be present,
/// undemanded elements in Sequence may be null (SDValue()). If passed a
/// non-null UndefElements bitvector, it will resize it to match the original
/// vector width and set the bits where elements are undef. If result is
/// false, Sequence will be empty.
bool getRepeatedSequence(const APInt &DemandedElts,
SmallVectorImpl<SDValue> &Sequence,
BitVector *UndefElements = nullptr) const;
/// Find the shortest repeating sequence of values in the build vector.
///
/// e.g. { u, X, u, X, u, u, X, u } -> { X }
/// { X, Y, u, Y, u, u, X, u } -> { X, Y }
///
/// Currently this must be a power-of-2 build vector.
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the original vector width and set the bits where elements are undef.
/// If result is false, Sequence will be empty.
bool getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
BitVector *UndefElements = nullptr) const;
/// Returns the demanded splatted constant or null if this is not a constant
/// splat.
///

View File

@ -9870,6 +9870,58 @@ SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
return getSplatValue(DemandedElts, UndefElements);
}
bool BuildVectorSDNode::getRepeatedSequence(const APInt &DemandedElts,
SmallVectorImpl<SDValue> &Sequence,
BitVector *UndefElements) const {
unsigned NumOps = getNumOperands();
Sequence.clear();
if (UndefElements) {
UndefElements->clear();
UndefElements->resize(NumOps);
}
assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size");
if (!DemandedElts || NumOps < 2 || !isPowerOf2_32(NumOps))
return false;
// Set the undefs even if we don't find a sequence (like getSplatValue).
if (UndefElements)
for (unsigned I = 0; I != NumOps; ++I)
if (DemandedElts[I] && getOperand(I).isUndef())
(*UndefElements)[I] = true;
// Iteratively widen the sequence length looking for repetitions.
for (unsigned SeqLen = 1; SeqLen < NumOps; SeqLen *= 2) {
Sequence.append(SeqLen, SDValue());
for (unsigned I = 0; I != NumOps; ++I) {
if (!DemandedElts[I])
continue;
SDValue &SeqOp = Sequence[I % SeqLen];
SDValue Op = getOperand(I);
if (Op.isUndef()) {
if (!SeqOp)
SeqOp = Op;
continue;
}
if (SeqOp && !SeqOp.isUndef() && SeqOp != Op) {
Sequence.clear();
break;
}
SeqOp = Op;
}
if (!Sequence.empty())
return true;
}
assert(Sequence.empty() && "Failed to empty non-repeating sequence pattern");
return false;
}
bool BuildVectorSDNode::getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
BitVector *UndefElements) const {
APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
return getRepeatedSequence(DemandedElts, Sequence, UndefElements);
}
ConstantSDNode *
BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
BitVector *UndefElements) const {

View File

@ -8659,43 +8659,6 @@ static bool isFoldableUseOfShuffle(SDNode *N) {
return false;
}
// Check if the current node of build vector is a zero extended vector.
// // If so, return the value extended.
// // For example: (0,0,0,a,0,0,0,a,0,0,0,a,0,0,0,a) returns a.
// // NumElt - return the number of zero extended identical values.
// // EltType - return the type of the value include the zero extend.
static SDValue isSplatZeroExtended(const BuildVectorSDNode *Op,
unsigned &NumElt, MVT &EltType) {
SDValue ExtValue = Op->getOperand(0);
unsigned NumElts = Op->getNumOperands();
unsigned Delta = NumElts;
for (unsigned i = 1; i < NumElts; i++) {
if (Op->getOperand(i) == ExtValue) {
Delta = i;
break;
}
if (!(Op->getOperand(i).isUndef() || isNullConstant(Op->getOperand(i))))
return SDValue();
}
if (!isPowerOf2_32(Delta) || Delta == 1)
return SDValue();
for (unsigned i = Delta; i < NumElts; i++) {
if (i % Delta == 0) {
if (Op->getOperand(i) != ExtValue)
return SDValue();
} else if (!(isNullConstant(Op->getOperand(i)) ||
Op->getOperand(i).isUndef()))
return SDValue();
}
unsigned EltSize = Op->getSimpleValueType(0).getScalarSizeInBits();
unsigned ExtVTSize = EltSize * Delta;
EltType = MVT::getIntegerVT(ExtVTSize);
NumElt = NumElts / Delta;
return ExtValue;
}
/// Attempt to use the vbroadcast instruction to generate a splat value
/// from a splat BUILD_VECTOR which uses:
/// a. A single scalar load, or a constant.
@ -8713,13 +8676,21 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
return SDValue();
MVT VT = BVOp->getSimpleValueType(0);
unsigned NumElts = VT.getVectorNumElements();
SDLoc dl(BVOp);
assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
"Unsupported vector type for broadcast.");
// See if the build vector is a repeating sequence of scalars (inc. splat).
SDValue Ld;
BitVector UndefElements;
SDValue Ld = BVOp->getSplatValue(&UndefElements);
SmallVector<SDValue, 16> Sequence;
if (BVOp->getRepeatedSequence(Sequence, &UndefElements)) {
assert((NumElts % Sequence.size()) == 0 && "Sequence doesn't fit.");
if (Sequence.size() == 1)
Ld = Sequence[0];
}
// Attempt to use VBROADCASTM
// From this pattern:
@ -8727,29 +8698,29 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
// b. t1 = (build_vector t0 t0)
//
// Create (VBROADCASTM v2i1 X)
if (Subtarget.hasCDI()) {
MVT EltType = VT.getScalarType();
unsigned NumElts = VT.getVectorNumElements();
SDValue BOperand;
SDValue ZeroExtended = isSplatZeroExtended(BVOp, NumElts, EltType);
if ((ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) ||
(ZeroExtended && ZeroExtended.getOpcode() == ISD::ZERO_EXTEND &&
ZeroExtended.getOperand(0).getOpcode() == ISD::BITCAST) ||
(Ld && Ld.getOpcode() == ISD::ZERO_EXTEND &&
Ld.getOperand(0).getOpcode() == ISD::BITCAST)) {
if (ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST)
BOperand = ZeroExtended.getOperand(0);
else if (ZeroExtended)
BOperand = ZeroExtended.getOperand(0).getOperand(0);
else
BOperand = Ld.getOperand(0).getOperand(0);
if (!Sequence.empty() && Subtarget.hasCDI()) {
// If not a splat, are the upper sequence values zeroable?
unsigned SeqLen = Sequence.size();
bool UpperZeroOrUndef =
SeqLen == 1 ||
llvm::all_of(makeArrayRef(Sequence).drop_front(), [](SDValue V) {
return !V || V.isUndef() || isNullConstant(V);
});
SDValue Op0 = Sequence[0];
if (UpperZeroOrUndef && ((Op0.getOpcode() == ISD::BITCAST) ||
(Op0.getOpcode() == ISD::ZERO_EXTEND &&
Op0.getOperand(0).getOpcode() == ISD::BITCAST))) {
SDValue BOperand = Op0.getOpcode() == ISD::BITCAST
? Op0.getOperand(0)
: Op0.getOperand(0).getOperand(0);
MVT MaskVT = BOperand.getSimpleValueType();
MVT EltType = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen);
if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) || // for broadcastmb2q
(EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d
MVT BcstVT = MVT::getVectorVT(EltType, NumElts);
MVT BcstVT = MVT::getVectorVT(EltType, NumElts / SeqLen);
if (!VT.is512BitVector() && !Subtarget.hasVLX()) {
unsigned Scale = 512 / VT.getSizeInBits();
BcstVT = MVT::getVectorVT(EltType, NumElts * Scale);
BcstVT = MVT::getVectorVT(EltType, Scale * (NumElts / SeqLen));
}
SDValue Bcst = DAG.getNode(X86ISD::VBROADCASTM, dl, BcstVT, BOperand);
if (BcstVT.getSizeInBits() != VT.getSizeInBits())
@ -8759,7 +8730,6 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
}
}
unsigned NumElts = VT.getVectorNumElements();
unsigned NumUndefElts = UndefElements.count();
if (!Ld || (NumElts - NumUndefElts) <= 1) {
APInt SplatValue, Undef;
@ -8833,6 +8803,8 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
(Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP);
bool IsLoad = ISD::isNormalLoad(Ld.getNode());
// TODO: Handle broadcasts of non-constant sequences.
// Make sure that all of the users of a non-constant load are from the
// BUILD_VECTOR node.
// FIXME: Is the use count needed for non-constant, non-load case?

View File

@ -472,6 +472,127 @@ TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTO
EXPECT_EQ(SplatIdx, 0);
}
TEST_F(AArch64SelectionDAGTest, getRepeatedSequence_Patterns) {
if (!TM)
return;
TargetLowering TL(*TM);
SDLoc Loc;
unsigned NumElts = 16;
MVT IntVT = MVT::i8;
MVT VecVT = MVT::getVectorVT(IntVT, NumElts);
// Base scalar constants.
SDValue Val0 = DAG->getConstant(0, Loc, IntVT);
SDValue Val1 = DAG->getConstant(1, Loc, IntVT);
SDValue Val2 = DAG->getConstant(2, Loc, IntVT);
SDValue Val3 = DAG->getConstant(3, Loc, IntVT);
SDValue UndefVal = DAG->getUNDEF(IntVT);
// Build some repeating sequences.
SmallVector<SDValue, 16> Pattern1111, Pattern1133, Pattern0123;
for(int I = 0; I != 4; ++I) {
Pattern1111.append(4, Val1);
Pattern1133.append(2, Val1);
Pattern1133.append(2, Val3);
Pattern0123.push_back(Val0);
Pattern0123.push_back(Val1);
Pattern0123.push_back(Val2);
Pattern0123.push_back(Val3);
}
// Build a non-pow2 repeating sequence.
SmallVector<SDValue, 16> Pattern022;
Pattern022.push_back(Val0);
Pattern022.append(2, Val2);
Pattern022.push_back(Val0);
Pattern022.append(2, Val2);
Pattern022.push_back(Val0);
Pattern022.append(2, Val2);
Pattern022.push_back(Val0);
Pattern022.append(2, Val2);
Pattern022.push_back(Val0);
Pattern022.append(2, Val2);
Pattern022.push_back(Val0);
// Build a non-repeating sequence.
SmallVector<SDValue, 16> Pattern1_3;
Pattern1_3.append(8, Val1);
Pattern1_3.append(8, Val3);
// Add some undefs to make it trickier.
Pattern1111[1] = Pattern1111[2] = Pattern1111[15] = UndefVal;
Pattern1133[0] = Pattern1133[2] = UndefVal;
auto *BV1111 =
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1111));
auto *BV1133 =
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1133));
auto *BV0123=
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern0123));
auto *BV022 =
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern022));
auto *BV1_3 =
cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1_3));
// Check for sequences.
SmallVector<SDValue, 16> Seq1111, Seq1133, Seq0123, Seq022, Seq1_3;
BitVector Undefs1111, Undefs1133, Undefs0123, Undefs022, Undefs1_3;
EXPECT_TRUE(BV1111->getRepeatedSequence(Seq1111, &Undefs1111));
EXPECT_EQ(Undefs1111.count(), 3);
EXPECT_EQ(Seq1111.size(), 1);
EXPECT_EQ(Seq1111[0], Val1);
EXPECT_TRUE(BV1133->getRepeatedSequence(Seq1133, &Undefs1133));
EXPECT_EQ(Undefs1133.count(), 2);
EXPECT_EQ(Seq1133.size(), 4);
EXPECT_EQ(Seq1133[0], Val1);
EXPECT_EQ(Seq1133[1], Val1);
EXPECT_EQ(Seq1133[2], Val3);
EXPECT_EQ(Seq1133[3], Val3);
EXPECT_TRUE(BV0123->getRepeatedSequence(Seq0123, &Undefs0123));
EXPECT_EQ(Undefs0123.count(), 0);
EXPECT_EQ(Seq0123.size(), 4);
EXPECT_EQ(Seq0123[0], Val0);
EXPECT_EQ(Seq0123[1], Val1);
EXPECT_EQ(Seq0123[2], Val2);
EXPECT_EQ(Seq0123[3], Val3);
EXPECT_FALSE(BV022->getRepeatedSequence(Seq022, &Undefs022));
EXPECT_FALSE(BV1_3->getRepeatedSequence(Seq1_3, &Undefs1_3));
// Try again with DemandedElts masks.
APInt Mask1111_0 = APInt::getOneBitSet(NumElts, 0);
EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_0, Seq1111, &Undefs1111));
EXPECT_EQ(Undefs1111.count(), 0);
EXPECT_EQ(Seq1111.size(), 1);
EXPECT_EQ(Seq1111[0], Val1);
APInt Mask1111_1 = APInt::getOneBitSet(NumElts, 2);
EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_1, Seq1111, &Undefs1111));
EXPECT_EQ(Undefs1111.count(), 1);
EXPECT_EQ(Seq1111.size(), 1);
EXPECT_EQ(Seq1111[0], UndefVal);
APInt Mask0123 = APInt(NumElts, 0x7777);
EXPECT_TRUE(BV0123->getRepeatedSequence(Mask0123, Seq0123, &Undefs0123));
EXPECT_EQ(Undefs0123.count(), 0);
EXPECT_EQ(Seq0123.size(), 4);
EXPECT_EQ(Seq0123[0], Val0);
EXPECT_EQ(Seq0123[1], Val1);
EXPECT_EQ(Seq0123[2], Val2);
EXPECT_EQ(Seq0123[3], SDValue());
APInt Mask1_3 = APInt::getHighBitsSet(16, 8);
EXPECT_TRUE(BV1_3->getRepeatedSequence(Mask1_3, Seq1_3, &Undefs1_3));
EXPECT_EQ(Undefs1_3.count(), 0);
EXPECT_EQ(Seq1_3.size(), 1);
EXPECT_EQ(Seq1_3[0], Val3);
}
TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) {
if (!TM)
return;