forked from OSchip/llvm-project
[NVPTX] Lower fp16 fminnum, fmaxnum to native on sm_80.
Reviewed By: bkramer, tra Differential Revision: https://reviews.llvm.org/D117122
This commit is contained in:
parent
9c9119ab36
commit
cc1b9acf55
|
@ -560,10 +560,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
|||
setOperationAction(Op, MVT::f64, Legal);
|
||||
setOperationAction(Op, MVT::v2f16, Expand);
|
||||
}
|
||||
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMINIMUM, MVT::f16, Promote);
|
||||
setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote);
|
||||
// max.f16 is supported on sm_80+.
|
||||
if (STI.allowFP16Math() && STI.getSmVersion() >= 80 &&
|
||||
STI.getPTXVersion() >= 70) {
|
||||
setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
|
||||
setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
|
||||
setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal);
|
||||
setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal);
|
||||
}
|
||||
|
||||
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
|
||||
// No FPOW or FREM in PTX.
|
||||
|
|
|
@ -249,6 +249,32 @@ multiclass F3<string OpcStr, SDNode OpNode> {
|
|||
(ins Float32Regs:$a, f32imm:$b),
|
||||
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
|
||||
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
|
||||
|
||||
def f16rr_ftz :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def f16rr :
|
||||
NVPTXInst<(outs Float16Regs:$dst),
|
||||
(ins Float16Regs:$a, Float16Regs:$b),
|
||||
!strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
|
||||
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
|
||||
Requires<[useFP16Math]>;
|
||||
|
||||
def f16x2rr_ftz :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
Requires<[useFP16Math, doF32FTZ]>;
|
||||
def f16x2rr :
|
||||
NVPTXInst<(outs Float16x2Regs:$dst),
|
||||
(ins Float16x2Regs:$a, Float16x2Regs:$b),
|
||||
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
|
||||
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
|
||||
Requires<[useFP16Math]>;
|
||||
}
|
||||
|
||||
// Template for instructions which take three FP args. The
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
; RUN: llc < %s | FileCheck %s
|
||||
; RUN: llc < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
|
||||
; RUN: llc < %s -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-F16
|
||||
; RUN: llc < %s -mcpu=sm_80 --nvptx-no-f16-math | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
|
||||
target triple = "nvptx64-nvidia-cuda"
|
||||
|
||||
; Checks that llvm intrinsics for math functions are correctly lowered to PTX.
|
||||
|
@ -17,10 +19,14 @@ declare float @llvm.trunc.f32(float) #0
|
|||
declare double @llvm.trunc.f64(double) #0
|
||||
declare float @llvm.fabs.f32(float) #0
|
||||
declare double @llvm.fabs.f64(double) #0
|
||||
declare half @llvm.minnum.f16(half, half) #0
|
||||
declare float @llvm.minnum.f32(float, float) #0
|
||||
declare double @llvm.minnum.f64(double, double) #0
|
||||
declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>) #0
|
||||
declare half @llvm.maxnum.f16(half, half) #0
|
||||
declare float @llvm.maxnum.f32(float, float) #0
|
||||
declare double @llvm.maxnum.f64(double, double) #0
|
||||
declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>) #0
|
||||
declare float @llvm.fma.f32(float, float, float) #0
|
||||
declare double @llvm.fma.f64(double, double, double) #0
|
||||
|
||||
|
@ -193,6 +199,14 @@ define double @abs_double(double %a) {
|
|||
|
||||
; ---- min ----
|
||||
|
||||
; CHECK-LABEL: min_half
|
||||
define half @min_half(half %a, half %b) {
|
||||
; CHECK-NOF16: min.f32
|
||||
; CHECK-F16: min.f16
|
||||
%x = call half @llvm.minnum.f16(half %a, half %b)
|
||||
ret half %x
|
||||
}
|
||||
|
||||
; CHECK-LABEL: min_float
|
||||
define float @min_float(float %a, float %b) {
|
||||
; CHECK: min.f32
|
||||
|
@ -228,8 +242,25 @@ define double @min_double(double %a, double %b) {
|
|||
ret double %x
|
||||
}
|
||||
|
||||
; CHECK-LABEL: min_v2half
|
||||
define <2 x half> @min_v2half(<2 x half> %a, <2 x half> %b) {
|
||||
; CHECK-NOF16: min.f32
|
||||
; CHECK-NOF16: min.f32
|
||||
; CHECK-F16: min.f16x2
|
||||
%x = call <2 x half> @llvm.minnum.v2f16(<2 x half> %a, <2 x half> %b)
|
||||
ret <2 x half> %x
|
||||
}
|
||||
|
||||
; ---- max ----
|
||||
|
||||
; CHECK-LABEL: max_half
|
||||
define half @max_half(half %a, half %b) {
|
||||
; CHECK-NOF16: max.f32
|
||||
; CHECK-F16: max.f16
|
||||
%x = call half @llvm.maxnum.f16(half %a, half %b)
|
||||
ret half %x
|
||||
}
|
||||
|
||||
; CHECK-LABEL: max_imm1
|
||||
define float @max_imm1(float %a) {
|
||||
; CHECK: max.f32
|
||||
|
@ -265,6 +296,15 @@ define double @max_double(double %a, double %b) {
|
|||
ret double %x
|
||||
}
|
||||
|
||||
; CHECK-LABEL: max_v2half
|
||||
define <2 x half> @max_v2half(<2 x half> %a, <2 x half> %b) {
|
||||
; CHECK-NOF16: max.f32
|
||||
; CHECK-NOF16: max.f32
|
||||
; CHECK-F16: max.f16x2
|
||||
%x = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %a, <2 x half> %b)
|
||||
ret <2 x half> %x
|
||||
}
|
||||
|
||||
; ---- fma ----
|
||||
|
||||
; CHECK-LABEL: @fma_float
|
||||
|
|
Loading…
Reference in New Issue