forked from OSSInnovation/mindspore
fix params KeyError in group params
This commit is contained in:
parent
1cadea12f0
commit
9409f83245
|
@ -219,8 +219,32 @@ class Optimizer(Cell):
|
|||
raise TypeError("Learning rate should be float, Tensor or Iterable.")
|
||||
return lr
|
||||
|
||||
def _check_group_params(self, parameters):
|
||||
"""Check group params."""
|
||||
parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
|
||||
for group_param in parameters:
|
||||
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
|
||||
if invalid_key:
|
||||
raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
|
||||
|
||||
if 'order_params' in group_param.keys():
|
||||
if len(group_param.keys()) > 1:
|
||||
raise ValueError("The order params dict in group parameters should "
|
||||
"only include the 'order_params' key.")
|
||||
if not isinstance(group_param['order_params'], Iterable):
|
||||
raise TypeError("The value of 'order_params' should be an Iterable type.")
|
||||
continue
|
||||
|
||||
if not group_param['params']:
|
||||
raise ValueError("Optimizer got an empty group parameter list.")
|
||||
|
||||
for param in group_param['params']:
|
||||
if not isinstance(param, Parameter):
|
||||
raise TypeError("The group param should be an iterator of Parameter type.")
|
||||
|
||||
def _parse_group_params(self, parameters, learning_rate):
|
||||
"""Parse group params."""
|
||||
self._check_group_params(parameters)
|
||||
if self.dynamic_lr:
|
||||
dynamic_lr_length = learning_rate.size()
|
||||
else:
|
||||
|
@ -250,9 +274,6 @@ class Optimizer(Cell):
|
|||
if dynamic_lr_length not in (lr_length, 0):
|
||||
raise ValueError("The dynamic learning rate in group should be the same size.")
|
||||
|
||||
if not group_param['params']:
|
||||
raise ValueError("Optimizer got an empty group parameter list.")
|
||||
|
||||
dynamic_lr_length = lr_length
|
||||
self.dynamic_lr_length = dynamic_lr_length
|
||||
|
||||
|
|
|
@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer):
|
|||
Output tensor or string to stdout.
|
||||
|
||||
Note:
|
||||
The print operation cannot support the following cases currently.
|
||||
|
||||
1. The type of tensor is float64 or bool.
|
||||
|
||||
2. The data of tensor is a scalar type.
|
||||
|
||||
In pynative mode, please use python print function.
|
||||
|
||||
Inputs:
|
||||
|
|
Loading…
Reference in New Issue