From 91183182f4d4dde4e2177bd49b46262bfdbf912e Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Thu, 11 Jun 2020 10:12:49 +0800 Subject: [PATCH] fixed CumSum, BesselI1e etc --- mindspore/ops/_grad/grad_math_ops.py | 7 ++++--- mindspore/ops/operations/math_ops.py | 6 ++---- mindspore/ops/operations/nn_ops.py | 3 --- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index ba9973d81b4..eb638487478 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -1001,15 +1001,16 @@ def get_bprop_bessel_i1e(self): reciprocal = P.Reciprocal() cast = P.Cast() dtype = P.DType() + abs_ops = P.Abs() def bprop(x, out, dout): zeros = zeros_like(x) np_eps = const_utils.get_np_eps(dtype(x)) eps = cast(np_eps, dtype(x)) - x_is_valid = less(eps, x) + x_is_valid = less(eps, abs_ops(x)) x_safe = select(x_is_valid, x, eps + zeros) - tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe)) - dx = select(x_is_valid, tmp, 0.5 + zeros) + tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe)) + dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout return (dx,) return bprop diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 01473e76041..1c83c4ac873 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer): def __infer__(self, x, axis): cls_name = self.name x_shp = x['shape'] + if axis['value'] is None: + raise ValueError(f"For {self.name}, axis must be const.") validator.check_value_type('axis', axis['value'], [int], cls_name) valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) @@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer): 'dtype': x['dtype'], 'value': None} - def infer_value(self, x, axis): - if axis is None: - raise ValueError(f"For {self.name}, axis must be const.") - class AddN(PrimitiveWithInfer): """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 91b09d2553e..a1f89a0ebfe 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1767,9 +1767,6 @@ class ApplyRMSProp(PrimitiveWithInfer): def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): if decay is None or momentum is None or epsilon is None: raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.") - if not self.is_ge and self.is_d: - return None, None, None - return None class ApplyCenteredRMSProp(PrimitiveWithInfer):