Revert "[X86][AVX] Add getBROADCAST_LOAD helper function. NFCI."

This reverts commit 1cfecf4fc4.

This commit broke LLVM code generated through XLA by removing a
conditional on Ld->getExtensionType() == ISD::NON_EXTLOAD

This is not a perfect revert. The new function is left as other uses of
it exist now.
This commit is contained in:
Tres Popp 2021-07-27 16:49:15 +02:00
parent 70fa9479b2
commit d225de60c9
1 changed files with 60 additions and 25 deletions

View File

@ -16084,12 +16084,21 @@ static SDValue lowerV2X128Shuffle(const SDLoc &DL, MVT VT, SDValue V1,
bool SplatHi = isShuffleEquivalent(Mask, {2, 3, 2, 3}, V1);
if ((SplatLo || SplatHi) && !Subtarget.hasAVX512() && V1.hasOneUse() &&
MayFoldLoad(peekThroughOneUseBitcasts(V1))) {
MVT MemVT = VT.getHalfNumVectorElementsVT();
unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
auto *Ld = cast<LoadSDNode>(peekThroughOneUseBitcasts(V1));
if (SDValue BcstLd = getBROADCAST_LOAD(X86ISD::SUBV_BROADCAST_LOAD, DL,
VT, MemVT, Ld, Ofs, DAG))
return BcstLd;
if (!Ld->isNonTemporal()) {
MVT MemVT = VT.getHalfNumVectorElementsVT();
unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
SDValue Ptr = DAG.getMemBasePlusOffset(Ld->getBasePtr(),
TypeSize::Fixed(Ofs), DL);
SDValue Ops[] = {Ld->getChain(), Ptr};
SDValue BcastLd = DAG.getMemIntrinsicNode(
X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops, MemVT,
DAG.getMachineFunction().getMachineMemOperand(
Ld->getMemOperand(), Ofs, MemVT.getStoreSize()));
DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
return BcastLd;
}
}
// With AVX2, use VPERMQ/VPERMPD for unary shuffles to allow memory folding.
@ -38011,7 +38020,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
return Res;
// Fold vperm2x128 subvector shuffle with an inner concat pattern.
// vperm2x128(concat(X,Y),concat(Z,W)) --> concat X,Y etc.
// vperm2x128(concat(X,Y),concat(Z,W)) --> concat X,Y etc.
auto FindSubVector128 = [&](unsigned Idx) {
if (Idx > 3)
return SDValue();
@ -38992,10 +39001,10 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
}
// Subvector broadcast.
case X86ISD::SUBV_BROADCAST_LOAD: {
SDLoc DL(Op);
auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
EVT MemVT = MemIntr->getMemoryVT();
if (ExtSizeInBits == MemVT.getStoreSizeInBits()) {
SDLoc DL(Op);
SDValue Ld =
TLO.DAG.getLoad(MemVT, DL, MemIntr->getChain(),
MemIntr->getBasePtr(), MemIntr->getMemOperand());
@ -39004,13 +39013,18 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Ld, 0,
TLO.DAG, DL, ExtSizeInBits));
} else if ((ExtSizeInBits % MemVT.getStoreSizeInBits()) == 0) {
SDLoc DL(Op);
EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
ExtSizeInBits / VT.getScalarSizeInBits());
if (SDValue BcstLd =
getBROADCAST_LOAD(Opc, DL, BcstVT, MemVT, MemIntr, 0, TLO.DAG))
return TLO.CombineTo(Op,
insertSubVector(TLO.DAG.getUNDEF(VT), BcstLd, 0,
TLO.DAG, DL, ExtSizeInBits));
SDVTList Tys = TLO.DAG.getVTList(BcstVT, MVT::Other);
SDValue Ops[] = {MemIntr->getOperand(0), MemIntr->getOperand(1)};
SDValue Bcst =
TLO.DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys,
Ops, MemVT, MemIntr->getMemOperand());
TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
Bcst.getValue(1));
return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0,
TLO.DAG, DL, ExtSizeInBits));
}
break;
}
@ -50083,21 +50097,36 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
if (Op0.getOpcode() == X86ISD::VBROADCAST)
return DAG.getNode(Op0.getOpcode(), DL, VT, Op0.getOperand(0));
// If this simple subvector or scalar/subvector broadcast_load is inserted
// into both halves, use a larger broadcast_load. Update other uses to use
// an extracted subvector.
if (Op0.getOpcode() == ISD::LOAD ||
Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
// If this scalar/subvector broadcast_load is inserted into both halves, use
// a larger broadcast_load. Update other uses to use an extracted subvector.
if (Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) {
auto *Mem = cast<MemSDNode>(Op0);
unsigned Opcode = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD
? X86ISD::VBROADCAST_LOAD
: X86ISD::SUBV_BROADCAST_LOAD;
if (SDValue BcastLd = getBROADCAST_LOAD(
Opcode, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) {
auto *MemIntr = cast<MemIntrinsicSDNode>(Op0);
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
SDValue Ops[] = {MemIntr->getChain(), MemIntr->getBasePtr()};
SDValue BcastLd = DAG.getMemIntrinsicNode(Op0.getOpcode(), DL, Tys, Ops,
MemIntr->getMemoryVT(),
MemIntr->getMemOperand());
DAG.ReplaceAllUsesOfValueWith(
Op0, extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
return BcastLd;
}
// If this is a simple subvector load repeated across multiple lanes, then
// broadcast the load. Update other uses to use an extracted subvector.
if (auto *Ld = dyn_cast<LoadSDNode>(Op0)) {
if (Ld->isSimple() && !Ld->isNonTemporal() &&
Ld->getExtensionType() == ISD::NON_EXTLOAD) {
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
SDValue Ops[] = {Ld->getChain(), Ld->getBasePtr()};
SDValue BcastLd =
DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops,
Ld->getMemoryVT(), Ld->getMemOperand());
DAG.ReplaceAllUsesOfValueWith(
Op0,
extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
return BcastLd;
}
}
@ -50461,8 +50490,14 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
if (Vec.isUndef() && IdxVal != 0 && SubVec.hasOneUse() &&
SubVec.getOpcode() == X86ISD::VBROADCAST_LOAD) {
auto *MemIntr = cast<MemIntrinsicSDNode>(SubVec);
return getBROADCAST_LOAD(X86ISD::VBROADCAST_LOAD, dl, OpVT,
MemIntr->getMemoryVT(), MemIntr, 0, DAG);
SDVTList Tys = DAG.getVTList(OpVT, MVT::Other);
SDValue Ops[] = { MemIntr->getChain(), MemIntr->getBasePtr() };
SDValue BcastLd =
DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops,
MemIntr->getMemoryVT(),
MemIntr->getMemOperand());
DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
return BcastLd;
}
// If we're splatting the lower half subvector of a full vector load into the