[mlir][spirv] Add OpenCL fma op and lowering

Also, it seems Khronos has changed html spec format so small adjustment to script was needed.
Base op parsing is also probably broken.

Differential Revision: https://reviews.llvm.org/D119678
This commit is contained in:
Ivan Butygin 2022-02-13 23:25:19 +03:00
parent 290e482342
commit 32389d0c2e
7 changed files with 114 additions and 14 deletions

View File

@ -82,15 +82,46 @@ class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
let assemblyFormat = "operands attr-dict `:` type($result)";
}
// Base class for OpenCL binary ops.
class SPV_OCLTernaryOp<string mnemonic, int opcode, Type resultType,
Type operandType, list<Trait> traits = []> :
SPV_OCLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {
let arguments = (ins
SPV_ScalarOrVectorOf<operandType>:$x,
SPV_ScalarOrVectorOf<operandType>:$y,
SPV_ScalarOrVectorOf<operandType>:$z
);
let results = (outs
SPV_ScalarOrVectorOf<resultType>:$result
);
let hasVerifier = 0;
}
// Base class for OpenCL Ternary arithmetic ops where operand types and
// return type matches.
class SPV_OCLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_OCLTernaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}
// -----
def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
let summary = [{
Error function of x encountered in integrating the normal distribution.
Compute the correctly rounded floating-point representation of the sum
of c with the infinitely precise product of a and b. Rounding of
intermediate products shall not occur. Edge case results are per the
IEEE 754-2008 standard.
}];
let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
@ -99,17 +130,13 @@ def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
<!-- End of AutoGen section -->
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
float-scalar-vector-type
```mlir
#### Example:
```
%2 = spv.OCL.erf %0 : f32
%3 = spv.OCL.erf %1 : vector<3xf16>
%0 = spv.OCL.fma %a, %b, %c : f32
%1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
```
}];
}
@ -179,6 +206,38 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
// -----
def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
let summary = [{
Error function of x encountered in integrating the normal distribution.
}];
let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
float-scalar-vector-type
```mlir
#### Example:
```
%2 = spv.OCL.erf %0 : f32
%3 = spv.OCL.erf %1 : vector<3xf16>
```
}];
}
// -----
def SPV_OCLExpOp : SPV_OCLUnaryArithmeticOp<"exp", 19, SPV_Float> {
let summary = "Exponentiation of Operand 1";

View File

@ -92,13 +92,13 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
typeConverter, patterns.getContext());
// OpenCL patterns
@ -109,6 +109,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,

View File

@ -76,7 +76,7 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
return
}
// CHECK-LABEL: @float32_ternary_scalar
// CHECK-LABEL: @float32_ternary_scalar
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
// CHECK: spv.GLSL.Fma %{{.*}}: f32
%0 = math.fma %a, %b, %c : f32

View File

@ -78,4 +78,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
return
}
// CHECK-LABEL: @float32_ternary_scalar
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
// CHECK: spv.OCL.fma %{{.*}}: f32
%0 = math.fma %a, %b, %c : f32
return
}
// CHECK-LABEL: @float32_ternary_vector
func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
%c: vector<4xf32>) {
// CHECK: spv.OCL.fma %{{.*}}: vector<4xf32>
%0 = math.fma %a, %b, %c : vector<4xf32>
return
}
} // end module

View File

@ -166,3 +166,22 @@ func @sabs(%arg0 : i32) -> () {
return
}
// -----
//===----------------------------------------------------------------------===//
// spv.OCL.fma
//===----------------------------------------------------------------------===//
func @fma(%a : f32, %b : f32, %c : f32) -> () {
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%2 = spv.OCL.fma %a, %b, %c : f32
return
}
// -----
func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () {
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32>
%2 = spv.OCL.fma %a, %b, %c : vector<3xf32>
return
}

View File

@ -38,4 +38,10 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
%0 = spv.OCL.fabs %arg0 : vector<16xf32>
spv.Return
}
spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%13 = spv.OCL.fma %arg0, %arg1, %arg2 : f32
spv.Return
}
}

View File

@ -51,7 +51,7 @@ def get_spirv_doc_from_html_spec(url, settings):
doc = {}
if settings.gen_ocl_ops:
section_anchor = spirv.find('h2', {'id': '_a_id_binary_a_binary_form'})
section_anchor = spirv.find('h2', {'id': '_binary_form'})
for section in section_anchor.parent.find_all('div', {'class': 'sect2'}):
for table in section.find_all('table'):
inst_html = table.tbody.tr.td