diff --git a/docs/api/api_python/nn/mindspore.nn.Adam.rst b/docs/api/api_python/nn/mindspore.nn.Adam.rst index 2689400bf74..f2b7f7ba4c9 100644 --- a/docs/api/api_python/nn/mindspore.nn.Adam.rst +++ b/docs/api/api_python/nn/mindspore.nn.Adam.rst @@ -55,7 +55,7 @@ mindspore.nn.Adam .. include:: mindspore.nn.optim_group_param.rst .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_gc.rst .. include:: mindspore.nn.optim_group_order.rst @@ -68,8 +68,10 @@ mindspore.nn.Adam - **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-8。 - **use_locking** (bool) - 是否对参数更新加锁保护。如果为True,则 `w` 、`m` 和 `v` 的tensor更新将受到锁的保护。如果为False,则结果不可预测。默认值:False。 - **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True,使用NAG更新梯度。如果为False,则在不使用NAG的情况下更新梯度。默认值:False。 - - **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst + .. include:: mindspore.nn.optim_arg_loss_scale.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.AdamOffload.rst b/docs/api/api_python/nn/mindspore.nn.AdamOffload.rst index 5e57017ae33..eaf95bfeaab 100644 --- a/docs/api/api_python/nn/mindspore.nn.AdamOffload.rst +++ b/docs/api/api_python/nn/mindspore.nn.AdamOffload.rst @@ -30,7 +30,7 @@ mindspore.nn.AdamOffload .. include:: mindspore.nn.optim_group_param.rst .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_order.rst - **learning_rate** (Union[float, Tensor, Iterable, LearningRateSchedule]): 默认值:1e-3。 @@ -42,8 +42,10 @@ mindspore.nn.AdamOffload - **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-8。 - **use_locking** (bool) - 是否对参数更新加锁保护。如果为True,则 `w` 、`m` 和 `v` 的更新将受到锁保护。如果为False,则结果不可预测。默认值:False。 - **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True,使用NAG更新梯度。如果为False,则在不使用NAG的情况下更新梯度。默认值:False。 - - **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst + .. include:: mindspore.nn.optim_arg_loss_scale.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.AdamWeightDecay.rst b/docs/api/api_python/nn/mindspore.nn.AdamWeightDecay.rst index 4438e22cecb..7bceaf09164 100644 --- a/docs/api/api_python/nn/mindspore.nn.AdamWeightDecay.rst +++ b/docs/api/api_python/nn/mindspore.nn.AdamWeightDecay.rst @@ -51,7 +51,7 @@ mindspore.nn.AdamWeightDecay .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_order.rst @@ -63,7 +63,9 @@ mindspore.nn.AdamWeightDecay - **beta1** (float):`moment1` 的指数衰减率。参数范围(0.0,1.0)。默认值:0.9。 - **beta2** (float):`moment2` 的指数衰减率。参数范围(0.0,1.0)。默认值:0.999。 - **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-6。 - - **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.FTRL.rst b/docs/api/api_python/nn/mindspore.nn.FTRL.rst index 8db976ee94a..8252c12bd50 100644 --- a/docs/api/api_python/nn/mindspore.nn.FTRL.rst +++ b/docs/api/api_python/nn/mindspore.nn.FTRL.rst @@ -38,7 +38,7 @@ mindspore.nn.FTRL - **lr** - 学习率当前不支持参数分组。 - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_gc.rst @@ -53,7 +53,9 @@ mindspore.nn.FTRL .. include:: mindspore.nn.optim_arg_loss_scale.rst - - **weight_decay** (Union[float, int]) - 要乘以权重的权重衰减值,必须为零或正值。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.Lamb.rst b/docs/api/api_python/nn/mindspore.nn.Lamb.rst index f81935dd948..89ac46441d0 100644 --- a/docs/api/api_python/nn/mindspore.nn.Lamb.rst +++ b/docs/api/api_python/nn/mindspore.nn.Lamb.rst @@ -61,7 +61,7 @@ mindspore.nn.Lamb .. include:: mindspore.nn.optim_group_param.rst .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_gc.rst .. include:: mindspore.nn.optim_group_order.rst @@ -72,7 +72,9 @@ mindspore.nn.Lamb - **beta1** (float):第一矩的指数衰减率。参数范围(0.0,1.0)。默认值:0.9。 - **beta2** (float):第二矩的指数衰减率。参数范围(0.0,1.0)。默认值:0.999。 - **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-6。 - - **weight_decay** (float) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.LazyAdam.rst b/docs/api/api_python/nn/mindspore.nn.LazyAdam.rst index 271450a2849..0163e35b3b1 100644 --- a/docs/api/api_python/nn/mindspore.nn.LazyAdam.rst +++ b/docs/api/api_python/nn/mindspore.nn.LazyAdam.rst @@ -33,7 +33,7 @@ mindspore.nn.LazyAdam .. include:: mindspore.nn.optim_group_param.rst .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_gc.rst .. include:: mindspore.nn.optim_group_order.rst @@ -46,8 +46,10 @@ mindspore.nn.LazyAdam - **eps** (float) - 将添加到分母中,以提高数值稳定性。必须大于0。默认值:1e-8。 - **use_locking** (bool) - 是否对参数更新加锁保护。如果为True,则 `w` 、`m` 和 `v` 的Tensor更新将受到锁的保护。如果为False,则结果不可预测。默认值:False。 - **use_nesterov** (bool) - 是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。如果为True,使用NAG更新梯度。如果为False,则在不使用NAG的情况下更新梯度。默认值:False。 - - **weight_decay** (Union[float, int]) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst + .. include:: mindspore.nn.optim_arg_loss_scale.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.Momentum.rst b/docs/api/api_python/nn/mindspore.nn.Momentum.rst index 3c866e084d5..5904ee74947 100644 --- a/docs/api/api_python/nn/mindspore.nn.Momentum.rst +++ b/docs/api/api_python/nn/mindspore.nn.Momentum.rst @@ -31,7 +31,7 @@ mindspore.nn.Momentum .. include:: mindspore.nn.optim_group_param.rst .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_gc.rst .. include:: mindspore.nn.optim_group_order.rst @@ -40,7 +40,9 @@ mindspore.nn.Momentum .. include:: mindspore.nn.optim_arg_dynamic_lr.rst - **momentum** (float) - 浮点数类型的超参,表示移动平均的动量。必须等于或大于0.0。 - - **weight_decay** (int, float) - 权重衰减(L2 penalty)值。必须大于等于0.0。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst .. include:: mindspore.nn.optim_arg_loss_scale.rst diff --git a/docs/api/api_python/nn/mindspore.nn.ProximalAdagrad.rst b/docs/api/api_python/nn/mindspore.nn.ProximalAdagrad.rst index ad792247bab..81c53100cd1 100644 --- a/docs/api/api_python/nn/mindspore.nn.ProximalAdagrad.rst +++ b/docs/api/api_python/nn/mindspore.nn.ProximalAdagrad.rst @@ -30,7 +30,7 @@ mindspore.nn.ProximalAdagrad .. include:: mindspore.nn.optim_group_param.rst .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_gc.rst .. include:: mindspore.nn.optim_group_order.rst @@ -46,7 +46,9 @@ mindspore.nn.ProximalAdagrad .. include:: mindspore.nn.optim_arg_loss_scale.rst - - **weight_decay** (Union[float, int]) - 要乘以权重的权重衰减值,必须为零或正值。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.RMSProp.rst b/docs/api/api_python/nn/mindspore.nn.RMSProp.rst index 82161fd9c3d..8b53bdf964c 100644 --- a/docs/api/api_python/nn/mindspore.nn.RMSProp.rst +++ b/docs/api/api_python/nn/mindspore.nn.RMSProp.rst @@ -52,7 +52,7 @@ mindspore.nn.RMSProp .. include:: mindspore.nn.optim_group_param.rst .. include:: mindspore.nn.optim_group_lr.rst - .. include:: mindspore.nn.optim_group_weight_decay.rst + .. include:: mindspore.nn.optim_group_dynamic_weight_decay.rst .. include:: mindspore.nn.optim_group_gc.rst .. include:: mindspore.nn.optim_group_order.rst @@ -68,7 +68,9 @@ mindspore.nn.RMSProp .. include:: mindspore.nn.optim_arg_loss_scale.rst - - **weight_decay** (Union[float, int]) - 权重衰减(L2 penalty)。必须大于等于0。默认值:0.0。 + - **weight_decay** (Union[float, int, Cell]) - 权重衰减(L2 penalty)。默认值:0.0。 + + .. include:: mindspore.nn.optim_arg_dynamic_wd.rst **输入:** diff --git a/docs/api/api_python/nn/mindspore.nn.optim_arg_dynamic_wd.rst b/docs/api/api_python/nn/mindspore.nn.optim_arg_dynamic_wd.rst new file mode 100644 index 00000000000..31478a4b48e --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.optim_arg_dynamic_wd.rst @@ -0,0 +1,3 @@ +- **float**: 固定值,必须大于或者等于0。 +- **int**: 固定值,必须大于或者等于0,会被转换成float。 +- **Cell**: 动态weight decay。在训练过程中,优化器会使用步数(step)作为输入,调用该Cell实例来计算当前weight decay值。 diff --git a/docs/api/api_python/nn/mindspore.nn.optim_group_dynamic_weight_decay.rst b/docs/api/api_python/nn/mindspore.nn.optim_group_dynamic_weight_decay.rst new file mode 100644 index 00000000000..bfc5545e154 --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.optim_group_dynamic_weight_decay.rst @@ -0,0 +1,3 @@ +- **weight_decay** - 可选。如果键中存在"weight_decay”,则使用对应的值作为权重衰减值。如果没有,则使用优化器中配置的 `weight_decay` 作为权重衰减值。 + 值得注意的是,`weight_decay`可以是常量,也可以是Cell类型。Cell类型的weight decay用于实现动态weight decay算法。这和动态学习率相似。 + 用户需要自定义一个输入为global step的weight_decay_schedule。在训练的过程中,优化器会调用WeightDecaySchedule的实例来获取当前step的weight decay值。