[SVE][CodeGen] Add a DAG combine to extend mscatter indices

This patch adds a target-specific DAG combine for mscatter to promote indices
with element types i8 or i16 before legalisation, plus various tests with illegal types.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D90945
This commit is contained in:
Kerry McLaughlin 2020-11-25 10:54:31 +00:00
parent 1c3451c665
commit 603d40da9d
2 changed files with 101 additions and 0 deletions

View File

@ -835,6 +835,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->supportsAddressTopByteIgnored())
setTargetDAGCombine(ISD::LOAD);
setTargetDAGCombine(ISD::MSCATTER);
setTargetDAGCombine(ISD::MUL);
setTargetDAGCombine(ISD::SELECT);
@ -13944,6 +13946,44 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}
static SDValue performMSCATTERCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
assert(MSC && "Can only combine scatter store nodes");
SDLoc DL(MSC);
SDValue Chain = MSC->getChain();
SDValue Scale = MSC->getScale();
SDValue Index = MSC->getIndex();
SDValue Data = MSC->getValue();
SDValue Mask = MSC->getMask();
SDValue BasePtr = MSC->getBasePtr();
ISD::MemIndexType IndexType = MSC->getIndexType();
EVT IdxVT = Index.getValueType();
if (DCI.isBeforeLegalize()) {
// SVE gather/scatter requires indices of i32/i64. Promote anything smaller
// prior to legalisation so the result can be split if required.
if ((IdxVT.getVectorElementType() == MVT::i8) ||
(IdxVT.getVectorElementType() == MVT::i16)) {
EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
if (MSC->isIndexSigned())
Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
else
Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
MSC->getMemoryVT(), DL, Ops,
MSC->getMemOperand(), IndexType,
MSC->isTruncatingStore());
}
}
return SDValue();
}
/// Target-specific DAG combine function for NEON load/store intrinsics
/// to merge base address updates.
@ -15136,6 +15176,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
break;
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
case ISD::MSCATTER:
return performMSCATTERCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:

View File

@ -0,0 +1,59 @@
; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
; Tests that exercise various type legalisation scenarios for ISD::MSCATTER.
; Code generate the scenario where the offset vector type is illegal.
define void @masked_scatter_nxv16i8(<vscale x 16 x i8> %data, i8* %base, <vscale x 16 x i8> %offsets, <vscale x 16 x i1> %mask) {
; CHECK-LABEL: masked_scatter_nxv16i8:
; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw]
; CHECK: ret
%ptrs = getelementptr i8, i8* %base, <vscale x 16 x i8> %offsets
call void @llvm.masked.scatter.nxv16i8(<vscale x 16 x i8> %data, <vscale x 16 x i8*> %ptrs, i32 1, <vscale x 16 x i1> %mask)
ret void
}
define void @masked_scatter_nxv8i16(<vscale x 8 x i16> %data, i16* %base, <vscale x 8 x i16> %offsets, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: masked_scatter_nxv8i16
; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1]
; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1]
; CHECK: ret
%ptrs = getelementptr i16, i16* %base, <vscale x 8 x i16> %offsets
call void @llvm.masked.scatter.nxv8i16(<vscale x 8 x i16> %data, <vscale x 8 x i16*> %ptrs, i32 1, <vscale x 8 x i1> %mask)
ret void
}
define void @masked_scatter_nxv8f32(<vscale x 8 x float> %data, float* %base, <vscale x 8 x i32> %indexes, <vscale x 8 x i1> %masks) {
; CHECK-LABEL: masked_scatter_nxv8f32
; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2]
; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2]
%ext = zext <vscale x 8 x i32> %indexes to <vscale x 8 x i64>
%ptrs = getelementptr float, float* %base, <vscale x 8 x i64> %ext
call void @llvm.masked.scatter.nxv8f32(<vscale x 8 x float> %data, <vscale x 8 x float*> %ptrs, i32 0, <vscale x 8 x i1> %masks)
ret void
}
; Code generate the worst case scenario when all vector types are illegal.
define void @masked_scatter_nxv32i32(<vscale x 32 x i32> %data, i32* %base, <vscale x 32 x i32> %offsets, <vscale x 32 x i1> %mask) {
; CHECK-LABEL: masked_scatter_nxv32i32:
; CHECK-NOT: unpkhi
; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK-DAG: st1w { z2.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK-DAG: st1w { z3.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK-DAG: st1w { z4.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK-DAG: st1w { z5.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK-DAG: st1w { z6.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK-DAG: st1w { z7.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2]
; CHECK: ret
%ptrs = getelementptr i32, i32* %base, <vscale x 32 x i32> %offsets
call void @llvm.masked.scatter.nxv32i32(<vscale x 32 x i32> %data, <vscale x 32 x i32*> %ptrs, i32 4, <vscale x 32 x i1> %mask)
ret void
}
declare void @llvm.masked.scatter.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8*>, i32, <vscale x 16 x i1>)
declare void @llvm.masked.scatter.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16*>, i32, <vscale x 8 x i1>)
declare void @llvm.masked.scatter.nxv8f32(<vscale x 8 x float>, <vscale x 8 x float*>, i32, <vscale x 8 x i1>)
declare void @llvm.masked.scatter.nxv32i32(<vscale x 32 x i32>, <vscale x 32 x i32*>, i32, <vscale x 32 x i1>)