[mlir][spirv] Add math to OpenCL conversion

Differential Revision: https://reviews.llvm.org/D113780
This commit is contained in:
Butygin 2021-10-28 19:04:35 +03:00
parent 26d1edfb10
commit 75a1bee05d
8 changed files with 278 additions and 10 deletions

View File

@ -22,7 +22,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// Base class for all GLSL ops.
class SPV_GLSLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits>;
SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits> {
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Shader]>
];
}
// Base class for GLSL unary ops.
class SPV_GLSLUnaryOp<string mnemonic, int opcode, Type resultType,

View File

@ -21,7 +21,15 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
// Base class for all OpenCL ops.
class SPV_OCLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
SPV_ExtInstOp<mnemonic, "OCL", "OpenCL.std", opcode, traits>;
SPV_ExtInstOp<mnemonic, "OCL", "OpenCL.std", opcode, traits> {
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Kernel]>
];
}
// Base class for OpenCL unary ops.
class SPV_OCLUnaryOp<string mnemonic, int opcode, Type resultType,
@ -78,6 +86,69 @@ class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
// -----
def SPV_OCLTanhOp : SPV_OCLUnaryArithmeticOp<"tanh", 63, SPV_Float> {
let summary = "Compute hyperbolic tangent of x radians.";
let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
tanh-op ::= ssa-id `=` `spv.OCL.tanh` ssa-use `:`
float-scalar-vector-type
```mlir
#### Example:
```
%2 = spv.OCL.tanh %0 : f32
%3 = spv.OCL.tanh %1 : vector<3xf16>
```
}];
}
// -----
def SPV_OCLCeilOp : SPV_OCLUnaryArithmeticOp<"ceil", 12, SPV_Float> {
let summary = [{
Round x to integral value using the round to positive infinity rounding
mode.
}];
let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
ceil-op ::= ssa-id `=` `spv.OCL.ceil` ssa-use `:`
float-scalar-vector-type
```mlir
#### Example:
```
%2 = spv.OCL.ceil %0 : f32
%3 = spv.OCL.ceil %1 : vector<3xf16>
```
}];
}
// -----
def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
let summary = "Compute the cosine of x radians.";
@ -93,7 +164,7 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
abs-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:`
cos-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:`
float-scalar-vector-type
```mlir
@ -168,6 +239,39 @@ def SPV_OCLFAbsOp : SPV_OCLUnaryArithmeticOp<"fabs", 23, SPV_Float> {
// -----
def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> {
let summary = [{
Round x to the integral value using the round to negative infinity
rounding mode.
}];
let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
floor-op ::= ssa-id `=` `spv.OCL.floor` ssa-use `:`
float-scalar-vector-type
```mlir
#### Example:
```
%2 = spv.OCL.floor %0 : f32
%3 = spv.OCL.ceifloorl %1 : vector<3xf16>
```
}];
}
// -----
def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
let summary = "Compute the natural logarithm of x.";
@ -183,7 +287,7 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
abs-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:`
log-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:`
float-scalar-vector-type
```mlir
@ -198,6 +302,67 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
// -----
def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> {
let summary = "Compute x to the power y.";
let description = [{
Result Type, x and y must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
restricted-float-scalar-type ::= `f16` | `f32`
restricted-float-scalar-vector-type ::=
restricted-float-scalar-type |
`vector<` integer-literal `x` restricted-float-scalar-type `>`
pow-op ::= ssa-id `=` `spv.OCL.pow` ssa-use `:`
restricted-float-scalar-vector-type
```
#### Example:
```mlir
%2 = spv.OCL.pow %0, %1 : f32
%3 = spv.OCL.pow %0, %1 : vector<3xf16>
```
}];
}
// -----
def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> {
let summary = "Compute inverse square root of x.";
let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
rsqrt-op ::= ssa-id `=` `spv.OCL.rsqrt` ssa-use `:`
float-scalar-vector-type
```mlir
#### Example:
```
%2 = spv.OCL.rsqrt %0 : f32
%3 = spv.OCL.rsqrt %1 : vector<3xf16>
```
}];
}
// -----
def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> {
let summary = "Compute sine of x radians.";
@ -213,7 +378,7 @@ def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> {
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
abs-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:`
sin-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:`
float-scalar-vector-type
```mlir
@ -243,7 +408,7 @@ def SPV_OCLSqrtOp : SPV_OCLUnaryArithmeticOp<"sqrt", 61, SPV_Float> {
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
abs-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:`
sqrt-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:`
float-scalar-vector-type
```mlir

View File

@ -34,6 +34,7 @@ namespace {
///
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
/// these operations.
template <typename LogOp>
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
public:
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
@ -48,7 +49,7 @@ public:
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
auto onePlus =
rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
return success();
}
};
@ -61,8 +62,10 @@ public:
namespace mlir {
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// GLSL patterns
patterns.add<
Log1pOpPattern,
Log1pOpPattern<spirv::GLSLLogOp>,
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
@ -75,6 +78,21 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
typeConverter, patterns.getContext());
// OpenCL patterns
patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
typeConverter, patterns.getContext());
}
} // namespace mlir

View File

@ -6,7 +6,7 @@
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, {}>
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, {}>
} {
// Check integer operation conversions.

View File

@ -1,5 +1,7 @@
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, {}> } {
// CHECK-LABEL: @float32_unary_scalar
func @float32_unary_scalar(%arg0: f32) {
// CHECK: spv.GLSL.Cos %{{.*}}: f32
@ -59,3 +61,5 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
%0 = math.powf %lhs, %rhs : vector<4xf32>
return
}
} // end module

View File

@ -0,0 +1,65 @@
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, {}> } {
// CHECK-LABEL: @float32_unary_scalar
func @float32_unary_scalar(%arg0: f32) {
// CHECK: spv.OCL.cos %{{.*}}: f32
%0 = math.cos %arg0 : f32
// CHECK: spv.OCL.exp %{{.*}}: f32
%1 = math.exp %arg0 : f32
// CHECK: spv.OCL.log %{{.*}}: f32
%2 = math.log %arg0 : f32
// CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
// CHECK: spv.OCL.log %[[ADDONE]]
%3 = math.log1p %arg0 : f32
// CHECK: spv.OCL.rsqrt %{{.*}}: f32
%4 = math.rsqrt %arg0 : f32
// CHECK: spv.OCL.sqrt %{{.*}}: f32
%5 = math.sqrt %arg0 : f32
// CHECK: spv.OCL.tanh %{{.*}}: f32
%6 = math.tanh %arg0 : f32
// CHECK: spv.OCL.sin %{{.*}}: f32
%7 = math.sin %arg0 : f32
return
}
// CHECK-LABEL: @float32_unary_vector
func @float32_unary_vector(%arg0: vector<3xf32>) {
// CHECK: spv.OCL.cos %{{.*}}: vector<3xf32>
%0 = math.cos %arg0 : vector<3xf32>
// CHECK: spv.OCL.exp %{{.*}}: vector<3xf32>
%1 = math.exp %arg0 : vector<3xf32>
// CHECK: spv.OCL.log %{{.*}}: vector<3xf32>
%2 = math.log %arg0 : vector<3xf32>
// CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
// CHECK: spv.OCL.log %[[ADDONE]]
%3 = math.log1p %arg0 : vector<3xf32>
// CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32>
%4 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32>
%5 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32>
%6 = math.tanh %arg0 : vector<3xf32>
// CHECK: spv.OCL.sin %{{.*}}: vector<3xf32>
%7 = math.sin %arg0 : vector<3xf32>
return
}
// CHECK-LABEL: @float32_binary_scalar
func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
// CHECK: spv.OCL.pow %{{.*}}: f32
%0 = math.powf %lhs, %rhs : f32
return
}
// CHECK-LABEL: @float32_binary_vector
func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
// CHECK: spv.OCL.pow %{{.*}}: vector<4xf32>
%0 = math.powf %lhs, %rhs : vector<4xf32>
return
}
} // end module

View File

@ -6,7 +6,7 @@
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, {}>
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, {}>
} {
// Check integer operation conversions.

View File

@ -14,6 +14,14 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
%4 = spv.OCL.log %arg0 : f32
// CHECK: {{%.*}} = spv.OCL.sqrt {{%.*}} : f32
%5 = spv.OCL.sqrt %arg0 : f32
// CHECK: {{%.*}} = spv.OCL.ceil {{%.*}} : f32
%6 = spv.OCL.ceil %arg0 : f32
// CHECK: {{%.*}} = spv.OCL.floor {{%.*}} : f32
%7 = spv.OCL.floor %arg0 : f32
// CHECK: {{%.*}} = spv.OCL.pow {{%.*}}, {{%.*}} : f32
%8 = spv.OCL.pow %arg0, %arg0 : f32
// CHECK: {{%.*}} = spv.OCL.rsqrt {{%.*}} : f32
%9 = spv.OCL.rsqrt %arg0 : f32
spv.Return
}