!3915 Revert modification of opt

Merge pull request !3915 from Simson/push-to-opensource
This commit is contained in:
mindspore-ci-bot 2020-08-04 15:59:20 +08:00 committed by Gitee
commit 8b396cea98
10 changed files with 23 additions and 22 deletions

View File

@ -40,7 +40,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (Number): Weight decay. Should be in range [0.0, 1.0].
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
@ -200,8 +200,8 @@ class Adam(Optimizer):
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
If True, updates the gradients using NAG.
If False, updates the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be not less than 1.0. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -318,7 +318,7 @@ class AdamWeightDecay(Optimizer):
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

View File

@ -116,7 +116,7 @@ class FTRL(Optimizer):
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If True use locks for update operation. Default: False.
loss_scale (float): Value for the loss scale. It should be equal to or greater than 1.0. Default: 1.0.
weight_decay (float): Weight decay value to multiply weight, should be in range [0.0, 1.0]. Default: 0.0.
weight_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0.
Inputs:
- **grads** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is as same as the `params`

View File

@ -43,7 +43,7 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (Number): Weight decay. Should be in range [0.0, 1.0].
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
global_step (Tensor): Global step.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
@ -126,7 +126,7 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr, weight_decay
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (Number): Weight decay. Should be in range [0.0, 1.0].
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
global_step (Tensor): Global step.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
@ -227,7 +227,7 @@ class Lamb(Optimizer):
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. Should be in range [0.0, 1.0].
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. Should be equal to or greater than 0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

View File

@ -133,7 +133,7 @@ class LazyAdam(Optimizer):
If True, updates the gradients using NAG.
If False, updates the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It should be not less than 1.0. Default:
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
1.0.
Inputs:

View File

@ -92,8 +92,8 @@ class Momentum(Optimizer):
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
momentum (float): Hyperparameter of type float, means momentum for the moving average.
It should be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
loss_scale (int, float): A floating point value for the loss scale. Should be not less than 1.0. Default: 1.0.
weight_decay (int, float): Weight decay (L2 penalty). It should be equal to or greater than 0.0. Default: 0.0.
loss_scale (int, float): A floating point value for the loss scale. It should be greater than 0.0. Default: 1.0.
use_nesterov (bool): Enable Nesterov momentum. Default: False.
Inputs:

View File

@ -78,9 +78,9 @@ class Optimizer(Cell):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
weight_decay (float): A floating point value for the weight decay. It should be in range [0.0, 1.0].
weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0.
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It should be not less than 1.0. If the
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the
type of `loss_scale` input is int, it will be converted to float. Default: 1.0.
Raises:
@ -102,7 +102,7 @@ class Optimizer(Cell):
if isinstance(loss_scale, int):
loss_scale = float(loss_scale)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
self.loss_scale = loss_scale
weight_decay = self._preprocess_weight_decay(weight_decay)

View File

@ -98,8 +98,8 @@ class ProximalAdagrad(Optimizer):
l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If True use locks for update operation. Default: False.
loss_scale (float): Value for the loss scale. It should be not less than 1.0. Default: 1.0.
weight_decay (float): Weight decay value to multiply weight, should be in range [0.0, 1.0]. Default: 0.0.
loss_scale (float): Value for the loss scale. It should be greater than 0.0. Default: 1.0.
weight_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0.
Inputs:
- **grads** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is as same as the `params`

View File

@ -121,8 +121,8 @@ class RMSProp(Optimizer):
0. Default: 1e-10.
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False.
loss_scale (float): A floating point value for the loss scale. Should be not less than 1.0. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Should be in range [0.0, 1.0]. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

View File

@ -88,10 +88,11 @@ class SGD(Optimizer):
Default: 0.1.
momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0.
dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
and dampening must equal to 0.0. Default: False.
loss_scale (float): A floating point value for the loss scale. Should be not less than 1.0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale, which should be larger
than 0.0. Default: 1.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

View File

@ -98,7 +98,7 @@ def test_momentum_with_loss_scale():
net = Net(strategy1, strategy2, weight)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=1.0)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=0.5)
net_with_loss = NetWithLoss(net, strategy3)
@ -169,7 +169,7 @@ def test_momentum_with_loss_scale_and_dynamic_lr():
net = Net(strategy1, strategy2, weight)
lr = Tensor(np.ones([6]), dtype=ms.float32)
optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9, loss_scale=1.0)
optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9, loss_scale=0.5)
net_with_loss = NetWithLoss(net, strategy3)