[MLIR][SPIRVToLLVM] Conversion of GLSL ops to LLVM intrinsics

This patch introduces new intrinsics in LLVM dialect:
-  `llvm.intr.floor`
-  `llvm.intr.maxnum`
-  `llvm.intr.minnum`
-  `llvm.intr.smax`
-  `llvm.intr.smin`
These intrinsics correspond to SPIR-V ops from GLSL
extended instruction set (`spv.GLSL.Floor`, `spv.GLSL.FMax`,
`spv.GLSL.FMin`,  `spv.GLSL.SMax` and `spv.GLSL.SMin`
respectively). Also conversion patterns for them were added.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D84661
This commit is contained in:
George Mitenkov 2020-07-30 10:53:06 +03:00
parent 0037a5f894
commit 1880532036
4 changed files with 120 additions and 0 deletions
mlir
include/mlir/Dialect/LLVMIR
lib/Conversion/SPIRVToLLVM
test

View File

@ -852,6 +852,7 @@ def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">;
def LLVM_Exp2Op : LLVM_UnaryIntrinsicOp<"exp2">;
def LLVM_FAbsOp : LLVM_UnaryIntrinsicOp<"fabs">;
def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
def LLVM_FFloorOp : LLVM_UnaryIntrinsicOp<"floor">;
def LLVM_FMAOp : LLVM_TernarySameArgsIntrinsicOp<"fma">;
def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
def LLVM_Log10Op : LLVM_UnaryIntrinsicOp<"log10">;
@ -865,6 +866,10 @@ def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">;
def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">;
def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrinsicOp<"maxnum">;
def LLVM_MinNumOp : LLVM_BinarySameArgsIntrinsicOp<"minnum">;
def LLVM_SMaxOp : LLVM_BinarySameArgsIntrinsicOp<"smax">;
def LLVM_SMinOp : LLVM_BinarySameArgsIntrinsicOp<"smin">;
def LLVM_MemcpyOp : LLVM_ZeroResultIntrOp<"memcpy", [0, 1, 2]>,
Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src,

View File

@ -1120,8 +1120,13 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
InverseSqrtPattern, TanPattern, TanhPattern,

View File

@ -52,6 +52,45 @@ func @fabs(%arg0: f32, %arg1: vector<3xf16>) {
return
}
//===----------------------------------------------------------------------===//
// spv.GLSL.Floor
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @floor
func @floor(%arg0: f32, %arg1: vector<3xf16>) {
// CHECK: "llvm.intr.floor"(%{{.*}}) : (!llvm.float) -> !llvm.float
%0 = spv.GLSL.Floor %arg0 : f32
// CHECK: "llvm.intr.floor"(%{{.*}}) : (!llvm<"<3 x half>">) -> !llvm<"<3 x half>">
%1 = spv.GLSL.Floor %arg1 : vector<3xf16>
return
}
//===----------------------------------------------------------------------===//
// spv.GLSL.FMax
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @fmax
func @fmax(%arg0: f32, %arg1: vector<3xf16>) {
// CHECK: "llvm.intr.maxnum"(%{{.*}}, %{{.*}}) : (!llvm.float, !llvm.float) -> !llvm.float
%0 = spv.GLSL.FMax %arg0, %arg0 : f32
// CHECK: "llvm.intr.maxnum"(%{{.*}}, %{{.*}}) : (!llvm<"<3 x half>">, !llvm<"<3 x half>">) -> !llvm<"<3 x half>">
%1 = spv.GLSL.FMax %arg1, %arg1 : vector<3xf16>
return
}
//===----------------------------------------------------------------------===//
// spv.GLSL.FMin
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @fmin
func @fmin(%arg0: f32, %arg1: vector<3xf16>) {
// CHECK: "llvm.intr.minnum"(%{{.*}}, %{{.*}}) : (!llvm.float, !llvm.float) -> !llvm.float
%0 = spv.GLSL.FMin %arg0, %arg0 : f32
// CHECK: "llvm.intr.minnum"(%{{.*}}, %{{.*}}) : (!llvm<"<3 x half>">, !llvm<"<3 x half>">) -> !llvm<"<3 x half>">
%1 = spv.GLSL.FMin %arg1, %arg1 : vector<3xf16>
return
}
//===----------------------------------------------------------------------===//
// spv.GLSL.Log
//===----------------------------------------------------------------------===//
@ -78,6 +117,32 @@ func @sin(%arg0: f32, %arg1: vector<3xf16>) {
return
}
//===----------------------------------------------------------------------===//
// spv.GLSL.SMax
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @smax
func @smax(%arg0: i16, %arg1: vector<3xi32>) {
// CHECK: "llvm.intr.smax"(%{{.*}}, %{{.*}}) : (!llvm.i16, !llvm.i16) -> !llvm.i16
%0 = spv.GLSL.SMax %arg0, %arg0 : i16
// CHECK: "llvm.intr.smax"(%{{.*}}, %{{.*}}) : (!llvm<"<3 x i32>">, !llvm<"<3 x i32>">) -> !llvm<"<3 x i32>">
%1 = spv.GLSL.SMax %arg1, %arg1 : vector<3xi32>
return
}
//===----------------------------------------------------------------------===//
// spv.GLSL.SMin
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @smin
func @smin(%arg0: i16, %arg1: vector<3xi32>) {
// CHECK: "llvm.intr.smin"(%{{.*}}, %{{.*}}) : (!llvm.i16, !llvm.i16) -> !llvm.i16
%0 = spv.GLSL.SMin %arg0, %arg0 : i16
// CHECK: "llvm.intr.smin"(%{{.*}}, %{{.*}}) : (!llvm<"<3 x i32>">, !llvm<"<3 x i32>">) -> !llvm<"<3 x i32>">
%1 = spv.GLSL.SMin %arg1, %arg1 : vector<3xi32>
return
}
//===----------------------------------------------------------------------===//
// spv.GLSL.Sqrt
//===----------------------------------------------------------------------===//

View File

@ -90,6 +90,15 @@ llvm.func @ceil_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
llvm.return
}
// CHECK-LABEL: @floor_test
llvm.func @floor_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.floor.f32
"llvm.intr.floor"(%arg0) : (!llvm.float) -> !llvm.float
// CHECK: call <8 x float> @llvm.floor.v8f32
"llvm.intr.floor"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>">
llvm.return
}
// CHECK-LABEL: @cos_test
llvm.func @cos_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.cos.f32
@ -135,6 +144,42 @@ llvm.func @ctpop_test(%arg0: !llvm.i32, %arg1: !llvm<"<8 x i32>">) {
llvm.return
}
// CHECK-LABEL: @maxnum_test
llvm.func @maxnum_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">, %arg3: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.maxnum.f32
"llvm.intr.maxnum"(%arg0, %arg1) : (!llvm.float, !llvm.float) -> !llvm.float
// CHECK: call <8 x float> @llvm.maxnum.v8f32
"llvm.intr.maxnum"(%arg2, %arg3) : (!llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
llvm.return
}
// CHECK-LABEL: @minnum_test
llvm.func @minnum_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x float>">, %arg3: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.minnum.f32
"llvm.intr.minnum"(%arg0, %arg1) : (!llvm.float, !llvm.float) -> !llvm.float
// CHECK: call <8 x float> @llvm.minnum.v8f32
"llvm.intr.minnum"(%arg2, %arg3) : (!llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
llvm.return
}
// CHECK-LABEL: @smax_test
llvm.func @smax_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm<"<8 x i32>">, %arg3: !llvm<"<8 x i32>">) {
// CHECK: call i32 @llvm.smax.i32
"llvm.intr.smax"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.i32
// CHECK: call <8 x i32> @llvm.smax.v8i32
"llvm.intr.smax"(%arg2, %arg3) : (!llvm<"<8 x i32>">, !llvm<"<8 x i32>">) -> !llvm<"<8 x i32>">
llvm.return
}
// CHECK-LABEL: @smin_test
llvm.func @smin_test(%arg0: !llvm.i32, %arg1: !llvm.i32, %arg2: !llvm<"<8 x i32>">, %arg3: !llvm<"<8 x i32>">) {
// CHECK: call i32 @llvm.smin.i32
"llvm.intr.smin"(%arg0, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.i32
// CHECK: call <8 x i32> @llvm.smin.v8i32
"llvm.intr.smin"(%arg2, %arg3) : (!llvm<"<8 x i32>">, !llvm<"<8 x i32>">) -> !llvm<"<8 x i32>">
llvm.return
}
// CHECK-LABEL: @vector_reductions
llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %arg2: !llvm<"<8 x i32>">) {
// CHECK: call i32 @llvm.experimental.vector.reduce.add.v8i32