[SVE][CodeGen] Add the ExtensionType flag to MGATHER

Adds the ExtensionType flag, which reflects the LoadExtType of a MaskedGatherSDNode.
Also updated SelectionDAGDumper::print_details so that details of the gather
load (is signed, is scaled & extension type) are printed.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D91084
This commit is contained in:
Kerry McLaughlin 2020-12-09 10:49:43 +00:00
parent 0bf4a82a5a
commit 4519ff4b6f
9 changed files with 56 additions and 17 deletions

View File

@ -1362,7 +1362,7 @@ public:
ISD::MemIndexedMode AM);
SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType);
ISD::MemIndexType IndexType, ISD::LoadExtType ExtTy);
SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType,

View File

@ -512,6 +512,7 @@ BEGIN_TWO_BYTE_PACK()
class LoadSDNodeBitfields {
friend class LoadSDNode;
friend class MaskedLoadSDNode;
friend class MaskedGatherSDNode;
uint16_t : NumLSBaseSDNodeBits;
@ -2451,12 +2452,18 @@ public:
MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
EVT MemVT, MachineMemOperand *MMO,
ISD::MemIndexType IndexType)
ISD::MemIndexType IndexType, ISD::LoadExtType ETy)
: MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, VTs, MemVT, MMO,
IndexType) {}
IndexType) {
LoadSDNodeBits.ExtTy = ETy;
}
const SDValue &getPassThru() const { return getOperand(1); }
ISD::LoadExtType getExtensionType() const {
return ISD::LoadExtType(LoadSDNodeBits.ExtTy);
}
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MGATHER;
}

View File

@ -9499,14 +9499,16 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops,
MGT->getMemOperand(), MGT->getIndexType());
MGT->getMemOperand(), MGT->getIndexType(),
MGT->getExtensionType());
}
if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
PassThru.getValueType(), DL, Ops,
MGT->getMemOperand(), MGT->getIndexType());
MGT->getMemOperand(), MGT->getIndexType(),
MGT->getExtensionType());
}
return SDValue();

View File

@ -679,12 +679,17 @@ SDValue DAGTypeLegalizer::PromoteIntRes_MGATHER(MaskedGatherSDNode *N) {
assert(NVT == ExtPassThru.getValueType() &&
"Gather result type and the passThru argument type should be the same");
ISD::LoadExtType ExtType = N->getExtensionType();
if (ExtType == ISD::NON_EXTLOAD)
ExtType = ISD::EXTLOAD;
SDLoc dl(N);
SDValue Ops[] = {N->getChain(), ExtPassThru, N->getMask(), N->getBasePtr(),
N->getIndex(), N->getScale() };
SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other),
N->getMemoryVT(), dl, Ops,
N->getMemOperand(), N->getIndexType());
N->getMemOperand(), N->getIndexType(),
ExtType);
// Legalize the chain result - switch anything that used the old chain to
// use the new one.
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));

View File

@ -1748,6 +1748,7 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
SDValue Scale = MGT->getScale();
EVT MemoryVT = MGT->getMemoryVT();
Align Alignment = MGT->getOriginalAlign();
ISD::LoadExtType ExtType = MGT->getExtensionType();
// Split Mask operand
SDValue MaskLo, MaskHi;
@ -1783,11 +1784,11 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale};
Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, OpsLo,
MMO, MGT->getIndexType());
MMO, MGT->getIndexType(), ExtType);
SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale};
Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, OpsHi,
MMO, MGT->getIndexType());
MMO, MGT->getIndexType(), ExtType);
// Build a factor node to remember that this load is independent of the
// other one.
@ -2392,6 +2393,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
SDValue Mask = MGT->getMask();
SDValue PassThru = MGT->getPassThru();
Align Alignment = MGT->getOriginalAlign();
ISD::LoadExtType ExtType = MGT->getExtensionType();
SDValue MaskLo, MaskHi;
if (getTypeAction(Mask.getValueType()) == TargetLowering::TypeSplitVector)
@ -2423,11 +2425,11 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale};
SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl,
OpsLo, MMO, MGT->getIndexType());
OpsLo, MMO, MGT->getIndexType(), ExtType);
SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale};
SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl,
OpsHi, MMO, MGT->getIndexType());
OpsHi, MMO, MGT->getIndexType(), ExtType);
// Build a factor node to remember that this load is independent of the
// other one.
@ -3928,7 +3930,8 @@ SDValue DAGTypeLegalizer::WidenVecRes_MGATHER(MaskedGatherSDNode *N) {
Scale };
SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other),
N->getMemoryVT(), dl, Ops,
N->getMemOperand(), N->getIndexType());
N->getMemOperand(), N->getIndexType(),
N->getExtensionType());
// Legalize the chain result - switch anything that used the old chain to
// use the new one.
@ -4722,7 +4725,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MGATHER(SDNode *N, unsigned OpNo) {
SDValue Ops[] = {MG->getChain(), DataOp, Mask, MG->getBasePtr(), Index,
Scale};
SDValue Res = DAG.getMaskedGather(MG->getVTList(), MG->getMemoryVT(), dl, Ops,
MG->getMemOperand(), MG->getIndexType());
MG->getMemOperand(), MG->getIndexType(),
MG->getExtensionType());
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
ReplaceValueWith(SDValue(N, 0), Res.getValue(0));
return SDValue();

View File

@ -7295,14 +7295,15 @@ SDValue SelectionDAG::getIndexedMaskedStore(SDValue OrigStore, const SDLoc &dl,
SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO,
ISD::MemIndexType IndexType) {
ISD::MemIndexType IndexType,
ISD::LoadExtType ExtTy) {
assert(Ops.size() == 6 && "Incompatible number of operands");
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedGatherSDNode>(
dl.getIROrder(), VTs, VT, MMO, IndexType));
dl.getIROrder(), VTs, VT, MMO, IndexType, ExtTy));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
@ -7312,7 +7313,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
IndexType = TLI->getCanonicalIndexType(IndexType, VT, Ops[4]);
auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(),
VTs, VT, MMO, IndexType);
VTs, VT, MMO, IndexType, ExtTy);
createOperands(N, Ops);
assert(N->getPassThru().getValueType() == N->getValueType(0) &&

View File

@ -4421,7 +4421,7 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
}
SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
Ops, MMO, IndexType);
Ops, MMO, IndexType, ISD::NON_EXTLOAD);
PendingLoads.push_back(Gather.getValue(1));
setValue(&I, Gather);

View File

@ -743,6 +743,25 @@ void SDNode::print_details(raw_ostream &OS, const SelectionDAG *G) const {
if (MSt->isCompressingStore())
OS << ", compressing";
OS << ">";
} else if (const auto *MGather = dyn_cast<MaskedGatherSDNode>(this)) {
OS << "<";
printMemOperand(OS, *MGather->getMemOperand(), G);
bool doExt = true;
switch (MGather->getExtensionType()) {
default: doExt = false; break;
case ISD::EXTLOAD: OS << ", anyext"; break;
case ISD::SEXTLOAD: OS << ", sext"; break;
case ISD::ZEXTLOAD: OS << ", zext"; break;
}
if (doExt)
OS << " from " << MGather->getMemoryVT().getEVTString();
auto Signed = MGather->isIndexSigned() ? "signed" : "unsigned";
auto Scaled = MGather->isIndexScaled() ? "scaled" : "unscaled";
OS << ", " << Signed << " " << Scaled << " offset";
OS << ">";
} else if (const auto *MScatter = dyn_cast<MaskedScatterSDNode>(this)) {
OS << "<";

View File

@ -47438,7 +47438,8 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
Gather->getIndexType(),
Gather->getExtensionType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),