adam parameter

This commit is contained in:
nomindcarry 2023-03-03 11:57:18 +08:00
parent 8952325e13
commit 582fb6ab34
2 changed files with 8 additions and 8 deletions

View File

@ -26,9 +26,9 @@ mindspore.ops.Adam
- **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True则使用NAG更新梯度。如果为False则在不使用NAG的情况下更新梯度。默认值False。
输入:
- **var** (Tensor) - 需更新的权重。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度其数据类型可以是float16或float32。
- **m** (Tensor) - 更新公式中的第一个动量矩阵shape和数据类型应与 `var` 相同。
- **v** (Tensor) - 更新公式中的第二个动量矩阵shape和数据类型应与 `var` 相同。均方梯度的数据类型也应与 `var` 相同。
- **var** (Parameter) - 需更新的权重。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度其数据类型可以是float16或float32。
- **m** (Parameter) - 更新公式中的第一个动量矩阵shape和数据类型应与 `var` 相同。
- **v** (Parameter) - 更新公式中的第二个动量矩阵shape和数据类型应与 `var` 相同。均方梯度的数据类型也应与 `var` 相同。
- **beta1_power** (float) - 在更新公式中的 :math:`beta_1^t(\beta_1^{t})` ,数据类型值应与 `var` 相同。
- **beta2_power** (float) - 在更新公式中的 :math:`beta_2^t(\beta_2^{t})` ,数据类型值应与 `var` 相同。
- **lr** (float) - 在更新公式中的 :math:`l` 。其论文建议取值为 :math:`10^{-8}` ,数据类型应与 `var` 相同。
@ -46,5 +46,5 @@ mindspore.ops.Adam
异常:
- **TypeError** - `use_locking``use_nesterov` 都不是bool。
- **TypeError** - `var``m``v` 不是Tensor。
- **TypeError** - `var``m``v` 不是Parameter。
- **TypeError** - `beta1_power``beta2_power1``lr``beta1``beta2``epsilon``gradient` 不是Tensor。

View File

@ -4583,11 +4583,11 @@ class Adam(Primitive):
If false, update the gradients without using NAG. Default: False.
Inputs:
- **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
- **var** (Parameter) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
any number of additional dimensions. The data type can be float16 or float32.
- **m** (Tensor) - The 1st moment vector in the updating formula,
- **m** (Parameter) - The 1st moment vector in the updating formula,
the shape and data type value should be the same as `var`.
- **v** (Tensor) - the 2nd moment vector in the updating formula,
- **v** (Parameter) - the 2nd moment vector in the updating formula,
the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
- **beta1_power** (float) - :math:`beta_1^t(\beta_1^{t})` in the updating formula,
the data type value should be the same as `var`.
@ -4611,7 +4611,7 @@ class Adam(Primitive):
Raises:
TypeError: If neither `use_locking` nor `use_nesterov` is a bool.
TypeError: If `var`, `m` or `v` is not a Tensor.
TypeError: If `var`, `m` or `v` is not a Parameter.
TypeError: If `beta1_power`, `beta2_power1`, `lr`, `beta1`, `beta2`, `epsilon` or `gradient` is not a Tensor.
Supported Platforms: