!16502 check Optimizer parameters must be list

From: @wangnan39
Reviewed-by: @zh_qh,@kingxian
Signed-off-by: @zh_qh,@kingxian
This commit is contained in:
mindspore-ci-bot 2021-05-19 09:23:56 +08:00 committed by Gitee
commit c47356d77b
2 changed files with 19 additions and 17 deletions

View File

@ -121,14 +121,9 @@ class Optimizer(Cell):
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
super(Optimizer, self).__init__(auto_prefix=False)
if parameters is not None and not isinstance(parameters, list):
parameters = list(parameters)
if not parameters:
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(parameters[0], (dict, Parameter)):
raise TypeError("Only a list of Parameter or dict can be supported.")
parameters = self._parameters_base_check(parameters, "parameters")
if not all(isinstance(x, Parameter) for x in parameters) and not all(isinstance(x, dict) for x in parameters):
raise TypeError("All elements of the optimizer parameters must be of type `Parameter` or `dict`.")
if isinstance(loss_scale, int):
loss_scale = float(loss_scale)
@ -373,6 +368,17 @@ class Optimizer(Cell):
return _IteratorLearningRate(learning_rate, name)
return learning_rate
def _parameters_base_check(self, parameters, param_info):
if parameters is None:
raise ValueError(f"Optimizer {param_info} can not be None.")
if not isinstance(parameters, Iterable):
raise TypeError(f"Optimizer {param_info} must be Iterable.")
parameters = list(parameters)
if not parameters:
raise ValueError(f"Optimizer got an empty {param_info} list.")
return parameters
def _check_group_params(self, parameters):
"""Check group params."""
parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization']
@ -389,12 +395,9 @@ class Optimizer(Cell):
raise TypeError("The value of 'order_params' should be an Iterable type.")
continue
if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")
for param in group_param['params']:
if not isinstance(param, Parameter):
raise TypeError("The group param should be an iterator of Parameter type.")
parameters = self._parameters_base_check(group_param['params'], "group `params`")
if not all(isinstance(x, Parameter) for x in parameters):
raise TypeError("The group `params` should be an iterator of Parameter type.")
def _parse_group_params(self, parameters, learning_rate):
"""Parse group params."""

View File

@ -70,7 +70,7 @@ class NetWithSparseGatherV2(nn.Cell):
def test_adamwithoutparam():
net = NetWithoutWeight()
net.set_train()
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
with pytest.raises(ValueError, match=r"Optimizer got an empty parameters list"):
AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
@ -205,7 +205,6 @@ def test_adamoffload_group():
def test_AdamWeightDecay_beta1():
net = Net()
print("**********", net.get_parameters())
with pytest.raises(ValueError):
AdamWeightDecay(net.get_parameters(), beta1=1.0, learning_rate=0.1)
@ -224,5 +223,5 @@ def test_AdamWeightDecay_e():
def test_adam_mindspore_with_empty_params():
net = nn.Flatten()
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
with pytest.raises(ValueError, match=r"Optimizer got an empty parameters list"):
AdamWeightDecay(net.get_parameters())