[X86] Remove mask output from X86 gather/scatter ISD opcodes.

Instead add it when we make the machine nodes during instruction
selections.

This makes this ISD node closer to ISD::MGATHER. Trying to see
if we remove the X86 specific ones.
This commit is contained in:
Craig Topper 2020-02-24 22:54:20 -08:00
parent 6a0c066c61
commit 9238dfb4d8
2 changed files with 26 additions and 24 deletions

View File

@ -5474,7 +5474,8 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
SDValue PassThru = Mgt->getPassThru();
SDValue Chain = Mgt->getChain();
SDVTList VTs = Mgt->getVTList();
// Gather instructions have a mask output not in the ISD node.
SDVTList VTs = CurDAG->getVTList(ValueVT, MaskVT, MVT::Other);
MachineSDNode *NewNode;
if (AVX512Gather) {
@ -5487,7 +5488,9 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
NewNode = CurDAG->getMachineNode(Opc, SDLoc(dl), VTs, Ops);
}
CurDAG->setNodeMemRefs(NewNode, {Mgt->getMemOperand()});
ReplaceNode(Node, NewNode);
ReplaceUses(SDValue(Node, 0), SDValue(NewNode, 0));
ReplaceUses(SDValue(Node, 1), SDValue(NewNode, 2));
CurDAG->RemoveDeadNode(Node);
return;
}
case X86ISD::MSCATTER: {
@ -5544,12 +5547,14 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
SDValue Mask = Sc->getMask();
SDValue Chain = Sc->getChain();
SDVTList VTs = Sc->getVTList();
// Scatter instructions have a mask output not in the ISD node.
SDVTList VTs = CurDAG->getVTList(Mask.getValueType(), MVT::Other);
SDValue Ops[] = {Base, Scale, Index, Disp, Segment, Mask, Value, Chain};
MachineSDNode *NewNode = CurDAG->getMachineNode(Opc, SDLoc(dl), VTs, Ops);
CurDAG->setNodeMemRefs(NewNode, {Sc->getMemOperand()});
ReplaceNode(Node, NewNode);
ReplaceUses(SDValue(Node, 0), SDValue(NewNode, 1));
CurDAG->RemoveDeadNode(Node);
return;
}
}

View File

@ -24771,7 +24771,7 @@ static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl,
TLI.getPointerTy(DAG.getDataLayout()));
EVT MaskVT = Mask.getValueType().changeVectorElementTypeToInteger();
SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Other);
// If source is undef or we know it won't be used, use a zero vector
// to break register dependency.
// TODO: use undef instead and let BreakFalseDeps deal with it?
@ -24787,7 +24787,7 @@ static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
SDValue Res =
DAG.getMemIntrinsicNode(X86ISD::MGATHER, dl, VTs, Ops,
MemIntr->getMemoryVT(), MemIntr->getMemOperand());
return DAG.getMergeValues({ Res, Res.getValue(2) }, dl);
return DAG.getMergeValues({Res, Res.getValue(1)}, dl);
}
static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG,
@ -24812,7 +24812,7 @@ static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG,
if (Mask.getValueType() != MaskVT)
Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Other);
// If source is undef or we know it won't be used, use a zero vector
// to break register dependency.
// TODO: use undef instead and let BreakFalseDeps deal with it?
@ -24825,7 +24825,7 @@ static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG,
SDValue Res =
DAG.getMemIntrinsicNode(X86ISD::MGATHER, dl, VTs, Ops,
MemIntr->getMemoryVT(), MemIntr->getMemOperand());
return DAG.getMergeValues({ Res, Res.getValue(2) }, dl);
return DAG.getMergeValues({Res, Res.getValue(1)}, dl);
}
static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
@ -24851,12 +24851,12 @@ static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op);
SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
SDVTList VTs = DAG.getVTList(MVT::Other);
SDValue Ops[] = {Chain, Src, Mask, Base, Index, Scale};
SDValue Res =
DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
MemIntr->getMemoryVT(), MemIntr->getMemOperand());
return Res.getValue(1);
return Res;
}
static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
@ -28523,11 +28523,10 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
EVT WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, WideVT, Src, DAG.getUNDEF(VT));
SDVTList VTs = DAG.getVTList(MVT::v2i1, MVT::Other);
SDVTList VTs = DAG.getVTList(MVT::Other);
SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
SDValue NewScatter = DAG.getMemIntrinsicNode(
X86ISD::MSCATTER, dl, VTs, Ops, N->getMemoryVT(), N->getMemOperand());
return SDValue(NewScatter.getNode(), 1);
return DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
N->getMemoryVT(), N->getMemOperand());
}
return SDValue();
}
@ -28558,11 +28557,10 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
Mask = ExtendToType(Mask, MaskVT, DAG, true);
}
SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
SDVTList VTs = DAG.getVTList(MVT::Other);
SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
SDValue NewScatter = DAG.getMemIntrinsicNode(
X86ISD::MSCATTER, dl, VTs, Ops, N->getMemoryVT(), N->getMemOperand());
return SDValue(NewScatter.getNode(), 1);
return DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
N->getMemoryVT(), N->getMemOperand());
}
static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget,
@ -28717,11 +28715,11 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
SDValue Ops[] = { N->getChain(), PassThru, Mask, N->getBasePtr(), Index,
N->getScale() };
SDValue NewGather = DAG.getMemIntrinsicNode(
X86ISD::MGATHER, dl, DAG.getVTList(VT, MaskVT, MVT::Other), Ops,
N->getMemoryVT(), N->getMemOperand());
X86ISD::MGATHER, dl, DAG.getVTList(VT, MVT::Other), Ops, N->getMemoryVT(),
N->getMemOperand());
SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OrigVT,
NewGather, DAG.getIntPtrConstant(0, dl));
return DAG.getMergeValues({Extract, NewGather.getValue(2)}, dl);
return DAG.getMergeValues({Extract, NewGather.getValue(1)}, dl);
}
static SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) {
@ -29833,11 +29831,10 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
SDValue Ops[] = { Gather->getChain(), PassThru, Mask,
Gather->getBasePtr(), Index, Gather->getScale() };
SDValue Res = DAG.getMemIntrinsicNode(
X86ISD::MGATHER, dl,
DAG.getVTList(WideVT, Mask.getValueType(), MVT::Other), Ops,
X86ISD::MGATHER, dl, DAG.getVTList(WideVT, MVT::Other), Ops,
Gather->getMemoryVT(), Gather->getMemOperand());
Results.push_back(Res);
Results.push_back(Res.getValue(2));
Results.push_back(Res.getValue(1));
return;
}
return;