[RISCV] Add support for fixed vector masked gather/scatter.

I've split the gather/scatter custom handler to avoid complicating
it with even more differences between gather/scatter.

Tests are the scalable vector tests with the vscale removed and
dropped the tests that used vector.insert. We're probably not
as thorough on the splitting cases since we use 128 for VLEN here
but scalable vector use a known min size of 64.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D98991
This commit is contained in:
Craig Topper 2021-03-22 09:54:17 -07:00
parent 5184f69041
commit 294efcd6f7
5 changed files with 4422 additions and 41 deletions

View File

@ -585,6 +585,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::MSTORE, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::ADD, VT, Custom);
setOperationAction(ISD::MUL, VT, Custom);
setOperationAction(ISD::SUB, VT, Custom);
@ -656,6 +658,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::STORE, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::MSTORE, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::FADD, VT, Custom);
setOperationAction(ISD::FSUB, VT, Custom);
setOperationAction(ISD::FMUL, VT, Custom);
@ -1724,8 +1728,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FCOPYSIGN:
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
case ISD::MGATHER:
return lowerMGATHER(Op, DAG);
case ISD::MSCATTER:
return lowerMGATHERMSCATTER(Op, DAG);
return lowerMSCATTER(Op, DAG);
}
}
@ -3487,54 +3492,154 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
}
// Custom lower MGATHER to a legalized form for RVV. It will then be matched to
// a RVV indexed load. The RVV indexed load/store instructions only support the
// a RVV indexed load. The RVV indexed load instructions only support the
// "unsigned unscaled" addressing mode; indices are implicitly zero-extended or
// truncated to XLEN and are treated as byte offsets. Any signed or scaled
// indexing is extended to the XLEN value type and scaled accordingly.
SDValue RISCVTargetLowering::lowerMGATHERMSCATTER(SDValue Op,
SelectionDAG &DAG) const {
auto *N = cast<MaskedGatherScatterSDNode>(Op.getNode());
SDValue RISCVTargetLowering::lowerMGATHER(SDValue Op, SelectionDAG &DAG) const {
auto *MGN = cast<MaskedGatherSDNode>(Op.getNode());
SDLoc DL(Op);
SDValue Index = N->getIndex();
SDValue Mask = N->getMask();
SDValue Index = MGN->getIndex();
SDValue Mask = MGN->getMask();
SDValue PassThru = MGN->getPassThru();
MVT VT = Op.getSimpleValueType();
MVT IndexVT = Index.getSimpleValueType();
MVT XLenVT = Subtarget.getXLenVT();
assert(N->getBasePtr().getSimpleValueType() == XLenVT &&
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
"Unexpected VTs!");
assert(MGN->getBasePtr().getSimpleValueType() == XLenVT &&
"Unexpected pointer type");
// Targets have to explicitly opt-in for extending vector loads and
// truncating vector stores.
const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
assert((!MGN || MGN->getExtensionType() == ISD::NON_EXTLOAD) &&
// Targets have to explicitly opt-in for extending vector loads.
assert(MGN->getExtensionType() == ISD::NON_EXTLOAD &&
"Unexpected extending MGATHER");
assert((!MSN || !MSN->isTruncatingStore()) &&
"Unexpected extending MSCATTER");
// If the mask is known to be all ones, optimize to an unmasked intrinsic;
// the selection of the masked intrinsics doesn't do this for us.
unsigned IntID = 0;
MVT IndexVT = Index.getSimpleValueType();
SDValue VL = getDefaultVLOps(IndexVT, IndexVT, DL, DAG, Subtarget).second;
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
if (IsUnmasked)
IntID = MGN ? Intrinsic::riscv_vloxei : Intrinsic::riscv_vsoxei;
else
IntID = MGN ? Intrinsic::riscv_vloxei_mask : Intrinsic::riscv_vsoxei_mask;
SmallVector<SDValue, 8> Ops{N->getChain(),
SDValue VL;
MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
// We need to use the larger of the result and index type to determine the
// scalable type to use so we don't increase LMUL for any operand/result.
if (VT.bitsGE(IndexVT)) {
ContainerVT = getContainerForFixedLengthVector(VT);
IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
ContainerVT.getVectorElementCount());
} else {
IndexVT = getContainerForFixedLengthVector(IndexVT);
ContainerVT = MVT::getVectorVT(ContainerVT.getVectorElementType(),
IndexVT.getVectorElementCount());
}
Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
if (!IsUnmasked) {
MVT MaskVT =
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
}
VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
} else
VL = DAG.getRegister(RISCV::X0, XLenVT);
unsigned IntID =
IsUnmasked ? Intrinsic::riscv_vloxei : Intrinsic::riscv_vloxei_mask;
SmallVector<SDValue, 8> Ops{MGN->getChain(),
DAG.getTargetConstant(IntID, DL, XLenVT)};
if (MSN)
Ops.push_back(MSN->getValue());
else if (!IsUnmasked)
Ops.push_back(MGN->getPassThru());
Ops.push_back(N->getBasePtr());
if (!IsUnmasked)
Ops.push_back(PassThru);
Ops.push_back(MGN->getBasePtr());
Ops.push_back(Index);
if (!IsUnmasked)
Ops.push_back(Mask);
Ops.push_back(VL);
return DAG.getMemIntrinsicNode(
MGN ? ISD::INTRINSIC_W_CHAIN : ISD::INTRINSIC_VOID, DL, N->getVTList(),
Ops, N->getMemoryVT(), N->getMemOperand());
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
SDValue Result =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
MGN->getMemoryVT(), MGN->getMemOperand());
SDValue Chain = Result.getValue(1);
if (VT.isFixedLengthVector())
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
return DAG.getMergeValues({Result, Chain}, DL);
}
// Custom lower MSCATTER to a legalized form for RVV. It will then be matched to
// a RVV indexed store. The RVV indexed store instructions only support the
// "unsigned unscaled" addressing mode; indices are implicitly zero-extended or
// truncated to XLEN and are treated as byte offsets. Any signed or scaled
// indexing is extended to the XLEN value type and scaled accordingly.
SDValue RISCVTargetLowering::lowerMSCATTER(SDValue Op,
SelectionDAG &DAG) const {
auto *MSN = cast<MaskedScatterSDNode>(Op.getNode());
SDLoc DL(Op);
SDValue Index = MSN->getIndex();
SDValue Mask = MSN->getMask();
SDValue Val = MSN->getValue();
MVT VT = Val.getSimpleValueType();
MVT IndexVT = Index.getSimpleValueType();
MVT XLenVT = Subtarget.getXLenVT();
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
"Unexpected VTs!");
assert(MSN->getBasePtr().getSimpleValueType() == XLenVT &&
"Unexpected pointer type");
// Targets have to explicitly opt-in for extending vector loads and
// truncating vector stores.
assert(!MSN->isTruncatingStore() && "Unexpected extending MSCATTER");
// If the mask is known to be all ones, optimize to an unmasked intrinsic;
// the selection of the masked intrinsics doesn't do this for us.
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
SDValue VL;
if (VT.isFixedLengthVector()) {
// We need to use the larger of the value and index type to determine the
// scalable type to use so we don't increase LMUL for any operand/result.
if (VT.bitsGE(IndexVT)) {
VT = getContainerForFixedLengthVector(VT);
IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
VT.getVectorElementCount());
} else {
IndexVT = getContainerForFixedLengthVector(IndexVT);
VT = MVT::getVectorVT(VT.getVectorElementType(),
IndexVT.getVectorElementCount());
}
Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
Val = convertToScalableVector(VT, Val, DAG, Subtarget);
if (!IsUnmasked) {
MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount());
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
}
VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
} else
VL = DAG.getRegister(RISCV::X0, XLenVT);
unsigned IntID =
IsUnmasked ? Intrinsic::riscv_vsoxei : Intrinsic::riscv_vsoxei_mask;
SmallVector<SDValue, 8> Ops{MSN->getChain(),
DAG.getTargetConstant(IntID, DL, XLenVT)};
Ops.push_back(Val);
Ops.push_back(MSN->getBasePtr());
Ops.push_back(Index);
if (!IsUnmasked)
Ops.push_back(Mask);
Ops.push_back(VL);
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, MSN->getVTList(), Ops,
MSN->getMemoryVT(), MSN->getMemOperand());
}
// Returns the opcode of the target-specific SDNode that implements the 32-bit

View File

@ -479,7 +479,8 @@ private:
SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;

View File

@ -61,15 +61,7 @@ public:
return ST->getXLen();
}
bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
if (!ST->hasStdExtV())
return false;
// Only support fixed vectors if we know the minimum vector size.
if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
return false;
Type *ScalarTy = DataType->getScalarType();
bool isLegalElementTypeForRVV(Type *ScalarTy) {
if (ScalarTy->isPointerTy())
return true;
@ -87,12 +79,41 @@ public:
return false;
}
bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
if (!ST->hasStdExtV())
return false;
// Only support fixed vectors if we know the minimum vector size.
if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
return false;
return isLegalElementTypeForRVV(DataType->getScalarType());
}
bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment) {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
if (!ST->hasStdExtV())
return false;
// Only support fixed vectors if we know the minimum vector size.
if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
return false;
return isLegalElementTypeForRVV(DataType->getScalarType());
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) {
return isLegalMaskedGatherScatter(DataType, Alignment);
}
bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
return isLegalMaskedGatherScatter(DataType, Alignment);
}
};
} // end namespace llvm

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff