!37121 intercept non-Parameter input types for AdamWeightDecay

Merge pull request !37121 from 李林杰/fix_0703
This commit is contained in:
i-robot 2022-07-04 02:34:53 +00:00 committed by Gitee
commit acc7a58399
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 14 additions and 3 deletions

View File

@ -4591,11 +4591,11 @@ class AdamWeightDecay(PrimitiveWithInfer):
If false, the result is unpredictable. Default: False.
Inputs:
- **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
- **var** (Parameter) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
any number of additional dimensions. The data type can be float16 or float32.
- **m** (Tensor) - The 1st moment vector in the updating formula,
- **m** (Parameter) - The 1st moment vector in the updating formula,
the shape and data type value should be the same as `var`.
- **v** (Tensor) - the 2nd moment vector in the updating formula,
- **v** (Parameter) - the 2nd moment vector in the updating formula,
the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
- **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
the data type value should be the same as `var`.
@ -4639,6 +4639,17 @@ class AdamWeightDecay(PrimitiveWithInfer):
[[0.999 0.999]
[0.999 0.999]]
"""
__mindspore_signature__ = (
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('lr', dtype=sig.sig_dtype.T),
sig.make_sig('beta1', dtype=sig.sig_dtype.T),
sig.make_sig('beta2', dtype=sig.sig_dtype.T),
sig.make_sig('epsilon', dtype=sig.sig_dtype.T),
sig.make_sig('decay', dtype=sig.sig_dtype.T),
sig.make_sig('gradient', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=False):