!31106 Revise optimizer CN docs
Merge pull request !31106 from wanyiming/code_docs_cn_opt2
This commit is contained in:
commit
41cc5a351a
|
@ -55,7 +55,7 @@ mindspore.nn.Adam
|
|||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
|
@ -68,8 +68,10 @@ mindspore.nn.Adam
|
|||
- **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-8。
|
||||
- **use_locking** (bool) - 是否对参数更新加锁保护。如果为True,则 `w` 、`m` 和 `v` 的tensor更新将受到锁的保护。如果为False,则结果不可预测。默认值:False。
|
||||
- **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True,使用NAG更新梯度。如果为False,则在不使用NAG的情况下更新梯度。默认值:False。
|
||||
- **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
**输入:**
|
||||
|
|
|
@ -30,7 +30,7 @@ mindspore.nn.AdamOffload
|
|||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
- **learning_rate** (Union[float, Tensor, Iterable, LearningRateSchedule]): 默认值:1e-3。
|
||||
|
@ -42,8 +42,10 @@ mindspore.nn.AdamOffload
|
|||
- **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-8。
|
||||
- **use_locking** (bool) - 是否对参数更新加锁保护。如果为True,则 `w` 、`m` 和 `v` 的更新将受到锁保护。如果为False,则结果不可预测。默认值:False。
|
||||
- **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True,使用NAG更新梯度。如果为False,则在不使用NAG的情况下更新梯度。默认值:False。
|
||||
- **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
**输入:**
|
||||
|
|
|
@ -51,7 +51,7 @@ mindspore.nn.AdamWeightDecay
|
|||
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
|
@ -63,7 +63,9 @@ mindspore.nn.AdamWeightDecay
|
|||
- **beta1** (float):`moment1` 的指数衰减率。参数范围(0.0,1.0)。默认值:0.9。
|
||||
- **beta2** (float):`moment2` 的指数衰减率。参数范围(0.0,1.0)。默认值:0.999。
|
||||
- **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-6。
|
||||
- **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
**输入:**
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ mindspore.nn.FTRL
|
|||
|
||||
- **lr** - 学习率当前不支持参数分组。
|
||||
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
|
||||
|
@ -53,7 +53,9 @@ mindspore.nn.FTRL
|
|||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
- **weight_decay** (Union[float, int]) - 要乘以权重的权重衰减值,必须为零或正值。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
**输入:**
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ mindspore.nn.Lamb
|
|||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
|
@ -72,7 +72,9 @@ mindspore.nn.Lamb
|
|||
- **beta1** (float):第一矩的指数衰减率。参数范围(0.0,1.0)。默认值:0.9。
|
||||
- **beta2** (float):第二矩的指数衰减率。参数范围(0.0,1.0)。默认值:0.999。
|
||||
- **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-6。
|
||||
- **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
**输入:**
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ mindspore.nn.LazyAdam
|
|||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
|
@ -46,8 +46,10 @@ mindspore.nn.LazyAdam
|
|||
- **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-8。
|
||||
- **use_locking** (bool) - 是否对参数更新加锁保护。如果为True,则 `w` 、`m` 和 `v` 的Tensor更新将受到锁的保护。如果为False,则结果不可预测。默认值:False。
|
||||
- **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True,使用NAG更新梯度。如果为False,则在不使用NAG的情况下更新梯度。默认值:False。
|
||||
- **weight_decay** (Union[float, int]) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
**输入:**
|
||||
|
|
|
@ -31,7 +31,7 @@ mindspore.nn.Momentum
|
|||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
|
@ -40,7 +40,9 @@ mindspore.nn.Momentum
|
|||
.. include:: mindspore.nn.optim_arg_dynamic_lr.rst
|
||||
|
||||
- **momentum** (float) - 浮点数类型的超参,表示移动平均的动量。必须等于或大于0.0。
|
||||
- **weight_decay** (int, float) - 权重衰减(L2 penalty)值。必须大于等于0.0。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ mindspore.nn.ProximalAdagrad
|
|||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
|
@ -46,7 +46,9 @@ mindspore.nn.ProximalAdagrad
|
|||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
- **weight_decay** (Union[float, int]) - 要乘以权重的权重衰减值,必须为零或正值。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
**输入:**
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ mindspore.nn.RMSProp
|
|||
|
||||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
.. include:: mindspore.nn.optim_group_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
||||
|
@ -68,7 +68,9 @@ mindspore.nn.RMSProp
|
|||
|
||||
.. include:: mindspore.nn.optim_arg_loss_scale.rst
|
||||
|
||||
- **weight_decay** (Union[float, int]) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。
|
||||
- **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。
|
||||
|
||||
.. include:: mindspore.nn.optim_arg_dynamic_wd.rst
|
||||
|
||||
**输入:**
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
- **float**: 固定值,必须大于或者等于0。
|
||||
- **int**: 固定值,必须大于或者等于0,会被转换成float。
|
||||
- **Cell**: 动态weight decay。在训练过程中,优化器会使用步数(step)作为输入,调用该Cell实例来计算当前weight decay值。
|
|
@ -0,0 +1,3 @@
|
|||
- **weight_decay** - 可选。如果键中存在"weight_decay”,则使用对应的值作为权重衰减值。如果没有,则使用优化器中配置的 `weight_decay` 作为权重衰减值。
|
||||
值得注意的是,`weight_decay`可以是常量,也可以是Cell类型。Cell类型的weight decay用于实现动态weight decay算法。这和动态学习率相似。
|
||||
用户需要自定义一个输入为global step的weight_decay_schedule。在训练的过程中,优化器会调用WeightDecaySchedule的实例来获取当前step的weight decay值。
|
Loading…
Reference in New Issue