forked from mindspore-Ecosystem/mindspore
!1986 fixed validator for CumSum
Merge pull request !1986 from jiangjinsheng/issue_fix2
This commit is contained in:
commit
af85b2cebf
|
@ -1001,15 +1001,16 @@ def get_bprop_bessel_i1e(self):
|
|||
reciprocal = P.Reciprocal()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
abs_ops = P.Abs()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
zeros = zeros_like(x)
|
||||
np_eps = const_utils.get_np_eps(dtype(x))
|
||||
eps = cast(np_eps, dtype(x))
|
||||
x_is_valid = less(eps, x)
|
||||
x_is_valid = less(eps, abs_ops(x))
|
||||
x_safe = select(x_is_valid, x, eps + zeros)
|
||||
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe))
|
||||
dx = select(x_is_valid, tmp, 0.5 + zeros)
|
||||
tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
|
||||
dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer):
|
|||
def __infer__(self, x, axis):
|
||||
cls_name = self.name
|
||||
x_shp = x['shape']
|
||||
if axis['value'] is None:
|
||||
raise ValueError(f"For {self.name}, axis must be const.")
|
||||
validator.check_value_type('axis', axis['value'], [int], cls_name)
|
||||
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
|
||||
|
@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer):
|
|||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
||||
def infer_value(self, x, axis):
|
||||
if axis is None:
|
||||
raise ValueError(f"For {self.name}, axis must be const.")
|
||||
|
||||
|
||||
class AddN(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -1767,9 +1767,6 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
|
||||
if decay is None or momentum is None or epsilon is None:
|
||||
raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.")
|
||||
if not self.is_ge and self.is_d:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
|
||||
class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||
|
|
Loading…
Reference in New Issue