forked from mindspore-Ecosystem/mindspore
!1065 Seperate lr groups and weight_decay groups
Merge pull request !1065 from ghzl/improve-parameter-groups
This commit is contained in:
commit
fd72534a1c
|
@ -243,7 +243,7 @@ class Adam(Optimizer):
|
|||
self.beta1_power = beta1_power
|
||||
beta2_power = self.beta2_power * self.beta2
|
||||
self.beta2_power = beta2_power
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
|
||||
self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2)
|
||||
|
|
|
@ -111,7 +111,7 @@ class Momentum(Optimizer):
|
|||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
||||
|
|
|
@ -94,6 +94,7 @@ class Optimizer(Cell):
|
|||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
|
||||
self.is_group = False
|
||||
self.is_group_lr = False
|
||||
self.loss_scale = loss_scale
|
||||
if isinstance(learning_rate, float):
|
||||
self.dynamic_lr = False
|
||||
|
@ -116,14 +117,17 @@ class Optimizer(Cell):
|
|||
self.group_weight_decay = []
|
||||
self._init_group_params(parameters, learning_rate, weight_decay)
|
||||
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
self.learning_rate = ParameterTuple(self.group_lr)
|
||||
else:
|
||||
self.learning_rate = Parameter(learning_rate, name="learning_rate")
|
||||
|
||||
if self.is_group:
|
||||
self.parameters = ParameterTuple(self.params)
|
||||
self.weight_decay = tuple(self.group_weight_decay)
|
||||
decay_filter = lambda x: x > 0
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
|
||||
else:
|
||||
self.learning_rate = Parameter(learning_rate, name="learning_rate")
|
||||
self.parameters = ParameterTuple(parameters)
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
|
||||
|
@ -207,6 +211,7 @@ class Optimizer(Cell):
|
|||
for group_param in parameters:
|
||||
lr_length = dynamic_lr_length
|
||||
if 'lr' in group_param.keys():
|
||||
self.is_group_lr = True
|
||||
self._get_single_lr(group_param['lr'])
|
||||
if isinstance(group_param['lr'], Iterable):
|
||||
lr_length = len(group_param['lr'])
|
||||
|
@ -247,6 +252,10 @@ class Optimizer(Cell):
|
|||
else:
|
||||
weight_decay_ = weight_decay * self.loss_scale
|
||||
|
||||
for key in group_param.keys():
|
||||
if key not in ('params', 'lr', 'weight_decay'):
|
||||
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
|
||||
|
||||
for param in group_param['params']:
|
||||
if param in params_store:
|
||||
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
|
||||
|
@ -261,7 +270,7 @@ class Optimizer(Cell):
|
|||
Returns:
|
||||
float, the learning rate of current step.
|
||||
"""
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
lr = self.learning_rate
|
||||
if self.dynamic_lr:
|
||||
lr = ()
|
||||
|
|
|
@ -176,7 +176,7 @@ class RMSProp(Optimizer):
|
|||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.centered:
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum), lr, params, self.mg, self.ms, self.moment, gradients)
|
||||
else:
|
||||
|
@ -184,7 +184,7 @@ class RMSProp(Optimizer):
|
|||
self.momentum, lr), params, self.mg, self.ms, self.moment, gradients)
|
||||
|
||||
else:
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
|
||||
self.momentum), lr, params, self.ms, self.moment, gradients)
|
||||
else:
|
||||
|
|
|
@ -139,7 +139,7 @@ class SGD(Optimizer):
|
|||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
|
||||
|
|
|
@ -65,12 +65,13 @@ def test_group_lr():
|
|||
|
||||
opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9)
|
||||
assert opt.is_group is True
|
||||
assert opt.is_group_lr is True
|
||||
assert opt.dynamic_lr is False
|
||||
for lr, param in zip(opt.learning_rate, opt.parameters):
|
||||
if param in conv_params:
|
||||
assert lr.data == Tensor(conv_lr, mstype.float32)
|
||||
assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy())
|
||||
else:
|
||||
assert lr.data == Tensor(default_lr, mstype.float32)
|
||||
assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy())
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, opt)
|
||||
|
@ -96,9 +97,9 @@ def test_group_dynamic_1():
|
|||
assert opt.dynamic_lr is True
|
||||
for lr, param in zip(opt.learning_rate, opt.parameters):
|
||||
if param in conv_params:
|
||||
assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32))
|
||||
assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy())
|
||||
else:
|
||||
assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32))
|
||||
assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy())
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, opt)
|
||||
|
@ -124,9 +125,9 @@ def test_group_dynamic_2():
|
|||
assert opt.dynamic_lr is True
|
||||
for lr, param in zip(opt.learning_rate, opt.parameters):
|
||||
if param in conv_params:
|
||||
assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))
|
||||
assert np.all(lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)))
|
||||
else:
|
||||
assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))
|
||||
assert np.all(lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)))
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, opt)
|
||||
|
@ -184,6 +185,7 @@ def test_weight_decay():
|
|||
|
||||
opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay)
|
||||
assert opt.is_group is True
|
||||
assert opt.is_group_lr is False
|
||||
for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters):
|
||||
if param in conv_params:
|
||||
assert weight_decay == conv_weight_decay
|
||||
|
|
Loading…
Reference in New Issue