forked from OSchip/llvm-project
[RISCV] Support scalable-vector masked scatter operations
This patch adds support for masked scatter intrinsics on scalable vector types. It is mostly an extension of the earlier masked gather support introduced in D96263, since the addressing mode legalization is the same. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D96486
This commit is contained in:
parent
251fe986af
commit
3495031a39
|
@ -475,6 +475,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
|
||||
|
||||
setOperationAction(ISD::MGATHER, VT, Custom);
|
||||
setOperationAction(ISD::MSCATTER, VT, Custom);
|
||||
|
||||
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
|
||||
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
|
||||
|
@ -517,6 +518,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
setOperationAction(ISD::FCOPYSIGN, VT, Legal);
|
||||
|
||||
setOperationAction(ISD::MGATHER, VT, Custom);
|
||||
setOperationAction(ISD::MSCATTER, VT, Custom);
|
||||
|
||||
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
|
||||
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
|
||||
|
@ -695,6 +697,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
if (Subtarget.hasStdExtV()) {
|
||||
setTargetDAGCombine(ISD::FCOPYSIGN);
|
||||
setTargetDAGCombine(ISD::MGATHER);
|
||||
setTargetDAGCombine(ISD::MSCATTER);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1719,7 +1722,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
|
|||
case ISD::FCOPYSIGN:
|
||||
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
|
||||
case ISD::MGATHER:
|
||||
return lowerMGATHER(Op, DAG);
|
||||
case ISD::MSCATTER:
|
||||
return lowerMGATHERMSCATTER(Op, DAG);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3467,39 +3471,50 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
|
|||
// "unsigned unscaled" addressing mode; indices are implicitly zero-extended or
|
||||
// truncated to XLEN and are treated as byte offsets. Any signed or scaled
|
||||
// indexing is extended to the XLEN value type and scaled accordingly.
|
||||
SDValue RISCVTargetLowering::lowerMGATHER(SDValue Op, SelectionDAG &DAG) const {
|
||||
MaskedGatherSDNode *N = cast<MaskedGatherSDNode>(Op.getNode());
|
||||
SDValue RISCVTargetLowering::lowerMGATHERMSCATTER(SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
auto *N = cast<MaskedGatherScatterSDNode>(Op.getNode());
|
||||
SDLoc DL(Op);
|
||||
MVT VT = Op.getSimpleValueType();
|
||||
SDValue Index = N->getIndex();
|
||||
SDValue Mask = N->getMask();
|
||||
SDValue PassThru = N->getPassThru();
|
||||
|
||||
MVT XLenVT = Subtarget.getXLenVT();
|
||||
assert(N->getBasePtr().getSimpleValueType() == XLenVT &&
|
||||
"Unexpected pointer type");
|
||||
// Targets have to explicitly opt-in for extending vector loads.
|
||||
assert(N->getExtensionType() == ISD::NON_EXTLOAD &&
|
||||
// Targets have to explicitly opt-in for extending vector loads and
|
||||
// truncating vector stores.
|
||||
const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
|
||||
const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
|
||||
assert((!MGN || MGN->getExtensionType() == ISD::NON_EXTLOAD) &&
|
||||
"Unexpected extending MGATHER");
|
||||
assert((!MSN || !MSN->isTruncatingStore()) &&
|
||||
"Unexpected extending MSCATTER");
|
||||
|
||||
SDValue VL = getDefaultVLOps(VT, VT, DL, DAG, Subtarget).second;
|
||||
// If the mask is known to be all ones, optimize to an unmasked intrinsic;
|
||||
// the selection of the masked intrinsics doesn't do this for us.
|
||||
if (ISD::isConstantSplatVectorAllOnes(Mask.getNode())) {
|
||||
SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vloxei, DL, XLenVT);
|
||||
SDValue Ops[] = {N->getChain(), IntID, N->getBasePtr(), Index, VL};
|
||||
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL,
|
||||
DAG.getVTList(VT, MVT::Other), Ops,
|
||||
N->getMemoryVT(), N->getMemOperand());
|
||||
}
|
||||
unsigned IntID = 0;
|
||||
MVT IndexVT = Index.getSimpleValueType();
|
||||
SDValue VL = getDefaultVLOps(IndexVT, IndexVT, DL, DAG, Subtarget).second;
|
||||
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
|
||||
|
||||
SDValue IntID =
|
||||
DAG.getTargetConstant(Intrinsic::riscv_vloxei_mask, DL, XLenVT);
|
||||
SDValue Ops[] = {N->getChain(), IntID, PassThru, N->getBasePtr(),
|
||||
Index, Mask, VL};
|
||||
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL,
|
||||
DAG.getVTList(VT, MVT::Other), Ops,
|
||||
N->getMemoryVT(), N->getMemOperand());
|
||||
if (IsUnmasked)
|
||||
IntID = MGN ? Intrinsic::riscv_vloxei : Intrinsic::riscv_vsoxei;
|
||||
else
|
||||
IntID = MGN ? Intrinsic::riscv_vloxei_mask : Intrinsic::riscv_vsoxei_mask;
|
||||
SmallVector<SDValue, 8> Ops{N->getChain(),
|
||||
DAG.getTargetConstant(IntID, DL, XLenVT)};
|
||||
if (MSN)
|
||||
Ops.push_back(MSN->getValue());
|
||||
else if (!IsUnmasked)
|
||||
Ops.push_back(MGN->getPassThru());
|
||||
Ops.push_back(N->getBasePtr());
|
||||
Ops.push_back(Index);
|
||||
if (!IsUnmasked)
|
||||
Ops.push_back(Mask);
|
||||
Ops.push_back(VL);
|
||||
return DAG.getMemIntrinsicNode(
|
||||
MGN ? ISD::INTRINSIC_W_CHAIN : ISD::INTRINSIC_VOID, DL, N->getVTList(),
|
||||
Ops, N->getMemoryVT(), N->getMemOperand());
|
||||
}
|
||||
|
||||
// Returns the opcode of the target-specific SDNode that implements the 32-bit
|
||||
|
@ -4519,18 +4534,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
|
|||
return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0),
|
||||
DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound));
|
||||
}
|
||||
case ISD::MGATHER: {
|
||||
case ISD::MGATHER:
|
||||
case ISD::MSCATTER: {
|
||||
if (!DCI.isBeforeLegalize())
|
||||
break;
|
||||
MaskedGatherSDNode *MGN = cast<MaskedGatherSDNode>(N);
|
||||
SDValue Index = MGN->getIndex();
|
||||
MaskedGatherScatterSDNode *MGSN = cast<MaskedGatherScatterSDNode>(N);
|
||||
SDValue Index = MGSN->getIndex();
|
||||
EVT IndexVT = Index.getValueType();
|
||||
MVT XLenVT = Subtarget.getXLenVT();
|
||||
// RISCV indexed loads only support the "unsigned unscaled" addressing
|
||||
// mode, so anything else must be manually legalized.
|
||||
bool NeedsIdxLegalization =
|
||||
MGN->isIndexScaled() ||
|
||||
(MGN->isIndexSigned() && IndexVT.getVectorElementType().bitsLT(XLenVT));
|
||||
bool NeedsIdxLegalization = MGSN->isIndexScaled() ||
|
||||
(MGSN->isIndexSigned() &&
|
||||
IndexVT.getVectorElementType().bitsLT(XLenVT));
|
||||
if (!NeedsIdxLegalization)
|
||||
break;
|
||||
|
||||
|
@ -4541,13 +4557,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
|
|||
// LLVM's legalization take care of the splitting.
|
||||
if (IndexVT.getVectorElementType().bitsLT(XLenVT)) {
|
||||
IndexVT = IndexVT.changeVectorElementType(XLenVT);
|
||||
Index = DAG.getNode(MGN->isIndexSigned() ? ISD::SIGN_EXTEND
|
||||
: ISD::ZERO_EXTEND,
|
||||
Index = DAG.getNode(MGSN->isIndexSigned() ? ISD::SIGN_EXTEND
|
||||
: ISD::ZERO_EXTEND,
|
||||
DL, IndexVT, Index);
|
||||
}
|
||||
|
||||
unsigned Scale = N->getConstantOperandVal(5);
|
||||
if (MGN->isIndexScaled() && Scale != 1) {
|
||||
if (MGSN->isIndexScaled() && Scale != 1) {
|
||||
// Manually scale the indices by the element size.
|
||||
// TODO: Sanitize the scale operand here?
|
||||
assert(isPowerOf2_32(Scale) && "Expecting power-of-two types");
|
||||
|
@ -4556,11 +4572,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
|
|||
}
|
||||
|
||||
ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_UNSCALED;
|
||||
return DAG.getMaskedGather(
|
||||
N->getVTList(), MGN->getMemoryVT(), DL,
|
||||
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(), MGN->getBasePtr(),
|
||||
Index, MGN->getScale()},
|
||||
MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
|
||||
if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N)) {
|
||||
return DAG.getMaskedGather(
|
||||
N->getVTList(), MGSN->getMemoryVT(), DL,
|
||||
{MGSN->getChain(), MGN->getPassThru(), MGSN->getMask(),
|
||||
MGSN->getBasePtr(), Index, MGN->getScale()},
|
||||
MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
|
||||
}
|
||||
const auto *MSN = cast<MaskedScatterSDNode>(N);
|
||||
return DAG.getMaskedScatter(
|
||||
N->getVTList(), MGSN->getMemoryVT(), DL,
|
||||
{MGSN->getChain(), MSN->getValue(), MGSN->getMask(), MGSN->getBasePtr(),
|
||||
Index, MGSN->getScale()},
|
||||
MGSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -477,7 +477,7 @@ private:
|
|||
SDValue lowerABS(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,
|
||||
SelectionDAG &DAG) const;
|
||||
SDValue lowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorMaskedLoadToRVV(SDValue Op,
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue