[Clang][NVPTX]Add NVPTX intrinsics and builtins for CUDA PTX cvt sm80 instructions

Adds NVPTX intrinsics and builtins for CUDA PTX cvt instructions for sm80
architectures and above. Requires ptx 7.0.

PTX ISA description of cvt instructions :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt

Signed-off-by: JackAKirk <jack.kirk@codeplay.com>

Differential Revision: https://reviews.llvm.org/D116673
This commit is contained in:
Jack Kirk 2022-01-13 12:01:20 -08:00 committed by Artem Belevich
parent 07f9fb8b51
commit bef3eb8344
8 changed files with 290 additions and 2 deletions

View File

@ -402,6 +402,23 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
BUILTIN(__nvvm_f2h_rn, "Usf", "")
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
// Bitcast
BUILTIN(__nvvm_bitcast_f2i, "if", "")

View File

@ -754,4 +754,40 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
__nvvm_cp_async_wait_all();
#endif
// CHECK: ret void
}
}
// CHECK-LABEL: nvvm_cvt_sm80
__device__ void nvvm_cvt_sm80() {
#if __CUDA_ARCH__ >= 800
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn_relu(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz_relu(1, 1);
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rn(1, 1);
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rn_relu(1, 1);
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz(1, 1);
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz_relu(1, 1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
__nvvm_f2bf16_rn(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
__nvvm_f2bf16_rn_relu(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
__nvvm_f2bf16_rz(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
__nvvm_f2bf16_rz_relu(1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
__nvvm_f2tf32_rna(1);
#endif
// CHECK: ret void
}

View File

@ -1185,6 +1185,36 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_f2h_rn : GCCBuiltin<"__nvvm_f2h_rn">,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
def int_nvvm_ff2bf16x2_rn : GCCBuiltin<"__nvvm_ff2bf16x2_rn">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2bf16x2_rn_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2bf16x2_rz : GCCBuiltin<"__nvvm_ff2bf16x2_rz">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2bf16x2_rz_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rz_relu">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2f16x2_rn : GCCBuiltin<"__nvvm_ff2f16x2_rn">,
Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2f16x2_rn_relu : GCCBuiltin<"__nvvm_ff2f16x2_rn_relu">,
Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2f16x2_rz : GCCBuiltin<"__nvvm_ff2f16x2_rz">,
Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2f16x2_rz_relu : GCCBuiltin<"__nvvm_ff2f16x2_rz_relu">,
Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_f2bf16_rn : GCCBuiltin<"__nvvm_f2bf16_rn">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
def int_nvvm_f2bf16_rn_relu : GCCBuiltin<"__nvvm_f2bf16_rn_relu">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
def int_nvvm_f2bf16_rz : GCCBuiltin<"__nvvm_f2bf16_rz">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
def int_nvvm_f2bf16_rz_relu : GCCBuiltin<"__nvvm_f2bf16_rz_relu">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
def int_nvvm_f2tf32_rna : GCCBuiltin<"__nvvm_f2tf32_rna">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem]>;
//
// Bitcast
//

View File

@ -108,6 +108,10 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
// SAT flag
if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
O << ".sat";
} else if (strcmp(Modifier, "relu") == 0) {
// RELU flag
if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
O << ".relu";
} else if (strcmp(Modifier, "base") == 0) {
// Default operand
switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
@ -139,6 +143,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
case NVPTX::PTXCvtMode::RP:
O << ".rp";
break;
case NVPTX::PTXCvtMode::RNA:
O << ".rna";
break;
}
} else {
llvm_unreachable("Invalid conversion modifier");

View File

@ -137,10 +137,12 @@ enum CvtMode {
RZ,
RM,
RP,
RNA,
BASE_MASK = 0x0F,
FTZ_FLAG = 0x10,
SAT_FLAG = 0x20
SAT_FLAG = 0x20,
RELU_FLAG = 0x40
};
}

View File

@ -48,6 +48,7 @@ def CvtRN : PatLeaf<(i32 0x5)>;
def CvtRZ : PatLeaf<(i32 0x6)>;
def CvtRM : PatLeaf<(i32 0x7)>;
def CvtRP : PatLeaf<(i32 0x8)>;
def CvtRNA : PatLeaf<(i32 0x9)>;
def CvtNONE_FTZ : PatLeaf<(i32 0x10)>;
def CvtRNI_FTZ : PatLeaf<(i32 0x11)>;
@ -62,6 +63,10 @@ def CvtRP_FTZ : PatLeaf<(i32 0x18)>;
def CvtSAT : PatLeaf<(i32 0x20)>;
def CvtSAT_FTZ : PatLeaf<(i32 0x30)>;
def CvtNONE_RELU : PatLeaf<(i32 0x40)>;
def CvtRN_RELU : PatLeaf<(i32 0x45)>;
def CvtRZ_RELU : PatLeaf<(i32 0x46)>;
def CvtMode : Operand<i32> {
let PrintMethod = "printCvtMode";
}
@ -526,6 +531,29 @@ let hasSideEffects = false in {
"cvt.s64.s16 \t$dst, $src;", []>;
def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
"cvt.s64.s32 \t$dst, $src;", []>;
multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:relu}.",
FromName, ".f32 \t$dst, $src;"), []>,
Requires<[hasPTX70, hasSM80]>;
}
defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:relu}.",
FromName, ".f32 \t$dst, $src1, $src2;"), []>,
Requires<[hasPTX70, hasSM80]>;
}
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>;
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
}
//-----------------------------------

View File

@ -1046,6 +1046,38 @@ def : Pat<(int_nvvm_ui2f_rm Int32Regs:$a),
def : Pat<(int_nvvm_ui2f_rp Int32Regs:$a),
(CVT_f32_u32 Int32Regs:$a, CvtRP)>;
def : Pat<(int_nvvm_ff2bf16x2_rn Float32Regs:$a, Float32Regs:$b),
(CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
def : Pat<(int_nvvm_ff2bf16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
(CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2bf16x2_rz Float32Regs:$a, Float32Regs:$b),
(CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
(CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
(CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
(CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
(CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
(CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
(CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
(CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
(CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
(CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
def CVT_tf32_f32 :
NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
"cvt.rna.tf32.f32 \t$dest, $a;",
[(set Int32Regs:$dest, (int_nvvm_f2tf32_rna Float32Regs:$a))]>;
def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};",
Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>;

View File

@ -0,0 +1,136 @@
; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
; CHECK-LABEL: cvt_rn_bf16x2_f32
define i32 @cvt_rn_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.bf16x2.f32
%val = call i32 @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2);
ret i32 %val
}
; CHECK-LABEL: cvt_rn_relu_bf16x2_f32
define i32 @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.relu.bf16x2.f32
%val = call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2);
ret i32 %val
}
; CHECK-LABEL: cvt_rz_bf16x2_f32
define i32 @cvt_rz_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rz.bf16x2.f32
%val = call i32 @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2);
ret i32 %val
}
; CHECK-LABEL: cvt_rz_relu_bf16x2_f32
define i32 @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rz.relu.bf16x2.f32
%val = call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2);
ret i32 %val
}
declare i32 @llvm.nvvm.ff2bf16x2.rn(float, float)
declare i32 @llvm.nvvm.ff2bf16x2.rn.relu(float, float)
declare i32 @llvm.nvvm.ff2bf16x2.rz(float, float)
declare i32 @llvm.nvvm.ff2bf16x2.rz.relu(float, float)
; CHECK-LABEL: cvt_rn_f16x2_f32
define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.f16x2.f32
%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2);
ret <2 x half> %val
}
; CHECK-LABEL: cvt_rn_relu_f16x2_f32
define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.relu.f16x2.f32
%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2);
ret <2 x half> %val
}
; CHECK-LABEL: cvt_rz_f16x2_f32
define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rz.f16x2.f32
%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2);
ret <2 x half> %val
}
; CHECK-LABEL: cvt_rz_relu_f16x2_f32
define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rz.relu.f16x2.f32
%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2);
ret <2 x half> %val
}
declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float)
declare <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float, float)
declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float)
declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float)
; CHECK-LABEL: cvt_rn_bf16_f32
define i16 @cvt_rn_bf16_f32(float %f1) {
; CHECK: cvt.rn.bf16.f32
%val = call i16 @llvm.nvvm.f2bf16.rn(float %f1);
ret i16 %val
}
; CHECK-LABEL: cvt_rn_relu_bf16_f32
define i16 @cvt_rn_relu_bf16_f32(float %f1) {
; CHECK: cvt.rn.relu.bf16.f32
%val = call i16 @llvm.nvvm.f2bf16.rn.relu(float %f1);
ret i16 %val
}
; CHECK-LABEL: cvt_rz_bf16_f32
define i16 @cvt_rz_bf16_f32(float %f1) {
; CHECK: cvt.rz.bf16.f32
%val = call i16 @llvm.nvvm.f2bf16.rz(float %f1);
ret i16 %val
}
; CHECK-LABEL: cvt_rz_relu_bf16_f32
define i16 @cvt_rz_relu_bf16_f32(float %f1) {
; CHECK: cvt.rz.relu.bf16.f32
%val = call i16 @llvm.nvvm.f2bf16.rz.relu(float %f1);
ret i16 %val
}
declare i16 @llvm.nvvm.f2bf16.rn(float)
declare i16 @llvm.nvvm.f2bf16.rn.relu(float)
declare i16 @llvm.nvvm.f2bf16.rz(float)
declare i16 @llvm.nvvm.f2bf16.rz.relu(float)
; CHECK-LABEL: cvt_rna_tf32_f32
define i32 @cvt_rna_tf32_f32(float %f1) {
; CHECK: cvt.rna.tf32.f32
%val = call i32 @llvm.nvvm.f2tf32.rna(float %f1);
ret i32 %val
}
declare i32 @llvm.nvvm.f2tf32.rna(float)