forked from mindspore-Ecosystem/mindspore
adam parameter
This commit is contained in:
parent
8952325e13
commit
582fb6ab34
|
@ -26,9 +26,9 @@ mindspore.ops.Adam
|
|||
- **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True,则使用NAG更新梯度。如果为False,则在不使用NAG的情况下更新梯度。默认值:False。
|
||||
|
||||
输入:
|
||||
- **var** (Tensor) - 需更新的权重。shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度,其数据类型可以是float16或float32。
|
||||
- **m** (Tensor) - 更新公式中的第一个动量矩阵,shape和数据类型应与 `var` 相同。
|
||||
- **v** (Tensor) - 更新公式中的第二个动量矩阵,shape和数据类型应与 `var` 相同。均方梯度的数据类型也应与 `var` 相同。
|
||||
- **var** (Parameter) - 需更新的权重。shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度,其数据类型可以是float16或float32。
|
||||
- **m** (Parameter) - 更新公式中的第一个动量矩阵,shape和数据类型应与 `var` 相同。
|
||||
- **v** (Parameter) - 更新公式中的第二个动量矩阵,shape和数据类型应与 `var` 相同。均方梯度的数据类型也应与 `var` 相同。
|
||||
- **beta1_power** (float) - 在更新公式中的 :math:`beta_1^t(\beta_1^{t})` ,数据类型值应与 `var` 相同。
|
||||
- **beta2_power** (float) - 在更新公式中的 :math:`beta_2^t(\beta_2^{t})` ,数据类型值应与 `var` 相同。
|
||||
- **lr** (float) - 在更新公式中的 :math:`l` 。其论文建议取值为 :math:`10^{-8}` ,数据类型应与 `var` 相同。
|
||||
|
@ -46,5 +46,5 @@ mindspore.ops.Adam
|
|||
|
||||
异常:
|
||||
- **TypeError** - `use_locking` 和 `use_nesterov` 都不是bool。
|
||||
- **TypeError** - `var` 、 `m` 或 `v` 不是Tensor。
|
||||
- **TypeError** - `var` 、 `m` 或 `v` 不是Parameter。
|
||||
- **TypeError** - `beta1_power` 、 `beta2_power1` 、 `lr` 、 `beta1` 、 `beta2` 、 `epsilon` 或 `gradient` 不是Tensor。
|
||||
|
|
|
@ -4583,11 +4583,11 @@ class Adam(Primitive):
|
|||
If false, update the gradients without using NAG. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
|
||||
- **var** (Parameter) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
|
||||
any number of additional dimensions. The data type can be float16 or float32.
|
||||
- **m** (Tensor) - The 1st moment vector in the updating formula,
|
||||
- **m** (Parameter) - The 1st moment vector in the updating formula,
|
||||
the shape and data type value should be the same as `var`.
|
||||
- **v** (Tensor) - the 2nd moment vector in the updating formula,
|
||||
- **v** (Parameter) - the 2nd moment vector in the updating formula,
|
||||
the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
|
||||
- **beta1_power** (float) - :math:`beta_1^t(\beta_1^{t})` in the updating formula,
|
||||
the data type value should be the same as `var`.
|
||||
|
@ -4611,7 +4611,7 @@ class Adam(Primitive):
|
|||
|
||||
Raises:
|
||||
TypeError: If neither `use_locking` nor `use_nesterov` is a bool.
|
||||
TypeError: If `var`, `m` or `v` is not a Tensor.
|
||||
TypeError: If `var`, `m` or `v` is not a Parameter.
|
||||
TypeError: If `beta1_power`, `beta2_power1`, `lr`, `beta1`, `beta2`, `epsilon` or `gradient` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
|
|
Loading…
Reference in New Issue