forked from OSchip/llvm-project
[mlir] Add polynomial approximation for math::ExpOp
Similar to fast_exp in https://github.com/boulos/syrah Differential Revision: https://reviews.llvm.org/D97599
This commit is contained in:
parent
74c883f7e5
commit
ea7f211b2e
|
@ -10,7 +10,6 @@
|
|||
// that do not rely on any of the library functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
|
@ -20,6 +19,7 @@
|
|||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include <limits.h>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
@ -28,6 +28,8 @@ using TypePredicate = llvm::function_ref<bool(Type)>;
|
|||
|
||||
static bool isF32(Type type) { return type.isF32(); }
|
||||
|
||||
static bool isI32(Type type) { return type.isInteger(32); }
|
||||
|
||||
// Returns vector width if the element type is matching the predicate (scalars
|
||||
// that do match the predicate have width equal to `1`).
|
||||
static Optional<int> vectorWidth(Type type, TypePredicate pred) {
|
||||
|
@ -153,6 +155,30 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
|
|||
return {normalizedFraction, exponent};
|
||||
}
|
||||
|
||||
// Computes exp2 for an i32 argument.
|
||||
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
|
||||
assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
|
||||
|
||||
int width = vectorWidth(arg.getType());
|
||||
|
||||
auto bcast = [&](Value value) -> Value {
|
||||
return broadcast(builder, value, width);
|
||||
};
|
||||
|
||||
auto f32Vec = broadcast(builder.getF32Type(), width);
|
||||
// The exponent of f32 located at 23-bit.
|
||||
auto exponetBitLocation = bcast(i32Cst(builder, 23));
|
||||
// Set the exponent bias to zero.
|
||||
auto bias = bcast(i32Cst(builder, 127));
|
||||
|
||||
Value biasedArg = builder.create<AddIOp>(arg, bias);
|
||||
Value exp2ValueInt =
|
||||
builder.create<ShiftLeftOp>(biasedArg, exponetBitLocation);
|
||||
Value exp2ValueF32 = builder.create<LLVM::BitcastOp>(f32Vec, exp2ValueInt);
|
||||
|
||||
return exp2ValueF32;
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// TanhOp approximation.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
@ -230,6 +256,11 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
#define LN2_VALUE \
|
||||
0.693147180559945309417232121458176568075500134360255254120680009493393621L
|
||||
#define LN2E_VALUE \
|
||||
1.442695040888963407359924681001892137426645954152985934135449406931109219L
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// LogOp approximation.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
@ -247,9 +278,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
#define LN2_VALUE \
|
||||
0.693147180559945309417232121458176568075500134360255254120680009493393621L
|
||||
|
||||
LogicalResult
|
||||
LogApproximation::matchAndRewrite(math::LogOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
|
@ -353,9 +381,125 @@ LogApproximation::matchAndRewrite(math::LogOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Exp approximation.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
namespace {
|
||||
|
||||
struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(math::ExpOp op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Approximate exp(x) using its reduced range exp(y) where y is in the range
|
||||
// [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
|
||||
// = exp(y) * 2^k. exp(y).
|
||||
LogicalResult
|
||||
ExpApproximation::matchAndRewrite(math::ExpOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto width = vectorWidth(op.operand().getType(), isF32);
|
||||
if (!width.hasValue())
|
||||
return rewriter.notifyMatchFailure(op, "unsupported operand type");
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
||||
|
||||
// TODO: Consider a common pattern rewriter with all methods below to
|
||||
// write the approximations.
|
||||
auto bcast = [&](Value value) -> Value {
|
||||
return broadcast(builder, value, *width);
|
||||
};
|
||||
auto fmla = [&](Value a, Value b, Value c) {
|
||||
return builder.create<FmaFOp>(a, b, c);
|
||||
};
|
||||
auto mul = [&](Value a, Value b) -> Value {
|
||||
return builder.create<MulFOp>(a, b);
|
||||
};
|
||||
auto sub = [&](Value a, Value b) -> Value {
|
||||
return builder.create<SubFOp>(a, b);
|
||||
};
|
||||
auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
|
||||
|
||||
Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
|
||||
Value cstLN2E = bcast(f32Cst(builder, static_cast<float>(LN2E_VALUE)));
|
||||
|
||||
// Polynomial coefficients.
|
||||
Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
|
||||
Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
|
||||
Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
|
||||
Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
|
||||
Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
|
||||
Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
|
||||
|
||||
Value x = op.operand();
|
||||
|
||||
// Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
|
||||
Value xL2Inv = mul(x, cstLN2E);
|
||||
Value kF32 = floor(xL2Inv);
|
||||
Value kLn2 = mul(kF32, cstLn2);
|
||||
Value y = sub(x, kLn2);
|
||||
|
||||
// Use Estrin's evaluation scheme with 3 independent parts:
|
||||
// P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
|
||||
Value y2 = mul(y, y);
|
||||
Value y4 = mul(y2, y2);
|
||||
|
||||
Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
|
||||
Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
|
||||
Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
|
||||
Value expY = fmla(q1, y2, q0);
|
||||
expY = fmla(q2, y4, expY);
|
||||
|
||||
auto i32Vec = broadcast(builder.getI32Type(), *width);
|
||||
|
||||
// exp2(k)
|
||||
Value k = builder.create<FPToSIOp>(kF32, i32Vec);
|
||||
Value exp2KValue = exp2I32(builder, k);
|
||||
|
||||
// exp(x) = exp(y) * exp2(k)
|
||||
expY = mul(expY, exp2KValue);
|
||||
|
||||
// Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
|
||||
// partitioned as the following:
|
||||
// exp(x) = 0, x <= -inf
|
||||
// exp(x) = underflow (min_float), x <= -88
|
||||
// exp(x) = inf (min_float), x >= 88
|
||||
// Note: |k| = 127 is the value where the 8-bits exponent saturates.
|
||||
Value zerof32Const = bcast(f32Cst(builder, 0));
|
||||
auto constPosInfinity =
|
||||
bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
|
||||
auto constNegIfinity =
|
||||
bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
|
||||
auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
|
||||
|
||||
Value kMaxConst = bcast(i32Cst(builder, 127));
|
||||
Value kMaxNegConst = bcast(i32Cst(builder, -127));
|
||||
Value rightBound = builder.create<CmpIOp>(CmpIPredicate::sle, k, kMaxConst);
|
||||
Value leftBound = builder.create<CmpIOp>(CmpIPredicate::sge, k, kMaxNegConst);
|
||||
|
||||
Value isNegInfinityX =
|
||||
builder.create<CmpFOp>(CmpFPredicate::OEQ, x, constNegIfinity);
|
||||
Value isPostiveX =
|
||||
builder.create<CmpFOp>(CmpFPredicate::OGT, x, zerof32Const);
|
||||
Value isComputable = builder.create<AndOp>(rightBound, leftBound);
|
||||
|
||||
expY = builder.create<SelectOp>(
|
||||
isComputable, expY,
|
||||
builder.create<SelectOp>(
|
||||
isPostiveX, constPosInfinity,
|
||||
builder.create<SelectOp>(isNegInfinityX, zerof32Const, underflow)));
|
||||
|
||||
rewriter.replaceOp(op, expY);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
void mlir::populateMathPolynomialApproximationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<TanhApproximation, LogApproximation>(ctx);
|
||||
patterns.insert<TanhApproximation, LogApproximation, ExpApproximation>(ctx);
|
||||
}
|
||||
|
|
|
@ -20,3 +20,16 @@ func @vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
|||
%1 = math.log %0 : vector<8xf32>
|
||||
return %1 : vector<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp_scalar
|
||||
func @exp_scalar(%arg0: f32) -> f32 {
|
||||
%0 = math.exp %arg0 : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp_vector
|
||||
func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
||||
// CHECK-NOT: math.exp
|
||||
%0 = math.exp %arg0 : vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
|
|
@ -71,8 +71,46 @@ func @log() {
|
|||
return
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Log.
|
||||
// -------------------------------------------------------------------------- //
|
||||
func @exp() {
|
||||
// CHECK: 2.71828
|
||||
%0 = constant 1.0 : f32
|
||||
%1 = math.exp %0 : f32
|
||||
vector.print %1 : f32
|
||||
|
||||
// CHECK: 0.778802, 2.117, 2.71828, 3.85742
|
||||
%2 = constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32>
|
||||
%3 = math.exp %2 : vector<4xf32>
|
||||
vector.print %3 : vector<4xf32>
|
||||
|
||||
// CHECK: 1
|
||||
%zero = constant 0.0 : f32
|
||||
%exp_zero = math.exp %zero : f32
|
||||
vector.print %exp_zero : f32
|
||||
|
||||
// CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf
|
||||
%special_vec = constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32>
|
||||
%exp_special_vec = math.exp %special_vec : vector<4xf32>
|
||||
vector.print %exp_special_vec : vector<4xf32>
|
||||
|
||||
// CHECK: inf
|
||||
%inf = constant 0x7f800000 : f32
|
||||
%exp_inf = math.exp %inf : f32
|
||||
vector.print %exp_inf : f32
|
||||
|
||||
// CHECK: 0
|
||||
%negative_inf = constant 0xff800000 : f32
|
||||
%exp_negative_inf = math.exp %negative_inf : f32
|
||||
vector.print %exp_negative_inf : f32
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @main() {
|
||||
call @tanh(): () -> ()
|
||||
call @log(): () -> ()
|
||||
call @exp(): () -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue