[mlir][spirv] Add path for math.round to spirv for OCL and GLSL

OpenCL's round function matches `math.round` so we can directly lower to
the op, this includes adding the op definition to the SPIRV OCL ops.
GLSL does not guarantee rounding direction so we include custom rounding
code to guarantee correct rounding direction.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D129236
This commit is contained in:
Robert Suderman 2022-07-07 19:19:13 +00:00
parent d3712b0852
commit b9e642afd1
4 changed files with 137 additions and 31 deletions

View File

@ -110,36 +110,6 @@ class SPV_OCLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
}
// -----
def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
let summary = [{
Compute the correctly rounded floating-point representation of the sum
of c with the infinitely precise product of a and b. Rounding of
intermediate products shall not occur. Edge case results are per the
IEEE 754-2008 standard.
}];
let description = [{
Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
float-scalar-vector-type
```mlir
```
%0 = spv.OCL.fma %a, %b, %c : f32
%1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
```
}];
}
// -----
@ -331,6 +301,37 @@ def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> {
// -----
def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
let summary = [{
Compute the correctly rounded floating-point representation of the sum
of c with the infinitely precise product of a and b. Rounding of
intermediate products shall not occur. Edge case results are per the
IEEE 754-2008 standard.
}];
let description = [{
Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
float-scalar-vector-type
```mlir
```
%0 = spv.OCL.fma %a, %b, %c : f32
%1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
```
}];
}
// -----
def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
let summary = "Compute the natural logarithm of x.";
@ -392,6 +393,38 @@ def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> {
// -----
def SPV_OCLRoundOp : SPV_OCLUnaryArithmeticOp<"round", 55, SPV_Float> {
let summary = [{
Return the integral value nearest to x rounding halfway cases away from
zero, regardless of the current rounding direction.
}];
let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
floating-point values.
All of the operands, including the Result Type operand, must be of the
same type.
<!-- End of AutoGen section -->
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
round-op ::= ssa-id `=` `spv.OCL.round` ssa-use `:`
float-scalar-vector-type
```
#### Example:
```mlir
%2 = spv.OCL.round %0 : f32
%3 = spv.OCL.round %0 : vector<3xf16>
```
}];
}
// -----
def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> {
let summary = "Compute inverse square root of x.";

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
@ -233,6 +234,43 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
}
};
/// Converts math.round to GLSL SPIRV extended ops.
struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = roundOp.getLoc();
auto operand = roundOp.getOperand();
auto ty = operand.getType();
auto ety = getElementTypeOrSelf(ty);
auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
Value half;
if (VectorType vty = ty.dyn_cast<VectorType>()) {
half = rewriter.create<spirv::ConstantOp>(
loc, vty,
DenseElementsAttr::get(vty,
rewriter.getFloatAttr(ety, 0.5).getValue()));
} else {
half = rewriter.create<spirv::ConstantOp>(
loc, ty, rewriter.getFloatAttr(ety, 0.5));
}
auto abs = rewriter.create<spirv::GLSLFAbsOp>(loc, operand);
auto floor = rewriter.create<spirv::GLSLFloorOp>(loc, abs);
auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
auto greater =
rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@ -248,7 +286,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// GLSL patterns
patterns
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern,
ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern, RoundOpPattern,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
@ -273,6 +311,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
spirv::ElementwiseOpPattern<math::RoundOp, spirv::OCLRoundOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,

View File

@ -145,6 +145,38 @@ func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32
return %0: vector<4xf32>
}
// CHECK-LABEL: @round_scalar
func.func @round_scalar(%x: f32) -> f32 {
// CHECK: %[[ZERO:.+]] = spv.Constant 0.000000e+00
// CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00
// CHECK: %[[HALF:.+]] = spv.Constant 5.000000e-01
// CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0
// CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]]
// CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]]
// CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
// CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]]
// CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]]
// CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]]
%0 = math.round %x : f32
return %0: f32
}
// CHECK-LABEL: @round_vector
func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[ZERO:.+]] = spv.Constant dense<0.000000e+00>
// CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00>
// CHECK: %[[HALF:.+]] = spv.Constant dense<5.000000e-01>
// CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0
// CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]]
// CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]]
// CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
// CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]]
// CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]]
// CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]]
%0 = math.round %x : vector<4xf32>
return %0: vector<4xf32>
}
} // end module
// -----

View File

@ -34,6 +34,8 @@ func.func @float32_unary_scalar(%arg0: f32) {
%11 = math.floor %arg0 : f32
// CHECK: spv.OCL.erf %{{.*}}: f32
%12 = math.erf %arg0 : f32
// CHECK: spv.OCL.round %{{.*}}: f32
%13 = math.round %arg0 : f32
return
}