diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 2e955f89049c..cae94d497596 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -546,13 +546,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // These map to conversion instructions for scalar FP types. for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, - ISD::FROUND, ISD::FTRUNC}) { + ISD::FTRUNC}) { setOperationAction(Op, MVT::f16, Legal); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); } + setOperationAction(ISD::FROUND, MVT::f16, Promote); + setOperationAction(ISD::FROUND, MVT::v2f16, Expand); + setOperationAction(ISD::FROUND, MVT::f32, Custom); + setOperationAction(ISD::FROUND, MVT::f64, Custom); + + // 'Expand' implements FCOPYSIGN without calling an external library. setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand); @@ -2068,6 +2074,100 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op, } } +SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + if (VT == MVT::f32) + return LowerFROUND32(Op, DAG); + + if (VT == MVT::f64) + return LowerFROUND64(Op, DAG); + + llvm_unreachable("unhandled type"); +} + +// This is the the rounding method used in CUDA libdevice in C like code: +// float roundf(float A) +// { +// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f)); +// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA; +// return abs(A) < 0.5 ? (float)(int)A : RoundedA; +// } +SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op, + SelectionDAG &DAG) const { + SDLoc SL(Op); + SDValue A = Op.getOperand(0); + EVT VT = Op.getValueType(); + + SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A); + + // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f)) + SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A); + const int SignBitMask = 0x80000000; + SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast, + DAG.getConstant(SignBitMask, SL, MVT::i32)); + const int PointFiveInBits = 0x3F000000; + SDValue PointFiveWithSignRaw = + DAG.getNode(ISD::OR, SL, MVT::i32, Sign, + DAG.getConstant(PointFiveInBits, SL, MVT::i32)); + SDValue PointFiveWithSign = + DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw); + SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign); + SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA); + + // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA; + EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue IsLarge = + DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT), + ISD::SETOGT); + RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA); + + // return abs(A) < 0.5 ? (float)(int)A : RoundedA; + SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA, + DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT); + SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A); + return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA); +} + +// The implementation of round(double) is similar to that of round(float) in +// that they both separate the value range into three regions and use a method +// specific to the region to round the values. However, round(double) first +// calculates the round of the absolute value and then adds the sign back while +// round(float) directly rounds the value with sign. +SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op, + SelectionDAG &DAG) const { + SDLoc SL(Op); + SDValue A = Op.getOperand(0); + EVT VT = Op.getValueType(); + + SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A); + + // double RoundedA = (double) (int) (abs(A) + 0.5f); + SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA, + DAG.getConstantFP(0.5, SL, VT)); + SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA); + + // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA; + EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA, + DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT); + RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall, + DAG.getConstantFP(0, SL, VT), + RoundedA); + + // Add sign to rounded_A + RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A); + DAG.getNode(ISD::FTRUNC, SL, VT, A); + + // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA; + SDValue IsLarge = + DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT), + ISD::SETOGT); + return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA); +} + + + SDValue NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -2098,6 +2198,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return LowerShiftRightParts(Op, DAG); case ISD::SELECT: return LowerSelect(Op, DAG); + case ISD::FROUND: + return LowerFROUND(Op, DAG); default: llvm_unreachable("Custom lowering not defined for operation"); } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index bbcc35f49d99..ef645fc1e541 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -556,6 +556,10 @@ private: SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 1bf556c9287a..2ee90abb4110 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -3002,15 +3002,6 @@ def : Pat<(ffloor Float32Regs:$a), def : Pat<(ffloor Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; -def : Pat<(f16 (fround Float16Regs:$a)), - (CVT_f16_f16 Float16Regs:$a, CvtRNI)>; -def : Pat<(fround Float32Regs:$a), - (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; -def : Pat<(f32 (fround Float32Regs:$a)), - (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; -def : Pat<(f64 (fround Float64Regs:$a)), - (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; - def : Pat<(ftrunc Float16Regs:$a), (CVT_f16_f16 Float16Regs:$a, CvtRZI)>; def : Pat<(ftrunc Float32Regs:$a), diff --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll index 7788adc86989..9aa81dac1262 100644 --- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll @@ -1107,9 +1107,11 @@ define half @test_nearbyint(half %a) #0 { } ; CHECK-LABEL: test_round( -; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_round_param_0]; -; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]]; -; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ld.param.b16 {{.*}}, [test_round_param_0]; +; check the use of sign mask and 0.5 to implement round +; CHECK: and.b32 [[R:%r[0-9]+]], {{.*}}, -2147483648; +; CHECK: or.b32 {{.*}}, [[R]], 1056964608; +; CHECK: st.param.b16 [func_retval0+0], {{.*}}; ; CHECK: ret; define half @test_round(half %a) #0 { %r = call half @llvm.round.f16(half %a) diff --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll index a8996815af41..44dda09a902d 100644 --- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll @@ -1378,12 +1378,13 @@ define <2 x half> @test_nearbyint(<2 x half> %a) #0 { } ; CHECK-LABEL: test_round( -; CHECK: ld.param.b32 [[A:%hh[0-9]+]], [test_round_param_0]; -; CHECK-DAG: mov.b32 {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]]; -; CHECK-DAG: cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]]; -; CHECK-DAG: cvt.rni.f16.f16 [[R0:%h[0-9]+]], [[A0]]; -; CHECK: mov.b32 [[R:%hh[0-9]+]], {[[R0]], [[R1]]} -; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ld.param.b32 {{.*}}, [test_round_param_0]; +; check the use of sign mask and 0.5 to implement round +; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648; +; CHECK: or.b32 {{.*}}, [[R1]], 1056964608; +; CHECK: and.b32 [[R2:%r[0-9]+]], {{.*}}, -2147483648; +; CHECK: or.b32 {{.*}}, [[R2]], 1056964608; +; CHECK: st.param.b32 [func_retval0+0], {{.*}}; ; CHECK: ret; define <2 x half> @test_round(<2 x half> %a) #0 { %r = call <2 x half> @llvm.round.f16(<2 x half> %a) diff --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll index 828a8807dcfa..412b25c7a3be 100644 --- a/llvm/test/CodeGen/NVPTX/math-intrins.ll +++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll @@ -74,21 +74,27 @@ define double @floor_double(double %a) { ; CHECK-LABEL: round_float define float @round_float(float %a) { - ; CHECK: cvt.rni.f32.f32 +; check the use of sign mask and 0.5 to implement round +; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648; +; CHECK: or.b32 {{.*}}, [[R1]], 1056964608; %b = call float @llvm.round.f32(float %a) ret float %b } ; CHECK-LABEL: round_float_ftz define float @round_float_ftz(float %a) #1 { - ; CHECK: cvt.rni.ftz.f32.f32 +; check the use of sign mask and 0.5 to implement round +; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648; +; CHECK: or.b32 {{.*}}, [[R1]], 1056964608; %b = call float @llvm.round.f32(float %a) ret float %b } ; CHECK-LABEL: round_double define double @round_double(double %a) { - ; CHECK: cvt.rni.f64.f64 +; check the use of 0.5 to implement round +; CHECK: setp.lt.f64 {{.*}}, [[R:%fd[0-9]+]], 0d3FE0000000000000; +; CHECK: add.rn.f64 {{.*}}, [[R]], 0d3FE0000000000000; %b = call double @llvm.round.f64(double %a) ret double %b }