From 36658376d5d4103b3828c726f211030ebc4f84b6 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Fri, 12 Feb 2021 15:10:18 -0800 Subject: [PATCH] [RISCV] Add support for fixed vector sqrt. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 ++ llvm/lib/Target/RISCV/RISCVISelLowering.h | 1 + .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 55 +++++++++++-------- .../CodeGen/RISCV/rvv/fixed-vectors-fp.ll | 48 ++++++++++++++++ 4 files changed, 84 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 06bdbfd75f19..1ce754c37214 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -577,6 +577,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FMUL, VT, Custom); setOperationAction(ISD::FDIV, VT, Custom); setOperationAction(ISD::FNEG, VT, Custom); + setOperationAction(ISD::FSQRT, VT, Custom); setOperationAction(ISD::FMA, VT, Custom); } } @@ -1209,6 +1210,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL); case ISD::FNEG: return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL); + case ISD::FSQRT: + return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL); case ISD::FMA: return lowerToScalableOp(Op, DAG, RISCVISD::FMA_VL); case ISD::SMIN: @@ -4739,6 +4742,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(FMUL_VL) NODE_NAME_CASE(FDIV_VL) NODE_NAME_CASE(FNEG_VL) + NODE_NAME_CASE(FSQRT_VL) NODE_NAME_CASE(FMA_VL) NODE_NAME_CASE(SMIN_VL) NODE_NAME_CASE(SMAX_VL) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 78177c8451cb..edb14c60bf9a 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -162,6 +162,7 @@ enum NodeType : unsigned { FMUL_VL, FDIV_VL, FNEG_VL, + FSQRT_VL, FMA_VL, SMIN_VL, SMAX_VL, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index a9cb535d8901..bc45922b89eb 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -51,28 +51,29 @@ def riscv_vle_vl : SDNode<"RISCVISD::VLE_VL", SDT_RISCVVLE_VL, def riscv_vse_vl : SDNode<"RISCVISD::VSE_VL", SDT_RISCVVSE_VL, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; -def riscv_add_vl : SDNode<"RISCVISD::ADD_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; -def riscv_sub_vl : SDNode<"RISCVISD::SUB_VL", SDT_RISCVIntBinOp_VL>; -def riscv_mul_vl : SDNode<"RISCVISD::MUL_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; -def riscv_and_vl : SDNode<"RISCVISD::AND_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; -def riscv_or_vl : SDNode<"RISCVISD::OR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; -def riscv_xor_vl : SDNode<"RISCVISD::XOR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; -def riscv_sdiv_vl : SDNode<"RISCVISD::SDIV_VL", SDT_RISCVIntBinOp_VL>; -def riscv_srem_vl : SDNode<"RISCVISD::SREM_VL", SDT_RISCVIntBinOp_VL>; -def riscv_udiv_vl : SDNode<"RISCVISD::UDIV_VL", SDT_RISCVIntBinOp_VL>; -def riscv_urem_vl : SDNode<"RISCVISD::UREM_VL", SDT_RISCVIntBinOp_VL>; -def riscv_shl_vl : SDNode<"RISCVISD::SHL_VL", SDT_RISCVIntBinOp_VL>; -def riscv_sra_vl : SDNode<"RISCVISD::SRA_VL", SDT_RISCVIntBinOp_VL>; -def riscv_srl_vl : SDNode<"RISCVISD::SRL_VL", SDT_RISCVIntBinOp_VL>; -def riscv_smin_vl : SDNode<"RISCVISD::SMIN_VL", SDT_RISCVIntBinOp_VL>; -def riscv_smax_vl : SDNode<"RISCVISD::SMAX_VL", SDT_RISCVIntBinOp_VL>; -def riscv_umin_vl : SDNode<"RISCVISD::UMIN_VL", SDT_RISCVIntBinOp_VL>; -def riscv_umax_vl : SDNode<"RISCVISD::UMAX_VL", SDT_RISCVIntBinOp_VL>; -def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>; -def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>; -def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>; -def riscv_fdiv_vl : SDNode<"RISCVISD::FDIV_VL", SDT_RISCVFPBinOp_VL>; -def riscv_fneg_vl : SDNode<"RISCVISD::FNEG_VL", SDT_RISCVFPUnOp_VL>; +def riscv_add_vl : SDNode<"RISCVISD::ADD_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; +def riscv_sub_vl : SDNode<"RISCVISD::SUB_VL", SDT_RISCVIntBinOp_VL>; +def riscv_mul_vl : SDNode<"RISCVISD::MUL_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; +def riscv_and_vl : SDNode<"RISCVISD::AND_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; +def riscv_or_vl : SDNode<"RISCVISD::OR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; +def riscv_xor_vl : SDNode<"RISCVISD::XOR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>; +def riscv_sdiv_vl : SDNode<"RISCVISD::SDIV_VL", SDT_RISCVIntBinOp_VL>; +def riscv_srem_vl : SDNode<"RISCVISD::SREM_VL", SDT_RISCVIntBinOp_VL>; +def riscv_udiv_vl : SDNode<"RISCVISD::UDIV_VL", SDT_RISCVIntBinOp_VL>; +def riscv_urem_vl : SDNode<"RISCVISD::UREM_VL", SDT_RISCVIntBinOp_VL>; +def riscv_shl_vl : SDNode<"RISCVISD::SHL_VL", SDT_RISCVIntBinOp_VL>; +def riscv_sra_vl : SDNode<"RISCVISD::SRA_VL", SDT_RISCVIntBinOp_VL>; +def riscv_srl_vl : SDNode<"RISCVISD::SRL_VL", SDT_RISCVIntBinOp_VL>; +def riscv_smin_vl : SDNode<"RISCVISD::SMIN_VL", SDT_RISCVIntBinOp_VL>; +def riscv_smax_vl : SDNode<"RISCVISD::SMAX_VL", SDT_RISCVIntBinOp_VL>; +def riscv_umin_vl : SDNode<"RISCVISD::UMIN_VL", SDT_RISCVIntBinOp_VL>; +def riscv_umax_vl : SDNode<"RISCVISD::UMAX_VL", SDT_RISCVIntBinOp_VL>; +def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>; +def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>; +def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>; +def riscv_fdiv_vl : SDNode<"RISCVISD::FDIV_VL", SDT_RISCVFPBinOp_VL>; +def riscv_fneg_vl : SDNode<"RISCVISD::FNEG_VL", SDT_RISCVFPUnOp_VL>; +def riscv_fsqrt_vl : SDNode<"RISCVISD::FSQRT_VL", SDT_RISCVFPUnOp_VL>; def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, @@ -440,9 +441,15 @@ foreach vti = AllFloatVectors in { GPR:$vl, vti.SEW)>; } -// 14.12. Vector Floating-Point Sign-Injection Instructions -// Handle fneg with VFSGNJN using the same input for both operands. foreach vti = AllFloatVectors in { + // 14.8. Vector Floating-Point Square-Root Instruction + def : Pat<(riscv_fsqrt_vl (vti.Vector vti.RegClass:$rs2), (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl))), + (!cast("PseudoVFSQRT_V_"# vti.LMul.MX) + vti.RegClass:$rs2, GPR:$vl, vti.SEW)>; + + // 14.12. Vector Floating-Point Sign-Injection Instructions + // Handle fneg with VFSGNJN using the same input for both operands. def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask true_mask), (XLenVT (VLOp GPR:$vl))), (!cast("PseudoVFSGNJN_VV_"# vti.LMul.MX) diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll index 7407aa8aa5b7..2c54c4690b08 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll @@ -253,6 +253,54 @@ define void @fneg_v2f64(<2 x double>* %x) { ret void } +define void @sqrt_v8f16(<8 x half>* %x) { +; CHECK-LABEL: sqrt_v8f16: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a1, zero, 8 +; CHECK-NEXT: vsetvli a1, a1, e16,m1,ta,mu +; CHECK-NEXT: vle16.v v25, (a0) +; CHECK-NEXT: vfsqrt.v v25, v25 +; CHECK-NEXT: vse16.v v25, (a0) +; CHECK-NEXT: ret + %a = load <8 x half>, <8 x half>* %x + %b = call <8 x half> @llvm.sqrt.v8f16(<8 x half> %a) + store <8 x half> %b, <8 x half>* %x + ret void +} +declare <8 x half> @llvm.sqrt.v8f16(<8 x half>) + +define void @sqrt_v4f32(<4 x float>* %x) { +; CHECK-LABEL: sqrt_v4f32: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a1, zero, 4 +; CHECK-NEXT: vsetvli a1, a1, e32,m1,ta,mu +; CHECK-NEXT: vle32.v v25, (a0) +; CHECK-NEXT: vfsqrt.v v25, v25 +; CHECK-NEXT: vse32.v v25, (a0) +; CHECK-NEXT: ret + %a = load <4 x float>, <4 x float>* %x + %b = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %a) + store <4 x float> %b, <4 x float>* %x + ret void +} +declare <4 x float> @llvm.sqrt.v4f32(<4 x float>) + +define void @sqrt_v2f64(<2 x double>* %x) { +; CHECK-LABEL: sqrt_v2f64: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a1, zero, 2 +; CHECK-NEXT: vsetvli a1, a1, e64,m1,ta,mu +; CHECK-NEXT: vle64.v v25, (a0) +; CHECK-NEXT: vfsqrt.v v25, v25 +; CHECK-NEXT: vse64.v v25, (a0) +; CHECK-NEXT: ret + %a = load <2 x double>, <2 x double>* %x + %b = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %a) + store <2 x double> %b, <2 x double>* %x + ret void +} +declare <2 x double> @llvm.sqrt.v2f64(<2 x double>) + define void @fma_v8f16(<8 x half>* %x, <8 x half>* %y, <8 x half>* %z) { ; CHECK-LABEL: fma_v8f16: ; CHECK: # %bb.0: