[SelectionDAG] Rename memory VT argument for getMaskedGather/getMaskedScatter from VT to MemVT.

Use getMemoryVT() in MGATHER/MSCATTER DAG combines instead of
using the passthru or store value VT for this argument.
This commit is contained in:
Craig Topper 2021-07-02 17:36:27 -07:00
parent 252a1eecc0
commit af331e8284
3 changed files with 16 additions and 16 deletions

View File

@ -1316,10 +1316,10 @@ public:
SDValue getIndexedMaskedStore(SDValue OrigStore, const SDLoc &dl, SDValue getIndexedMaskedStore(SDValue OrigStore, const SDLoc &dl,
SDValue Base, SDValue Offset, SDValue Base, SDValue Offset,
ISD::MemIndexedMode AM); ISD::MemIndexedMode AM);
SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, SDValue getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO, ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType, ISD::LoadExtType ExtTy); ISD::MemIndexType IndexType, ISD::LoadExtType ExtTy);
SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, SDValue getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO, ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType, ISD::MemIndexType IndexType,
bool IsTruncating = false); bool IsTruncating = false);

View File

@ -9736,14 +9736,14 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
if (refineUniformBase(BasePtr, Index, DAG)) { if (refineUniformBase(BasePtr, Index, DAG)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter( return DAG.getMaskedScatter(
DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops, DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
} }
if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) { if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter( return DAG.getMaskedScatter(
DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops, DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
} }
@ -9792,7 +9792,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
if (refineUniformBase(BasePtr, Index, DAG)) { if (refineUniformBase(BasePtr, Index, DAG)) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops, MGT->getMemoryVT(), DL, Ops,
MGT->getMemOperand(), MGT->getIndexType(), MGT->getMemOperand(), MGT->getIndexType(),
MGT->getExtensionType()); MGT->getExtensionType());
} }
@ -9800,7 +9800,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) { if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops, MGT->getMemoryVT(), DL, Ops,
MGT->getMemOperand(), MGT->getIndexType(), MGT->getMemOperand(), MGT->getIndexType(),
MGT->getExtensionType()); MGT->getExtensionType());
} }

View File

@ -7644,7 +7644,7 @@ SDValue SelectionDAG::getIndexedMaskedStore(SDValue OrigStore, const SDLoc &dl,
AM, ST->isTruncatingStore(), ST->isCompressingStore()); AM, ST->isTruncatingStore(), ST->isCompressingStore());
} }
SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops, ArrayRef<SDValue> Ops,
MachineMemOperand *MMO, MachineMemOperand *MMO,
ISD::MemIndexType IndexType, ISD::MemIndexType IndexType,
@ -7653,9 +7653,9 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
FoldingSetNodeID ID; FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
ID.AddInteger(VT.getRawBits()); ID.AddInteger(MemVT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedGatherSDNode>( ID.AddInteger(getSyntheticNodeSubclassData<MaskedGatherSDNode>(
dl.getIROrder(), VTs, VT, MMO, IndexType, ExtTy)); dl.getIROrder(), VTs, MemVT, MMO, IndexType, ExtTy));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr; void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
@ -7663,9 +7663,9 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
return SDValue(E, 0); return SDValue(E, 0);
} }
IndexType = TLI->getCanonicalIndexType(IndexType, VT, Ops[4]); IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]);
auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(), auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(),
VTs, VT, MMO, IndexType, ExtTy); VTs, MemVT, MMO, IndexType, ExtTy);
createOperands(N, Ops); createOperands(N, Ops);
assert(N->getPassThru().getValueType() == N->getValueType(0) && assert(N->getPassThru().getValueType() == N->getValueType(0) &&
@ -7691,7 +7691,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
return V; return V;
} }
SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl,
ArrayRef<SDValue> Ops, ArrayRef<SDValue> Ops,
MachineMemOperand *MMO, MachineMemOperand *MMO,
ISD::MemIndexType IndexType, ISD::MemIndexType IndexType,
@ -7700,9 +7700,9 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
FoldingSetNodeID ID; FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops); AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
ID.AddInteger(VT.getRawBits()); ID.AddInteger(MemVT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>( ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>(
dl.getIROrder(), VTs, VT, MMO, IndexType, IsTrunc)); dl.getIROrder(), VTs, MemVT, MMO, IndexType, IsTrunc));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr; void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
@ -7710,9 +7710,9 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
return SDValue(E, 0); return SDValue(E, 0);
} }
IndexType = TLI->getCanonicalIndexType(IndexType, VT, Ops[4]); IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]);
auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(), auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(),
VTs, VT, MMO, IndexType, IsTrunc); VTs, MemVT, MMO, IndexType, IsTrunc);
createOperands(N, Ops); createOperands(N, Ops);
assert(N->getMask().getValueType().getVectorElementCount() == assert(N->getMask().getValueType().getVectorElementCount() ==