!32938 optimize global step for optimizer
Merge pull request !32938 from zhangbuxue/optimize_global_step_for_optimizer
This commit is contained in:
commit
84dcf74bbd
|
@ -177,6 +177,8 @@ class ASGD(Optimizer):
|
|||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lrs = self.get_lr()
|
||||
if not self.is_dynamic_lr_or_weight_decay():
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
success = True
|
||||
|
||||
for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, self.parameters, self.mu, self.eta, self.ax)):
|
||||
|
|
|
@ -355,6 +355,8 @@ class Lamb(Optimizer):
|
|||
def construct(self, gradients):
|
||||
weight_decay = self.get_weight_decay()
|
||||
lr = self.get_lr()
|
||||
if not self.is_dynamic_lr_or_weight_decay():
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
if self.is_group:
|
||||
|
|
|
@ -187,4 +187,6 @@ class LARS(Optimizer):
|
|||
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr, weight_decay),
|
||||
gradients, params, self.decay_flags, self.lars_flag)
|
||||
success = self.opt(gradients)
|
||||
if self.is_dynamic_lr_or_weight_decay() and not self.opt.is_dynamic_lr_or_weight_decay():
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
return success
|
||||
|
|
|
@ -675,7 +675,8 @@ class Optimizer(Cell):
|
|||
lr += (current_dynamic_lr,)
|
||||
else:
|
||||
lr = self.learning_rate(self.global_step)
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
if self.is_dynamic_lr_or_weight_decay():
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
return lr
|
||||
|
||||
def get_lr_parameter(self, param):
|
||||
|
@ -732,6 +733,15 @@ class Optimizer(Cell):
|
|||
|
||||
return lr if isinstance(param, list) else lr[0]
|
||||
|
||||
def is_dynamic_lr_or_weight_decay(self):
|
||||
"""
|
||||
Determine whether the learning rate or weight decay is dynamic.
|
||||
|
||||
Returns:
|
||||
bool, represents the learning rate or weight decay is dynamic or not.
|
||||
"""
|
||||
return self.dynamic_lr or self.dynamic_weight_decay
|
||||
|
||||
def _get_parameter_group_id(self):
|
||||
"""
|
||||
Get the parameter partition group id, which is less than the number of devices.
|
||||
|
|
|
@ -192,6 +192,8 @@ class Rprop(Optimizer):
|
|||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lrs = self.get_lr()
|
||||
if not self.is_dynamic_lr_or_weight_decay():
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
success = True
|
||||
|
||||
for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self.parameters,
|
||||
|
|
|
@ -146,7 +146,7 @@ def test_ascend_not_cell_dump():
|
|||
check_dump_structure(dump_path, dump_config_path, 1, 1, 1)
|
||||
|
||||
# make sure set_dump is ignored and all cell layer are dumped
|
||||
assert len(os.listdir(dump_file_path)) == 11
|
||||
assert len(os.listdir(dump_file_path)) == 10
|
||||
del os.environ['MINDSPORE_DUMP_CONFIG']
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue