[AArch64][SVE] Fix selection failures for scalable MLOAD nodes with passthru

Differential Revision: https://reviews.llvm.org/D105348
This commit is contained in:
Bradley Smith 2021-07-01 16:48:24 +00:00
parent 5ffa051447
commit 5ab9000fbb
5 changed files with 65 additions and 1 deletions

View File

@ -1154,6 +1154,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::MUL, VT, Custom);
setOperationAction(ISD::MULHS, VT, Custom);
setOperationAction(ISD::MULHU, VT, Custom);
@ -1245,6 +1246,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::FADD, VT, Custom);
@ -1280,6 +1282,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
}
setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom);
@ -4476,6 +4479,32 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
return DAG.getNode(Opcode, DL, VTs, Ops);
}
SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
MaskedLoadSDNode *LoadNode = cast<MaskedLoadSDNode>(Op);
assert(LoadNode && "Expected custom lowering of a masked load node");
EVT VT = Op->getValueType(0);
if (useSVEForFixedLengthVectorVT(VT, true))
return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
SDValue PassThru = LoadNode->getPassThru();
SDValue Mask = LoadNode->getMask();
if (PassThru->isUndef() || isZerosVector(PassThru.getNode()))
return Op;
SDValue Load = DAG.getMaskedLoad(
VT, DL, LoadNode->getChain(), LoadNode->getBasePtr(),
LoadNode->getOffset(), Mask, DAG.getUNDEF(VT), LoadNode->getMemoryVT(),
LoadNode->getMemOperand(), LoadNode->getAddressingMode(),
LoadNode->getExtensionType());
SDValue Result = DAG.getSelect(DL, VT, Mask, Load, PassThru);
return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
}
// Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
EVT VT, EVT MemVT,
@ -4854,7 +4883,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::TRUNCATE:
return LowerTRUNCATE(Op, DAG);
case ISD::MLOAD:
return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
return LowerMLOAD(Op, DAG);
case ISD::LOAD:
if (useSVEForFixedLengthVectorVT(Op.getValueType()))
return LowerFixedLengthVectorLoadToSVE(Op, DAG);

View File

@ -859,6 +859,8 @@ private:
SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
bool isEligibleForTailCallOptimization(

View File

@ -92,6 +92,15 @@ define <vscale x 8 x bfloat> @masked_load_nxv8bf16(<vscale x 8 x bfloat> *%a, <v
ret <vscale x 8 x bfloat> %load
}
define <vscale x 4 x i32> @masked_load_passthru(<vscale x 4 x i32> *%a, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) nounwind {
; CHECK-LABEL: masked_load_passthru:
; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0]
; CHECK-NEXT: mov z0.s, p0/m, z1.s
; CHECK-NEXT: ret
%load = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32(<vscale x 4 x i32> *%a, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru)
ret <vscale x 4 x i32> %load
}
;
; Masked Stores
;

View File

@ -58,6 +58,18 @@ define <vscale x 8 x i16> @masked_sload_nxv8i8(<vscale x 8 x i8> *%a, <vscale x
ret <vscale x 8 x i16> %ext
}
define <vscale x 2 x i64> @masked_sload_passthru(<vscale x 2 x i32> *%a, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru) {
; CHECK-LABEL: masked_sload_passthru:
; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG1:p[0-9]+]]/z, [x0]
; CHECK-NEXT: ptrue [[PG2:p[0-9]+]].d
; CHECK-NEXT: sxtw z0.d, [[PG2]]/m, z0.d
; CHECK-NEXT: mov z0.d, [[PG1]]/m, [[IN]].d
; CHECK-NEXT: ret
%load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32> *%a, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru)
%ext = sext <vscale x 2 x i32> %load to <vscale x 2 x i64>
ret <vscale x 2 x i64> %ext
}
declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)

View File

@ -64,6 +64,18 @@ define <vscale x 8 x i16> @masked_zload_nxv8i8(<vscale x 8 x i8>* %src, <vscale
ret <vscale x 8 x i16> %ext
}
define <vscale x 2 x i64> @masked_zload_passthru(<vscale x 2 x i32>* %src, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru) {
; CHECK-LABEL: masked_zload_passthru:
; CHECK-NOT: ld1sw
; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0]
; CHECK-NEXT: and z0.d, z0.d, #0xffffffff
; CHECK-NEXT: mov z0.d, [[PG]]/m, [[IN]].d
; CHECK-NEXT: ret
%load = call <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>* %src, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x i32> %passthru)
%ext = zext <vscale x 2 x i32> %load to <vscale x 2 x i64>
ret <vscale x 2 x i64> %ext
}
declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)