Fix boost same parameter name.

This commit is contained in:
linqingke 2022-03-17 16:45:31 +08:00
parent 3df01a5176
commit 9c1bafda44
5 changed files with 88 additions and 16 deletions

View File

@ -69,13 +69,17 @@ class OptimizerProcess:
def __init__(self, opt):
if isinstance(opt, LARS):
self.is_lars = True
self.single_opt = opt.opt
self.opt_class = type(opt.opt)
self.opt_init_args = opt.opt.init_args
self.lars_init_args = opt.init_args
self.learning_rate = opt.opt.init_learning_rate
else:
self.is_lars = False
self.single_opt = opt
self.opt_class = type(opt)
self.opt_init_args = opt.init_args
self.learning_rate = opt.init_learning_rate
self.origin_params = opt.init_params["params"]
def build_params_dict(self, network):
@ -155,10 +159,13 @@ class OptimizerProcess:
def generate_new_optimizer(self):
"""Generate new optimizer."""
if self.learning_rate is None:
self.learning_rate = self.single_opt.learning_rate
if not self.is_lars:
opt = self.opt_class(params=self.origin_params, **self.opt_init_args)
opt = self.opt_class(params=self.origin_params, learning_rate=self.learning_rate, **self.opt_init_args)
else:
opt = LARS(self.opt_class(params=self.origin_params, **self.opt_init_args), **self.lars_init_args)
opt = LARS(self.opt_class(params=self.origin_params, learning_rate=self.learning_rate, \
**self.opt_init_args), **self.lars_init_args)
return opt

View File

@ -59,33 +59,45 @@ class FreezeOpt(Cell):
self.opt_class = type(opt.opt)
self.opt_init_args = opt.opt.init_args
self.lars_init_args = opt.init_args
self.single_opt = opt.opt
self.parameters = opt.opt.parameters
self.learning_rate = opt.opt.init_learning_rate
self.dynamic_lr = opt.opt.dynamic_lr
else:
self.is_lars = False
self.opt_class = type(opt)
self.opt_init_args = opt.init_args
self.single_opt = opt
self.parameters = opt.parameters
self.opts = []
self.learning_rate = opt.init_learning_rate
self.dynamic_lr = opt.dynamic_lr
self.opts = []
if train_parameter_groups is None:
groups_num = 10
self.groups_num = 10
step = 6
parameters = opt.parameters
para_groups = (parameters[(i * step):] for i in range(groups_num))
self.opts = [self._generate_new_optimizer(
params) for params in para_groups]
train_parameter_groups = (tuple(parameters[(i * step):]) for i in range(self.groups_num))
else:
if not isinstance(train_parameter_groups, (tuple, list)):
raise TypeError(
"The specified 'train_parameter_groups' should be tuple or list")
for params in train_parameter_groups:
if not isinstance(params, (tuple, list)):
raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
"to store the Parameter")
self.groups_num = len(train_parameter_groups)
# generate one-to-one opt corresponding to the parameter group
self.opts.append(self._generate_new_optimizer(params))
self._init_train_strategy(train_strategy)
self._create_new_group_learning_rate()
self.opt_index = 0
for params in train_parameter_groups:
if not isinstance(params, (tuple, list)):
raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
"to store the Parameter")
# generate one-to-one opt corresponding to the parameter group
self.opts.append(self._generate_new_optimizer(params))
self.opt_index += 1
def _init_train_strategy(self, train_strategy):
"""Init train strategy for gradient freeze."""
if isinstance(train_strategy, (tuple, list)):
for ele in train_strategy:
if not isinstance(ele, int):
@ -103,13 +115,32 @@ class FreezeOpt(Cell):
raise TypeError(
"The specified 'train_strategy' should be None, tuple, list or Tensor")
def _create_new_group_learning_rate(self):
"""Create new learning rate for different global step."""
self.dynamic_learning_rate = [[] for _ in range(self.groups_num)]
if self.learning_rate is None:
self.learning_rate = self.single_opt.learning_rate
return
if self.dynamic_lr and isinstance(self.learning_rate, list) and isinstance(self.train_strategy, Tensor):
train_strategy = list(self.train_strategy.asnumpy())
if len(self.learning_rate) <= len(train_strategy):
for i, lr in enumerate(self.learning_rate):
self.dynamic_learning_rate[train_strategy[i]].append(lr)
def _generate_new_optimizer(self, params):
"""Generate new optimizer."""
if not self.is_lars:
opt = self.opt_class(params=params, **self.opt_init_args)
if self.dynamic_learning_rate[self.opt_index]:
lr = self.dynamic_learning_rate[self.opt_index]
else:
opt = LARS(self.opt_class(params=params, **self.opt_init_args),
lr = self.learning_rate
if not self.is_lars:
opt = self.opt_class(params=params, learning_rate=lr, **self.opt_init_args)
opt._update_local_parameters_name("boost_{}".format(self.opt_index)) # pylint: disable=W0212
else:
opt = LARS(self.opt_class(params=params, learning_rate=lr, **self.opt_init_args),
**self.lars_init_args)
opt.opt._update_local_parameters_name("boost_{}".format(self.opt_index)) # pylint: disable=W0212
opt._update_local_parameters_name("boost_{}".format(self.opt_index)) # pylint: disable=W0212
return opt

View File

@ -137,6 +137,7 @@ class Cell(Cell_):
self._auto_parallel_compile_and_run = False
self.cast = Cast()
self._has_config_recompute = False
self._user_parameters = []
def __getstate__(self):
base = Cell_.__getstate__(self)
@ -1158,6 +1159,27 @@ class Cell(Cell_):
param.is_init = False
param.name = prefix + name
def _update_local_parameters_name(self, prefix='', recurse=True):
"""
Updates the names of local parameters with given prefix string.
Adds the given prefix to the names of local parameters.
Local parameters means the parameters without user input.
Args:
prefix (str): The prefix string. Default: ''.
recurse (bool): Whether contains the parameters of subcells. Default: True.
"""
Validator.check_str_by_regular(prefix)
for name, param in self.parameters_and_names(expand=recurse):
if name in self._user_parameters:
continue
if prefix != '':
param.is_init = False
param.name = prefix + name
def trainable_params(self, recurse=True):
"""
Returns all trainable parameters.

View File

@ -114,6 +114,7 @@ class LARS(Optimizer):
self.weight_decay = optimizer.weight_decay
self.global_step = optimizer.global_step
self.parameters = optimizer.parameters
self._user_parameters += [param.name for param in self.parameters]
self.use_clip = use_clip
self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
self.is_group = optimizer.is_group

View File

@ -50,6 +50,14 @@ def opt_init_args_register(fn):
if 'optimizer' in arguments.keys():
setattr(self, 'init_params', dict({"params": arguments['optimizer'].init_params["params"]}))
arguments.pop('optimizer')
if 'learning_rate' in arguments.keys():
if isinstance(arguments['learning_rate'], Tensor):
arguments['learning_rate'] = list(arguments['learning_rate'].asnumpy())
if isinstance(arguments['learning_rate'], Cell):
setattr(self, 'init_learning_rate', None)
else:
setattr(self, 'init_learning_rate', arguments['learning_rate'])
arguments.pop('learning_rate')
setattr(self, 'init_args', arguments)
fn(self, *args, **kwargs)
return deco
@ -193,6 +201,9 @@ class Optimizer(Cell):
if param.unique:
self._unique = False
break
# set user's parameters as local parameters
for param in self.parameters:
self._user_parameters.append(param.name)
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
cache_filter = lambda x: x.cache_enable