forked from mindspore-Ecosystem/mindspore
!16502 check Optimizer parameters must be list
From: @wangnan39 Reviewed-by: @zh_qh,@kingxian Signed-off-by: @zh_qh,@kingxian
This commit is contained in:
commit
c47356d77b
|
@ -121,14 +121,9 @@ class Optimizer(Cell):
|
|||
|
||||
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
|
||||
super(Optimizer, self).__init__(auto_prefix=False)
|
||||
if parameters is not None and not isinstance(parameters, list):
|
||||
parameters = list(parameters)
|
||||
|
||||
if not parameters:
|
||||
raise ValueError("Optimizer got an empty parameter list.")
|
||||
|
||||
if not isinstance(parameters[0], (dict, Parameter)):
|
||||
raise TypeError("Only a list of Parameter or dict can be supported.")
|
||||
parameters = self._parameters_base_check(parameters, "parameters")
|
||||
if not all(isinstance(x, Parameter) for x in parameters) and not all(isinstance(x, dict) for x in parameters):
|
||||
raise TypeError("All elements of the optimizer parameters must be of type `Parameter` or `dict`.")
|
||||
|
||||
if isinstance(loss_scale, int):
|
||||
loss_scale = float(loss_scale)
|
||||
|
@ -373,6 +368,17 @@ class Optimizer(Cell):
|
|||
return _IteratorLearningRate(learning_rate, name)
|
||||
return learning_rate
|
||||
|
||||
def _parameters_base_check(self, parameters, param_info):
|
||||
if parameters is None:
|
||||
raise ValueError(f"Optimizer {param_info} can not be None.")
|
||||
if not isinstance(parameters, Iterable):
|
||||
raise TypeError(f"Optimizer {param_info} must be Iterable.")
|
||||
parameters = list(parameters)
|
||||
|
||||
if not parameters:
|
||||
raise ValueError(f"Optimizer got an empty {param_info} list.")
|
||||
return parameters
|
||||
|
||||
def _check_group_params(self, parameters):
|
||||
"""Check group params."""
|
||||
parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization']
|
||||
|
@ -389,12 +395,9 @@ class Optimizer(Cell):
|
|||
raise TypeError("The value of 'order_params' should be an Iterable type.")
|
||||
continue
|
||||
|
||||
if not group_param['params']:
|
||||
raise ValueError("Optimizer got an empty group parameter list.")
|
||||
|
||||
for param in group_param['params']:
|
||||
if not isinstance(param, Parameter):
|
||||
raise TypeError("The group param should be an iterator of Parameter type.")
|
||||
parameters = self._parameters_base_check(group_param['params'], "group `params`")
|
||||
if not all(isinstance(x, Parameter) for x in parameters):
|
||||
raise TypeError("The group `params` should be an iterator of Parameter type.")
|
||||
|
||||
def _parse_group_params(self, parameters, learning_rate):
|
||||
"""Parse group params."""
|
||||
|
|
|
@ -70,7 +70,7 @@ class NetWithSparseGatherV2(nn.Cell):
|
|||
def test_adamwithoutparam():
|
||||
net = NetWithoutWeight()
|
||||
net.set_train()
|
||||
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
|
||||
with pytest.raises(ValueError, match=r"Optimizer got an empty parameters list"):
|
||||
AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
|
||||
|
||||
|
||||
|
@ -205,7 +205,6 @@ def test_adamoffload_group():
|
|||
|
||||
def test_AdamWeightDecay_beta1():
|
||||
net = Net()
|
||||
print("**********", net.get_parameters())
|
||||
with pytest.raises(ValueError):
|
||||
AdamWeightDecay(net.get_parameters(), beta1=1.0, learning_rate=0.1)
|
||||
|
||||
|
@ -224,5 +223,5 @@ def test_AdamWeightDecay_e():
|
|||
|
||||
def test_adam_mindspore_with_empty_params():
|
||||
net = nn.Flatten()
|
||||
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
|
||||
with pytest.raises(ValueError, match=r"Optimizer got an empty parameters list"):
|
||||
AdamWeightDecay(net.get_parameters())
|
||||
|
|
Loading…
Reference in New Issue