forked from OSchip/llvm-project
[NVPTX] Fix the codegen for llvm.round.
Summary: Previously, we translate llvm.round to PTX cvt.rni, which rounds to the even interger when the source is equidistant between two integers. This is not correct as llvm.round should round away from zero. This change replaces llvm.round with a round away from zero implementation through target specific custom lowering. Modify a few affected tests to not check for cvt.rni. Instead, we check for the use of a few constants used in implementing round. We are also adding CUDA runnable tests to check for the values produced by llvm.round to test-suites/External/CUDA. Reviewers: tra Subscribers: jholewinski, sanjoy, jlebar, hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D59947 llvm-svn: 357407
This commit is contained in:
parent
d109e2a7c3
commit
6c21ccd245
|
@ -546,13 +546,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
|||
|
||||
// These map to conversion instructions for scalar FP types.
|
||||
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
|
||||
ISD::FROUND, ISD::FTRUNC}) {
|
||||
ISD::FTRUNC}) {
|
||||
setOperationAction(Op, MVT::f16, Legal);
|
||||
setOperationAction(Op, MVT::f32, Legal);
|
||||
setOperationAction(Op, MVT::f64, Legal);
|
||||
setOperationAction(Op, MVT::v2f16, Expand);
|
||||
}
|
||||
|
||||
setOperationAction(ISD::FROUND, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
|
||||
setOperationAction(ISD::FROUND, MVT::f32, Custom);
|
||||
setOperationAction(ISD::FROUND, MVT::f64, Custom);
|
||||
|
||||
|
||||
// 'Expand' implements FCOPYSIGN without calling an external library.
|
||||
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
|
||||
setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
|
||||
|
@ -2068,6 +2074,100 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
|
|||
}
|
||||
}
|
||||
|
||||
SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
|
||||
EVT VT = Op.getValueType();
|
||||
|
||||
if (VT == MVT::f32)
|
||||
return LowerFROUND32(Op, DAG);
|
||||
|
||||
if (VT == MVT::f64)
|
||||
return LowerFROUND64(Op, DAG);
|
||||
|
||||
llvm_unreachable("unhandled type");
|
||||
}
|
||||
|
||||
// This is the the rounding method used in CUDA libdevice in C like code:
|
||||
// float roundf(float A)
|
||||
// {
|
||||
// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
|
||||
// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
|
||||
// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
|
||||
// }
|
||||
SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
SDLoc SL(Op);
|
||||
SDValue A = Op.getOperand(0);
|
||||
EVT VT = Op.getValueType();
|
||||
|
||||
SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
|
||||
|
||||
// RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
|
||||
SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
|
||||
const int SignBitMask = 0x80000000;
|
||||
SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
|
||||
DAG.getConstant(SignBitMask, SL, MVT::i32));
|
||||
const int PointFiveInBits = 0x3F000000;
|
||||
SDValue PointFiveWithSignRaw =
|
||||
DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
|
||||
DAG.getConstant(PointFiveInBits, SL, MVT::i32));
|
||||
SDValue PointFiveWithSign =
|
||||
DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
|
||||
SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
|
||||
SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
|
||||
|
||||
// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
|
||||
EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
|
||||
SDValue IsLarge =
|
||||
DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
|
||||
ISD::SETOGT);
|
||||
RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
|
||||
|
||||
// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
|
||||
SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
|
||||
DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
|
||||
SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
|
||||
return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
|
||||
}
|
||||
|
||||
// The implementation of round(double) is similar to that of round(float) in
|
||||
// that they both separate the value range into three regions and use a method
|
||||
// specific to the region to round the values. However, round(double) first
|
||||
// calculates the round of the absolute value and then adds the sign back while
|
||||
// round(float) directly rounds the value with sign.
|
||||
SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
|
||||
SelectionDAG &DAG) const {
|
||||
SDLoc SL(Op);
|
||||
SDValue A = Op.getOperand(0);
|
||||
EVT VT = Op.getValueType();
|
||||
|
||||
SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
|
||||
|
||||
// double RoundedA = (double) (int) (abs(A) + 0.5f);
|
||||
SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
|
||||
DAG.getConstantFP(0.5, SL, VT));
|
||||
SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
|
||||
|
||||
// RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
|
||||
EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
|
||||
SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
|
||||
DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
|
||||
RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
|
||||
DAG.getConstantFP(0, SL, VT),
|
||||
RoundedA);
|
||||
|
||||
// Add sign to rounded_A
|
||||
RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
|
||||
DAG.getNode(ISD::FTRUNC, SL, VT, A);
|
||||
|
||||
// RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
|
||||
SDValue IsLarge =
|
||||
DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
|
||||
ISD::SETOGT);
|
||||
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
|
||||
}
|
||||
|
||||
|
||||
|
||||
SDValue
|
||||
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
|
||||
switch (Op.getOpcode()) {
|
||||
|
@ -2098,6 +2198,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
|
|||
return LowerShiftRightParts(Op, DAG);
|
||||
case ISD::SELECT:
|
||||
return LowerSelect(Op, DAG);
|
||||
case ISD::FROUND:
|
||||
return LowerFROUND(Op, DAG);
|
||||
default:
|
||||
llvm_unreachable("Custom lowering not defined for operation");
|
||||
}
|
||||
|
|
|
@ -556,6 +556,10 @@ private:
|
|||
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
|
||||
|
||||
SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
|
||||
|
||||
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
|
||||
SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
|
||||
|
||||
|
|
|
@ -3002,15 +3002,6 @@ def : Pat<(ffloor Float32Regs:$a),
|
|||
def : Pat<(ffloor Float64Regs:$a),
|
||||
(CVT_f64_f64 Float64Regs:$a, CvtRMI)>;
|
||||
|
||||
def : Pat<(f16 (fround Float16Regs:$a)),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRNI)>;
|
||||
def : Pat<(fround Float32Regs:$a),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
|
||||
def : Pat<(f32 (fround Float32Regs:$a)),
|
||||
(CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
|
||||
def : Pat<(f64 (fround Float64Regs:$a)),
|
||||
(CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
|
||||
|
||||
def : Pat<(ftrunc Float16Regs:$a),
|
||||
(CVT_f16_f16 Float16Regs:$a, CvtRZI)>;
|
||||
def : Pat<(ftrunc Float32Regs:$a),
|
||||
|
|
|
@ -1107,9 +1107,11 @@ define half @test_nearbyint(half %a) #0 {
|
|||
}
|
||||
|
||||
; CHECK-LABEL: test_round(
|
||||
; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_round_param_0];
|
||||
; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]];
|
||||
; CHECK: st.param.b16 [func_retval0+0], [[R]];
|
||||
; CHECK: ld.param.b16 {{.*}}, [test_round_param_0];
|
||||
; check the use of sign mask and 0.5 to implement round
|
||||
; CHECK: and.b32 [[R:%r[0-9]+]], {{.*}}, -2147483648;
|
||||
; CHECK: or.b32 {{.*}}, [[R]], 1056964608;
|
||||
; CHECK: st.param.b16 [func_retval0+0], {{.*}};
|
||||
; CHECK: ret;
|
||||
define half @test_round(half %a) #0 {
|
||||
%r = call half @llvm.round.f16(half %a)
|
||||
|
|
|
@ -1378,12 +1378,13 @@ define <2 x half> @test_nearbyint(<2 x half> %a) #0 {
|
|||
}
|
||||
|
||||
; CHECK-LABEL: test_round(
|
||||
; CHECK: ld.param.b32 [[A:%hh[0-9]+]], [test_round_param_0];
|
||||
; CHECK-DAG: mov.b32 {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]];
|
||||
; CHECK-DAG: cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]];
|
||||
; CHECK-DAG: cvt.rni.f16.f16 [[R0:%h[0-9]+]], [[A0]];
|
||||
; CHECK: mov.b32 [[R:%hh[0-9]+]], {[[R0]], [[R1]]}
|
||||
; CHECK: st.param.b32 [func_retval0+0], [[R]];
|
||||
; CHECK: ld.param.b32 {{.*}}, [test_round_param_0];
|
||||
; check the use of sign mask and 0.5 to implement round
|
||||
; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
|
||||
; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
|
||||
; CHECK: and.b32 [[R2:%r[0-9]+]], {{.*}}, -2147483648;
|
||||
; CHECK: or.b32 {{.*}}, [[R2]], 1056964608;
|
||||
; CHECK: st.param.b32 [func_retval0+0], {{.*}};
|
||||
; CHECK: ret;
|
||||
define <2 x half> @test_round(<2 x half> %a) #0 {
|
||||
%r = call <2 x half> @llvm.round.f16(<2 x half> %a)
|
||||
|
|
|
@ -74,21 +74,27 @@ define double @floor_double(double %a) {
|
|||
|
||||
; CHECK-LABEL: round_float
|
||||
define float @round_float(float %a) {
|
||||
; CHECK: cvt.rni.f32.f32
|
||||
; check the use of sign mask and 0.5 to implement round
|
||||
; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
|
||||
; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
|
||||
%b = call float @llvm.round.f32(float %a)
|
||||
ret float %b
|
||||
}
|
||||
|
||||
; CHECK-LABEL: round_float_ftz
|
||||
define float @round_float_ftz(float %a) #1 {
|
||||
; CHECK: cvt.rni.ftz.f32.f32
|
||||
; check the use of sign mask and 0.5 to implement round
|
||||
; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
|
||||
; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
|
||||
%b = call float @llvm.round.f32(float %a)
|
||||
ret float %b
|
||||
}
|
||||
|
||||
; CHECK-LABEL: round_double
|
||||
define double @round_double(double %a) {
|
||||
; CHECK: cvt.rni.f64.f64
|
||||
; check the use of 0.5 to implement round
|
||||
; CHECK: setp.lt.f64 {{.*}}, [[R:%fd[0-9]+]], 0d3FE0000000000000;
|
||||
; CHECK: add.rn.f64 {{.*}}, [[R]], 0d3FE0000000000000;
|
||||
%b = call double @llvm.round.f64(double %a)
|
||||
ret double %b
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue