mindspore/docs/api/api_python/ops/mindspore.ops.ApplyAdaMax.rst

48 lines
3.1 KiB
ReStructuredText
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

mindspore.ops.ApplyAdaMax
==========================
.. py:class:: mindspore.ops.ApplyAdaMax
根据AdaMax算法更新相关参数。
AdaMax优化器是参考 `Adam论文 <https://arxiv.org/abs/1412.6980>`_ 中Adamax优化相关内容所实现的。
更新公式如下:
.. math::
\begin{array}{ll} \\
m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
v_{t+1} = \max(\beta_2 * v_{t}, \left| g \right|) \\
var = var - \frac{l}{1 - \beta_1^{t+1}} * \frac{m_{t+1}}{v_{t+1} + \epsilon}
\end{array}
:math:`t` 表示更新步长,而 :math:`m` 表示第一个动量矩阵, :math:`m_{t}`:math:`m_{t+1}` 的最后时刻, :math:`v` 代表第二个动量矩阵, :math:`v_{t}`:math:`v_{t+1}` 的最后时刻, :math:`l` 代表学习率 `lr` :math:`g` 代表 `grad` :math:`\beta_1, \beta_2` 代表 `beta1``beta2` :math:`\beta_1^{t+1}` 代表 `beta1_power` :math:`var` 代表要更新的变量, :math:`\epsilon` 代表 `epsilon`
`var``m``v``grad` 的输入符合隐式类型转换规则,使数据类型一致。如果它们具有不同的数据类型,则低精度数据类型将转换为相对最高精度的数据类型。
**输入:**
- **var** (Parameter) - 要更新的权重为任意维度。数据类型为float32或float16。
- **m** (Parameter) - 更新公式中的第一个动量矩阵shape和数据类型与 `var` 相同。数据类型为float32或float16。
- **v** (Parameter) - 更新公式中的第二个动量矩阵。shape和类型与 `var` 相同。数据类型为float32或float16。
- **beta1_power** (Union[Number, Tensor]) - :math:`beta_1^t` 必须是Scalar。数据类型为float32或float16。
- **lr** (Union[Number, Tensor]) - 学习率,公式中的 :math:`l` 必须是Scalar。数据类型为float32或float16。
- **beta1** (Union[Number, Tensor]) - 第一个动量矩阵的指数衰减率必须是Scalar。数据类型为float32或float16。
- **beta2** (Union[Number, Tensor]) - 第二个动量矩阵的指数衰减率必须是Scalar。数据类型为float32或float16。
- **epsilon** (Union[Number, Tensor]) - 加在分母上的值以确保数值稳定必须是Scalar。数据类型为float32或float16。
- **grad** (Tensor) - 为梯度是一个Tensorshape和数据类型与 `var` 相同。数据类型为float32或float16。
**输出:**
3个Tensor组成的tuple更新后的数据。
- **var** (Tensor) - shape和数据类型与 `var` 相同。
- **m** (Tensor) - shape和数据类型与 `m` 相同。
- **v** (Tensor) - shape和数据类型与 `v` 相同。
**异常:**
- **TypeError** - 如果 `var``m``v``beta_power``lr``beta1``beta2``epsilon``grad` 的数据类型既不是float16也不是float32。
- **TypeError** - 如果 `beta_power``lr``beta1``beta2``epsilon` 既不是数值型也不是Tensor。
- **TypeError** - 如果 `grad` 不是Tensor。
- **RuntimeError** - 如果 `var``m``v``grad` 不支持数据类型转换。