forked from mindspore-Ecosystem/mindspore
!2819 Fix group param order
Merge pull request !2819 from ghzl/add_group_param_order
This commit is contained in:
commit
60de6aae02
|
@ -360,16 +360,18 @@ class Optimizer(Cell):
|
||||||
if len(ordered_parameters) != len(self.group_params):
|
if len(ordered_parameters) != len(self.group_params):
|
||||||
raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
|
raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
|
||||||
|
|
||||||
|
ordered_params = [None] * params_length
|
||||||
ordered_learning_rate = [None] * params_length
|
ordered_learning_rate = [None] * params_length
|
||||||
ordered_weight_decay = [None] * params_length
|
ordered_weight_decay = [None] * params_length
|
||||||
params_name = [param.name for param in ordered_parameters]
|
params_name = [param.name for param in ordered_parameters]
|
||||||
|
|
||||||
for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay):
|
for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay):
|
||||||
index = params_name.index(param.name)
|
index = params_name.index(param.name)
|
||||||
|
ordered_params[index] = param
|
||||||
ordered_learning_rate[index] = lr
|
ordered_learning_rate[index] = lr
|
||||||
ordered_weight_decay[index] = wd
|
ordered_weight_decay[index] = wd
|
||||||
|
|
||||||
self.group_params = list(ordered_parameters)
|
self.group_params = ordered_params
|
||||||
self.group_lr = ordered_learning_rate
|
self.group_lr = ordered_learning_rate
|
||||||
self.group_weight_decay = ordered_weight_decay
|
self.group_weight_decay = ordered_weight_decay
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue