!48353 fix parameter check and error of KLDiv when reduction='mean'

Merge pull request !48353 from yanzhenxiang2020/fix_parameter
This commit is contained in:
i-robot 2023-02-03 07:02:10 +00:00 committed by Gitee
commit 091ac0b702
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 11 additions and 11 deletions

View File

@ -5397,7 +5397,7 @@ class KLDivLoss(Primitive):
elif device_target == "GPU":
support_mode = ['none', 'mean', 'sum']
elif device_target == "Ascend":
support_mode = ['none', 'batchmean', 'sum']
support_mode = ['none', 'batchmean', 'sum', 'mean']
else:
raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
@ -8689,10 +8689,10 @@ class SparseApplyCenteredRMSProp(Primitive):
"""
__mindspore_signature__ = (
sig.make_sig('var', dtype=sig.sig_dtype.T),
sig.make_sig('mg', dtype=sig.sig_dtype.T),
sig.make_sig('ms', dtype=sig.sig_dtype.T),
sig.make_sig('mom', dtype=sig.sig_dtype.T),
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('mg', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('ms', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('mom', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('lr', dtype=sig.sig_dtype.T),
sig.make_sig('rho', dtype=sig.sig_dtype.T),
sig.make_sig('momentum', dtype=sig.sig_dtype.T),
@ -9957,7 +9957,7 @@ class SparseApplyProximalGradientDescent(Primitive):
"""
__mindspore_signature__ = (
sig.make_sig('var', dtype=sig.sig_dtype.T),
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('alpha', dtype=sig.sig_dtype.T),
sig.make_sig('l1', dtype=sig.sig_dtype.T),
sig.make_sig('l2', dtype=sig.sig_dtype.T),

View File

@ -3545,10 +3545,10 @@ test_case_nn_ops = [
'skip': ['backward']}),
('SparseApplyCenteredRMSProp', {
'block': SparseApplyCenteredRMSPropNet(),
'desc_inputs': [Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)),
Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)),
Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)),
Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)),
'desc_inputs': [Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32))),
Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32))),
Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))),
Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))),
Tensor(0.001, mstype.float32),
Tensor(1e-10, mstype.float32),
Tensor(0.001, mstype.float32),
@ -3621,7 +3621,7 @@ test_case_nn_ops = [
'skip': ['backward']}),
('SparseApplyProximalGradientDescent', {
'block': SparseApplyProximalGradientDescentNet(),
'desc_inputs': [Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32)),
'desc_inputs': [Parameter(Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32))),
Tensor(0.01, mstype.float32),
Tensor(0.88, mstype.float32),
Tensor(0.3, mstype.float32),