forked from OSchip/llvm-project
[RISCV] Add support for fixed vector masked gather/scatter.
I've split the gather/scatter custom handler to avoid complicating it with even more differences between gather/scatter. Tests are the scalable vector tests with the vscale removed and dropped the tests that used vector.insert. We're probably not as thorough on the splitting cases since we use 128 for VLEN here but scalable vector use a known min size of 64. Reviewed By: frasercrmck Differential Revision: https://reviews.llvm.org/D98991
This commit is contained in:
parent
5184f69041
commit
294efcd6f7
|
@ -585,6 +585,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
|
||||
setOperationAction(ISD::MLOAD, VT, Custom);
|
||||
setOperationAction(ISD::MSTORE, VT, Custom);
|
||||
setOperationAction(ISD::MGATHER, VT, Custom);
|
||||
setOperationAction(ISD::MSCATTER, VT, Custom);
|
||||
setOperationAction(ISD::ADD, VT, Custom);
|
||||
setOperationAction(ISD::MUL, VT, Custom);
|
||||
setOperationAction(ISD::SUB, VT, Custom);
|
||||
|
@ -656,6 +658,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
|
|||
setOperationAction(ISD::STORE, VT, Custom);
|
||||
setOperationAction(ISD::MLOAD, VT, Custom);
|
||||
setOperationAction(ISD::MSTORE, VT, Custom);
|
||||
setOperationAction(ISD::MGATHER, VT, Custom);
|
||||
setOperationAction(ISD::MSCATTER, VT, Custom);
|
||||
setOperationAction(ISD::FADD, VT, Custom);
|
||||
setOperationAction(ISD::FSUB, VT, Custom);
|
||||
setOperationAction(ISD::FMUL, VT, Custom);
|
||||
|
@ -1724,8 +1728,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
|
|||
case ISD::FCOPYSIGN:
|
||||
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
|
||||
case ISD::MGATHER:
|
||||
return lowerMGATHER(Op, DAG);
|
||||
case ISD::MSCATTER:
|
||||
return lowerMGATHERMSCATTER(Op, DAG);
|
||||
return lowerMSCATTER(Op, DAG);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3487,54 +3492,154 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
|
|||
}
|
||||
|
||||
// Custom lower MGATHER to a legalized form for RVV. It will then be matched to
|
||||
// a RVV indexed load. The RVV indexed load/store instructions only support the
|
||||
// a RVV indexed load. The RVV indexed load instructions only support the
|
||||
// "unsigned unscaled" addressing mode; indices are implicitly zero-extended or
|
||||
// truncated to XLEN and are treated as byte offsets. Any signed or scaled
|
||||
// indexing is extended to the XLEN value type and scaled accordingly.
|
||||
SDValue RISCVTargetLowering::lowerMGATHERMSCATTER(SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
auto *N = cast<MaskedGatherScatterSDNode>(Op.getNode());
|
||||
SDValue RISCVTargetLowering::lowerMGATHER(SDValue Op, SelectionDAG &DAG) const {
|
||||
auto *MGN = cast<MaskedGatherSDNode>(Op.getNode());
|
||||
SDLoc DL(Op);
|
||||
SDValue Index = N->getIndex();
|
||||
SDValue Mask = N->getMask();
|
||||
|
||||
SDValue Index = MGN->getIndex();
|
||||
SDValue Mask = MGN->getMask();
|
||||
SDValue PassThru = MGN->getPassThru();
|
||||
|
||||
MVT VT = Op.getSimpleValueType();
|
||||
MVT IndexVT = Index.getSimpleValueType();
|
||||
MVT XLenVT = Subtarget.getXLenVT();
|
||||
assert(N->getBasePtr().getSimpleValueType() == XLenVT &&
|
||||
|
||||
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
|
||||
"Unexpected VTs!");
|
||||
assert(MGN->getBasePtr().getSimpleValueType() == XLenVT &&
|
||||
"Unexpected pointer type");
|
||||
// Targets have to explicitly opt-in for extending vector loads and
|
||||
// truncating vector stores.
|
||||
const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
|
||||
const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
|
||||
assert((!MGN || MGN->getExtensionType() == ISD::NON_EXTLOAD) &&
|
||||
// Targets have to explicitly opt-in for extending vector loads.
|
||||
assert(MGN->getExtensionType() == ISD::NON_EXTLOAD &&
|
||||
"Unexpected extending MGATHER");
|
||||
assert((!MSN || !MSN->isTruncatingStore()) &&
|
||||
"Unexpected extending MSCATTER");
|
||||
|
||||
// If the mask is known to be all ones, optimize to an unmasked intrinsic;
|
||||
// the selection of the masked intrinsics doesn't do this for us.
|
||||
unsigned IntID = 0;
|
||||
MVT IndexVT = Index.getSimpleValueType();
|
||||
SDValue VL = getDefaultVLOps(IndexVT, IndexVT, DL, DAG, Subtarget).second;
|
||||
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
|
||||
|
||||
if (IsUnmasked)
|
||||
IntID = MGN ? Intrinsic::riscv_vloxei : Intrinsic::riscv_vsoxei;
|
||||
else
|
||||
IntID = MGN ? Intrinsic::riscv_vloxei_mask : Intrinsic::riscv_vsoxei_mask;
|
||||
SmallVector<SDValue, 8> Ops{N->getChain(),
|
||||
SDValue VL;
|
||||
MVT ContainerVT = VT;
|
||||
if (VT.isFixedLengthVector()) {
|
||||
// We need to use the larger of the result and index type to determine the
|
||||
// scalable type to use so we don't increase LMUL for any operand/result.
|
||||
if (VT.bitsGE(IndexVT)) {
|
||||
ContainerVT = getContainerForFixedLengthVector(VT);
|
||||
IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
|
||||
ContainerVT.getVectorElementCount());
|
||||
} else {
|
||||
IndexVT = getContainerForFixedLengthVector(IndexVT);
|
||||
ContainerVT = MVT::getVectorVT(ContainerVT.getVectorElementType(),
|
||||
IndexVT.getVectorElementCount());
|
||||
}
|
||||
|
||||
Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
|
||||
|
||||
if (!IsUnmasked) {
|
||||
MVT MaskVT =
|
||||
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
|
||||
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
|
||||
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
|
||||
}
|
||||
|
||||
VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
|
||||
} else
|
||||
VL = DAG.getRegister(RISCV::X0, XLenVT);
|
||||
|
||||
unsigned IntID =
|
||||
IsUnmasked ? Intrinsic::riscv_vloxei : Intrinsic::riscv_vloxei_mask;
|
||||
SmallVector<SDValue, 8> Ops{MGN->getChain(),
|
||||
DAG.getTargetConstant(IntID, DL, XLenVT)};
|
||||
if (MSN)
|
||||
Ops.push_back(MSN->getValue());
|
||||
else if (!IsUnmasked)
|
||||
Ops.push_back(MGN->getPassThru());
|
||||
Ops.push_back(N->getBasePtr());
|
||||
if (!IsUnmasked)
|
||||
Ops.push_back(PassThru);
|
||||
Ops.push_back(MGN->getBasePtr());
|
||||
Ops.push_back(Index);
|
||||
if (!IsUnmasked)
|
||||
Ops.push_back(Mask);
|
||||
Ops.push_back(VL);
|
||||
return DAG.getMemIntrinsicNode(
|
||||
MGN ? ISD::INTRINSIC_W_CHAIN : ISD::INTRINSIC_VOID, DL, N->getVTList(),
|
||||
Ops, N->getMemoryVT(), N->getMemOperand());
|
||||
|
||||
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
|
||||
SDValue Result =
|
||||
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
|
||||
MGN->getMemoryVT(), MGN->getMemOperand());
|
||||
SDValue Chain = Result.getValue(1);
|
||||
|
||||
if (VT.isFixedLengthVector())
|
||||
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
|
||||
|
||||
return DAG.getMergeValues({Result, Chain}, DL);
|
||||
}
|
||||
|
||||
// Custom lower MSCATTER to a legalized form for RVV. It will then be matched to
|
||||
// a RVV indexed store. The RVV indexed store instructions only support the
|
||||
// "unsigned unscaled" addressing mode; indices are implicitly zero-extended or
|
||||
// truncated to XLEN and are treated as byte offsets. Any signed or scaled
|
||||
// indexing is extended to the XLEN value type and scaled accordingly.
|
||||
SDValue RISCVTargetLowering::lowerMSCATTER(SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
auto *MSN = cast<MaskedScatterSDNode>(Op.getNode());
|
||||
SDLoc DL(Op);
|
||||
SDValue Index = MSN->getIndex();
|
||||
SDValue Mask = MSN->getMask();
|
||||
SDValue Val = MSN->getValue();
|
||||
|
||||
MVT VT = Val.getSimpleValueType();
|
||||
MVT IndexVT = Index.getSimpleValueType();
|
||||
MVT XLenVT = Subtarget.getXLenVT();
|
||||
|
||||
assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
|
||||
"Unexpected VTs!");
|
||||
assert(MSN->getBasePtr().getSimpleValueType() == XLenVT &&
|
||||
"Unexpected pointer type");
|
||||
// Targets have to explicitly opt-in for extending vector loads and
|
||||
// truncating vector stores.
|
||||
assert(!MSN->isTruncatingStore() && "Unexpected extending MSCATTER");
|
||||
|
||||
// If the mask is known to be all ones, optimize to an unmasked intrinsic;
|
||||
// the selection of the masked intrinsics doesn't do this for us.
|
||||
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
|
||||
|
||||
SDValue VL;
|
||||
if (VT.isFixedLengthVector()) {
|
||||
// We need to use the larger of the value and index type to determine the
|
||||
// scalable type to use so we don't increase LMUL for any operand/result.
|
||||
if (VT.bitsGE(IndexVT)) {
|
||||
VT = getContainerForFixedLengthVector(VT);
|
||||
IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(),
|
||||
VT.getVectorElementCount());
|
||||
} else {
|
||||
IndexVT = getContainerForFixedLengthVector(IndexVT);
|
||||
VT = MVT::getVectorVT(VT.getVectorElementType(),
|
||||
IndexVT.getVectorElementCount());
|
||||
}
|
||||
|
||||
Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
|
||||
Val = convertToScalableVector(VT, Val, DAG, Subtarget);
|
||||
|
||||
if (!IsUnmasked) {
|
||||
MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount());
|
||||
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
|
||||
}
|
||||
|
||||
VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
|
||||
} else
|
||||
VL = DAG.getRegister(RISCV::X0, XLenVT);
|
||||
|
||||
unsigned IntID =
|
||||
IsUnmasked ? Intrinsic::riscv_vsoxei : Intrinsic::riscv_vsoxei_mask;
|
||||
SmallVector<SDValue, 8> Ops{MSN->getChain(),
|
||||
DAG.getTargetConstant(IntID, DL, XLenVT)};
|
||||
Ops.push_back(Val);
|
||||
Ops.push_back(MSN->getBasePtr());
|
||||
Ops.push_back(Index);
|
||||
if (!IsUnmasked)
|
||||
Ops.push_back(Mask);
|
||||
Ops.push_back(VL);
|
||||
|
||||
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, MSN->getVTList(), Ops,
|
||||
MSN->getMemoryVT(), MSN->getMemOperand());
|
||||
}
|
||||
|
||||
// Returns the opcode of the target-specific SDNode that implements the 32-bit
|
||||
|
|
|
@ -479,7 +479,8 @@ private:
|
|||
SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,
|
||||
SelectionDAG &DAG) const;
|
||||
SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
|
||||
|
|
|
@ -61,15 +61,7 @@ public:
|
|||
return ST->getXLen();
|
||||
}
|
||||
|
||||
bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
|
||||
if (!ST->hasStdExtV())
|
||||
return false;
|
||||
|
||||
// Only support fixed vectors if we know the minimum vector size.
|
||||
if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
|
||||
return false;
|
||||
|
||||
Type *ScalarTy = DataType->getScalarType();
|
||||
bool isLegalElementTypeForRVV(Type *ScalarTy) {
|
||||
if (ScalarTy->isPointerTy())
|
||||
return true;
|
||||
|
||||
|
@ -87,12 +79,41 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
|
||||
if (!ST->hasStdExtV())
|
||||
return false;
|
||||
|
||||
// Only support fixed vectors if we know the minimum vector size.
|
||||
if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
|
||||
return false;
|
||||
|
||||
return isLegalElementTypeForRVV(DataType->getScalarType());
|
||||
}
|
||||
|
||||
bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
|
||||
return isLegalMaskedLoadStore(DataType, Alignment);
|
||||
}
|
||||
bool isLegalMaskedStore(Type *DataType, Align Alignment) {
|
||||
return isLegalMaskedLoadStore(DataType, Alignment);
|
||||
}
|
||||
|
||||
bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
|
||||
if (!ST->hasStdExtV())
|
||||
return false;
|
||||
|
||||
// Only support fixed vectors if we know the minimum vector size.
|
||||
if (isa<FixedVectorType>(DataType) && ST->getMinRVVVectorSizeInBits() == 0)
|
||||
return false;
|
||||
|
||||
return isLegalElementTypeForRVV(DataType->getScalarType());
|
||||
}
|
||||
|
||||
bool isLegalMaskedGather(Type *DataType, Align Alignment) {
|
||||
return isLegalMaskedGatherScatter(DataType, Alignment);
|
||||
}
|
||||
bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
|
||||
return isLegalMaskedGatherScatter(DataType, Alignment);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace llvm
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue