From 421f1b7294ef4dbe8f02d83fcd3b9eb604465bf5 Mon Sep 17 00:00:00 2001 From: Cameron McInally Date: Wed, 14 Oct 2020 09:11:58 -0500 Subject: [PATCH] [SVE] Lower fixed length VECREDUCE_FADD operation Differential Revision: https://reviews.llvm.org/D89263 --- .../Target/AArch64/AArch64ISelLowering.cpp | 9 + .../AArch64/sve-fixed-length-fp-reduce.ll | 261 ++++++++++++++++++ 2 files changed, 270 insertions(+) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ef90de7a1003..eff29d92199d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1125,6 +1125,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_OR, VT, Custom); setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); } + + // Use SVE for vectors with more than 2 elements. + for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32}) + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); } } @@ -1261,6 +1265,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { setOperationAction(ISD::UMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); setOperationAction(ISD::VECREDUCE_AND, VT, Custom); + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); @@ -3963,6 +3968,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::VECREDUCE_SMIN: case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_FMAX: case ISD::VECREDUCE_FMIN: return LowerVECREDUCE(Op, DAG); @@ -9749,6 +9755,7 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, bool OverrideNEON = Op.getOpcode() == ISD::VECREDUCE_AND || Op.getOpcode() == ISD::VECREDUCE_OR || Op.getOpcode() == ISD::VECREDUCE_XOR || + Op.getOpcode() == ISD::VECREDUCE_FADD || (Op.getOpcode() != ISD::VECREDUCE_ADD && SrcVT.getVectorElementType() == MVT::i64); if (useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) { @@ -9769,6 +9776,8 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, return LowerFixedLengthReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG); case ISD::VECREDUCE_XOR: return LowerFixedLengthReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG); + case ISD::VECREDUCE_FADD: + return LowerFixedLengthReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG); case ISD::VECREDUCE_FMAX: return LowerFixedLengthReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG); case ISD::VECREDUCE_FMIN: diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll index 6991c0ad3a68..92b04a068861 100644 --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll @@ -20,6 +20,246 @@ target triple = "aarch64-unknown-linux-gnu" ; Don't use SVE when its registers are no bigger than NEON. ; NO_SVE-NOT: ptrue +; +; FADDV +; + +; No single instruction NEON support for 4 element vectors. +define half @faddv_v4f16(half %start, <4 x half> %a) #0 { +; CHECK-LABEL: faddv_v4f16: +; CHECK: ptrue [[PG:p[0-9]+]].h, vl4 +; CHECK-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], z1.h +; CHECK-NEXT: fadd h0, h0, [[RDX]] +; CHECK-NEXT: ret + %res = call fast half @llvm.vector.reduce.fadd.v4f16(half %start, <4 x half> %a) + ret half %res +} + +; No single instruction NEON support for 8 element vectors. +define half @faddv_v8f16(half %start, <8 x half> %a) #0 { +; CHECK-LABEL: faddv_v8f16: +; CHECK: ptrue [[PG:p[0-9]+]].h, vl8 +; CHECK-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], z1.h +; CHECK-NEXT: fadd h0, h0, [[RDX]] +; CHECK-NEXT: ret + %res = call fast half @llvm.vector.reduce.fadd.v8f16(half %start, <8 x half> %a) + ret half %res +} + +define half @faddv_v16f16(half %start, <16 x half>* %a) #0 { +; CHECK-LABEL: faddv_v16f16: +; CHECK: ptrue [[PG:p[0-9]+]].h, vl16 +; CHECK-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0] +; CHECK-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h +; CHECK-NEXT: fadd h0, h0, [[RDX]] +; CHECK-NEXT: ret + %op = load <16 x half>, <16 x half>* %a + %res = call fast half @llvm.vector.reduce.fadd.v16f16(half %start, <16 x half> %op) + ret half %res +} + +define half @faddv_v32f16(half %start, <32 x half>* %a) #0 { +; CHECK-LABEL: faddv_v32f16: +; VBITS_GE_512: ptrue [[PG:p[0-9]+]].h, vl32 +; VBITS_GE_512-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0] +; VBITS_GE_512-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h +; VBITS_GE_512-NEXT: fadd h0, h0, [[RDX]] +; VBITS_GE_512-NEXT: ret + +; Ensure sensible type legalisation. +; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].h, vl16 +; VBITS_EQ_256-DAG: add x[[A_HI:[0-9]+]], x0, #32 +; VBITS_EQ_256-DAG: ld1h { [[LO:z[0-9]+]].h }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: ld1h { [[HI:z[0-9]+]].h }, [[PG]]/z, [x[[A_HI]]] +; VBITS_EQ_256-DAG: fadd [[ADD:z[0-9]+]].h, [[PG]]/m, [[LO]].h, [[HI]].h +; VBITS_EQ_256-DAG: faddv h1, [[PG]], [[ADD]].h +; VBITS_EQ_256-DAG: fadd h0, h0, [[RDX]] +; VBITS_EQ_256-NEXT: ret + %op = load <32 x half>, <32 x half>* %a + %res = call fast half @llvm.vector.reduce.fadd.v32f16(half %start, <32 x half> %op) + ret half %res +} + +define half @faddv_v64f16(half %start, <64 x half>* %a) #0 { +; CHECK-LABEL: faddv_v64f16: +; VBITS_GE_1024: ptrue [[PG:p[0-9]+]].h, vl64 +; VBITS_GE_1024-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0] +; VBITS_GE_1024-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h +; VBITS_GE_1024-NEXT: fadd h0, h0, [[RDX]] +; VBITS_GE_1024-NEXT: ret + %op = load <64 x half>, <64 x half>* %a + %res = call fast half @llvm.vector.reduce.fadd.v64f16(half %start, <64 x half> %op) + ret half %res +} + +define half @faddv_v128f16(half %start, <128 x half>* %a) #0 { +; CHECK-LABEL: faddv_v128f16: +; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].h, vl128 +; VBITS_GE_2048-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0] +; VBITS_GE_2048-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h +; VBITS_GE_2048-NEXT: fadd h0, h0, [[RDX]] +; VBITS_GE_2048-NEXT: ret + %op = load <128 x half>, <128 x half>* %a + %res = call fast half @llvm.vector.reduce.fadd.v128f16(half %start, <128 x half> %op) + ret half %res +} + +; Don't use SVE for 2 element vectors. +define float @faddv_v2f32(float %start, <2 x float> %a) #0 { +; CHECK-LABEL: faddv_v2f32: +; CHECK: faddp s1, v1.2s +; CHECK-NEXT: fadd s0, s0, s1 +; CHECK-NEXT: ret + %res = call fast float @llvm.vector.reduce.fadd.v2f32(float %start, <2 x float> %a) + ret float %res +} + +; No single instruction NEON support for 4 element vectors. +define float @faddv_v4f32(float %start, <4 x float> %a) #0 { +; CHECK-LABEL: faddv_v4f32: +; CHECK: ptrue [[PG:p[0-9]+]].s, vl4 +; CHECK-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], z1.s +; CHECK-NEXT: fadd s0, s0, [[RDX]] +; CHECK-NEXT: ret + %res = call fast float @llvm.vector.reduce.fadd.v4f32(float %start, <4 x float> %a) + ret float %res +} + +define float @faddv_v8f32(float %start, <8 x float>* %a) #0 { +; CHECK-LABEL: faddv_v8f32: +; CHECK: ptrue [[PG:p[0-9]+]].s, vl8 +; CHECK-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0] +; CHECK-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s +; CHECK-NEXT: fadd s0, s0, [[RDX]] +; CHECK-NEXT: ret + %op = load <8 x float>, <8 x float>* %a + %res = call fast float @llvm.vector.reduce.fadd.v8f32(float %start, <8 x float> %op) + ret float %res +} + +define float @faddv_v16f32(float %start, <16 x float>* %a) #0 { +; CHECK-LABEL: faddv_v16f32: +; VBITS_GE_512: ptrue [[PG:p[0-9]+]].s, vl16 +; VBITS_GE_512-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_GE_512-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s +; VBITS_GE_512-NEXT: fadd s0, s0, [[RDX]] +; VBITS_GE_512-NEXT: ret + +; Ensure sensible type legalisation. +; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].s, vl8 +; VBITS_EQ_256-DAG: add x[[A_LO:[0-9]+]], x0, #32 +; VBITS_EQ_256-DAG: ld1w { [[LO:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: ld1w { [[HI:z[0-9]+]].s }, [[PG]]/z, [x[[A_LO]]] +; VBITS_EQ_256-DAG: fadd [[ADD:z[0-9]+]].s, [[PG]]/m, [[LO]].s, [[HI]].s +; VBITS_EQ_256-DAG: faddv [[RDX:s[0-9]+]], [[PG]], [[ADD]].s +; VBITS_EQ_256-DAG: fadd s0, s0, [[RDX]] +; VBITS_EQ_256-NEXT: ret + %op = load <16 x float>, <16 x float>* %a + %res = call fast float @llvm.vector.reduce.fadd.v16f32(float %start, <16 x float> %op) + ret float %res +} + +define float @faddv_v32f32(float %start, <32 x float>* %a) #0 { +; CHECK-LABEL: faddv_v32f32: +; VBITS_GE_1024: ptrue [[PG:p[0-9]+]].s, vl32 +; VBITS_GE_1024-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_GE_1024-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s +; VBITS_GE_1024-NEXT: fadd s0, s0, [[RDX]] +; VBITS_GE_1024-NEXT: ret + %op = load <32 x float>, <32 x float>* %a + %res = call fast float @llvm.vector.reduce.fadd.v32f32(float %start, <32 x float> %op) + ret float %res +} + +define float @faddv_v64f32(float %start, <64 x float>* %a) #0 { +; CHECK-LABEL: faddv_v64f32: +; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].s, vl64 +; VBITS_GE_2048-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_GE_2048-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s +; VBITS_GE_2048-NEXT: fadd s0, s0, [[RDX]] +; VBITS_GE_2048-NEXT: ret + %op = load <64 x float>, <64 x float>* %a + %res = call fast float @llvm.vector.reduce.fadd.v64f32(float %start, <64 x float> %op) + ret float %res +} + +; Don't use SVE for 1 element vectors. +define double @faddv_v1f64(double %start, <1 x double> %a) #0 { +; CHECK-LABEL: faddv_v1f64: +; CHECK: fadd d0, d0, d1 +; CHECK-NEXT: ret + %res = call fast double @llvm.vector.reduce.fadd.v1f64(double %start, <1 x double> %a) + ret double %res +} + +; Don't use SVE for 2 element vectors. +define double @faddv_v2f64(double %start, <2 x double> %a) #0 { +; CHECK-LABEL: faddv_v2f64: +; CHECK: faddp d1, v1.2d +; CHECK-NEXT: fadd d0, d0, d1 +; CHECK-NEXT: ret + %res = call fast double @llvm.vector.reduce.fadd.v2f64(double %start, <2 x double> %a) + ret double %res +} + +define double @faddv_v4f64(double %start, <4 x double>* %a) #0 { +; CHECK-LABEL: faddv_v4f64: +; CHECK: ptrue [[PG:p[0-9]+]].d, vl4 +; CHECK-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0] +; CHECK-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d +; CHECK-NEXT: fadd d0, d0, [[RDX]] +; CHECK-NEXT: ret + %op = load <4 x double>, <4 x double>* %a + %res = call fast double @llvm.vector.reduce.fadd.v4f64(double %start, <4 x double> %op) + ret double %res +} + +define double @faddv_v8f64(double %start, <8 x double>* %a) #0 { +; CHECK-LABEL: faddv_v8f64: +; VBITS_GE_512: ptrue [[PG:p[0-9]+]].d, vl8 +; VBITS_GE_512-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0] +; VBITS_GE_512-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d +; VBITS_GE_512-NEXT: fadd d0, d0, [[RDX]] +; VBITS_GE_512-NEXT: ret + +; Ensure sensible type legalisation. +; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].d, vl4 +; VBITS_EQ_256-DAG: add x[[A_LO:[0-9]+]], x0, #32 +; VBITS_EQ_256-DAG: ld1d { [[LO:z[0-9]+]].d }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: ld1d { [[HI:z[0-9]+]].d }, [[PG]]/z, [x[[A_LO]]] +; VBITS_EQ_256-DAG: fadd [[ADD:z[0-9]+]].d, [[PG]]/m, [[LO]].d, [[HI]].d +; VBITS_EQ_256-DAG: faddv [[RDX:d[0-9]+]], [[PG]], [[ADD]].d +; VBITS_EQ_256-DAG: fadd d0, d0, [[RDX]] +; VBITS_EQ_256-NEXT: ret + %op = load <8 x double>, <8 x double>* %a + %res = call fast double @llvm.vector.reduce.fadd.v8f64(double %start, <8 x double> %op) + ret double %res +} + +define double @faddv_v16f64(double %start, <16 x double>* %a) #0 { +; CHECK-LABEL: faddv_v16f64: +; VBITS_GE_1024: ptrue [[PG:p[0-9]+]].d, vl16 +; VBITS_GE_1024-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0] +; VBITS_GE_1024-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d +; VBITS_GE_1024-NEXT: fadd d0, d0, [[RDX]] +; VBITS_GE_1024-NEXT: ret + %op = load <16 x double>, <16 x double>* %a + %res = call fast double @llvm.vector.reduce.fadd.v16f64(double %start, <16 x double> %op) + ret double %res +} + +define double @faddv_v32f64(double %start, <32 x double>* %a) #0 { +; CHECK-LABEL: faddv_v32f64: +; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].d, vl32 +; VBITS_GE_2048-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0] +; VBITS_GE_2048-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d +; VBITS_GE_2048-NEXT: fadd d0, d0, [[RDX]] +; VBITS_GE_2048-NEXT: ret + %op = load <32 x double>, <32 x double>* %a + %res = call fast double @llvm.vector.reduce.fadd.v32f64(double %start, <32 x double> %op) + ret double %res +} + ; ; FMAXV ; @@ -456,6 +696,27 @@ define double @fminv_v32f64(<32 x double>* %a) #0 { attributes #0 = { "target-features"="+sve" } +declare half @llvm.vector.reduce.fadd.v4f16(half, <4 x half>) +declare half @llvm.vector.reduce.fadd.v8f16(half, <8 x half>) +declare half @llvm.vector.reduce.fadd.v16f16(half, <16 x half>) +declare half @llvm.vector.reduce.fadd.v32f16(half, <32 x half>) +declare half @llvm.vector.reduce.fadd.v64f16(half, <64 x half>) +declare half @llvm.vector.reduce.fadd.v128f16(half, <128 x half>) + +declare float @llvm.vector.reduce.fadd.v2f32(float, <2 x float>) +declare float @llvm.vector.reduce.fadd.v4f32(float, <4 x float>) +declare float @llvm.vector.reduce.fadd.v8f32(float, <8 x float>) +declare float @llvm.vector.reduce.fadd.v16f32(float, <16 x float>) +declare float @llvm.vector.reduce.fadd.v32f32(float, <32 x float>) +declare float @llvm.vector.reduce.fadd.v64f32(float, <64 x float>) + +declare double @llvm.vector.reduce.fadd.v1f64(double, <1 x double>) +declare double @llvm.vector.reduce.fadd.v2f64(double, <2 x double>) +declare double @llvm.vector.reduce.fadd.v4f64(double, <4 x double>) +declare double @llvm.vector.reduce.fadd.v8f64(double, <8 x double>) +declare double @llvm.vector.reduce.fadd.v16f64(double, <16 x double>) +declare double @llvm.vector.reduce.fadd.v32f64(double, <32 x double>) + declare half @llvm.vector.reduce.fmax.v4f16(<4 x half>) declare half @llvm.vector.reduce.fmax.v8f16(<8 x half>) declare half @llvm.vector.reduce.fmax.v16f16(<16 x half>)