forked from OSchip/llvm-project
[mlir] Add polynomial approximation for math::ExpM1
This approximation matches the one in Eigen. ``` name old cpu/op new cpu/op delta BM_mlir_Expm1_f32/10 90.9ns ± 4% 52.2ns ± 4% -42.60% (p=0.000 n=74+87) BM_mlir_Expm1_f32/100 837ns ± 3% 231ns ± 4% -72.43% (p=0.000 n=79+69) BM_mlir_Expm1_f32/1k 8.43µs ± 3% 1.58µs ± 5% -81.30% (p=0.000 n=77+83) BM_mlir_Expm1_f32/10k 83.8µs ± 3% 15.4µs ± 5% -81.65% (p=0.000 n=83+69) BM_eigen_s_Expm1_f32/10 68.8ns ±17% 72.5ns ±14% +5.40% (p=0.000 n=118+115) BM_eigen_s_Expm1_f32/100 694ns ±11% 717ns ± 2% +3.34% (p=0.000 n=120+75) BM_eigen_s_Expm1_f32/1k 7.69µs ± 2% 7.97µs ±11% +3.56% (p=0.000 n=95+117) BM_eigen_s_Expm1_f32/10k 88.0µs ± 1% 89.3µs ± 6% +1.45% (p=0.000 n=74+106) BM_eigen_v_Expm1_f32/10 44.3ns ± 6% 45.0ns ± 8% +1.45% (p=0.018 n=81+111) BM_eigen_v_Expm1_f32/100 351ns ± 1% 360ns ± 9% +2.58% (p=0.000 n=73+99) BM_eigen_v_Expm1_f32/1k 3.31µs ± 1% 3.42µs ± 9% +3.37% (p=0.000 n=71+100) BM_eigen_v_Expm1_f32/10k 33.7µs ± 8% 34.1µs ± 9% +1.04% (p=0.007 n=99+98) ``` Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D101852
This commit is contained in:
parent
a11489ae3e
commit
0edc4bc84a
|
@ -576,10 +576,65 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// ExpM1 approximation.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
namespace {
|
||||
|
||||
struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(math::ExpM1Op op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto width = vectorWidth(op.operand().getType(), isF32);
|
||||
if (!width.hasValue())
|
||||
return rewriter.notifyMatchFailure(op, "unsupported operand type");
|
||||
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
||||
auto bcast = [&](Value value) -> Value {
|
||||
return broadcast(builder, value, *width);
|
||||
};
|
||||
|
||||
// expm1(x) = exp(x) - 1 = u - 1.
|
||||
// We have to handle it carefully when x is near 0, i.e. u ~= 1,
|
||||
// and when the input is ~= -inf, i.e. u - 1 ~= -1.
|
||||
Value cstOne = bcast(f32Cst(builder, 1.0f));
|
||||
Value cstNegOne = bcast(f32Cst(builder, -1.0f));
|
||||
Value x = op.operand();
|
||||
Value u = builder.create<math::ExpOp>(x);
|
||||
Value uEqOne = builder.create<CmpFOp>(CmpFPredicate::OEQ, u, cstOne);
|
||||
Value uMinusOne = builder.create<SubFOp>(u, cstOne);
|
||||
Value uMinusOneEqNegOne =
|
||||
builder.create<CmpFOp>(CmpFPredicate::OEQ, uMinusOne, cstNegOne);
|
||||
// logU = log(u) ~= x
|
||||
Value logU = builder.create<math::LogOp>(u);
|
||||
|
||||
// Detect exp(x) = +inf; written this way to avoid having to form +inf.
|
||||
Value isInf = builder.create<CmpFOp>(CmpFPredicate::OEQ, logU, u);
|
||||
|
||||
// (u - 1) * (x / ~x)
|
||||
Value expm1 =
|
||||
builder.create<MulFOp>(uMinusOne, builder.create<DivFOp>(x, logU));
|
||||
expm1 = builder.create<SelectOp>(isInf, u, expm1);
|
||||
Value approximation = builder.create<SelectOp>(
|
||||
uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
|
||||
rewriter.replaceOp(op, approximation);
|
||||
return success();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
void mlir::populateMathPolynomialApproximationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
|
||||
Log1pApproximation, ExpApproximation>(patterns.getContext());
|
||||
Log1pApproximation, ExpApproximation, ExpM1Approximation>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
|
|
@ -11,7 +11,10 @@ func @scalar(%arg0: f32) -> f32 {
|
|||
%1 = math.log %0 : f32
|
||||
%2 = math.log2 %1 : f32
|
||||
%3 = math.log1p %2 : f32
|
||||
return %3 : f32
|
||||
// CHECK-NOT: exp
|
||||
%4 = math.exp %3 : f32
|
||||
%5 = math.expm1 %4 : f32
|
||||
return %5 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector
|
||||
|
@ -22,18 +25,8 @@ func @vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
|||
%1 = math.log %0 : vector<8xf32>
|
||||
%2 = math.log2 %1 : vector<8xf32>
|
||||
%3 = math.log1p %2 : vector<8xf32>
|
||||
return %3 : vector<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp_scalar
|
||||
func @exp_scalar(%arg0: f32) -> f32 {
|
||||
%0 = math.exp %arg0 : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp_vector
|
||||
func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
||||
// CHECK-NOT: math.exp
|
||||
%0 = math.exp %arg0 : vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
// CHECK-NOT: exp
|
||||
%4 = math.exp %3 : vector<8xf32>
|
||||
%5 = math.expm1 %4 : vector<8xf32>
|
||||
return %5 : vector<8xf32>
|
||||
}
|
||||
|
|
|
@ -186,11 +186,46 @@ func @exp() {
|
|||
return
|
||||
}
|
||||
|
||||
func @expm1() {
|
||||
// CHECK: 1e-10
|
||||
%0 = constant 1.0e-10 : f32
|
||||
%1 = math.expm1 %0 : f32
|
||||
vector.print %1 : f32
|
||||
|
||||
// CHECK: -0.00995016, 0.0100502, 0.648721, 6.38905
|
||||
%2 = constant dense<[-0.01, 0.01, 0.5, 2.0]> : vector<4xf32>
|
||||
%3 = math.expm1 %2 : vector<4xf32>
|
||||
vector.print %3 : vector<4xf32>
|
||||
|
||||
// CHECK: -0.181269, 0, 0.221403, 0.491825, 0.822119, 1.22554, 1.71828, 2.32012
|
||||
%4 = constant dense<[-0.2, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2]> : vector<8xf32>
|
||||
%5 = math.expm1 %4 : vector<8xf32>
|
||||
vector.print %5 : vector<8xf32>
|
||||
|
||||
// CHECK: -1
|
||||
%neg_inf = constant 0xff800000 : f32
|
||||
%expm1_neg_inf = math.expm1 %neg_inf : f32
|
||||
vector.print %expm1_neg_inf : f32
|
||||
|
||||
// CHECK: inf
|
||||
%inf = constant 0x7f800000 : f32
|
||||
%expm1_inf = math.expm1 %inf : f32
|
||||
vector.print %expm1_inf : f32
|
||||
|
||||
// CHECK: -1, inf, 1e-10
|
||||
%special_vec = constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
|
||||
%log_special_vec = math.expm1 %special_vec : vector<3xf32>
|
||||
vector.print %log_special_vec : vector<3xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @main() {
|
||||
call @tanh(): () -> ()
|
||||
call @log(): () -> ()
|
||||
call @log2(): () -> ()
|
||||
call @log1p(): () -> ()
|
||||
call @exp(): () -> ()
|
||||
call @expm1(): () -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue