fix parameter check

This commit is contained in:
yanzhenxiang2020 2023-01-12 11:01:47 +08:00
parent 4e80bf4c72
commit c7ad4c80a9
2 changed files with 11 additions and 11 deletions

View File

@ -5443,7 +5443,7 @@ class KLDivLoss(Primitive):
elif device_target == "GPU":
support_mode = ['none', 'mean', 'sum']
elif device_target == "Ascend":
support_mode = ['none', 'batchmean', 'sum']
support_mode = ['none', 'batchmean', 'sum', 'mean']
else:
raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
@ -8735,10 +8735,10 @@ class SparseApplyCenteredRMSProp(Primitive):
"""
__mindspore_signature__ = (
sig.make_sig('var', dtype=sig.sig_dtype.T),
sig.make_sig('mg', dtype=sig.sig_dtype.T),
sig.make_sig('ms', dtype=sig.sig_dtype.T),
sig.make_sig('mom', dtype=sig.sig_dtype.T),
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('mg', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('ms', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('mom', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('lr', dtype=sig.sig_dtype.T),
sig.make_sig('rho', dtype=sig.sig_dtype.T),
sig.make_sig('momentum', dtype=sig.sig_dtype.T),
@ -10007,7 +10007,7 @@ class SparseApplyProximalGradientDescent(Primitive):
"""
__mindspore_signature__ = (
sig.make_sig('var', dtype=sig.sig_dtype.T),
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('alpha', dtype=sig.sig_dtype.T),
sig.make_sig('l1', dtype=sig.sig_dtype.T),
sig.make_sig('l2', dtype=sig.sig_dtype.T),

View File

@ -3544,10 +3544,10 @@ test_case_nn_ops = [
'skip': ['backward']}),
('SparseApplyCenteredRMSProp', {
'block': SparseApplyCenteredRMSPropNet(),
'desc_inputs': [Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)),
Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)),
Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)),
Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)),
'desc_inputs': [Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32))),
Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32))),
Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))),
Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))),
Tensor(0.001, mstype.float32),
Tensor(1e-10, mstype.float32),
Tensor(0.001, mstype.float32),
@ -3620,7 +3620,7 @@ test_case_nn_ops = [
'skip': ['backward']}),
('SparseApplyProximalGradientDescent', {
'block': SparseApplyProximalGradientDescentNet(),
'desc_inputs': [Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32)),
'desc_inputs': [Parameter(Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32))),
Tensor(0.01, mstype.float32),
Tensor(0.88, mstype.float32),
Tensor(0.3, mstype.float32),