forked from mindspore-Ecosystem/mindspore
!48353 fix parameter check and error of KLDiv when reduction='mean'
Merge pull request !48353 from yanzhenxiang2020/fix_parameter
This commit is contained in:
commit
091ac0b702
|
@ -5397,7 +5397,7 @@ class KLDivLoss(Primitive):
|
|||
elif device_target == "GPU":
|
||||
support_mode = ['none', 'mean', 'sum']
|
||||
elif device_target == "Ascend":
|
||||
support_mode = ['none', 'batchmean', 'sum']
|
||||
support_mode = ['none', 'batchmean', 'sum', 'mean']
|
||||
else:
|
||||
raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
|
||||
|
||||
|
@ -8689,10 +8689,10 @@ class SparseApplyCenteredRMSProp(Primitive):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('var', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mg', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('ms', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mom', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mg', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('ms', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mom', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('rho', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('momentum', dtype=sig.sig_dtype.T),
|
||||
|
@ -9957,7 +9957,7 @@ class SparseApplyProximalGradientDescent(Primitive):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('var', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('alpha', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('l1', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('l2', dtype=sig.sig_dtype.T),
|
||||
|
|
|
@ -3545,10 +3545,10 @@ test_case_nn_ops = [
|
|||
'skip': ['backward']}),
|
||||
('SparseApplyCenteredRMSProp', {
|
||||
'block': SparseApplyCenteredRMSPropNet(),
|
||||
'desc_inputs': [Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)),
|
||||
Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)),
|
||||
Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)),
|
||||
Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)),
|
||||
'desc_inputs': [Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32))),
|
||||
Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32))),
|
||||
Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))),
|
||||
Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32))),
|
||||
Tensor(0.001, mstype.float32),
|
||||
Tensor(1e-10, mstype.float32),
|
||||
Tensor(0.001, mstype.float32),
|
||||
|
@ -3621,7 +3621,7 @@ test_case_nn_ops = [
|
|||
'skip': ['backward']}),
|
||||
('SparseApplyProximalGradientDescent', {
|
||||
'block': SparseApplyProximalGradientDescentNet(),
|
||||
'desc_inputs': [Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32)),
|
||||
'desc_inputs': [Parameter(Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32))),
|
||||
Tensor(0.01, mstype.float32),
|
||||
Tensor(0.88, mstype.float32),
|
||||
Tensor(0.3, mstype.float32),
|
||||
|
|
Loading…
Reference in New Issue