[mlir] Add polynomial approximation for math::ExpOp

Similar to fast_exp in https://github.com/boulos/syrah

Differential Revision: https://reviews.llvm.org/D97599
This commit is contained in:
Ahmed Taei 2021-02-26 17:05:44 -08:00
parent 74c883f7e5
commit ea7f211b2e
3 changed files with 200 additions and 5 deletions

View File

@ -10,7 +10,6 @@
// that do not rely on any of the library functions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -20,6 +19,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <limits.h>
using namespace mlir;
using namespace mlir::vector;
@ -28,6 +28,8 @@ using TypePredicate = llvm::function_ref<bool(Type)>;
static bool isF32(Type type) { return type.isF32(); }
static bool isI32(Type type) { return type.isInteger(32); }
// Returns vector width if the element type is matching the predicate (scalars
// that do match the predicate have width equal to `1`).
static Optional<int> vectorWidth(Type type, TypePredicate pred) {
@ -153,6 +155,30 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
return {normalizedFraction, exponent};
}
// Computes exp2 for an i32 argument.
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
int width = vectorWidth(arg.getType());
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, width);
};
auto f32Vec = broadcast(builder.getF32Type(), width);
// The exponent of f32 located at 23-bit.
auto exponetBitLocation = bcast(i32Cst(builder, 23));
// Set the exponent bias to zero.
auto bias = bcast(i32Cst(builder, 127));
Value biasedArg = builder.create<AddIOp>(arg, bias);
Value exp2ValueInt =
builder.create<ShiftLeftOp>(biasedArg, exponetBitLocation);
Value exp2ValueF32 = builder.create<LLVM::BitcastOp>(f32Vec, exp2ValueInt);
return exp2ValueF32;
}
//----------------------------------------------------------------------------//
// TanhOp approximation.
//----------------------------------------------------------------------------//
@ -230,6 +256,11 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
return success();
}
#define LN2_VALUE \
0.693147180559945309417232121458176568075500134360255254120680009493393621L
#define LN2E_VALUE \
1.442695040888963407359924681001892137426645954152985934135449406931109219L
//----------------------------------------------------------------------------//
// LogOp approximation.
//----------------------------------------------------------------------------//
@ -247,9 +278,6 @@ public:
};
} // namespace
#define LN2_VALUE \
0.693147180559945309417232121458176568075500134360255254120680009493393621L
LogicalResult
LogApproximation::matchAndRewrite(math::LogOp op,
PatternRewriter &rewriter) const {
@ -353,9 +381,125 @@ LogApproximation::matchAndRewrite(math::LogOp op,
return success();
}
//----------------------------------------------------------------------------//
// Exp approximation.
//----------------------------------------------------------------------------//
namespace {
struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::ExpOp op,
PatternRewriter &rewriter) const final;
};
} // namespace
// Approximate exp(x) using its reduced range exp(y) where y is in the range
// [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
// = exp(y) * 2^k. exp(y).
LogicalResult
ExpApproximation::matchAndRewrite(math::ExpOp op,
PatternRewriter &rewriter) const {
auto width = vectorWidth(op.operand().getType(), isF32);
if (!width.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
// TODO: Consider a common pattern rewriter with all methods below to
// write the approximations.
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, *width);
};
auto fmla = [&](Value a, Value b, Value c) {
return builder.create<FmaFOp>(a, b, c);
};
auto mul = [&](Value a, Value b) -> Value {
return builder.create<MulFOp>(a, b);
};
auto sub = [&](Value a, Value b) -> Value {
return builder.create<SubFOp>(a, b);
};
auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
Value cstLN2E = bcast(f32Cst(builder, static_cast<float>(LN2E_VALUE)));
// Polynomial coefficients.
Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
Value x = op.operand();
// Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
Value xL2Inv = mul(x, cstLN2E);
Value kF32 = floor(xL2Inv);
Value kLn2 = mul(kF32, cstLn2);
Value y = sub(x, kLn2);
// Use Estrin's evaluation scheme with 3 independent parts:
// P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
Value y2 = mul(y, y);
Value y4 = mul(y2, y2);
Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
Value expY = fmla(q1, y2, q0);
expY = fmla(q2, y4, expY);
auto i32Vec = broadcast(builder.getI32Type(), *width);
// exp2(k)
Value k = builder.create<FPToSIOp>(kF32, i32Vec);
Value exp2KValue = exp2I32(builder, k);
// exp(x) = exp(y) * exp2(k)
expY = mul(expY, exp2KValue);
// Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
// partitioned as the following:
// exp(x) = 0, x <= -inf
// exp(x) = underflow (min_float), x <= -88
// exp(x) = inf (min_float), x >= 88
// Note: |k| = 127 is the value where the 8-bits exponent saturates.
Value zerof32Const = bcast(f32Cst(builder, 0));
auto constPosInfinity =
bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
auto constNegIfinity =
bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
Value kMaxConst = bcast(i32Cst(builder, 127));
Value kMaxNegConst = bcast(i32Cst(builder, -127));
Value rightBound = builder.create<CmpIOp>(CmpIPredicate::sle, k, kMaxConst);
Value leftBound = builder.create<CmpIOp>(CmpIPredicate::sge, k, kMaxNegConst);
Value isNegInfinityX =
builder.create<CmpFOp>(CmpFPredicate::OEQ, x, constNegIfinity);
Value isPostiveX =
builder.create<CmpFOp>(CmpFPredicate::OGT, x, zerof32Const);
Value isComputable = builder.create<AndOp>(rightBound, leftBound);
expY = builder.create<SelectOp>(
isComputable, expY,
builder.create<SelectOp>(
isPostiveX, constPosInfinity,
builder.create<SelectOp>(isNegInfinityX, zerof32Const, underflow)));
rewriter.replaceOp(op, expY);
return success();
}
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<TanhApproximation, LogApproximation>(ctx);
patterns.insert<TanhApproximation, LogApproximation, ExpApproximation>(ctx);
}

View File

@ -20,3 +20,16 @@ func @vector(%arg0: vector<8xf32>) -> vector<8xf32> {
%1 = math.log %0 : vector<8xf32>
return %1 : vector<8xf32>
}
// CHECK-LABEL: @exp_scalar
func @exp_scalar(%arg0: f32) -> f32 {
%0 = math.exp %arg0 : f32
return %0 : f32
}
// CHECK-LABEL: @exp_vector
func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-NOT: math.exp
%0 = math.exp %arg0 : vector<8xf32>
return %0 : vector<8xf32>
}

View File

@ -71,8 +71,46 @@ func @log() {
return
}
// -------------------------------------------------------------------------- //
// Log.
// -------------------------------------------------------------------------- //
func @exp() {
// CHECK: 2.71828
%0 = constant 1.0 : f32
%1 = math.exp %0 : f32
vector.print %1 : f32
// CHECK: 0.778802, 2.117, 2.71828, 3.85742
%2 = constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32>
%3 = math.exp %2 : vector<4xf32>
vector.print %3 : vector<4xf32>
// CHECK: 1
%zero = constant 0.0 : f32
%exp_zero = math.exp %zero : f32
vector.print %exp_zero : f32
// CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf
%special_vec = constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32>
%exp_special_vec = math.exp %special_vec : vector<4xf32>
vector.print %exp_special_vec : vector<4xf32>
// CHECK: inf
%inf = constant 0x7f800000 : f32
%exp_inf = math.exp %inf : f32
vector.print %exp_inf : f32
// CHECK: 0
%negative_inf = constant 0xff800000 : f32
%exp_negative_inf = math.exp %negative_inf : f32
vector.print %exp_negative_inf : f32
return
}
func @main() {
call @tanh(): () -> ()
call @log(): () -> ()
call @exp(): () -> ()
return
}