[RISCV] Lower scalable vector masked loads to intrinsics to match fixed vectors and reduce isel patterns.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D98840
This commit is contained in:
Craig Topper 2021-03-19 10:39:33 -07:00
parent 3587728ed5
commit 85f3f6b3cc
3 changed files with 41 additions and 63 deletions

View File

@ -474,6 +474,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::MSTORE, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
@ -517,6 +519,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
setOperationAction(ISD::FCOPYSIGN, VT, Legal);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::MSTORE, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
@ -1651,9 +1655,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::STORE:
return lowerFixedLengthVectorStoreToRVV(Op, DAG);
case ISD::MLOAD:
return lowerFixedLengthVectorMaskedLoadToRVV(Op, DAG);
return lowerMLOAD(Op, DAG);
case ISD::MSTORE:
return lowerFixedLengthVectorMaskedStoreToRVV(Op, DAG);
return lowerMSTORE(Op, DAG);
case ISD::SETCC:
return lowerFixedLengthVectorSetccToRVV(Op, DAG);
case ISD::ADD:
@ -3194,50 +3198,63 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
Store->getMemoryVT(), Store->getMemOperand());
}
SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedLoadToRVV(
SDValue Op, SelectionDAG &DAG) const {
SDValue RISCVTargetLowering::lowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
auto *Load = cast<MaskedLoadSDNode>(Op);
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(VT);
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
MVT XLenVT = Subtarget.getXLenVT();
SDValue Mask =
convertToScalableVector(MaskVT, Load->getMask(), DAG, Subtarget);
SDValue PassThru =
convertToScalableVector(ContainerVT, Load->getPassThru(), DAG, Subtarget);
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
SDValue Mask = Load->getMask();
SDValue PassThru = Load->getPassThru();
SDValue VL;
MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VT);
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
} else
VL = DAG.getRegister(RISCV::X0, XLenVT);
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vle_mask, DL, XLenVT);
SDValue Ops[] = {Load->getChain(), IntID, PassThru,
Load->getBasePtr(), Mask, VL};
SDValue NewLoad =
SDValue Result =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
Load->getMemoryVT(), Load->getMemOperand());
SDValue Chain = Result.getValue(1);
SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget);
return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
if (VT.isFixedLengthVector())
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
return DAG.getMergeValues({Result, Chain}, DL);
}
SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedStoreToRVV(
SDValue Op, SelectionDAG &DAG) const {
SDValue RISCVTargetLowering::lowerMSTORE(SDValue Op, SelectionDAG &DAG) const {
auto *Store = cast<MaskedStoreSDNode>(Op);
SDLoc DL(Op);
SDValue Val = Store->getValue();
SDValue Mask = Store->getMask();
MVT VT = Val.getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(VT);
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
MVT XLenVT = Subtarget.getXLenVT();
SDValue VL;
Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
SDValue Mask =
convertToScalableVector(MaskVT, Store->getMask(), DAG, Subtarget);
MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VT);
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
} else
VL = DAG.getRegister(RISCV::X0, XLenVT);
SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vse_mask, DL, XLenVT);
return DAG.getMemIntrinsicNode(

View File

@ -475,15 +475,13 @@ private:
SDValue lowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVECTOR_REVERSE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerABS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorMaskedLoadToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorMaskedStoreToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG,
unsigned MaskOpc,

View File

@ -33,21 +33,6 @@ def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector,
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
def SplatPat_uimm5 : ComplexPattern<vAny, 1, "selectVSplatUimm5", [splat_vector, rv32_splat_i64], [], 2>;
def masked_load :
PatFrag<(ops node:$ptr, node:$mask, node:$maskedoff),
(masked_ld node:$ptr, undef, node:$mask, node:$maskedoff), [{
return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() &&
cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD &&
cast<MaskedLoadSDNode>(N)->isUnindexed();
}]>;
def masked_store :
PatFrag<(ops node:$val, node:$ptr, node:$mask),
(masked_st node:$val, node:$ptr, undef, node:$mask), [{
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed();
}]>;
class SwapHelper<dag Prefix, dag A, dag B, dag Suffix, bit swap> {
dag Value = !con(Prefix, !if(swap, B, A), !if(swap, A, B), Suffix);
}
@ -68,25 +53,6 @@ multiclass VPatUSLoadStoreSDNode<ValueType type,
(store_instr reg_class:$rs2, BaseAddr:$rs1, avl, sew)>;
}
multiclass VPatUSLoadStoreSDNodeMask<ValueType type,
ValueType mask_type,
int sew,
LMULInfo vlmul,
OutPatFrag avl,
VReg reg_class>
{
defvar load_instr = !cast<Instruction>("PseudoVLE"#sew#"_V_"#vlmul.MX#"_MASK");
defvar store_instr = !cast<Instruction>("PseudoVSE"#sew#"_V_"#vlmul.MX#"_MASK");
// Load
def : Pat<(type (masked_load BaseAddr:$rs1, (mask_type V0), type:$merge)),
(load_instr reg_class:$merge, BaseAddr:$rs1, (mask_type V0),
avl, sew)>;
// Store
def : Pat<(masked_store type:$rs2, BaseAddr:$rs1, (mask_type V0)),
(store_instr reg_class:$rs2, BaseAddr:$rs1, (mask_type V0),
avl, sew)>;
}
multiclass VPatUSLoadStoreWholeVRSDNode<ValueType type,
int sew,
LMULInfo vlmul,
@ -394,9 +360,6 @@ foreach vti = !listconcat(FractionalGroupIntegerVectors,
FractionalGroupFloatVectors) in
defm "" : VPatUSLoadStoreSDNode<vti.Vector, vti.SEW, vti.LMul,
vti.AVL, vti.RegClass>;
foreach vti = AllVectors in
defm "" : VPatUSLoadStoreSDNodeMask<vti.Vector, vti.Mask, vti.SEW, vti.LMul,
vti.AVL, vti.RegClass>;
foreach vti = [VI8M1, VI16M1, VI32M1, VI64M1, VF16M1, VF32M1, VF64M1] in
defm "" : VPatUSLoadStoreWholeVRSDNode<vti.Vector, vti.SEW, vti.LMul,
vti.RegClass>;