[NVPTX] Lower fp16 fminnum, fmaxnum to native on sm_80.

Reviewed By: bkramer, tra

Differential Revision: https://reviews.llvm.org/D117122
This commit is contained in:
Christian Sigg 2022-01-12 20:52:53 +01:00
parent 9c9119ab36
commit cc1b9acf55
3 changed files with 75 additions and 5 deletions

View File

@ -560,10 +560,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
}
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
setOperationAction(ISD::FMINIMUM, MVT::f16, Promote);
setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote);
// max.f16 is supported on sm_80+.
if (STI.allowFP16Math() && STI.getSmVersion() >= 80 &&
STI.getPTXVersion() >= 70) {
setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal);
setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal);
}
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
// No FPOW or FREM in PTX.

View File

@ -249,6 +249,32 @@ multiclass F3<string OpcStr, SDNode OpNode> {
(ins Float32Regs:$a, f32imm:$b),
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
def f16rr_ftz :
NVPTXInst<(outs Float16Regs:$dst),
(ins Float16Regs:$a, Float16Regs:$b),
!strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
Requires<[useFP16Math, doF32FTZ]>;
def f16rr :
NVPTXInst<(outs Float16Regs:$dst),
(ins Float16Regs:$a, Float16Regs:$b),
!strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
Requires<[useFP16Math]>;
def f16x2rr_ftz :
NVPTXInst<(outs Float16x2Regs:$dst),
(ins Float16x2Regs:$a, Float16x2Regs:$b),
!strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
Requires<[useFP16Math, doF32FTZ]>;
def f16x2rr :
NVPTXInst<(outs Float16x2Regs:$dst),
(ins Float16x2Regs:$a, Float16x2Regs:$b),
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
Requires<[useFP16Math]>;
}
// Template for instructions which take three FP args. The

View File

@ -1,4 +1,6 @@
; RUN: llc < %s | FileCheck %s
; RUN: llc < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
; RUN: llc < %s -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-F16
; RUN: llc < %s -mcpu=sm_80 --nvptx-no-f16-math | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
target triple = "nvptx64-nvidia-cuda"
; Checks that llvm intrinsics for math functions are correctly lowered to PTX.
@ -17,10 +19,14 @@ declare float @llvm.trunc.f32(float) #0
declare double @llvm.trunc.f64(double) #0
declare float @llvm.fabs.f32(float) #0
declare double @llvm.fabs.f64(double) #0
declare half @llvm.minnum.f16(half, half) #0
declare float @llvm.minnum.f32(float, float) #0
declare double @llvm.minnum.f64(double, double) #0
declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>) #0
declare half @llvm.maxnum.f16(half, half) #0
declare float @llvm.maxnum.f32(float, float) #0
declare double @llvm.maxnum.f64(double, double) #0
declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>) #0
declare float @llvm.fma.f32(float, float, float) #0
declare double @llvm.fma.f64(double, double, double) #0
@ -193,6 +199,14 @@ define double @abs_double(double %a) {
; ---- min ----
; CHECK-LABEL: min_half
define half @min_half(half %a, half %b) {
; CHECK-NOF16: min.f32
; CHECK-F16: min.f16
%x = call half @llvm.minnum.f16(half %a, half %b)
ret half %x
}
; CHECK-LABEL: min_float
define float @min_float(float %a, float %b) {
; CHECK: min.f32
@ -228,8 +242,25 @@ define double @min_double(double %a, double %b) {
ret double %x
}
; CHECK-LABEL: min_v2half
define <2 x half> @min_v2half(<2 x half> %a, <2 x half> %b) {
; CHECK-NOF16: min.f32
; CHECK-NOF16: min.f32
; CHECK-F16: min.f16x2
%x = call <2 x half> @llvm.minnum.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %x
}
; ---- max ----
; CHECK-LABEL: max_half
define half @max_half(half %a, half %b) {
; CHECK-NOF16: max.f32
; CHECK-F16: max.f16
%x = call half @llvm.maxnum.f16(half %a, half %b)
ret half %x
}
; CHECK-LABEL: max_imm1
define float @max_imm1(float %a) {
; CHECK: max.f32
@ -265,6 +296,15 @@ define double @max_double(double %a, double %b) {
ret double %x
}
; CHECK-LABEL: max_v2half
define <2 x half> @max_v2half(<2 x half> %a, <2 x half> %b) {
; CHECK-NOF16: max.f32
; CHECK-NOF16: max.f32
; CHECK-F16: max.f16x2
%x = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %a, <2 x half> %b)
ret <2 x half> %x
}
; ---- fma ----
; CHECK-LABEL: @fma_float