forked from OSchip/llvm-project
[mlir][Math] Fix NaN handling in Exp approximation
Differential Revision: https://reviews.llvm.org/D119832
This commit is contained in:
parent
97db9d32f5
commit
b122cbebec
|
@ -930,6 +930,8 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
|
|||
|
||||
Value x = op.getOperand();
|
||||
|
||||
Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
|
||||
|
||||
// Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
|
||||
Value xL2Inv = mul(x, cstLog2E);
|
||||
Value kF32 = floor(xL2Inv);
|
||||
|
@ -985,13 +987,15 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
|
|||
Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
|
||||
|
||||
expY = builder.create<arith::SelectOp>(
|
||||
isNegInfinityX, zerof32Const,
|
||||
isNan, x,
|
||||
builder.create<arith::SelectOp>(
|
||||
isPosInfinityX, constPosInfinity,
|
||||
isNegInfinityX, zerof32Const,
|
||||
builder.create<arith::SelectOp>(
|
||||
isComputable, expY,
|
||||
builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
|
||||
underflow))));
|
||||
isPosInfinityX, constPosInfinity,
|
||||
builder.create<arith::SelectOp>(
|
||||
isComputable, expY,
|
||||
builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
|
||||
underflow)))));
|
||||
|
||||
rewriter.replaceOp(op, expY);
|
||||
|
||||
|
|
|
@ -110,6 +110,7 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
|||
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 1.17549435E-38 : f32
|
||||
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 127 : i32
|
||||
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -127 : i32
|
||||
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[VAL_0]], %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_0]], %[[VAL_2]] : f32
|
||||
// CHECK: %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f32
|
||||
// CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_16]], %[[VAL_1]] : f32
|
||||
|
@ -136,7 +137,8 @@ func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
|||
// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32
|
||||
// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32
|
||||
// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32
|
||||
// CHECK: return %[[VAL_40]] : f32
|
||||
// CHECK: %[[VAL_41:.*]] = arith.select %[[IS_NAN]], %[[VAL_0]], %[[VAL_40]] : f32
|
||||
// CHECK: return %[[VAL_41]] : f32
|
||||
func @exp_scalar(%arg0: f32) -> f32 {
|
||||
%0 = math.exp %arg0 : f32
|
||||
return %0 : f32
|
||||
|
@ -146,7 +148,7 @@ func @exp_scalar(%arg0: f32) -> f32 {
|
|||
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32>
|
||||
// CHECK-NOT: exp
|
||||
// CHECK-COUNT-3: select
|
||||
// CHECK-COUNT-4: select
|
||||
// CHECK: %[[VAL_40:.*]] = arith.select
|
||||
// CHECK: return %[[VAL_40]] : vector<8xf32>
|
||||
func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
||||
|
@ -161,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
|
|||
// CHECK-DAG: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32
|
||||
// CHECK-NOT: exp
|
||||
// CHECK-COUNT-3: select
|
||||
// CHECK-COUNT-4: select
|
||||
// CHECK: %[[EXP_X:.*]] = arith.select
|
||||
// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
|
||||
// CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
|
||||
|
@ -186,7 +188,7 @@ func @expm1_scalar(%arg0: f32) -> f32 {
|
|||
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32>
|
||||
// CHECK-NOT: exp
|
||||
// CHECK-COUNT-4: select
|
||||
// CHECK-COUNT-5: select
|
||||
// CHECK-NOT: log
|
||||
// CHECK-COUNT-5: select
|
||||
// CHECK-NOT: expm1
|
||||
|
|
|
@ -258,6 +258,11 @@ func @exp() {
|
|||
%exp_negative_inf = math.exp %negative_inf : f32
|
||||
vector.print %exp_negative_inf : f32
|
||||
|
||||
// CHECK: nan
|
||||
%nan = arith.constant 0x7fc00000 : f32
|
||||
%exp_nan = math.exp %nan : f32
|
||||
vector.print %exp_nan : f32
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -292,6 +297,11 @@ func @expm1() {
|
|||
%log_special_vec = math.expm1 %special_vec : vector<3xf32>
|
||||
vector.print %log_special_vec : vector<3xf32>
|
||||
|
||||
// CHECK: nan
|
||||
%nan = arith.constant 0x7fc00000 : f32
|
||||
%exp_nan = math.expm1 %nan : f32
|
||||
vector.print %exp_nan : f32
|
||||
|
||||
return
|
||||
}
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
|
Loading…
Reference in New Issue