forked from mindspore-Ecosystem/mindspore
add optimizer.get_lr_parameter() method
This commit is contained in:
parent
fd72534a1c
commit
e70b2f5430
|
@ -257,6 +257,7 @@ class Optimizer(Cell):
|
|||
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
|
||||
|
||||
for param in group_param['params']:
|
||||
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
|
||||
if param in params_store:
|
||||
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
|
||||
params_store.append(param)
|
||||
|
@ -286,6 +287,36 @@ class Optimizer(Cell):
|
|||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||
return lr
|
||||
|
||||
def get_lr_parameter(self, param):
|
||||
"""
|
||||
Get the learning rate of parameter.
|
||||
|
||||
Args:
|
||||
param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
|
||||
|
||||
Returns:
|
||||
Parameter, single `Parameter` or `list[Parameter]` according to the input type.
|
||||
"""
|
||||
if not isinstance(param, (Parameter, list)):
|
||||
raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.")
|
||||
|
||||
if isinstance(param, list):
|
||||
lr = []
|
||||
for p in param:
|
||||
validator.check_value_type("parameter", p, [Parameter], self.cls_name)
|
||||
if self.is_group_lr:
|
||||
index = self.parameters.index(p)
|
||||
lr.append(self.learning_rate[index])
|
||||
else:
|
||||
lr.append(self.learning_rate)
|
||||
else:
|
||||
if self.is_group_lr:
|
||||
index = self.parameters.index(param)
|
||||
lr = self.learning_rate[index]
|
||||
else:
|
||||
lr = self.learning_rate
|
||||
return lr
|
||||
|
||||
def construct(self, *hyper_params):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -210,3 +210,41 @@ def test_group_repeat_param():
|
|||
{'params': no_conv_params}]
|
||||
with pytest.raises(RuntimeError):
|
||||
Adam(group_params, learning_rate=default_lr)
|
||||
|
||||
|
||||
def test_get_lr_parameter_with_group():
|
||||
net = LeNet5()
|
||||
conv_lr = 0.1
|
||||
default_lr = 0.3
|
||||
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
group_params = [{'params': conv_params, 'lr': conv_lr},
|
||||
{'params': no_conv_params, 'lr': default_lr}]
|
||||
opt = SGD(group_params)
|
||||
assert opt.is_group_lr is True
|
||||
for param in opt.parameters:
|
||||
lr = opt.get_lr_parameter(param)
|
||||
assert lr.name == 'lr_' + param.name
|
||||
|
||||
lr_list = opt.get_lr_parameter(conv_params)
|
||||
for lr, param in zip(lr_list, conv_params):
|
||||
assert lr.name == 'lr_' + param.name
|
||||
|
||||
|
||||
def test_get_lr_parameter_with_no_group():
|
||||
net = LeNet5()
|
||||
conv_weight_decay = 0.8
|
||||
|
||||
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
|
||||
{'params': no_conv_params}]
|
||||
opt = SGD(group_params)
|
||||
assert opt.is_group_lr is False
|
||||
for param in opt.parameters:
|
||||
lr = opt.get_lr_parameter(param)
|
||||
assert lr.name == opt.learning_rate.name
|
||||
|
||||
params_error = [1, 2, 3]
|
||||
with pytest.raises(TypeError):
|
||||
opt.get_lr_parameter(params_error)
|
||||
|
|
Loading…
Reference in New Issue