From c7ad4c80a9e13e20f78487202f83bbe03423b7db Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Thu, 12 Jan 2023 11:01:47 +0800 Subject: [PATCH] fix parameter check --- mindspore/python/mindspore/ops/operations/nn_ops.py | 12 ++++++------ tests/ut/python/ops/test_ops.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 666c536d672..ffcce9f5369 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -5443,7 +5443,7 @@ class KLDivLoss(Primitive): elif device_target == "GPU": support_mode = ['none', 'mean', 'sum'] elif device_target == "Ascend": - support_mode = ['none', 'batchmean', 'sum'] + support_mode = ['none', 'batchmean', 'sum', 'mean'] else: raise ValueError(f"'{self.name}' unknown device target: '{device_target}'") @@ -8735,10 +8735,10 @@ class SparseApplyCenteredRMSProp(Primitive): """ __mindspore_signature__ = ( - sig.make_sig('var', dtype=sig.sig_dtype.T), - sig.make_sig('mg', dtype=sig.sig_dtype.T), - sig.make_sig('ms', dtype=sig.sig_dtype.T), - sig.make_sig('mom', dtype=sig.sig_dtype.T), + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('mg', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('ms', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('mom', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('lr', dtype=sig.sig_dtype.T), sig.make_sig('rho', dtype=sig.sig_dtype.T), sig.make_sig('momentum', dtype=sig.sig_dtype.T), @@ -10007,7 +10007,7 @@ class SparseApplyProximalGradientDescent(Primitive): """ __mindspore_signature__ = ( - sig.make_sig('var', dtype=sig.sig_dtype.T), + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('alpha', dtype=sig.sig_dtype.T), sig.make_sig('l1', dtype=sig.sig_dtype.T), sig.make_sig('l2', dtype=sig.sig_dtype.T), diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 2c96c38810f..0eabd63cf21 100644 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -3544,10 +3544,10 @@ test_case_nn_ops = [ 'skip': ['backward']}), ('SparseApplyCenteredRMSProp', { 'block': SparseApplyCenteredRMSPropNet(), - 'desc_inputs': [Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), - Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)), - Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), - Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), + 'desc_inputs': [Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32))), + Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32))), + Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))), + Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))), Tensor(0.001, mstype.float32), Tensor(1e-10, mstype.float32), Tensor(0.001, mstype.float32), @@ -3620,7 +3620,7 @@ test_case_nn_ops = [ 'skip': ['backward']}), ('SparseApplyProximalGradientDescent', { 'block': SparseApplyProximalGradientDescentNet(), - 'desc_inputs': [Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32)), + 'desc_inputs': [Parameter(Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32))), Tensor(0.01, mstype.float32), Tensor(0.88, mstype.float32), Tensor(0.3, mstype.float32),