forked from OSchip/llvm-project
[mlir] Add math polynomial approximation pass
This gives ~30x speedup compared to expanding Tanh into exp operations: ``` name old cpu/op new cpu/op delta BM_mlir_Tanh_f32/10 253ns ± 3% 55ns ± 7% -78.35% (p=0.000 n=44+41) BM_mlir_Tanh_f32/100 2.21µs ± 4% 0.14µs ± 8% -93.85% (p=0.000 n=48+49) BM_mlir_Tanh_f32/1k 22.6µs ± 4% 0.7µs ± 5% -96.68% (p=0.000 n=32+42) BM_mlir_Tanh_f32/10k 225µs ± 5% 7µs ± 6% -96.88% (p=0.000 n=49+55) name old time/op new time/op delta BM_mlir_Tanh_f32/10 259ns ± 1% 56ns ± 2% -78.31% (p=0.000 n=41+39) BM_mlir_Tanh_f32/100 2.27µs ± 1% 0.14µs ± 5% -93.89% (p=0.000 n=46+49) BM_mlir_Tanh_f32/1k 22.9µs ± 1% 0.8µs ± 4% -96.67% (p=0.000 n=30+42) BM_mlir_Tanh_f32/10k 230µs ± 0% 7µs ± 3% -96.88% (p=0.000 n=37+55) ``` This approximations is based on Eigen::generic_fast_tanh function Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D96739
This commit is contained in:
parent
0923a60ea7
commit
f99ccf6516
|
@ -19,6 +19,9 @@ class OwningRewritePatternList;
|
|||
void populateExpandTanhPattern(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
void populateMathPolynomialApproximationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRMathTransforms
|
||||
ExpandTanh.cpp
|
||||
PolynomialApproximation.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements expansion of math operations to fast approximations
|
||||
// that do not rely on any of the library functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
static bool isValidFloatType(Type type) {
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
return vectorType.getElementType().isa<FloatType>();
|
||||
return type.isa<FloatType>();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// A PatternRewriter wrapper that provides concise API for building expansions
|
||||
// for operations on float scalars or vectors.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
namespace {
|
||||
class FloatApproximationBuilder {
|
||||
public:
|
||||
FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter);
|
||||
|
||||
Value constant(double value) const;
|
||||
|
||||
Value abs(Value a) const;
|
||||
Value min(Value a, Value b) const;
|
||||
Value max(Value a, Value b) const;
|
||||
Value mul(Value a, Value b) const;
|
||||
Value div(Value a, Value b) const;
|
||||
|
||||
// Fused multiple-add operation: a * b + c.
|
||||
Value madd(Value a, Value b, Value c) const;
|
||||
|
||||
// Compares values `a` and `b` with the given `predicate`.
|
||||
Value cmp(CmpFPredicate predicate, Value a, Value b) const;
|
||||
|
||||
// Selects values from `a` or `b` based on the `predicate`.
|
||||
Value select(Value predicate, Value a, Value b) const;
|
||||
|
||||
private:
|
||||
Location loc;
|
||||
PatternRewriter &rewriter;
|
||||
VectorType vectorType; // can be null for scalar type
|
||||
FloatType elementType;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type,
|
||||
PatternRewriter &rewriter)
|
||||
: loc(loc), rewriter(rewriter) {
|
||||
vectorType = type.dyn_cast<VectorType>();
|
||||
|
||||
if (vectorType)
|
||||
elementType = vectorType.getElementType().cast<FloatType>();
|
||||
else
|
||||
elementType = type.cast<FloatType>();
|
||||
}
|
||||
|
||||
Value FloatApproximationBuilder::constant(double value) const {
|
||||
auto attr = rewriter.getFloatAttr(elementType, value);
|
||||
Value scalar = rewriter.create<ConstantOp>(loc, attr);
|
||||
|
||||
if (vectorType)
|
||||
return rewriter.create<BroadcastOp>(loc, vectorType, scalar);
|
||||
return scalar;
|
||||
}
|
||||
|
||||
Value FloatApproximationBuilder::abs(Value a) const {
|
||||
return rewriter.create<AbsFOp>(loc, a);
|
||||
}
|
||||
|
||||
Value FloatApproximationBuilder::min(Value a, Value b) const {
|
||||
return select(cmp(CmpFPredicate::OLT, a, b), a, b);
|
||||
}
|
||||
Value FloatApproximationBuilder::max(Value a, Value b) const {
|
||||
return select(cmp(CmpFPredicate::OGT, a, b), a, b);
|
||||
}
|
||||
Value FloatApproximationBuilder::mul(Value a, Value b) const {
|
||||
return rewriter.create<MulFOp>(loc, a, b);
|
||||
}
|
||||
|
||||
Value FloatApproximationBuilder::div(Value a, Value b) const {
|
||||
return rewriter.create<DivFOp>(loc, a, b);
|
||||
}
|
||||
|
||||
Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const {
|
||||
return rewriter.create<FmaFOp>(loc, a, b, c);
|
||||
}
|
||||
|
||||
Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a,
|
||||
Value b) const {
|
||||
return rewriter.create<CmpFOp>(loc, predicate, a, b);
|
||||
}
|
||||
|
||||
Value FloatApproximationBuilder::select(Value predicate, Value a,
|
||||
Value b) const {
|
||||
return rewriter.create<SelectOp>(loc, predicate, a, b);
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// TanhOp approximation.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
namespace {
|
||||
struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(math::TanhOp op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
TanhApproximation::matchAndRewrite(math::TanhOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (!isValidFloatType(op.operand().getType()))
|
||||
return rewriter.notifyMatchFailure(op, "unsupported operand type");
|
||||
|
||||
Value operand = op.operand();
|
||||
FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter);
|
||||
|
||||
// Clamp operand into [plusClamp, minusClamp] range.
|
||||
Value plusClamp = builder.constant(7.90531110763549805);
|
||||
Value minusClamp = builder.constant(-7.9053111076354980);
|
||||
Value x = builder.max(builder.min(operand, plusClamp), minusClamp);
|
||||
|
||||
// Mask for tiny values that are approximated with `operand`.
|
||||
Value tiny = builder.constant(0.0004f);
|
||||
Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny);
|
||||
|
||||
// The monomial coefficients of the numerator polynomial (odd).
|
||||
Value alpha1 = builder.constant(4.89352455891786e-03);
|
||||
Value alpha3 = builder.constant(6.37261928875436e-04);
|
||||
Value alpha5 = builder.constant(1.48572235717979e-05);
|
||||
Value alpha7 = builder.constant(5.12229709037114e-08);
|
||||
Value alpha9 = builder.constant(-8.60467152213735e-11);
|
||||
Value alpha11 = builder.constant(2.00018790482477e-13);
|
||||
Value alpha13 = builder.constant(-2.76076847742355e-16);
|
||||
|
||||
// The monomial coefficients of the denominator polynomial (even).
|
||||
Value beta0 = builder.constant(4.89352518554385e-03);
|
||||
Value beta2 = builder.constant(2.26843463243900e-03);
|
||||
Value beta4 = builder.constant(1.18534705686654e-04);
|
||||
Value beta6 = builder.constant(1.19825839466702e-06);
|
||||
|
||||
// Since the polynomials are odd/even, we need x^2.
|
||||
Value x2 = builder.mul(x, x);
|
||||
|
||||
// Evaluate the numerator polynomial p.
|
||||
Value p = builder.madd(x2, alpha13, alpha11);
|
||||
p = builder.madd(x2, p, alpha9);
|
||||
p = builder.madd(x2, p, alpha7);
|
||||
p = builder.madd(x2, p, alpha5);
|
||||
p = builder.madd(x2, p, alpha3);
|
||||
p = builder.madd(x2, p, alpha1);
|
||||
p = builder.mul(x, p);
|
||||
|
||||
// Evaluate the denominator polynomial q.
|
||||
Value q = builder.madd(x2, beta6, beta4);
|
||||
q = builder.madd(x2, q, beta2);
|
||||
q = builder.madd(x2, q, beta0);
|
||||
|
||||
// Divide the numerator by the denominator.
|
||||
Value res = builder.select(tinyMask, x, builder.div(p, q));
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
void mlir::populateMathPolynomialApproximationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<TanhApproximation>(ctx);
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @tanh_scalar
|
||||
func @tanh_scalar(%arg0: f32) -> f32 {
|
||||
// CHECK-NOT: tanh
|
||||
%0 = math.tanh %arg0 : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tanh_vector
|
||||
func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
||||
// CHECK-NOT: tanh
|
||||
%0 = math.tanh %arg0 : vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
|
@ -26,6 +26,7 @@ add_mlir_library(MLIRTestTransforms
|
|||
TestLoopUnrolling.cpp
|
||||
TestNumberOfExecutions.cpp
|
||||
TestOpaqueLoc.cpp
|
||||
TestPolynomialApproximation.cpp
|
||||
TestMemRefBoundCheck.cpp
|
||||
TestMemRefDependenceCheck.cpp
|
||||
TestMemRefStrideCalculation.cpp
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
//===- TestPolynomialApproximation.cpp - Test math ops approximations -----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains test passes for expanding math operations into
|
||||
// polynomial approximations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestMathPolynomialApproximationPass
|
||||
: public PassWrapper<TestMathPolynomialApproximationPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<vector::VectorDialect, math::MathDialect>();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void TestMathPolynomialApproximationPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
populateMathPolynomialApproximationPatterns(patterns, &getContext());
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestMathPolynomialApproximationPass() {
|
||||
PassRegistration<TestMathPolynomialApproximationPass> pass(
|
||||
"test-math-polynomial-approximation",
|
||||
"Test math polynomial approximations");
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
|
@ -0,0 +1,32 @@
|
|||
// RUN: mlir-opt %s -test-math-polynomial-approximation \
|
||||
// RUN: -convert-vector-to-llvm \
|
||||
// RUN: -convert-std-to-llvm \
|
||||
// RUN: | mlir-cpu-runner \
|
||||
// RUN: -e main -entry-point-result=void -O0 \
|
||||
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
|
||||
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
|
||||
func @main() {
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Tanh.
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
// CHECK: 0.848284
|
||||
%0 = constant 1.25 : f32
|
||||
%1 = math.tanh %0 : f32
|
||||
vector.print %1 : f32
|
||||
|
||||
// CHECK: 0.244919, 0.635149, 0.761594, 0.848284
|
||||
%2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
|
||||
%3 = math.tanh %2 : vector<4xf32>
|
||||
vector.print %3 : vector<4xf32>
|
||||
|
||||
// CHECK: 0.099668, 0.197375, 0.291313, 0.379949, 0.462117, 0.53705, 0.604368, 0.664037
|
||||
%4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
|
||||
%5 = math.tanh %4 : vector<8xf32>
|
||||
vector.print %5 : vector<8xf32>
|
||||
|
||||
return
|
||||
}
|
|
@ -84,6 +84,7 @@ void registerTestLivenessPass();
|
|||
void registerTestLoopFusion();
|
||||
void registerTestLoopMappingPass();
|
||||
void registerTestLoopUnrollingPass();
|
||||
void registerTestMathPolynomialApproximationPass();
|
||||
void registerTestMemRefDependenceCheck();
|
||||
void registerTestMemRefStrideCalculation();
|
||||
void registerTestNumberOfBlockExecutionsPass();
|
||||
|
@ -157,6 +158,7 @@ void registerTestPasses() {
|
|||
test::registerTestLoopFusion();
|
||||
test::registerTestLoopMappingPass();
|
||||
test::registerTestLoopUnrollingPass();
|
||||
test::registerTestMathPolynomialApproximationPass();
|
||||
test::registerTestMemRefDependenceCheck();
|
||||
test::registerTestMemRefStrideCalculation();
|
||||
test::registerTestNumberOfBlockExecutionsPass();
|
||||
|
|
Loading…
Reference in New Issue