forked from OSchip/llvm-project
[mlir][spirv] Add path for math.round to spirv for OCL and GLSL
OpenCL's round function matches `math.round` so we can directly lower to the op, this includes adding the op definition to the SPIRV OCL ops. GLSL does not guarantee rounding direction so we include custom rounding code to guarantee correct rounding direction. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D129236
This commit is contained in:
parent
d3712b0852
commit
b9e642afd1
|
@ -110,36 +110,6 @@ class SPV_OCLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
|
|||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
|
||||
let summary = [{
|
||||
Compute the correctly rounded floating-point representation of the sum
|
||||
of c with the infinitely precise product of a and b. Rounding of
|
||||
intermediate products shall not occur. Edge case results are per the
|
||||
IEEE 754-2008 standard.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
```
|
||||
%0 = spv.OCL.fma %a, %b, %c : f32
|
||||
%1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -331,6 +301,37 @@ def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
|
||||
let summary = [{
|
||||
Compute the correctly rounded floating-point representation of the sum
|
||||
of c with the infinitely precise product of a and b. Rounding of
|
||||
intermediate products shall not occur. Edge case results are per the
|
||||
IEEE 754-2008 standard.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```mlir
|
||||
|
||||
```
|
||||
%0 = spv.OCL.fma %a, %b, %c : f32
|
||||
%1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
|
||||
let summary = "Compute the natural logarithm of x.";
|
||||
|
||||
|
@ -392,6 +393,38 @@ def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_OCLRoundOp : SPV_OCLUnaryArithmeticOp<"round", 55, SPV_Float> {
|
||||
let summary = [{
|
||||
Return the integral value nearest to x rounding halfway cases away from
|
||||
zero, regardless of the current rounding direction.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type and x must be floating-point or vector(2,3,4,8,16) of
|
||||
floating-point values.
|
||||
|
||||
All of the operands, including the Result Type operand, must be of the
|
||||
same type.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
float-scalar-vector-type ::= float-type |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
round-op ::= ssa-id `=` `spv.OCL.round` ssa-use `:`
|
||||
float-scalar-vector-type
|
||||
```
|
||||
#### Example:
|
||||
|
||||
```mlir
|
||||
%2 = spv.OCL.round %0 : f32
|
||||
%3 = spv.OCL.round %0 : vector<3xf16>
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> {
|
||||
let summary = "Compute inverse square root of x.";
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
|
@ -233,6 +234,43 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Converts math.round to GLSL SPIRV extended ops.
|
||||
struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = roundOp.getLoc();
|
||||
auto operand = roundOp.getOperand();
|
||||
auto ty = operand.getType();
|
||||
auto ety = getElementTypeOrSelf(ty);
|
||||
|
||||
auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
|
||||
auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
|
||||
Value half;
|
||||
if (VectorType vty = ty.dyn_cast<VectorType>()) {
|
||||
half = rewriter.create<spirv::ConstantOp>(
|
||||
loc, vty,
|
||||
DenseElementsAttr::get(vty,
|
||||
rewriter.getFloatAttr(ety, 0.5).getValue()));
|
||||
} else {
|
||||
half = rewriter.create<spirv::ConstantOp>(
|
||||
loc, ty, rewriter.getFloatAttr(ety, 0.5));
|
||||
}
|
||||
|
||||
auto abs = rewriter.create<spirv::GLSLFAbsOp>(loc, operand);
|
||||
auto floor = rewriter.create<spirv::GLSLFloorOp>(loc, abs);
|
||||
auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
|
||||
auto greater =
|
||||
rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
|
||||
auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
|
||||
auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
|
||||
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -248,7 +286,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
// GLSL patterns
|
||||
patterns
|
||||
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
|
||||
ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern,
|
||||
ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern, RoundOpPattern,
|
||||
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
|
||||
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
|
||||
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
|
||||
|
@ -273,6 +311,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
|
||||
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
|
||||
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
|
||||
spirv::ElementwiseOpPattern<math::RoundOp, spirv::OCLRoundOp>,
|
||||
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
|
||||
spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
|
||||
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
|
||||
|
|
|
@ -145,6 +145,38 @@ func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32
|
|||
return %0: vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @round_scalar
|
||||
func.func @round_scalar(%x: f32) -> f32 {
|
||||
// CHECK: %[[ZERO:.+]] = spv.Constant 0.000000e+00
|
||||
// CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00
|
||||
// CHECK: %[[HALF:.+]] = spv.Constant 5.000000e-01
|
||||
// CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0
|
||||
// CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]]
|
||||
// CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]]
|
||||
// CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
|
||||
// CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]]
|
||||
// CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]]
|
||||
// CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]]
|
||||
%0 = math.round %x : f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @round_vector
|
||||
func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> {
|
||||
// CHECK: %[[ZERO:.+]] = spv.Constant dense<0.000000e+00>
|
||||
// CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00>
|
||||
// CHECK: %[[HALF:.+]] = spv.Constant dense<5.000000e-01>
|
||||
// CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0
|
||||
// CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]]
|
||||
// CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]]
|
||||
// CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
|
||||
// CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]]
|
||||
// CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]]
|
||||
// CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]]
|
||||
%0 = math.round %x : vector<4xf32>
|
||||
return %0: vector<4xf32>
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
|
|
@ -34,6 +34,8 @@ func.func @float32_unary_scalar(%arg0: f32) {
|
|||
%11 = math.floor %arg0 : f32
|
||||
// CHECK: spv.OCL.erf %{{.*}}: f32
|
||||
%12 = math.erf %arg0 : f32
|
||||
// CHECK: spv.OCL.round %{{.*}}: f32
|
||||
%13 = math.round %arg0 : f32
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue