forked from OSchip/llvm-project
[X86] Add a helper function to pull some repeated code out of combineGatherScatter. NFC
This commit is contained in:
parent
13a97305ba
commit
89ab5c69c8
|
@ -44661,13 +44661,33 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
|
||||||
return SDValue();
|
return SDValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
|
||||||
|
SDValue Index, SDValue Base, SDValue Scale,
|
||||||
|
SelectionDAG &DAG) {
|
||||||
|
SDLoc DL(GorS);
|
||||||
|
|
||||||
|
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
|
||||||
|
SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
|
||||||
|
Gather->getMask(), Base, Index, Scale } ;
|
||||||
|
return DAG.getMaskedGather(Gather->getVTList(),
|
||||||
|
Gather->getMemoryVT(), DL, Ops,
|
||||||
|
Gather->getMemOperand(),
|
||||||
|
Gather->getIndexType());
|
||||||
|
}
|
||||||
|
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
|
||||||
|
SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
|
||||||
|
Scatter->getMask(), Base, Index, Scale };
|
||||||
|
return DAG.getMaskedScatter(Scatter->getVTList(),
|
||||||
|
Scatter->getMemoryVT(), DL,
|
||||||
|
Ops, Scatter->getMemOperand(),
|
||||||
|
Scatter->getIndexType());
|
||||||
|
}
|
||||||
|
|
||||||
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
|
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
|
||||||
TargetLowering::DAGCombinerInfo &DCI) {
|
TargetLowering::DAGCombinerInfo &DCI) {
|
||||||
SDLoc DL(N);
|
SDLoc DL(N);
|
||||||
auto *GorS = cast<MaskedGatherScatterSDNode>(N);
|
auto *GorS = cast<MaskedGatherScatterSDNode>(N);
|
||||||
SDValue Chain = GorS->getChain();
|
|
||||||
SDValue Index = GorS->getIndex();
|
SDValue Index = GorS->getIndex();
|
||||||
SDValue Mask = GorS->getMask();
|
|
||||||
SDValue Base = GorS->getBasePtr();
|
SDValue Base = GorS->getBasePtr();
|
||||||
SDValue Scale = GorS->getScale();
|
SDValue Scale = GorS->getScale();
|
||||||
|
|
||||||
|
@ -44687,21 +44707,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
|
||||||
unsigned NumElts = Index.getValueType().getVectorNumElements();
|
unsigned NumElts = Index.getValueType().getVectorNumElements();
|
||||||
EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
|
EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
|
||||||
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
|
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
|
||||||
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
|
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
|
||||||
SDValue Ops[] = { Chain, Gather->getPassThru(),
|
|
||||||
Mask, Base, Index, Scale } ;
|
|
||||||
return DAG.getMaskedGather(Gather->getVTList(),
|
|
||||||
Gather->getMemoryVT(), DL, Ops,
|
|
||||||
Gather->getMemOperand(),
|
|
||||||
Gather->getIndexType());
|
|
||||||
}
|
|
||||||
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
|
|
||||||
SDValue Ops[] = { Chain, Scatter->getValue(),
|
|
||||||
Mask, Base, Index, Scale };
|
|
||||||
return DAG.getMaskedScatter(Scatter->getVTList(),
|
|
||||||
Scatter->getMemoryVT(), DL,
|
|
||||||
Ops, Scatter->getMemOperand(),
|
|
||||||
Scatter->getIndexType());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44716,21 +44722,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
|
||||||
unsigned NumElts = Index.getValueType().getVectorNumElements();
|
unsigned NumElts = Index.getValueType().getVectorNumElements();
|
||||||
EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
|
EVT NewVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts);
|
||||||
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
|
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
|
||||||
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
|
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
|
||||||
SDValue Ops[] = { Chain, Gather->getPassThru(),
|
|
||||||
Mask, Base, Index, Scale } ;
|
|
||||||
return DAG.getMaskedGather(Gather->getVTList(),
|
|
||||||
Gather->getMemoryVT(), DL, Ops,
|
|
||||||
Gather->getMemOperand(),
|
|
||||||
Gather->getIndexType());
|
|
||||||
}
|
|
||||||
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
|
|
||||||
SDValue Ops[] = { Chain, Scatter->getValue(),
|
|
||||||
Mask, Base, Index, Scale };
|
|
||||||
return DAG.getMaskedScatter(Scatter->getVTList(),
|
|
||||||
Scatter->getMemoryVT(), DL,
|
|
||||||
Ops, Scatter->getMemOperand(),
|
|
||||||
Scatter->getIndexType());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44743,25 +44735,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
|
||||||
EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
|
EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
|
||||||
Index.getValueType().getVectorNumElements());
|
Index.getValueType().getVectorNumElements());
|
||||||
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
|
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
|
||||||
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
|
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
|
||||||
SDValue Ops[] = { Chain, Gather->getPassThru(),
|
|
||||||
Mask, Base, Index, Scale } ;
|
|
||||||
return DAG.getMaskedGather(Gather->getVTList(),
|
|
||||||
Gather->getMemoryVT(), DL, Ops,
|
|
||||||
Gather->getMemOperand(),
|
|
||||||
Gather->getIndexType());
|
|
||||||
}
|
|
||||||
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
|
|
||||||
SDValue Ops[] = { Chain, Scatter->getValue(),
|
|
||||||
Mask, Base, Index, Scale };
|
|
||||||
return DAG.getMaskedScatter(Scatter->getVTList(),
|
|
||||||
Scatter->getMemoryVT(), DL,
|
|
||||||
Ops, Scatter->getMemOperand(),
|
|
||||||
Scatter->getIndexType());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// With vector masks we only demand the upper bit of the mask.
|
// With vector masks we only demand the upper bit of the mask.
|
||||||
|
SDValue Mask = GorS->getMask();
|
||||||
if (Mask.getScalarValueSizeInBits() != 1) {
|
if (Mask.getScalarValueSizeInBits() != 1) {
|
||||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||||
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
|
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
|
||||||
|
|
Loading…
Reference in New Issue