[RISCV] Support scalable-vector masked scatter operations

This patch adds support for masked scatter intrinsics on scalable vector
types. It is mostly an extension of the earlier masked gather support
introduced in D96263, since the addressing mode legalization is the
same.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D96486
This commit is contained in:
Fraser Cormack 2021-02-08 15:33:23 +00:00
parent 251fe986af
commit 3495031a39
3 changed files with 1915 additions and 37 deletions

View File

@ -475,6 +475,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
@ -517,6 +518,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOPYSIGN, VT, Legal);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
@ -695,6 +697,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtV()) {
setTargetDAGCombine(ISD::FCOPYSIGN);
setTargetDAGCombine(ISD::MGATHER);
setTargetDAGCombine(ISD::MSCATTER);
}
}
@ -1719,7 +1722,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FCOPYSIGN:
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
case ISD::MGATHER:
return lowerMGATHER(Op, DAG);
case ISD::MSCATTER:
return lowerMGATHERMSCATTER(Op, DAG);
}
}
@ -3467,39 +3471,50 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
// "unsigned unscaled" addressing mode; indices are implicitly zero-extended or
// truncated to XLEN and are treated as byte offsets. Any signed or scaled
// indexing is extended to the XLEN value type and scaled accordingly.
SDValue RISCVTargetLowering::lowerMGATHER(SDValue Op, SelectionDAG &DAG) const {
MaskedGatherSDNode *N = cast<MaskedGatherSDNode>(Op.getNode());
SDValue RISCVTargetLowering::lowerMGATHERMSCATTER(SDValue Op,
SelectionDAG &DAG) const {
auto *N = cast<MaskedGatherScatterSDNode>(Op.getNode());
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
SDValue Index = N->getIndex();
SDValue Mask = N->getMask();
SDValue PassThru = N->getPassThru();
MVT XLenVT = Subtarget.getXLenVT();
assert(N->getBasePtr().getSimpleValueType() == XLenVT &&
"Unexpected pointer type");
// Targets have to explicitly opt-in for extending vector loads.
assert(N->getExtensionType() == ISD::NON_EXTLOAD &&
// Targets have to explicitly opt-in for extending vector loads and
// truncating vector stores.
const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
assert((!MGN || MGN->getExtensionType() == ISD::NON_EXTLOAD) &&
"Unexpected extending MGATHER");
assert((!MSN || !MSN->isTruncatingStore()) &&
"Unexpected extending MSCATTER");
SDValue VL = getDefaultVLOps(VT, VT, DL, DAG, Subtarget).second;
// If the mask is known to be all ones, optimize to an unmasked intrinsic;
// the selection of the masked intrinsics doesn't do this for us.
if (ISD::isConstantSplatVectorAllOnes(Mask.getNode())) {
SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vloxei, DL, XLenVT);
SDValue Ops[] = {N->getChain(), IntID, N->getBasePtr(), Index, VL};
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL,
DAG.getVTList(VT, MVT::Other), Ops,
N->getMemoryVT(), N->getMemOperand());
}
unsigned IntID = 0;
MVT IndexVT = Index.getSimpleValueType();
SDValue VL = getDefaultVLOps(IndexVT, IndexVT, DL, DAG, Subtarget).second;
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
SDValue IntID =
DAG.getTargetConstant(Intrinsic::riscv_vloxei_mask, DL, XLenVT);
SDValue Ops[] = {N->getChain(), IntID, PassThru, N->getBasePtr(),
Index, Mask, VL};
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL,
DAG.getVTList(VT, MVT::Other), Ops,
N->getMemoryVT(), N->getMemOperand());
if (IsUnmasked)
IntID = MGN ? Intrinsic::riscv_vloxei : Intrinsic::riscv_vsoxei;
else
IntID = MGN ? Intrinsic::riscv_vloxei_mask : Intrinsic::riscv_vsoxei_mask;
SmallVector<SDValue, 8> Ops{N->getChain(),
DAG.getTargetConstant(IntID, DL, XLenVT)};
if (MSN)
Ops.push_back(MSN->getValue());
else if (!IsUnmasked)
Ops.push_back(MGN->getPassThru());
Ops.push_back(N->getBasePtr());
Ops.push_back(Index);
if (!IsUnmasked)
Ops.push_back(Mask);
Ops.push_back(VL);
return DAG.getMemIntrinsicNode(
MGN ? ISD::INTRINSIC_W_CHAIN : ISD::INTRINSIC_VOID, DL, N->getVTList(),
Ops, N->getMemoryVT(), N->getMemOperand());
}
// Returns the opcode of the target-specific SDNode that implements the 32-bit
@ -4519,18 +4534,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0),
DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound));
}
case ISD::MGATHER: {
case ISD::MGATHER:
case ISD::MSCATTER: {
if (!DCI.isBeforeLegalize())
break;
MaskedGatherSDNode *MGN = cast<MaskedGatherSDNode>(N);
SDValue Index = MGN->getIndex();
MaskedGatherScatterSDNode *MGSN = cast<MaskedGatherScatterSDNode>(N);
SDValue Index = MGSN->getIndex();
EVT IndexVT = Index.getValueType();
MVT XLenVT = Subtarget.getXLenVT();
// RISCV indexed loads only support the "unsigned unscaled" addressing
// mode, so anything else must be manually legalized.
bool NeedsIdxLegalization =
MGN->isIndexScaled() ||
(MGN->isIndexSigned() && IndexVT.getVectorElementType().bitsLT(XLenVT));
bool NeedsIdxLegalization = MGSN->isIndexScaled() ||
(MGSN->isIndexSigned() &&
IndexVT.getVectorElementType().bitsLT(XLenVT));
if (!NeedsIdxLegalization)
break;
@ -4541,13 +4557,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
// LLVM's legalization take care of the splitting.
if (IndexVT.getVectorElementType().bitsLT(XLenVT)) {
IndexVT = IndexVT.changeVectorElementType(XLenVT);
Index = DAG.getNode(MGN->isIndexSigned() ? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND,
Index = DAG.getNode(MGSN->isIndexSigned() ? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND,
DL, IndexVT, Index);
}
unsigned Scale = N->getConstantOperandVal(5);
if (MGN->isIndexScaled() && Scale != 1) {
if (MGSN->isIndexScaled() && Scale != 1) {
// Manually scale the indices by the element size.
// TODO: Sanitize the scale operand here?
assert(isPowerOf2_32(Scale) && "Expecting power-of-two types");
@ -4556,11 +4572,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_UNSCALED;
return DAG.getMaskedGather(
N->getVTList(), MGN->getMemoryVT(), DL,
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(), MGN->getBasePtr(),
Index, MGN->getScale()},
MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N)) {
return DAG.getMaskedGather(
N->getVTList(), MGSN->getMemoryVT(), DL,
{MGSN->getChain(), MGN->getPassThru(), MGSN->getMask(),
MGSN->getBasePtr(), Index, MGN->getScale()},
MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
}
const auto *MSN = cast<MaskedScatterSDNode>(N);
return DAG.getMaskedScatter(
N->getVTList(), MGSN->getMemoryVT(), DL,
{MGSN->getChain(), MSN->getValue(), MGSN->getMask(), MGSN->getBasePtr(),
Index, MGSN->getScale()},
MGSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore());
}
}

View File

@ -477,7 +477,7 @@ private:
SDValue lowerABS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorMaskedLoadToRVV(SDValue Op,

File diff suppressed because it is too large Load Diff