forked from OSchip/llvm-project
[mlir][spirv] Add math to OpenCL conversion
Differential Revision: https://reviews.llvm.org/D113780
This commit is contained in:
parent
26d1edfb10
commit
75a1bee05d
|
@ -22,7 +22,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
|
||||
// Base class for all GLSL ops.
|
||||
class SPV_GLSLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
|
||||
SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits>;
|
||||
SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits> {
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Shader]>
|
||||
];
|
||||
}
|
||||
|
||||
// Base class for GLSL unary ops.
|
||||
class SPV_GLSLUnaryOp<string mnemonic, int opcode, Type resultType,
|
||||
|
|
|
@ -21,7 +21,15 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
|
|||
|
||||
// Base class for all OpenCL ops.
|
||||
class SPV_OCLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
|
||||
SPV_ExtInstOp<mnemonic, "OCL", "OpenCL.std", opcode, traits>;
|
||||
SPV_ExtInstOp<mnemonic, "OCL", "OpenCL.std", opcode, traits> {
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Kernel]>
|
||||
];
|
||||
}
|
||||
|
||||
// Base class for OpenCL unary ops.
|
||||
class SPV_OCLUnaryOp<string mnemonic, int opcode, Type resultType,
|
||||
|
@ -78,6 +86,69 @@ class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_OCLTanhOp : SPV_OCLUnaryArithmeticOp<"tanh", 63, SPV_Float> {
|
||||
let summary = "Compute hyperbolic tangent of x radians.";
|
||||
|
||||
let description = [{
|
||||
Result Type and x must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
tanh-op ::= ssa-id `=` `spv.OCL.tanh` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%2 = spv.OCL.tanh %0 : f32
|
||||
%3 = spv.OCL.tanh %1 : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLCeilOp : SPV_OCLUnaryArithmeticOp<"ceil", 12, SPV_Float> {
|
||||
let summary = [{
|
||||
Round x to integral value using the round to positive infinity rounding
|
||||
mode.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type and x must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
ceil-op ::= ssa-id `=` `spv.OCL.ceil` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%2 = spv.OCL.ceil %0 : f32
|
||||
%3 = spv.OCL.ceil %1 : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
|
||||
let summary = "Compute the cosine of x radians.";
|
||||
|
||||
|
@ -93,7 +164,7 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
|
|||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
abs-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:`
|
||||
cos-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
|
@ -168,6 +239,39 @@ def SPV_OCLFAbsOp : SPV_OCLUnaryArithmeticOp<"fabs", 23, SPV_Float> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> {
|
||||
let summary = [{
|
||||
Round x to the integral value using the round to negative infinity
|
||||
rounding mode.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type and x must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
floor-op ::= ssa-id `=` `spv.OCL.floor` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%2 = spv.OCL.floor %0 : f32
|
||||
%3 = spv.OCL.ceifloorl %1 : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
|
||||
let summary = "Compute the natural logarithm of x.";
|
||||
|
||||
|
@ -183,7 +287,7 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
|
|||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
abs-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:`
|
||||
log-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
|
@ -198,6 +302,67 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> {
|
||||
let summary = "Compute x to the power y.";
|
||||
|
||||
let description = [{
|
||||
Result Type, x and y must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
restricted-float-scalar-type ::= `f16` | `f32`
|
||||
restricted-float-scalar-vector-type ::=
|
||||
restricted-float-scalar-type |
|
||||
`vector<` integer-literal `x` restricted-float-scalar-type `>`
|
||||
pow-op ::= ssa-id `=` `spv.OCL.pow` ssa-use `:`
|
||||
restricted-float-scalar-vector-type
|
||||
```
|
||||
#### Example:
|
||||
|
||||
```mlir
|
||||
%2 = spv.OCL.pow %0, %1 : f32
|
||||
%3 = spv.OCL.pow %0, %1 : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> {
|
||||
let summary = "Compute inverse square root of x.";
|
||||
|
||||
let description = [{
|
||||
Result Type and x must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
rsqrt-op ::= ssa-id `=` `spv.OCL.rsqrt` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%2 = spv.OCL.rsqrt %0 : f32
|
||||
%3 = spv.OCL.rsqrt %1 : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> {
|
||||
let summary = "Compute sine of x radians.";
|
||||
|
||||
|
@ -213,7 +378,7 @@ def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> {
|
|||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
abs-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:`
|
||||
sin-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
|
@ -243,7 +408,7 @@ def SPV_OCLSqrtOp : SPV_OCLUnaryArithmeticOp<"sqrt", 61, SPV_Float> {
|
|||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
abs-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:`
|
||||
sqrt-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ namespace {
|
|||
///
|
||||
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
|
||||
/// these operations.
|
||||
template <typename LogOp>
|
||||
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
|
||||
public:
|
||||
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
|
||||
|
@ -48,7 +49,7 @@ public:
|
|||
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
|
||||
auto onePlus =
|
||||
rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
|
||||
rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
|
||||
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -61,8 +62,10 @@ public:
|
|||
namespace mlir {
|
||||
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
|
||||
// GLSL patterns
|
||||
patterns.add<
|
||||
Log1pOpPattern,
|
||||
Log1pOpPattern<spirv::GLSLLogOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
|
||||
|
@ -75,6 +78,21 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
|
||||
// OpenCL patterns
|
||||
patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, {}>
|
||||
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, {}>
|
||||
} {
|
||||
|
||||
// Check integer operation conversions.
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, {}> } {
|
||||
|
||||
// CHECK-LABEL: @float32_unary_scalar
|
||||
func @float32_unary_scalar(%arg0: f32) {
|
||||
// CHECK: spv.GLSL.Cos %{{.*}}: f32
|
||||
|
@ -59,3 +61,5 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
|
|||
%0 = math.powf %lhs, %rhs : vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
|
@ -0,0 +1,65 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, {}> } {
|
||||
|
||||
// CHECK-LABEL: @float32_unary_scalar
|
||||
func @float32_unary_scalar(%arg0: f32) {
|
||||
// CHECK: spv.OCL.cos %{{.*}}: f32
|
||||
%0 = math.cos %arg0 : f32
|
||||
// CHECK: spv.OCL.exp %{{.*}}: f32
|
||||
%1 = math.exp %arg0 : f32
|
||||
// CHECK: spv.OCL.log %{{.*}}: f32
|
||||
%2 = math.log %arg0 : f32
|
||||
// CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
|
||||
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
|
||||
// CHECK: spv.OCL.log %[[ADDONE]]
|
||||
%3 = math.log1p %arg0 : f32
|
||||
// CHECK: spv.OCL.rsqrt %{{.*}}: f32
|
||||
%4 = math.rsqrt %arg0 : f32
|
||||
// CHECK: spv.OCL.sqrt %{{.*}}: f32
|
||||
%5 = math.sqrt %arg0 : f32
|
||||
// CHECK: spv.OCL.tanh %{{.*}}: f32
|
||||
%6 = math.tanh %arg0 : f32
|
||||
// CHECK: spv.OCL.sin %{{.*}}: f32
|
||||
%7 = math.sin %arg0 : f32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_unary_vector
|
||||
func @float32_unary_vector(%arg0: vector<3xf32>) {
|
||||
// CHECK: spv.OCL.cos %{{.*}}: vector<3xf32>
|
||||
%0 = math.cos %arg0 : vector<3xf32>
|
||||
// CHECK: spv.OCL.exp %{{.*}}: vector<3xf32>
|
||||
%1 = math.exp %arg0 : vector<3xf32>
|
||||
// CHECK: spv.OCL.log %{{.*}}: vector<3xf32>
|
||||
%2 = math.log %arg0 : vector<3xf32>
|
||||
// CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
|
||||
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
|
||||
// CHECK: spv.OCL.log %[[ADDONE]]
|
||||
%3 = math.log1p %arg0 : vector<3xf32>
|
||||
// CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32>
|
||||
%4 = math.rsqrt %arg0 : vector<3xf32>
|
||||
// CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32>
|
||||
%5 = math.sqrt %arg0 : vector<3xf32>
|
||||
// CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32>
|
||||
%6 = math.tanh %arg0 : vector<3xf32>
|
||||
// CHECK: spv.OCL.sin %{{.*}}: vector<3xf32>
|
||||
%7 = math.sin %arg0 : vector<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_binary_scalar
|
||||
func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
|
||||
// CHECK: spv.OCL.pow %{{.*}}: f32
|
||||
%0 = math.powf %lhs, %rhs : f32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_binary_vector
|
||||
func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
|
||||
// CHECK: spv.OCL.pow %{{.*}}: vector<4xf32>
|
||||
%0 = math.powf %lhs, %rhs : vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, {}>
|
||||
#spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, {}>
|
||||
} {
|
||||
|
||||
// Check integer operation conversions.
|
||||
|
|
|
@ -14,6 +14,14 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
|
|||
%4 = spv.OCL.log %arg0 : f32
|
||||
// CHECK: {{%.*}} = spv.OCL.sqrt {{%.*}} : f32
|
||||
%5 = spv.OCL.sqrt %arg0 : f32
|
||||
// CHECK: {{%.*}} = spv.OCL.ceil {{%.*}} : f32
|
||||
%6 = spv.OCL.ceil %arg0 : f32
|
||||
// CHECK: {{%.*}} = spv.OCL.floor {{%.*}} : f32
|
||||
%7 = spv.OCL.floor %arg0 : f32
|
||||
// CHECK: {{%.*}} = spv.OCL.pow {{%.*}}, {{%.*}} : f32
|
||||
%8 = spv.OCL.pow %arg0, %arg0 : f32
|
||||
// CHECK: {{%.*}} = spv.OCL.rsqrt {{%.*}} : f32
|
||||
%9 = spv.OCL.rsqrt %arg0 : f32
|
||||
spv.Return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue