forked from mindspore-Ecosystem/mindspore
!1494 Add check for empty group parameters
Merge pull request !1494 from ghzl/add-check-parameter-for-group-learning-rate
This commit is contained in:
commit
7b1031bf86
|
@ -122,7 +122,7 @@ class Optimizer(Cell):
|
|||
learning_rate = self._get_single_lr(learning_rate)
|
||||
if isinstance(parameters[0], dict):
|
||||
self.is_group = True
|
||||
self.params = []
|
||||
self.group_params = []
|
||||
self.group_lr = []
|
||||
self.group_weight_decay = []
|
||||
self._init_group_params(parameters, learning_rate, weight_decay)
|
||||
|
@ -133,7 +133,7 @@ class Optimizer(Cell):
|
|||
self.learning_rate = Parameter(learning_rate, name="learning_rate")
|
||||
|
||||
if self.is_group:
|
||||
self.parameters = ParameterTuple(self.params)
|
||||
self.parameters = ParameterTuple(self.group_params)
|
||||
self.weight_decay = tuple(self.group_weight_decay)
|
||||
decay_filter = lambda x: x > 0
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
|
||||
|
@ -240,7 +240,10 @@ class Optimizer(Cell):
|
|||
|
||||
params_store = []
|
||||
for group_param in parameters:
|
||||
self.params += group_param['params']
|
||||
if not group_param['params']:
|
||||
raise ValueError("Optimizer got an empty parameter list.")
|
||||
|
||||
self.group_params += group_param['params']
|
||||
if 'lr' in group_param.keys():
|
||||
params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor))
|
||||
|
||||
|
|
Loading…
Reference in New Issue