fix group parameter code for check

This commit is contained in:
guohongzilong 2020-05-21 22:15:33 +08:00
parent bddd743ca9
commit 2d2f9ba8fd
3 changed files with 11 additions and 5 deletions

View File

@ -80,6 +80,8 @@ class LARS(Optimizer):
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name,
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0):
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")])
if optimizer.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
self.opt = optimizer
self.parameters = optimizer.parameters
self.learning_rate = optimizer.learning_rate

View File

@ -81,7 +81,7 @@ class Optimizer(Cell):
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(parameters[0], (dict, Parameter)):
raise ValueError("Only a list of Parameter or dict can be supported.")
raise TypeError("Only a list of Parameter or dict can be supported.")
if isinstance(loss_scale, int):
loss_scale = float(loss_scale)
@ -258,9 +258,9 @@ class Optimizer(Cell):
for param in group_param['params']:
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
if param in params_store:
if param.name in params_store:
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
params_store.append(param)
params_store.append(param.name)
self.group_lr.append(Parameter(lr, name="lr_" + param.name))
self.group_weight_decay.append(weight_decay_)
@ -298,18 +298,22 @@ class Optimizer(Cell):
Parameter, single `Parameter` or `list[Parameter]` according to the input type.
"""
if not isinstance(param, (Parameter, list)):
raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.")
raise TypeError(f"The parameter only support 'Parameter' or 'list' type.")
if isinstance(param, list):
lr = []
for p in param:
validator.check_value_type("parameter", p, [Parameter], self.cls_name)
if p not in self.parameters:
raise ValueError(f"The parameter {p.name} is not in optimizer.")
if self.is_group_lr:
index = self.parameters.index(p)
lr.append(self.learning_rate[index])
else:
lr.append(self.learning_rate)
else:
if param not in self.parameters:
raise ValueError(f"The parameter {param.name} is not in optimizer.")
if self.is_group_lr:
index = self.parameters.index(param)
lr = self.learning_rate[index]

View File

@ -94,7 +94,7 @@ class TestUnsupportParam():
""" TestUnsupportParam definition """
def test_optim_init(self):
with pytest.raises(ValueError):
with pytest.raises(TypeError):
Optimizer(0.1, (1, 2, 3))
def test_AdamWightDecay_init(self):