!32938 optimize global step for optimizer

Merge pull request !32938 from zhangbuxue/optimize_global_step_for_optimizer
This commit is contained in:
i-robot 2022-04-16 03:53:04 +00:00 committed by Gitee
commit 84dcf74bbd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 20 additions and 2 deletions

View File

@ -177,6 +177,8 @@ class ASGD(Optimizer):
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lrs = self.get_lr()
if not self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
success = True
for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, self.parameters, self.mu, self.eta, self.ax)):

View File

@ -355,6 +355,8 @@ class Lamb(Optimizer):
def construct(self, gradients):
weight_decay = self.get_weight_decay()
lr = self.get_lr()
if not self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
gradients = self.gradients_centralization(gradients)
if self.is_group:

View File

@ -187,4 +187,6 @@ class LARS(Optimizer):
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr, weight_decay),
gradients, params, self.decay_flags, self.lars_flag)
success = self.opt(gradients)
if self.is_dynamic_lr_or_weight_decay() and not self.opt.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
return success

View File

@ -675,7 +675,8 @@ class Optimizer(Cell):
lr += (current_dynamic_lr,)
else:
lr = self.learning_rate(self.global_step)
self.assignadd(self.global_step, self.global_step_increase_tensor)
if self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
return lr
def get_lr_parameter(self, param):
@ -732,6 +733,15 @@ class Optimizer(Cell):
return lr if isinstance(param, list) else lr[0]
def is_dynamic_lr_or_weight_decay(self):
"""
Determine whether the learning rate or weight decay is dynamic.
Returns:
bool, represents the learning rate or weight decay is dynamic or not.
"""
return self.dynamic_lr or self.dynamic_weight_decay
def _get_parameter_group_id(self):
"""
Get the parameter partition group id, which is less than the number of devices.

View File

@ -192,6 +192,8 @@ class Rprop(Optimizer):
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lrs = self.get_lr()
if not self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
success = True
for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self.parameters,

View File

@ -146,7 +146,7 @@ def test_ascend_not_cell_dump():
check_dump_structure(dump_path, dump_config_path, 1, 1, 1)
# make sure set_dump is ignored and all cell layer are dumped
assert len(os.listdir(dump_file_path)) == 11
assert len(os.listdir(dump_file_path)) == 10
del os.environ['MINDSPORE_DUMP_CONFIG']