forked from mindspore-Ecosystem/mindspore
!9473 Remove float64 from IGamma's supported dtype
From: @peixu_ren Reviewed-by: @zichun_ye,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
f019a4a0af
|
@ -433,7 +433,6 @@ class DiGamma(Cell):
|
|||
nan, real_result)
|
||||
|
||||
|
||||
eps_fp64 = Tensor(np.finfo(np.float64).eps, mstype.float64)
|
||||
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
||||
|
||||
def _while_helper_func(cond, body, vals):
|
||||
|
@ -452,10 +451,8 @@ def _IgammaSeries(ax, x, a, enabled):
|
|||
dtype = P.DType()
|
||||
select = P.Select()
|
||||
|
||||
if dtype(ax) == mstype.float64:
|
||||
epsilon = eps_fp64
|
||||
else:
|
||||
epsilon = eps_fp32
|
||||
# If more data types are supported, this epsilon need to be selected.
|
||||
epsilon = eps_fp32
|
||||
|
||||
def cond(vals):
|
||||
enabled = vals[0]
|
||||
|
@ -504,10 +501,8 @@ def _IgammacContinuedFraction(ax, x, a, enabled):
|
|||
dtype = P.DType()
|
||||
select = P.Select()
|
||||
|
||||
if dtype(ax) == mstype.float64:
|
||||
epsilon = eps_fp64
|
||||
else:
|
||||
epsilon = eps_fp32
|
||||
# If more data types are supported, this epsilon need to be selected.
|
||||
epsilon = eps_fp32
|
||||
|
||||
def cond(vals):
|
||||
enabled = vals[0]
|
||||
|
@ -624,9 +619,9 @@ class IGamma(Cell):
|
|||
``Ascend``
|
||||
|
||||
Inputs:
|
||||
- **a** (Tensor) - The input tensor. With float32 or float64 data type. `a` should have
|
||||
- **a** (Tensor) - The input tensor. With float32 data type. `a` should have
|
||||
the same dtype with `x`.
|
||||
- **x** (Tensor) - The input tensor. With float32 or float64 data type. `x` should have
|
||||
- **x** (Tensor) - The input tensor. With float32 data type. `x` should have
|
||||
the same dtype with `a`.
|
||||
|
||||
Outputs:
|
||||
|
@ -644,7 +639,7 @@ class IGamma(Cell):
|
|||
def __init__(self):
|
||||
super(IGamma, self).__init__()
|
||||
# const numbers
|
||||
self.log_maxfloat64 = Tensor(np.log(np.finfo(np.float64).max), mstype.float64)
|
||||
# If more data types are supported, this float max value need to be selected.
|
||||
self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32)
|
||||
|
||||
# operations
|
||||
|
@ -669,7 +664,7 @@ class IGamma(Cell):
|
|||
def construct(self, a, x):
|
||||
a_dtype = self.dtype(a)
|
||||
x_dtype = self.dtype(x)
|
||||
_check_input_dtype("input_a", a_dtype, [mstype.float32, mstype.float64], self.cls_name)
|
||||
_check_input_dtype("input_a", a_dtype, [mstype.float32], self.cls_name)
|
||||
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name)
|
||||
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
|
||||
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
|
||||
|
@ -680,10 +675,7 @@ class IGamma(Cell):
|
|||
x = boradcastto(x)
|
||||
y = boradcastto(y)
|
||||
x_is_zero = self.equal(x, 0)
|
||||
if a_dtype == mstype.float64:
|
||||
log_maxfloat = self.log_maxfloat64
|
||||
else:
|
||||
log_maxfloat = self.log_maxfloat32
|
||||
log_maxfloat = self.log_maxfloat32
|
||||
underflow = self.less(ax, self.neg(log_maxfloat))
|
||||
ax = self.exp(ax)
|
||||
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
|
||||
|
|
Loading…
Reference in New Issue