forked from mindspore-Ecosystem/mindspore
!709 fix weight decay error in optimizer AdamWeightDecay
Merge pull request !709 from wangnan39/fix_bug_in_adamweightdecay
This commit is contained in:
commit
1f2ca74cd1
|
@ -31,8 +31,8 @@ _learning_rate_update_func = ['linear', 'cos', 'sin']
|
|||
adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
|
||||
|
||||
@adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient):
|
||||
@adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
|
@ -67,7 +67,8 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|||
next_v = op_mul(beta2, v) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient))
|
||||
|
||||
update = next_m / (op_sqrt(next_v) + eps)
|
||||
update = update + op_mul(weight_decay_tensor, param)
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param)
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param - op_reshape(update_with_lr, op_shape(param))
|
||||
|
@ -90,6 +91,17 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
|||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
|
||||
|
||||
def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_float_positive('learning_rate', learning_rate, prim_name)
|
||||
validator.check_float_legal_value('learning_rate', learning_rate, prim_name)
|
||||
validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_positive('power', power, prim_name)
|
||||
validator.check_float_legal_value('power', power, prim_name)
|
||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
|
||||
|
||||
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor")
|
||||
def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1,
|
||||
|
@ -126,8 +138,13 @@ class Adam(Optimizer):
|
|||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
learning_rate (Union[float, Tensor, Iterable]): The Learning rate.
|
||||
Iterable type is used for the dynamic learning rate.
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
|
@ -140,6 +157,8 @@ class Adam(Optimizer):
|
|||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
||||
Should be equal to or greater than 1.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -207,7 +226,13 @@ class AdamWeightDecay(Optimizer):
|
|||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
learning_rate (float): A floating point value for the learning rate. Default: 1e-3.
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
|
||||
|
@ -215,6 +240,8 @@ class AdamWeightDecay(Optimizer):
|
|||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -228,10 +255,10 @@ class AdamWeightDecay(Optimizer):
|
|||
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(AdamWeightDecay, self).__init__(learning_rate, params)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
self.lr = Tensor(np.array([learning_rate]).astype(np.float32))
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
|
@ -240,13 +267,15 @@ class AdamWeightDecay(Optimizer):
|
|||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, gradients):
|
||||
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, self.lr,
|
||||
lr = self.get_lr()
|
||||
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients)
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
return updated_velocity
|
||||
|
||||
|
@ -269,6 +298,8 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -291,10 +322,11 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
eps=1e-6,
|
||||
weight_decay=0.0):
|
||||
weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
|
||||
_check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
|
||||
# turn them to scalar when me support scalar/tensor mix operations
|
||||
self.global_step = Parameter(initializer(0, [1]), name="global_step")
|
||||
self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32))
|
||||
|
@ -308,7 +340,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.min = P.Minimum()
|
||||
self.pow = P.Pow()
|
||||
|
@ -320,7 +352,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
|
||||
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients)
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
added_global_step = self.global_step + self.one
|
||||
F.control_depend(lr, added_global_step)
|
||||
|
|
|
@ -112,16 +112,18 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
|||
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
|
||||
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("decay_steps", decay_steps, [int], prim_name)
|
||||
validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name)
|
||||
validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name)
|
||||
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
|
||||
validator.check_value_type("power", power, [float], prim_name)
|
||||
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
|
||||
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
|
||||
validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name)
|
||||
validator.check_float_positive('power', power, prim_name)
|
||||
validator.check_float_legal_value('power', power, prim_name)
|
||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
|
|
|
@ -42,7 +42,13 @@ class SGD(Optimizer):
|
|||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
learning_rate (float): A floating point value for the learning rate. Default: 0.1.
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 0.1.
|
||||
momentum (float): A floating point value the momentum. Default: 0.
|
||||
dampening (float): A floating point value of dampening for momentum. Default: 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.
|
||||
|
|
Loading…
Reference in New Issue