diff --git a/mindspore/python/mindspore/nn/optim/asgd.py b/mindspore/python/mindspore/nn/optim/asgd.py index ac20f3d35ad..4b44129237e 100755 --- a/mindspore/python/mindspore/nn/optim/asgd.py +++ b/mindspore/python/mindspore/nn/optim/asgd.py @@ -177,6 +177,8 @@ class ASGD(Optimizer): gradients = self.gradients_centralization(gradients) gradients = self.scale_grad(gradients) lrs = self.get_lr() + if not self.is_dynamic_lr_or_weight_decay(): + self.assignadd(self.global_step, self.global_step_increase_tensor) success = True for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, self.parameters, self.mu, self.eta, self.ax)): diff --git a/mindspore/python/mindspore/nn/optim/lamb.py b/mindspore/python/mindspore/nn/optim/lamb.py index aaa4eb3d3f5..bbc37cf15ac 100755 --- a/mindspore/python/mindspore/nn/optim/lamb.py +++ b/mindspore/python/mindspore/nn/optim/lamb.py @@ -355,6 +355,8 @@ class Lamb(Optimizer): def construct(self, gradients): weight_decay = self.get_weight_decay() lr = self.get_lr() + if not self.is_dynamic_lr_or_weight_decay(): + self.assignadd(self.global_step, self.global_step_increase_tensor) lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt gradients = self.gradients_centralization(gradients) if self.is_group: diff --git a/mindspore/python/mindspore/nn/optim/lars.py b/mindspore/python/mindspore/nn/optim/lars.py index 5132cb83ad1..a7717ede4ac 100755 --- a/mindspore/python/mindspore/nn/optim/lars.py +++ b/mindspore/python/mindspore/nn/optim/lars.py @@ -187,4 +187,6 @@ class LARS(Optimizer): gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr, weight_decay), gradients, params, self.decay_flags, self.lars_flag) success = self.opt(gradients) + if self.is_dynamic_lr_or_weight_decay() and not self.opt.is_dynamic_lr_or_weight_decay(): + self.assignadd(self.global_step, self.global_step_increase_tensor) return success diff --git a/mindspore/python/mindspore/nn/optim/optimizer.py b/mindspore/python/mindspore/nn/optim/optimizer.py index 5271197d1c8..199a75d95d9 100644 --- a/mindspore/python/mindspore/nn/optim/optimizer.py +++ b/mindspore/python/mindspore/nn/optim/optimizer.py @@ -675,7 +675,8 @@ class Optimizer(Cell): lr += (current_dynamic_lr,) else: lr = self.learning_rate(self.global_step) - self.assignadd(self.global_step, self.global_step_increase_tensor) + if self.is_dynamic_lr_or_weight_decay(): + self.assignadd(self.global_step, self.global_step_increase_tensor) return lr def get_lr_parameter(self, param): @@ -732,6 +733,15 @@ class Optimizer(Cell): return lr if isinstance(param, list) else lr[0] + def is_dynamic_lr_or_weight_decay(self): + """ + Determine whether the learning rate or weight decay is dynamic. + + Returns: + bool, represents the learning rate or weight decay is dynamic or not. + """ + return self.dynamic_lr or self.dynamic_weight_decay + def _get_parameter_group_id(self): """ Get the parameter partition group id, which is less than the number of devices. diff --git a/mindspore/python/mindspore/nn/optim/rprop.py b/mindspore/python/mindspore/nn/optim/rprop.py index 8ce0abff56a..39932e73262 100755 --- a/mindspore/python/mindspore/nn/optim/rprop.py +++ b/mindspore/python/mindspore/nn/optim/rprop.py @@ -192,6 +192,8 @@ class Rprop(Optimizer): gradients = self.gradients_centralization(gradients) gradients = self.scale_grad(gradients) lrs = self.get_lr() + if not self.is_dynamic_lr_or_weight_decay(): + self.assignadd(self.global_step, self.global_step_increase_tensor) success = True for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self.parameters, diff --git a/tests/st/dump/test_cell_dump.py b/tests/st/dump/test_cell_dump.py index c080fe2c6fb..9d6ac70c22b 100644 --- a/tests/st/dump/test_cell_dump.py +++ b/tests/st/dump/test_cell_dump.py @@ -146,7 +146,7 @@ def test_ascend_not_cell_dump(): check_dump_structure(dump_path, dump_config_path, 1, 1, 1) # make sure set_dump is ignored and all cell layer are dumped - assert len(os.listdir(dump_file_path)) == 11 + assert len(os.listdir(dump_file_path)) == 10 del os.environ['MINDSPORE_DUMP_CONFIG']