[mlir][Math] Fix NaN handling in Exp approximation

Differential Revision: https://reviews.llvm.org/D119832
This commit is contained in:
Adrian Kuegel 2022-02-15 13:49:30 +01:00
parent 97db9d32f5
commit b122cbebec
3 changed files with 25 additions and 9 deletions

View File

@ -930,6 +930,8 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
Value x = op.getOperand();
Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
// Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
Value xL2Inv = mul(x, cstLog2E);
Value kF32 = floor(xL2Inv);
@ -985,13 +987,15 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
expY = builder.create<arith::SelectOp>(
isNegInfinityX, zerof32Const,
isNan, x,
builder.create<arith::SelectOp>(
isPosInfinityX, constPosInfinity,
isNegInfinityX, zerof32Const,
builder.create<arith::SelectOp>(
isComputable, expY,
builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
underflow))));
isPosInfinityX, constPosInfinity,
builder.create<arith::SelectOp>(
isComputable, expY,
builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
underflow)))));
rewriter.replaceOp(op, expY);

View File

@ -110,6 +110,7 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 1.17549435E-38 : f32
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 127 : i32
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -127 : i32
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[VAL_0]], %[[VAL_0]] : f32
// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_0]], %[[VAL_2]] : f32
// CHECK: %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f32
// CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_16]], %[[VAL_1]] : f32
@ -136,7 +137,8 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32
// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32
// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32
// CHECK: return %[[VAL_40]] : f32
// CHECK: %[[VAL_41:.*]] = arith.select %[[IS_NAN]], %[[VAL_0]], %[[VAL_40]] : f32
// CHECK: return %[[VAL_41]] : f32
func @exp_scalar(%arg0: f32) -> f32 {
%0 = math.exp %arg0 : f32
return %0 : f32
@ -146,7 +148,7 @@ func @exp_scalar(%arg0: f32) -> f32 {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32>
// CHECK-NOT: exp
// CHECK-COUNT-3: select
// CHECK-COUNT-4: select
// CHECK: %[[VAL_40:.*]] = arith.select
// CHECK: return %[[VAL_40]] : vector<8xf32>
func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
@ -161,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-DAG: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32
// CHECK-NOT: exp
// CHECK-COUNT-3: select
// CHECK-COUNT-4: select
// CHECK: %[[EXP_X:.*]] = arith.select
// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
// CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
@ -186,7 +188,7 @@ func @expm1_scalar(%arg0: f32) -> f32 {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32>
// CHECK-NOT: exp
// CHECK-COUNT-4: select
// CHECK-COUNT-5: select
// CHECK-NOT: log
// CHECK-COUNT-5: select
// CHECK-NOT: expm1

View File

@ -258,6 +258,11 @@ func @exp() {
%exp_negative_inf = math.exp %negative_inf : f32
vector.print %exp_negative_inf : f32
// CHECK: nan
%nan = arith.constant 0x7fc00000 : f32
%exp_nan = math.exp %nan : f32
vector.print %exp_nan : f32
return
}
@ -292,6 +297,11 @@ func @expm1() {
%log_special_vec = math.expm1 %special_vec : vector<3xf32>
vector.print %log_special_vec : vector<3xf32>
// CHECK: nan
%nan = arith.constant 0x7fc00000 : f32
%exp_nan = math.expm1 %nan : f32
vector.print %exp_nan : f32
return
}
// -------------------------------------------------------------------------- //