!1986 fixed validator for CumSum

Merge pull request !1986 from jiangjinsheng/issue_fix2
This commit is contained in:
mindspore-ci-bot 2020-06-11 20:47:38 +08:00 committed by Gitee
commit af85b2cebf
3 changed files with 6 additions and 10 deletions

View File

@ -1001,15 +1001,16 @@ def get_bprop_bessel_i1e(self):
reciprocal = P.Reciprocal()
cast = P.Cast()
dtype = P.DType()
abs_ops = P.Abs()
def bprop(x, out, dout):
zeros = zeros_like(x)
np_eps = const_utils.get_np_eps(dtype(x))
eps = cast(np_eps, dtype(x))
x_is_valid = less(eps, x)
x_is_valid = less(eps, abs_ops(x))
x_safe = select(x_is_valid, x, eps + zeros)
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe))
dx = select(x_is_valid, tmp, 0.5 + zeros)
tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
return (dx,)
return bprop

View File

@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer):
def __infer__(self, x, axis):
cls_name = self.name
x_shp = x['shape']
if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': None}
def infer_value(self, x, axis):
if axis is None:
raise ValueError(f"For {self.name}, axis must be const.")
class AddN(PrimitiveWithInfer):
"""

View File

@ -1767,9 +1767,6 @@ class ApplyRMSProp(PrimitiveWithInfer):
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
if decay is None or momentum is None or epsilon is None:
raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.")
if not self.is_ge and self.is_d:
return None, None, None
return None
class ApplyCenteredRMSProp(PrimitiveWithInfer):