diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index b55d1c55746..7b05b372eb2 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -13,22 +13,18 @@ # limitations under the License. # ============================================================================ """lars optimizer""" -from typing import Iterable -from mindspore.common import dtype as mstype -from mindspore.common import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore._checkparam import Validator as validator +from mindspore.common import Tensor, Parameter, dtype as mstype from .optimizer import _grad_scale, Optimizer _lars_opt = C.MultitypeFuncGraph("lars_opt") -@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool") -def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): +@_lars_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Bool", "Bool") +def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag): """Apply lars optimizer to the weight parameter.""" if lars_flag: op_reduce_sum = P.SquareSumAll() @@ -42,10 +38,12 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f return gradient -def _check_param_value(optimizer, epsilon, hyperpara, use_clip, prim_name): +def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name): validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) + if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name: + raise TypeError("LARS can not be used with ", optimizer.cls_name) validator.check_value_type("epsilon", epsilon, [float], prim_name) - validator.check_value_type("hyperpara", hyperpara, [float], prim_name) + validator.check_value_type("coefficient", coefficient, [float], prim_name) validator.check_value_type("use_clip", use_clip, [bool], prim_name) class LARS(Optimizer): @@ -58,14 +56,10 @@ class LARS(Optimizer): Args: optimizer (Optimizer): MindSpore optimizer for which to wrap and modify gradients. epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05. - hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001. - weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0. + coefficient (float): Trust coefficient for calculating the local learning rate. Default: 0.001. use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False. - decay_filter (Function): A function to determine whether apply weight decay on parameters. Default: - lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. lars_filter (Function): A function to determine whether apply lars algorithm. Default: lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. - loss_scale (float): A floating point value for the loss scale. It should be greater than 0. Default: 1.0. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is @@ -78,51 +72,54 @@ class LARS(Optimizer): >>> net = Net() >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9) - >>> opt_lars = nn.LARS(opt, epsilon=1e-08, hyperpara=0.02) + >>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02) >>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None) """ - def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, - decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, - lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): - super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")], weight_decay, loss_scale) - if optimizer.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(optimizer, epsilon, hyperpara, use_clip, self.cls_name) + def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, + lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): + super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) + _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) self.opt = optimizer - self.parameters = optimizer.parameters - self.learning_rate = optimizer.learning_rate - self.lars = P.LARSUpdate(epsilon, hyperpara, use_clip) - self.reciprocal_scale = 1.0 / loss_scale - self.weight_decay = weight_decay + self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.cast = P.Cast() - self.decay_flag = tuple(decay_filter(x) for x in self.parameters) + self.parameters = optimizer.parameters + if use_clip is True: + self.learning_rate = optimizer.learning_rate + self.dynamic_lr = optimizer.dynamic_lr + self.gather = optimizer.gather + self.assignadd = optimizer.assignadd + self.global_step = optimizer.global_step + else: + self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") + self.reciprocal_scale = optimizer.reciprocal_scale + optimizer.reciprocal_scale = 1.0 + self.is_group = optimizer.is_group + if self.is_group: + self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay)) + else: + self.weight_decay = optimizer.weight_decay / optimizer.loss_scale + optimizer.exec_weight_decay = False + optimizer.weight_decay = 0.0 + self.decay_flags = optimizer.decay_flags self.lars_flag = tuple(lars_filter(x) for x in self.parameters) self.hyper_map = C.HyperMap() - self.dynamic_lr = False - self.gather = None - self.global_step = None - self.axis = None - if isinstance(self.learning_rate.default_input, Iterable) or \ - (isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1): - self.dynamic_lr = True - self.assignadd = P.AssignAdd() - self.gather = P.GatherV2() - self.global_step = Parameter(initializer(0, [1], mstype.int32), name="lars_global_step") - self.axis = 0 def construct(self, gradients): params = self.parameters if self.dynamic_lr: - lr = self.gather(self.learning_rate, self.global_step, self.axis) + lr = self.gather(self.learning_rate, self.global_step, 0) F.control_depend(lr, self.assignadd(self.global_step, 1)) else: lr = self.learning_rate if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) - - grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, self.weight_decay, lr), - gradients, params, self.decay_flag, self.lars_flag) + if self.is_group: + grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay, + gradients, params, self.decay_flags, self.lars_flag) + else: + grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay), + gradients, params, self.decay_flags, self.lars_flag) success = self.opt(grad_t) return success diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index c991b469eea..64f4adda998 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -182,13 +182,11 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): {'order_params': net.trainable_params()}] if config.use_lars: - momentum = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + momentum = nn.Momentum(group_params, lr, config.momentum, + weight_decay=config.weight_decay, loss_scale=config.loss_scale, use_nesterov=config.use_nesterov) - opt = nn.LARS(momentum, epsilon=config.lars_epsilon, hyperpara=config.lars_coefficient, - weight_decay=config.weight_decay, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, - lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, - loss_scale=config.loss_scale) + opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient, + lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name) else: opt = nn.Momentum(group_params, lr, config.momentum, diff --git a/tests/ut/python/nn/optim/test_lars.py b/tests/ut/python/nn/optim/test_lars.py index d6fd4cd90eb..1373691b72b 100644 --- a/tests/ut/python/nn/optim/test_lars.py +++ b/tests/ut/python/nn/optim/test_lars.py @@ -56,7 +56,7 @@ def test_lars_multi_step_lr(): lr = multisteplr(10, [2, 6]) SGD = Momentum(net.trainable_params(), lr, 0.9) - optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, + optimizer = LARS(SGD, epsilon=1e-08, coefficient=0.02, use_clip=True, lars_filter=lambda x: 'bn' not in x.name) net_with_loss = WithLossCell(net, loss) @@ -73,7 +73,7 @@ def test_lars_float_lr(): lr = 0.1 SGD = Momentum(net.trainable_params(), lr, 0.9) - optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, + optimizer = LARS(SGD, epsilon=1e-08, coefficient=0.02, lars_filter=lambda x: 'bn' not in x.name) net_with_loss = WithLossCell(net, loss) diff --git a/tests/ut/python/parallel/test_loss_and_optimizer.py b/tests/ut/python/parallel/test_loss_and_optimizer.py index b4cf62c29eb..91be7682abd 100644 --- a/tests/ut/python/parallel/test_loss_and_optimizer.py +++ b/tests/ut/python/parallel/test_loss_and_optimizer.py @@ -205,7 +205,7 @@ def test_lars(): lr = Tensor(np.ones([6]), dtype=ms.float32) sgd = Momentum(net.trainable_params(), lr, 0.9) - optimizer = LARS(sgd, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, + optimizer = LARS(sgd, epsilon=1e-08, coefficient=0.02, lars_filter=lambda x: 'bn' not in x.name) net_with_loss = NetWithLoss(net, strategy3) train_net = TrainOneStepCell(net_with_loss, optimizer)