forked from OSchip/llvm-project
[mlir][spirv] Add OpenCL fma op and lowering
Also, it seems Khronos has changed html spec format so small adjustment to script was needed. Base op parsing is also probably broken. Differential Revision: https://reviews.llvm.org/D119678
This commit is contained in:
parent
290e482342
commit
32389d0c2e
|
@ -82,15 +82,46 @@ class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
|
|||
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
// Base class for OpenCL binary ops.
|
||||
class SPV_OCLTernaryOp<string mnemonic, int opcode, Type resultType,
|
||||
Type operandType, list<Trait> traits = []> :
|
||||
SPV_OCLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScalarOrVectorOf<operandType>:$x,
|
||||
SPV_ScalarOrVectorOf<operandType>:$y,
|
||||
SPV_ScalarOrVectorOf<operandType>:$z
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_ScalarOrVectorOf<resultType>:$result
|
||||
);
|
||||
|
||||
let hasVerifier = 0;
|
||||
}
|
||||
|
||||
// Base class for OpenCL Ternary arithmetic ops where operand types and
|
||||
// return type matches.
|
||||
class SPV_OCLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
|
||||
list<Trait> traits = []> :
|
||||
SPV_OCLTernaryOp<mnemonic, opcode, type, type,
|
||||
traits # [SameOperandsAndResultType]> {
|
||||
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
|
||||
def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
|
||||
let summary = [{
|
||||
Error function of x encountered in integrating the normal distribution.
|
||||
Compute the correctly rounded floating-point representation of the sum
|
||||
of c with the infinitely precise product of a and b. Rounding of
|
||||
intermediate products shall not occur. Edge case results are per the
|
||||
IEEE 754-2008 standard.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type and x must be floating-point or vector(2,3,4,8,16) of
|
||||
Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
|
@ -99,17 +130,13 @@ def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
|
|||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
|
||||
fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%2 = spv.OCL.erf %0 : f32
|
||||
%3 = spv.OCL.erf %1 : vector<3xf16>
|
||||
%0 = spv.OCL.fma %a, %b, %c : f32
|
||||
%1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
@ -179,6 +206,38 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
|
||||
let summary = [{
|
||||
Error function of x encountered in integrating the normal distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type and x must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%2 = spv.OCL.erf %0 : f32
|
||||
%3 = spv.OCL.erf %1 : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLExpOp : SPV_OCLUnaryArithmeticOp<"exp", 19, SPV_Float> {
|
||||
let summary = "Exponentiation of Operand 1";
|
||||
|
||||
|
|
|
@ -92,13 +92,13 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
|
||||
spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
|
||||
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
|
||||
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
|
||||
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
|
||||
spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
|
||||
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
|
||||
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
|
||||
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
|
||||
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
|
||||
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
|
||||
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
|
||||
// OpenCL patterns
|
||||
|
@ -109,6 +109,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
|
||||
spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
|
||||
spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
|
||||
spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
|
||||
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
|
||||
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
|
||||
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
|
||||
|
|
|
@ -76,7 +76,7 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_ternary_scalar
|
||||
// CHECK-LABEL: @float32_ternary_scalar
|
||||
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
|
||||
// CHECK: spv.GLSL.Fma %{{.*}}: f32
|
||||
%0 = math.fma %a, %b, %c : f32
|
||||
|
|
|
@ -78,4 +78,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_ternary_scalar
|
||||
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
|
||||
// CHECK: spv.OCL.fma %{{.*}}: f32
|
||||
%0 = math.fma %a, %b, %c : f32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_ternary_vector
|
||||
func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
|
||||
%c: vector<4xf32>) {
|
||||
// CHECK: spv.OCL.fma %{{.*}}: vector<4xf32>
|
||||
%0 = math.fma %a, %b, %c : vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
|
|
@ -166,3 +166,22 @@ func @sabs(%arg0 : i32) -> () {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.OCL.fma
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @fma(%a : f32, %b : f32, %c : f32) -> () {
|
||||
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
|
||||
%2 = spv.OCL.fma %a, %b, %c : f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () {
|
||||
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32>
|
||||
%2 = spv.OCL.fma %a, %b, %c : vector<3xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -38,4 +38,10 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
|
|||
%0 = spv.OCL.fabs %arg0 : vector<16xf32>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
|
||||
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
|
||||
%13 = spv.OCL.fma %arg0, %arg1, %arg2 : f32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ def get_spirv_doc_from_html_spec(url, settings):
|
|||
doc = {}
|
||||
|
||||
if settings.gen_ocl_ops:
|
||||
section_anchor = spirv.find('h2', {'id': '_a_id_binary_a_binary_form'})
|
||||
section_anchor = spirv.find('h2', {'id': '_binary_form'})
|
||||
for section in section_anchor.parent.find_all('div', {'class': 'sect2'}):
|
||||
for table in section.find_all('table'):
|
||||
inst_html = table.tbody.tr.td
|
||||
|
|
Loading…
Reference in New Issue