From 4b9a5d03c5659f6946cbefcd76f9279b07a08862 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 3 Dec 2020 17:13:04 -0500 Subject: [PATCH] Remove float64 from IGamma's supported dtype --- mindspore/nn/layer/math.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 60b7acd3117..765683526d3 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -433,7 +433,6 @@ class DiGamma(Cell): nan, real_result) -eps_fp64 = Tensor(np.finfo(np.float64).eps, mstype.float64) eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) def _while_helper_func(cond, body, vals): @@ -452,10 +451,8 @@ def _IgammaSeries(ax, x, a, enabled): dtype = P.DType() select = P.Select() - if dtype(ax) == mstype.float64: - epsilon = eps_fp64 - else: - epsilon = eps_fp32 + # If more data types are supported, this epsilon need to be selected. + epsilon = eps_fp32 def cond(vals): enabled = vals[0] @@ -504,10 +501,8 @@ def _IgammacContinuedFraction(ax, x, a, enabled): dtype = P.DType() select = P.Select() - if dtype(ax) == mstype.float64: - epsilon = eps_fp64 - else: - epsilon = eps_fp32 + # If more data types are supported, this epsilon need to be selected. + epsilon = eps_fp32 def cond(vals): enabled = vals[0] @@ -624,9 +619,9 @@ class IGamma(Cell): ``Ascend`` Inputs: - - **a** (Tensor) - The input tensor. With float32 or float64 data type. `a` should have + - **a** (Tensor) - The input tensor. With float32 data type. `a` should have the same dtype with `x`. - - **x** (Tensor) - The input tensor. With float32 or float64 data type. `x` should have + - **x** (Tensor) - The input tensor. With float32 data type. `x` should have the same dtype with `a`. Outputs: @@ -644,7 +639,7 @@ class IGamma(Cell): def __init__(self): super(IGamma, self).__init__() # const numbers - self.log_maxfloat64 = Tensor(np.log(np.finfo(np.float64).max), mstype.float64) + # If more data types are supported, this float max value need to be selected. self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) # operations @@ -669,7 +664,7 @@ class IGamma(Cell): def construct(self, a, x): a_dtype = self.dtype(a) x_dtype = self.dtype(x) - _check_input_dtype("input_a", a_dtype, [mstype.float32, mstype.float64], self.cls_name) + _check_input_dtype("input_a", a_dtype, [mstype.float32], self.cls_name) _check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) @@ -680,10 +675,7 @@ class IGamma(Cell): x = boradcastto(x) y = boradcastto(y) x_is_zero = self.equal(x, 0) - if a_dtype == mstype.float64: - log_maxfloat = self.log_maxfloat64 - else: - log_maxfloat = self.log_maxfloat32 + log_maxfloat = self.log_maxfloat32 underflow = self.less(ax, self.neg(log_maxfloat)) ax = self.exp(ax) enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))