[X86] Add a helper function to pull some repeated code out of combineGatherScatter. NFC

This commit is contained in:
Craig Topper 2020-02-18 11:10:01 -08:00
parent 13a97305ba
commit 89ab5c69c8
1 changed files with 26 additions and 47 deletions

View File

@ -44661,13 +44661,33 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
SDValue Index, SDValue Base, SDValue Scale,
SelectionDAG &DAG) {
SDLoc DL(GorS);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
Gather->getMask(), Base, Index, Scale } ;
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
Scatter->getMask(), Base, Index, Scale };
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
}
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
auto *GorS = cast<MaskedGatherScatterSDNode>(N);
SDValue Chain = GorS->getChain();
SDValue Index = GorS->getIndex();
SDValue Mask = GorS->getMask();
SDValue Base = GorS->getBasePtr();
SDValue Scale = GorS->getScale();
@ -44687,21 +44707,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
unsigned NumElts = Index.getValueType().getVectorNumElements();
EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Chain, Gather->getPassThru(),
Mask, Base, Index, Scale } ;
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Chain, Scatter->getValue(),
Mask, Base, Index, Scale };
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}
}
@ -44716,21 +44722,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
unsigned NumElts = Index.getValueType().getVectorNumElements();
EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Chain, Gather->getPassThru(),
Mask, Base, Index, Scale } ;
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Chain, Scatter->getValue(),
Mask, Base, Index, Scale };
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}
}
@ -44743,25 +44735,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
Index.getValueType().getVectorNumElements());
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Chain, Gather->getPassThru(),
Mask, Base, Index, Scale } ;
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Chain, Scatter->getValue(),
Mask, Base, Index, Scale };
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}
}
// With vector masks we only demand the upper bit of the mask.
SDValue Mask = GorS->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));