[mlir] Add math polynomial approximation pass

This gives ~30x speedup compared to expanding Tanh into exp operations:

```
name                  old cpu/op  new cpu/op  delta
BM_mlir_Tanh_f32/10    253ns ± 3%    55ns ± 7%  -78.35%  (p=0.000 n=44+41)
BM_mlir_Tanh_f32/100  2.21µs ± 4%  0.14µs ± 8%  -93.85%  (p=0.000 n=48+49)
BM_mlir_Tanh_f32/1k   22.6µs ± 4%   0.7µs ± 5%  -96.68%  (p=0.000 n=32+42)
BM_mlir_Tanh_f32/10k   225µs ± 5%     7µs ± 6%  -96.88%  (p=0.000 n=49+55)

name                  old time/op             new time/op             delta
BM_mlir_Tanh_f32/10    259ns ± 1%               56ns ± 2%  -78.31%        (p=0.000 n=41+39)
BM_mlir_Tanh_f32/100  2.27µs ± 1%             0.14µs ± 5%  -93.89%        (p=0.000 n=46+49)
BM_mlir_Tanh_f32/1k   22.9µs ± 1%              0.8µs ± 4%  -96.67%        (p=0.000 n=30+42)
BM_mlir_Tanh_f32/10k   230µs ± 0%                7µs ± 3%  -96.88%        (p=0.000 n=37+55)
```

This approximations is based on Eigen::generic_fast_tanh function

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96739
This commit is contained in:
Eugene Zhulenev 2021-02-18 16:24:56 -08:00
parent 0923a60ea7
commit f99ccf6516
8 changed files with 294 additions and 0 deletions

View File

@ -19,6 +19,9 @@ class OwningRewritePatternList;
void populateExpandTanhPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx);
void populateMathPolynomialApproximationPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);
} // namespace mlir
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_

View File

@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
ExpandTanh.cpp
PolynomialApproximation.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms

View File

@ -0,0 +1,194 @@
//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of math operations to fast approximations
// that do not rely on any of the library functions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
static bool isValidFloatType(Type type) {
if (auto vectorType = type.dyn_cast<VectorType>())
return vectorType.getElementType().isa<FloatType>();
return type.isa<FloatType>();
}
//----------------------------------------------------------------------------//
// A PatternRewriter wrapper that provides concise API for building expansions
// for operations on float scalars or vectors.
//----------------------------------------------------------------------------//
namespace {
class FloatApproximationBuilder {
public:
FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter);
Value constant(double value) const;
Value abs(Value a) const;
Value min(Value a, Value b) const;
Value max(Value a, Value b) const;
Value mul(Value a, Value b) const;
Value div(Value a, Value b) const;
// Fused multiple-add operation: a * b + c.
Value madd(Value a, Value b, Value c) const;
// Compares values `a` and `b` with the given `predicate`.
Value cmp(CmpFPredicate predicate, Value a, Value b) const;
// Selects values from `a` or `b` based on the `predicate`.
Value select(Value predicate, Value a, Value b) const;
private:
Location loc;
PatternRewriter &rewriter;
VectorType vectorType; // can be null for scalar type
FloatType elementType;
};
} // namespace
FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type,
PatternRewriter &rewriter)
: loc(loc), rewriter(rewriter) {
vectorType = type.dyn_cast<VectorType>();
if (vectorType)
elementType = vectorType.getElementType().cast<FloatType>();
else
elementType = type.cast<FloatType>();
}
Value FloatApproximationBuilder::constant(double value) const {
auto attr = rewriter.getFloatAttr(elementType, value);
Value scalar = rewriter.create<ConstantOp>(loc, attr);
if (vectorType)
return rewriter.create<BroadcastOp>(loc, vectorType, scalar);
return scalar;
}
Value FloatApproximationBuilder::abs(Value a) const {
return rewriter.create<AbsFOp>(loc, a);
}
Value FloatApproximationBuilder::min(Value a, Value b) const {
return select(cmp(CmpFPredicate::OLT, a, b), a, b);
}
Value FloatApproximationBuilder::max(Value a, Value b) const {
return select(cmp(CmpFPredicate::OGT, a, b), a, b);
}
Value FloatApproximationBuilder::mul(Value a, Value b) const {
return rewriter.create<MulFOp>(loc, a, b);
}
Value FloatApproximationBuilder::div(Value a, Value b) const {
return rewriter.create<DivFOp>(loc, a, b);
}
Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const {
return rewriter.create<FmaFOp>(loc, a, b, c);
}
Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a,
Value b) const {
return rewriter.create<CmpFOp>(loc, predicate, a, b);
}
Value FloatApproximationBuilder::select(Value predicate, Value a,
Value b) const {
return rewriter.create<SelectOp>(loc, predicate, a, b);
}
//----------------------------------------------------------------------------//
// TanhOp approximation.
//----------------------------------------------------------------------------//
namespace {
struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::TanhOp op,
PatternRewriter &rewriter) const final;
};
} // namespace
LogicalResult
TanhApproximation::matchAndRewrite(math::TanhOp op,
PatternRewriter &rewriter) const {
if (!isValidFloatType(op.operand().getType()))
return rewriter.notifyMatchFailure(op, "unsupported operand type");
Value operand = op.operand();
FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter);
// Clamp operand into [plusClamp, minusClamp] range.
Value plusClamp = builder.constant(7.90531110763549805);
Value minusClamp = builder.constant(-7.9053111076354980);
Value x = builder.max(builder.min(operand, plusClamp), minusClamp);
// Mask for tiny values that are approximated with `operand`.
Value tiny = builder.constant(0.0004f);
Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny);
// The monomial coefficients of the numerator polynomial (odd).
Value alpha1 = builder.constant(4.89352455891786e-03);
Value alpha3 = builder.constant(6.37261928875436e-04);
Value alpha5 = builder.constant(1.48572235717979e-05);
Value alpha7 = builder.constant(5.12229709037114e-08);
Value alpha9 = builder.constant(-8.60467152213735e-11);
Value alpha11 = builder.constant(2.00018790482477e-13);
Value alpha13 = builder.constant(-2.76076847742355e-16);
// The monomial coefficients of the denominator polynomial (even).
Value beta0 = builder.constant(4.89352518554385e-03);
Value beta2 = builder.constant(2.26843463243900e-03);
Value beta4 = builder.constant(1.18534705686654e-04);
Value beta6 = builder.constant(1.19825839466702e-06);
// Since the polynomials are odd/even, we need x^2.
Value x2 = builder.mul(x, x);
// Evaluate the numerator polynomial p.
Value p = builder.madd(x2, alpha13, alpha11);
p = builder.madd(x2, p, alpha9);
p = builder.madd(x2, p, alpha7);
p = builder.madd(x2, p, alpha5);
p = builder.madd(x2, p, alpha3);
p = builder.madd(x2, p, alpha1);
p = builder.mul(x, p);
// Evaluate the denominator polynomial q.
Value q = builder.madd(x2, beta6, beta4);
q = builder.madd(x2, q, beta2);
q = builder.madd(x2, q, beta0);
// Divide the numerator by the denominator.
Value res = builder.select(tinyMask, x, builder.div(p, q));
rewriter.replaceOp(op, res);
return success();
}
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<TanhApproximation>(ctx);
}

View File

@ -0,0 +1,15 @@
// RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s
// CHECK-LABEL: @tanh_scalar
func @tanh_scalar(%arg0: f32) -> f32 {
// CHECK-NOT: tanh
%0 = math.tanh %arg0 : f32
return %0 : f32
}
// CHECK-LABEL: @tanh_vector
func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-NOT: tanh
%0 = math.tanh %arg0 : vector<8xf32>
return %0 : vector<8xf32>
}

View File

@ -26,6 +26,7 @@ add_mlir_library(MLIRTestTransforms
TestLoopUnrolling.cpp
TestNumberOfExecutions.cpp
TestOpaqueLoc.cpp
TestPolynomialApproximation.cpp
TestMemRefBoundCheck.cpp
TestMemRefDependenceCheck.cpp
TestMemRefStrideCalculation.cpp

View File

@ -0,0 +1,46 @@
//===- TestPolynomialApproximation.cpp - Test math ops approximations -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains test passes for expanding math operations into
// polynomial approximations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestMathPolynomialApproximationPass
: public PassWrapper<TestMathPolynomialApproximationPass, FunctionPass> {
void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect, math::MathDialect>();
}
};
} // end anonymous namespace
void TestMathPolynomialApproximationPass::runOnFunction() {
OwningRewritePatternList patterns;
populateMathPolynomialApproximationPatterns(patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
namespace mlir {
namespace test {
void registerTestMathPolynomialApproximationPass() {
PassRegistration<TestMathPolynomialApproximationPass> pass(
"test-math-polynomial-approximation",
"Test math polynomial approximations");
}
} // namespace test
} // namespace mlir

View File

@ -0,0 +1,32 @@
// RUN: mlir-opt %s -test-math-polynomial-approximation \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: -e main -entry-point-result=void -O0 \
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
func @main() {
// ------------------------------------------------------------------------ //
// Tanh.
// ------------------------------------------------------------------------ //
// CHECK: 0.848284
%0 = constant 1.25 : f32
%1 = math.tanh %0 : f32
vector.print %1 : f32
// CHECK: 0.244919, 0.635149, 0.761594, 0.848284
%2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
%3 = math.tanh %2 : vector<4xf32>
vector.print %3 : vector<4xf32>
// CHECK: 0.099668, 0.197375, 0.291313, 0.379949, 0.462117, 0.53705, 0.604368, 0.664037
%4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
%5 = math.tanh %4 : vector<8xf32>
vector.print %5 : vector<8xf32>
return
}

View File

@ -84,6 +84,7 @@ void registerTestLivenessPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestNumberOfBlockExecutionsPass();
@ -157,6 +158,7 @@ void registerTestPasses() {
test::registerTestLoopFusion();
test::registerTestLoopMappingPass();
test::registerTestLoopUnrollingPass();
test::registerTestMathPolynomialApproximationPass();
test::registerTestMemRefDependenceCheck();
test::registerTestMemRefStrideCalculation();
test::registerTestNumberOfBlockExecutionsPass();