[SelectionDAG] Rename memory VT argument for getMaskedGather/getMaskedScatter from VT to MemVT.

Use getMemoryVT() in MGATHER/MSCATTER DAG combines instead of
using the passthru or store value VT for this argument.
This commit is contained in:
Craig Topper 2021-07-02 17:36:27 -07:00
parent 252a1eecc0
commit af331e8284
3 changed files with 16 additions and 16 deletions

View File

@ -1316,10 +1316,10 @@ public:
SDValue getIndexedMaskedStore(SDValue OrigStore, const SDLoc &dl,
SDValue Base, SDValue Offset,
ISD::MemIndexedMode AM);
SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
SDValue getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType, ISD::LoadExtType ExtTy);
SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
SDValue getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType,
bool IsTruncating = false);

View File

@ -9736,14 +9736,14 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
if (refineUniformBase(BasePtr, Index, DAG)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(
DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops,
DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
}
if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(
DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops,
DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
}
@ -9792,7 +9792,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
if (refineUniformBase(BasePtr, Index, DAG)) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops,
MGT->getMemoryVT(), DL, Ops,
MGT->getMemOperand(), MGT->getIndexType(),
MGT->getExtensionType());
}
@ -9800,7 +9800,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops,
MGT->getMemoryVT(), DL, Ops,
MGT->getMemOperand(), MGT->getIndexType(),
MGT->getExtensionType());
}

View File

@ -7644,7 +7644,7 @@ SDValue SelectionDAG::getIndexedMaskedStore(SDValue OrigStore, const SDLoc &dl,
AM, ST->isTruncatingStore(), ST->isCompressingStore());
}
SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO,
ISD::MemIndexType IndexType,
@ -7653,9 +7653,9 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(MemVT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedGatherSDNode>(
dl.getIROrder(), VTs, VT, MMO, IndexType, ExtTy));
dl.getIROrder(), VTs, MemVT, MMO, IndexType, ExtTy));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
@ -7663,9 +7663,9 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
return SDValue(E, 0);
}
IndexType = TLI->getCanonicalIndexType(IndexType, VT, Ops[4]);
IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]);
auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(),
VTs, VT, MMO, IndexType, ExtTy);
VTs, MemVT, MMO, IndexType, ExtTy);
createOperands(N, Ops);
assert(N->getPassThru().getValueType() == N->getValueType(0) &&
@ -7691,7 +7691,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
return V;
}
SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO,
ISD::MemIndexType IndexType,
@ -7700,9 +7700,9 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(MemVT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>(
dl.getIROrder(), VTs, VT, MMO, IndexType, IsTrunc));
dl.getIROrder(), VTs, MemVT, MMO, IndexType, IsTrunc));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
@ -7710,9 +7710,9 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
return SDValue(E, 0);
}
IndexType = TLI->getCanonicalIndexType(IndexType, VT, Ops[4]);
IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]);
auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(),
VTs, VT, MMO, IndexType, IsTrunc);
VTs, MemVT, MMO, IndexType, IsTrunc);
createOperands(N, Ops);
assert(N->getMask().getValueType().getVectorElementCount() ==