From 9fc54826e0fffbbd491c4782920978cf4575da28 Mon Sep 17 00:00:00 2001 From: Evandro Menezes Date: Mon, 14 Nov 2016 23:29:01 +0000 Subject: [PATCH] [AArch64] Compute the Newton series for reciprocals natively Implement the Newton series for square root, its reciprocal and reciprocal natively using the specialized instructions in AArch64 to perform each series iteration. Differential revision: https://reviews.llvm.org/D26518 llvm-svn: 286907 --- .../Target/AArch64/AArch64ISelLowering.cpp | 51 +++++++++- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 6 +- llvm/lib/Target/AArch64/AArch64InstrInfo.td | 24 +++++ llvm/test/CodeGen/AArch64/recp-fastmath.ll | 34 +++---- llvm/test/CodeGen/AArch64/sqrt-fastmath.ll | 93 +++++++++++-------- 5 files changed, 149 insertions(+), 59 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 803481e47718..403021e87d36 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -959,8 +959,10 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { case AArch64ISD::ST4LANEpost: return "AArch64ISD::ST4LANEpost"; case AArch64ISD::SMULL: return "AArch64ISD::SMULL"; case AArch64ISD::UMULL: return "AArch64ISD::UMULL"; - case AArch64ISD::FRSQRTE: return "AArch64ISD::FRSQRTE"; case AArch64ISD::FRECPE: return "AArch64ISD::FRECPE"; + case AArch64ISD::FRECPS: return "AArch64ISD::FRECPS"; + case AArch64ISD::FRSQRTE: return "AArch64ISD::FRSQRTE"; + case AArch64ISD::FRSQRTS: return "AArch64ISD::FRSQRTS"; } return nullptr; } @@ -4653,7 +4655,34 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, (Enabled == ReciprocalEstimate::Unspecified && Subtarget->useRSqrt())) if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRSQRTE, Operand, DAG, ExtraSteps)) { - UseOneConst = true; + SDLoc DL(Operand); + EVT VT = Operand.getValueType(); + + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); + + // Newton reciprocal square root iteration: E * 0.5 * (3 - X * E^2) + // AArch64 reciprocal square root iteration instruction: 0.5 * (3 - M * N) + for (int i = ExtraSteps; i > 0; --i) { + SDValue Step = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Estimate, + &Flags); + Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, &Flags); + Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, &Flags); + } + + if (!Reciprocal) { + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), + VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ); + + Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, &Flags); + // Correct the result if the operand is 0.0. + Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, + VT, Eq, Operand, Estimate); + } + + ExtraSteps = 0; return Estimate; } @@ -4665,8 +4694,24 @@ SDValue AArch64TargetLowering::getRecipEstimate(SDValue Operand, int &ExtraSteps) const { if (Enabled == ReciprocalEstimate::Enabled) if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRECPE, Operand, - DAG, ExtraSteps)) + DAG, ExtraSteps)) { + SDLoc DL(Operand); + EVT VT = Operand.getValueType(); + + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); + + // Newton reciprocal iteration: E * (2 - X * E) + // AArch64 reciprocal iteration instruction: (2 - M * N) + for (int i = ExtraSteps; i > 0; --i) { + SDValue Step = DAG.getNode(AArch64ISD::FRECPS, DL, VT, Operand, + Estimate, &Flags); + Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, &Flags); + } + + ExtraSteps = 0; return Estimate; + } return SDValue(); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 7b317d6ff5cc..7867d7c2b427 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -187,9 +187,9 @@ enum NodeType : unsigned { SMULL, UMULL, - // Reciprocal estimates. - FRECPE, - FRSQRTE, + // Reciprocal estimates and steps. + FRECPE, FRECPS, + FRSQRTE, FRSQRTS, // NEON Load/Store with post-increment base updates LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index e8386a65325e..3bed50016b40 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -287,7 +287,9 @@ def AArch64smull : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull>; def AArch64umull : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull>; def AArch64frecpe : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>; +def AArch64frecps : SDNode<"AArch64ISD::FRECPS", SDTFPBinOp>; def AArch64frsqrte : SDNode<"AArch64ISD::FRSQRTE", SDTFPUnaryOp>; +def AArch64frsqrts : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>; def AArch64saddv : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>; def AArch64uaddv : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>; @@ -3422,6 +3424,17 @@ def : Pat<(v1f64 (AArch64frecpe (v1f64 FPR64:$Rn))), def : Pat<(v2f64 (AArch64frecpe (v2f64 FPR128:$Rn))), (FRECPEv2f64 FPR128:$Rn)>; +def : Pat<(f32 (AArch64frecps (f32 FPR32:$Rn), (f32 FPR32:$Rm))), + (FRECPS32 FPR32:$Rn, FPR32:$Rm)>; +def : Pat<(v2f32 (AArch64frecps (v2f32 V64:$Rn), (v2f32 V64:$Rm))), + (FRECPSv2f32 V64:$Rn, V64:$Rm)>; +def : Pat<(v4f32 (AArch64frecps (v4f32 FPR128:$Rn), (v4f32 FPR128:$Rm))), + (FRECPSv4f32 FPR128:$Rn, FPR128:$Rm)>; +def : Pat<(f64 (AArch64frecps (f64 FPR64:$Rn), (f64 FPR64:$Rm))), + (FRECPS64 FPR64:$Rn, FPR64:$Rm)>; +def : Pat<(v2f64 (AArch64frecps (v2f64 FPR128:$Rn), (v2f64 FPR128:$Rm))), + (FRECPSv2f64 FPR128:$Rn, FPR128:$Rm)>; + def : Pat<(f32 (int_aarch64_neon_frecpx (f32 FPR32:$Rn))), (FRECPXv1i32 FPR32:$Rn)>; def : Pat<(f64 (int_aarch64_neon_frecpx (f64 FPR64:$Rn))), @@ -3447,6 +3460,17 @@ def : Pat<(v1f64 (AArch64frsqrte (v1f64 FPR64:$Rn))), def : Pat<(v2f64 (AArch64frsqrte (v2f64 FPR128:$Rn))), (FRSQRTEv2f64 FPR128:$Rn)>; +def : Pat<(f32 (AArch64frsqrts (f32 FPR32:$Rn), (f32 FPR32:$Rm))), + (FRSQRTS32 FPR32:$Rn, FPR32:$Rm)>; +def : Pat<(v2f32 (AArch64frsqrts (v2f32 V64:$Rn), (v2f32 V64:$Rm))), + (FRSQRTSv2f32 V64:$Rn, V64:$Rm)>; +def : Pat<(v4f32 (AArch64frsqrts (v4f32 FPR128:$Rn), (v4f32 FPR128:$Rm))), + (FRSQRTSv4f32 FPR128:$Rn, FPR128:$Rm)>; +def : Pat<(f64 (AArch64frsqrts (f64 FPR64:$Rn), (f64 FPR64:$Rm))), + (FRSQRTS64 FPR64:$Rn, FPR64:$Rm)>; +def : Pat<(v2f64 (AArch64frsqrts (v2f64 FPR128:$Rn), (v2f64 FPR128:$Rm))), + (FRSQRTSv2f64 FPR128:$Rn, FPR128:$Rm)>; + // If an integer is about to be converted to a floating point value, // just load it on the floating point unit. // Here are the patterns for 8 and 16-bits to float. diff --git a/llvm/test/CodeGen/AArch64/recp-fastmath.ll b/llvm/test/CodeGen/AArch64/recp-fastmath.ll index 280ef75b8918..38e0fb360e49 100644 --- a/llvm/test/CodeGen/AArch64/recp-fastmath.ll +++ b/llvm/test/CodeGen/AArch64/recp-fastmath.ll @@ -16,8 +16,8 @@ define float @frecp1(float %x) #1 { ; CHECK-LABEL: frecp1: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: frecpe -; CHECK-NEXT: fmov +; CHECK-NEXT: frecpe [[R:s[0-7]]] +; CHECK-NEXT: frecps {{s[0-7](, s[0-7])?}}, [[R]] } define <2 x float> @f2recp0(<2 x float> %x) #0 { @@ -36,8 +36,8 @@ define <2 x float> @f2recp1(<2 x float> %x) #1 { ; CHECK-LABEL: f2recp1: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frecpe +; CHECK-NEXT: frecpe [[R:v[0-7]\.2s]] +; CHECK-NEXT: frecps {{v[0-7]\.2s(, v[0-7].2s)?}}, [[R]] } define <4 x float> @f4recp0(<4 x float> %x) #0 { @@ -56,8 +56,8 @@ define <4 x float> @f4recp1(<4 x float> %x) #1 { ; CHECK-LABEL: f4recp1: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frecpe +; CHECK-NEXT: frecpe [[R:v[0-7]\.4s]] +; CHECK-NEXT: frecps {{v[0-7]\.4s(, v[0-7].4s)?}}, [[R]] } define <8 x float> @f8recp0(<8 x float> %x) #0 { @@ -77,9 +77,10 @@ define <8 x float> @f8recp1(<8 x float> %x) #1 { ; CHECK-LABEL: f8recp1: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frecpe -; CHECK: frecpe +; CHECK-NEXT: frecpe [[RA:v[0-7]\.4s]] +; CHECK-NEXT: frecpe [[RB:v[0-7]\.4s]] +; CHECK-NEXT: frecps {{v[0-7]\.4s(, v[0-7].4s)?}}, [[RA]] +; CHECK: frecps {{v[0-7]\.4s(, v[0-7].4s)?}}, [[RB]] } define double @drecp0(double %x) #0 { @@ -98,8 +99,8 @@ define double @drecp1(double %x) #1 { ; CHECK-LABEL: drecp1: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: frecpe -; CHECK-NEXT: fmov +; CHECK-NEXT: frecpe [[R:d[0-7]]] +; CHECK-NEXT: frecps {{d[0-7](, d[0-7])?}}, [[R]] } define <2 x double> @d2recp0(<2 x double> %x) #0 { @@ -118,8 +119,8 @@ define <2 x double> @d2recp1(<2 x double> %x) #1 { ; CHECK-LABEL: d2recp1: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frecpe +; CHECK-NEXT: frecpe [[R:v[0-7]\.2d]] +; CHECK-NEXT: frecps {{v[0-7]\.2d(, v[0-7].2d)?}}, [[R]] } define <4 x double> @d4recp0(<4 x double> %x) #0 { @@ -139,9 +140,10 @@ define <4 x double> @d4recp1(<4 x double> %x) #1 { ; CHECK-LABEL: d4recp1: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frecpe -; CHECK: frecpe +; CHECK-NEXT: frecpe [[RA:v[0-7]\.2d]] +; CHECK-NEXT: frecpe [[RB:v[0-7]\.2d]] +; CHECK-NEXT: frecps {{v[0-7]\.2d(, v[0-7].2d)?}}, [[RA]] +; CHECK: frecps {{v[0-7]\.2d(, v[0-7].2d)?}}, [[RB]] } attributes #0 = { nounwind "unsafe-fp-math"="true" } diff --git a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll index cd8200bdfd23..079562c05819 100644 --- a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll +++ b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll @@ -19,8 +19,10 @@ define float @fsqrt(float %a) #0 { ; CHECK-LABEL: fsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:s[0-7]]] +; CHECK-NEXT: fmul [[RB:s[0-7]]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{s[0-7](, s[0-7])?}}, [[RB]] +; CHECK: fcmp s0, #0 } define <2 x float> @f2sqrt(<2 x float> %a) #0 { @@ -33,9 +35,10 @@ define <2 x float> @f2sqrt(<2 x float> %a) #0 { ; CHECK-LABEL: f2sqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: mov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2s]] +; CHECK-NEXT: fmul [[RB:v[0-7]\.2s]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{v[0-7]\.2s(, v[0-7]\.2s)?}}, [[RB]] +; CHECK: fcmeq {{v[0-7]\.2s, v0\.2s}}, #0 } define <4 x float> @f4sqrt(<4 x float> %a) #0 { @@ -48,9 +51,10 @@ define <4 x float> @f4sqrt(<4 x float> %a) #0 { ; CHECK-LABEL: f4sqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: mov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]] +; CHECK-NEXT: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]] +; CHECK: fcmeq {{v[0-7]\.4s, v0\.4s}}, #0 } define <8 x float> @f8sqrt(<8 x float> %a) #0 { @@ -64,10 +68,10 @@ define <8 x float> @f8sqrt(<8 x float> %a) #0 { ; CHECK-LABEL: f8sqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: mov -; CHECK-NEXT: frsqrte -; CHECK: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]] +; CHECK: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]] +; CHECK: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]] +; CHECK: fcmeq {{v[0-7]\.4s, v[0-1]\.4s}}, #0 } define double @dsqrt(double %a) #0 { @@ -80,8 +84,10 @@ define double @dsqrt(double %a) #0 { ; CHECK-LABEL: dsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:d[0-7]]] +; CHECK-NEXT: fmul [[RB:d[0-7]]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{d[0-7](, d[0-7])?}}, [[RB]] +; CHECK: fcmp d0, #0 } define <2 x double> @d2sqrt(<2 x double> %a) #0 { @@ -94,9 +100,10 @@ define <2 x double> @d2sqrt(<2 x double> %a) #0 { ; CHECK-LABEL: d2sqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: mov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]] +; CHECK-NEXT: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]] +; CHECK: fcmeq {{v[0-7]\.2d, v0\.2d}}, #0 } define <4 x double> @d4sqrt(<4 x double> %a) #0 { @@ -110,10 +117,10 @@ define <4 x double> @d4sqrt(<4 x double> %a) #0 { ; CHECK-LABEL: d4sqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: mov -; CHECK-NEXT: frsqrte -; CHECK: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]] +; CHECK: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]] +; CHECK: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]] +; CHECK: fcmeq {{v[0-7]\.2d, v[0-1]\.2d}}, #0 } define float @frsqrt(float %a) #0 { @@ -127,8 +134,10 @@ define float @frsqrt(float %a) #0 { ; CHECK-LABEL: frsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:s[0-7]]] +; CHECK-NEXT: fmul [[RB:s[0-7]]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{s[0-7](, s[0-7])?}}, [[RB]] +; CHECK-NOT: fcmp {{s[0-7]}}, #0 } define <2 x float> @f2rsqrt(<2 x float> %a) #0 { @@ -142,8 +151,10 @@ define <2 x float> @f2rsqrt(<2 x float> %a) #0 { ; CHECK-LABEL: f2rsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2s]] +; CHECK-NEXT: fmul [[RB:v[0-7]\.2s]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{v[0-7]\.2s(, v[0-7]\.2s)?}}, [[RB]] +; CHECK-NOT: fcmeq {{v[0-7]\.2s, v0\.2s}}, #0 } define <4 x float> @f4rsqrt(<4 x float> %a) #0 { @@ -157,8 +168,10 @@ define <4 x float> @f4rsqrt(<4 x float> %a) #0 { ; CHECK-LABEL: f4rsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]] +; CHECK-NEXT: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]] +; CHECK-NOT: fcmeq {{v[0-7]\.4s, v0\.4s}}, #0 } define <8 x float> @f8rsqrt(<8 x float> %a) #0 { @@ -173,9 +186,10 @@ define <8 x float> @f8rsqrt(<8 x float> %a) #0 { ; CHECK-LABEL: f8rsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte -; CHECK: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]] +; CHECK: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]] +; CHECK: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]] +; CHECK-NOT: fcmeq {{v[0-7]\.4s, v0\.4s}}, #0 } define double @drsqrt(double %a) #0 { @@ -189,8 +203,10 @@ define double @drsqrt(double %a) #0 { ; CHECK-LABEL: drsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:d[0-7]]] +; CHECK-NEXT: fmul [[RB:d[0-7]]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{d[0-7](, d[0-7])?}}, [[RB]] +; CHECK-NOT: fcmp d0, #0 } define <2 x double> @d2rsqrt(<2 x double> %a) #0 { @@ -204,8 +220,10 @@ define <2 x double> @d2rsqrt(<2 x double> %a) #0 { ; CHECK-LABEL: d2rsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]] +; CHECK-NEXT: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]] +; CHECK-NEXT: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]] +; CHECK-NOT: fcmeq {{v[0-7]\.2d, v0\.2d}}, #0 } define <4 x double> @d4rsqrt(<4 x double> %a) #0 { @@ -220,9 +238,10 @@ define <4 x double> @d4rsqrt(<4 x double> %a) #0 { ; CHECK-LABEL: d4rsqrt: ; CHECK-NEXT: BB#0 -; CHECK-NEXT: fmov -; CHECK-NEXT: frsqrte -; CHECK: frsqrte +; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]] +; CHECK: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]] +; CHECK: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]] +; CHECK-NOT: fcmeq {{v[0-7]\.2d, v0\.2d}}, #0 } attributes #0 = { nounwind "unsafe-fp-math"="true" }