From 082433183d7e247501f9d908379625e5245fee94 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Tue, 30 Jun 2020 17:25:16 +0800 Subject: [PATCH] uniform learning_rate behavior of optimizers --- mindspore/nn/dynamic_lr.py | 59 ++- mindspore/nn/learning_rate_schedule.py | 368 ++++++++++++++++++ mindspore/nn/optim/__init__.py | 4 +- mindspore/nn/optim/adam.py | 267 +++++-------- mindspore/nn/optim/ftrl.py | 82 ++-- mindspore/nn/optim/lamb.py | 214 +++++----- mindspore/nn/optim/lars.py | 74 ++-- mindspore/nn/optim/lazyadam.py | 30 +- mindspore/nn/optim/momentum.py | 22 +- mindspore/nn/optim/optimizer.py | 284 ++++++++------ mindspore/nn/optim/proximal_ada_grad.py | 65 +++- mindspore/nn/optim/rmsprop.py | 22 +- mindspore/nn/optim/sgd.py | 25 +- model_zoo/mass/src/utils/lr_scheduler.py | 35 +- model_zoo/mass/train.py | 16 +- model_zoo/official/nlp/bert/README.md | 2 +- model_zoo/official/nlp/bert/run_classifier.py | 40 +- model_zoo/official/nlp/bert/run_ner.py | 39 +- model_zoo/official/nlp/bert/run_pretrain.py | 43 +- model_zoo/official/nlp/bert/run_squad.py | 40 +- model_zoo/official/nlp/bert/src/config.py | 6 +- .../nlp/bert/src/finetune_eval_config.py | 7 +- model_zoo/official/nlp/bert/src/utils.py | 23 ++ .../apps/test_lamb_check_loss.py | 2 +- .../pipeline/gradient/check_training.py | 2 +- tests/perf_test/bert/test_bert_train.py | 34 +- tests/st/networks/models/bert/src/config.py | 4 +- .../models/bert/src/finetune_config.py | 4 +- .../models/bert/test_bert_graph_kernel.py | 33 +- .../models/bert/test_bert_tdt_lossscale.py | 48 ++- tests/ut/python/nn/optim/test_adam.py | 74 +++- tests/ut/python/nn/optim/test_lamb.py | 62 ++- tests/ut/python/nn/optim/test_optimizer.py | 10 +- .../python/nn/optim/test_proximal_ada_grad.py | 1 + tests/ut/python/nn/test_dynamic_lr.py | 7 +- .../python/nn/test_learning_rate_schedule.py | 157 ++++++++ .../test_optimizer_with_loss_scale.py | 2 +- .../test_optimizer_with_parameter_groups.py | 47 ++- .../parallel/test_parallel_optimizer.py | 25 +- 39 files changed, 1572 insertions(+), 707 deletions(-) create mode 100644 mindspore/nn/learning_rate_schedule.py create mode 100644 tests/ut/python/nn/test_learning_rate_schedule.py diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 6eeba415a7c..aca241419cb 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -231,8 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): >>> cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] """ - validator.check_float_positive('min_lr', min_lr, None) - validator.check_float_legal_value('min_lr', min_lr, None) + if not isinstance(min_lr, float): + raise TypeError("min_lr must be float.") + validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_float_positive('max_lr', max_lr, None) validator.check_float_legal_value('max_lr', max_lr, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) @@ -288,8 +289,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e """ validator.check_float_positive('learning_rate', learning_rate, None) validator.check_float_legal_value('learning_rate', learning_rate, None) - validator.check_float_positive('end_learning_rate', end_learning_rate, None) - validator.check_float_legal_value('end_learning_rate', end_learning_rate, None) + if not isinstance(end_learning_rate, float): + raise TypeError("end_learning_rate must be float.") + validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_float_positive('power', power, None) validator.check_float_legal_value('power', power, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) @@ -311,11 +313,58 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e return lr +def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch): + r""" + Get learning rate warming up. + + For the i-th step, the formula of computing warmup_learning_rate[i] is: + + .. math:: + warmup\_learning\_rate[i] = learning\_rate * tmp\_epoch / tmp\_warmup\_epoch + + Where :math:`tmp\_epoch=min(current\_epoch, warmup\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})` + + Args: + learning_rate (float): The initial value of learning rate. + warmup_steps (int): The warm up steps of learning rate. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> total_step = 6 + >>> step_per_epoch = 2 + >>> warmup_epoch = 2 + >>> warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch) + [0.0, 0.0, 0.05, 0.05, 0.1, 0.1] + """ + if not isinstance(learning_rate, float): + raise TypeError("learning_rate must be float.") + validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) + validator.check_integer('warmup_epoch', warmup_epoch, 0, Rel.GT, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + + function = lambda x, y: (x, min(x, y)) + + lr = [] + for i in range(total_step): + current_epoch = math.floor(i / step_per_epoch) + warmup_epoch, tmp_epoch = function(warmup_epoch, current_epoch) + lr.append(learning_rate * tmp_epoch/ warmup_epoch) + return lr + + __all__ = [ 'piecewise_constant_lr', 'exponential_decay_lr', 'natural_exp_decay_lr', 'inverse_decay_lr', 'cosine_decay_lr', - 'polynomial_decay_lr' + 'polynomial_decay_lr', + 'warmup_lr' ] diff --git a/mindspore/nn/learning_rate_schedule.py b/mindspore/nn/learning_rate_schedule.py new file mode 100644 index 00000000000..b8cde1673b0 --- /dev/null +++ b/mindspore/nn/learning_rate_schedule.py @@ -0,0 +1,368 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning rate schedule.""" + +import math + +from ..common import dtype as mstype +from ..ops import operations as P +from .cell import Cell +from .._checkparam import Validator as validator +from .._checkparam import Rel + + +class LearningRateSchedule(Cell): + def __init__(self): + super(LearningRateSchedule, self).__init__() + + def construct(self, global_step): + raise NotImplementedError + + +def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): + validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, cls_name) + validator.check_float_positive('learning_rate', learning_rate, cls_name) + validator.check_float_legal_value('learning_rate', learning_rate, cls_name) + validator.check_float_positive('decay_rate', decay_rate, cls_name) + validator.check_float_legal_value('decay_rate', decay_rate, cls_name) + validator.check_value_type('is_stair', is_stair, [bool], cls_name) + + +class ExponentialDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on exponential decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{p}} + + Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula + is :math:`p = floor(\frac{current\_step}{decay\_steps})`. + + Args: + learning_rate (float): The initial value of learning rate. + decay_rate (float): The decay rate. + decay_steps (int): A value used to calculate decayed learning rate. + is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> decay_rate = 0.9 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> exponential_decay_lr = ExponentialDecayLR(learning_rate, decay_rate, decay_steps) + >>> exponential_decay_lr(global_step) + """ + def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False): + super(ExponentialDecayLR, self).__init__() + _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name) + self.learning_rate = learning_rate + self.decay_rate = decay_rate + self.decay_steps = decay_steps + self.is_stair = is_stair + self.pow = P.Pow() + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(global_step, mstype.float32) / self.decay_steps + if self.is_stair: + p = P.Floor()(p) + return self.learning_rate * self.pow(self.decay_rate, p) + + +class NaturalExpDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on natural exponential decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * p} + + Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula + is :math:`p = floor(\frac{current\_step}{decay\_steps})`. + + Args: + learning_rate (float): The initial value of learning rate. + decay_rate (float): The decay rate. + decay_steps (int): A value used to calculate decayed learning rate. + is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> decay_rate = 0.9 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> natural_exp_decay_lr = NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True) + >>> natural_exp_decay_lr(global_step) + """ + def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False): + super(NaturalExpDecayLR, self).__init__() + _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name) + self.learning_rate = learning_rate + self.decay_rate = decay_rate + self.decay_steps = decay_steps + self.is_stair = is_stair + self.math_e = math.e + self.pow = P.Pow() + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(global_step, mstype.float32) + if self.is_stair: + p = P.FloorDiv()(p, self.decay_steps) * self.decay_steps + return self.learning_rate * self.pow(self.math_e, -self.decay_rate * p) + + +class InverseDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on inverse-time decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * p} + + Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula + is :math:`p = floor(\frac{current\_step}{decay\_steps})`. + + Args: + learning_rate (float): The initial value of learning rate. + decay_rate (float): The decay rate. + decay_epoch (int): A value used to calculate decayed learning rate. + is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> decay_rate = 0.9 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> inverse_decay_lr = InverseDecayLR(learning_rate, decay_rate, decay_steps, True) + >>> inverse_decay_lr(global_step) + """ + def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False): + super(InverseDecayLR, self).__init__() + _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name) + self.learning_rate = learning_rate + self.decay_rate = decay_rate + self.decay_steps = decay_steps + self.is_stair = is_stair + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(global_step, mstype.float32) / self.decay_steps + if self.is_stair: + p = P.Floor()(p) + return self.learning_rate / (1 + self.decay_rate * p) + + +class CosineDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on cosine decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) * + (1 + cos(\frac{current\_epoch}{decay\_epoch}\pi)) + + Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`. + + Args: + min_lr (float): The minimum value of learning rate. + max_lr (float): The maximum value of learning rate. + decay_steps (int): A value used to calculate decayed learning rate. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> min_lr = 0.01 + >>> max_lr = 0.1 + >>> decay_steps = 4 + >>> global_step = Tenosr(2, mstype.int32) + >>> cosine_decay_lr = CosineDecayLR(min_lr, max_lr, decay_steps) + >>> cosine_decay_lr(global_steps) + """ + def __init__(self, min_lr, max_lr, decay_steps): + super(CosineDecayLR, self).__init__() + if not isinstance(min_lr, float): + raise TypeError("min_lr must be float.") + validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_float_positive('max_lr', max_lr, self.cls_name) + validator.check_float_legal_value('max_lr', max_lr, self.cls_name) + validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) + if min_lr >= max_lr: + raise ValueError('`max_lr` should be greater than `min_lr`.') + self.min_lr = min_lr + self.max_lr = max_lr + self.decay_steps = decay_steps + self.math_pi = math.pi + self.delta = 0.5 * (max_lr - min_lr) + self.cos = P.Cos() + self.min = P.Minimum() + self.cast = P.Cast() + + def construct(self, global_step): + p = self.cast(self.min(global_step, self.decay_steps), mstype.float32) + return self.min_lr + self.delta * (1.0 + self.cos(self.math_pi * p / self.decay_steps)) + + +class PolynomialDecayLR(LearningRateSchedule): + r""" + Calculate learning rate base on polynomial decay function. + + For the i-th step, the formula of computing decayed_learning_rate[i] is: + + .. math:: + decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) * + (1 - tmp\_step / tmp\_decay\_step)^{power} + end\_learning\_rate + + Where :math:`tmp\_step=min(global\_step, decay\_step). + If `update_decay_steps` is true, update the value of `tmp_decay_step` every `decay_steps`. The formula + is :math:`tmp\_decay\_step = decay\_step * ceil(global\_step / decay\_steps)` + + Args: + learning_rate (float): The initial value of learning rate. + end_learning_rate (float): The end value of learning rate. + decay_steps (int): A value used to calculate decayed learning rate. + power (float): A value used to calculate decayed learning rate. This parameter should be greater than 0. + update_decay_steps (bool): If true, learning rate decay once every `decay_steps` times. Default: False. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> end_learning_rate = 0.01 + >>> decay_steps = 4 + >>> power = 0.5 + >>> global_step = Tenosr(2, mstype.int32) + >>> polynomial_decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + >>> polynomial_decay_lr(global_step) + """ + def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False): + super(PolynomialDecayLR, self).__init__() + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) + if not isinstance(end_learning_rate, float): + raise TypeError("end_learning_rate must be float.") + validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, + self.cls_name) + validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) + validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) + validator.check_float_positive('power', power, self.cls_name) + validator.check_float_legal_value('power', power, self.cls_name) + + self.decay_steps = decay_steps + self.start_learning_rate = learning_rate + self.end_learning_rate = end_learning_rate + self.diff_learning_rate = learning_rate - end_learning_rate + self.power = power + self.update_decay_steps = update_decay_steps + self.pow = P.Pow() + self.ceil = P.Ceil() + self.min = P.Minimum() + self.max = P.Maximum() + + def construct(self, global_step): + tmp_global_step = P.Cast()(global_step, mstype.float32) + tmp_decay_step = self.decay_steps + if self.update_decay_steps: + tmp_decay_step = tmp_decay_step * self.max(self.ceil(tmp_global_step / tmp_decay_step), 1) + else: + tmp_global_step = self.min(tmp_global_step, tmp_decay_step) + p = tmp_global_step / tmp_decay_step + lr = self.diff_learning_rate * self.pow(1.0 - p, self.power) + self.end_learning_rate + return lr + + +class WarmUpLR(LearningRateSchedule): + r""" + Get learning rate warming up. + + For the i-th step, the formula of computing warmup_learning_rate[i] is: + + .. math:: + warmup\_learning\_rate[i] = learning\_rate * tmp\_step / warmup\_steps + + Where :math:`tmp\_step=min(global\_step, warmup\_steps). + + Args: + learning_rate (float): The initial value of learning rate. + warmup_steps (int): The warm up steps of learning rate. + + Inputs: + Tensor. The current step number. + + Returns: + Tensor. The learning rate value for the current step. + + Examples: + >>> learning_rate = 0.1 + >>> warmup_steps = 2 + >>> global_step = Tenosr(2, mstype.int32) + >>> warmup_lr = WarmUpLR(learning_rate, warmup_steps) + >>> warmup_lr(global_step) + """ + def __init__(self, learning_rate, warmup_steps): + super(WarmUpLR, self).__init__() + if not isinstance(learning_rate, float): + raise TypeError("learning_rate must be float.") + validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GT, self.cls_name) + self.warmup_steps = warmup_steps + self.learning_rate = learning_rate + self.min = P.Minimum() + self.cast = P.Cast() + + def construct(self, global_step): + warmup_percent = self.cast(self.min(global_step, self.warmup_steps), mstype.float32)/ self.warmup_steps + return self.learning_rate * warmup_percent + + +__all__ = [ + 'ExponentialDecayLR', + 'NaturalExpDecayLR', + 'InverseDecayLR', + 'CosineDecayLR', + 'PolynomialDecayLR', + 'WarmUpLR' +] diff --git a/mindspore/nn/optim/__init__.py b/mindspore/nn/optim/__init__.py index 538c4000678..1bf49d9ec70 100644 --- a/mindspore/nn/optim/__init__.py +++ b/mindspore/nn/optim/__init__.py @@ -20,7 +20,7 @@ The optimizer is used to calculate and update the gradients. """ from .optimizer import Optimizer from .momentum import Momentum -from .adam import Adam, PSAdam, AdamWeightDecay, AdamWeightDecayDynamicLR +from .adam import Adam, PSAdam, AdamWeightDecay from .lamb import Lamb from .sgd import SGD from .lars import LARS @@ -30,4 +30,4 @@ from .proximal_ada_grad import ProximalAdagrad from .lazyadam import LazyAdam __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam', - 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad'] + 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad'] diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index cb8833269cd..e302cc599ac 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -30,9 +30,9 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt") _adam_push_pull_opt = C.MultitypeFuncGraph("_adam_push_pull_opt") -@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") -def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter): +def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. @@ -41,7 +41,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay_tensor (Tensor): Weight decay. Should be in range [0.0, 1.0]. + weight_decay (Number): Weight decay. Should be in range [0.0, 1.0]. param (Tensor): Parameters. m (Tensor): m value of parameters. v (Tensor): v value of parameters. @@ -73,7 +73,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad update = next_m / (eps + op_sqrt(next_v)) if decay_flag: - update = op_mul(weight_decay_tensor, param_fp32) + update + update = op_mul(weight_decay, param_fp32) + update update_with_lr = op_mul(lr, update) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) @@ -85,29 +85,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad return gradient -def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): - """Check the type of inputs.""" - validator.check_value_type("beta1", beta1, [float], prim_name) - validator.check_value_type("beta2", beta2, [float], prim_name) - validator.check_value_type("eps", eps, [float], prim_name) - validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - - -def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name): - """Check the type of inputs.""" - validator.check_value_type("learning_rate", learning_rate, [float], prim_name) - validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) - validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) - validator.check_float_positive('power', power, prim_name) - validator.check_float_legal_value('power', power, prim_name) - validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) - - @_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor", "Bool") def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, @@ -179,6 +156,16 @@ def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, bet return success +def _check_param_value(beta1, beta2, eps, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + + class Adam(Optimizer): r""" Updates gradients by Adaptive Moment Estimation (Adam) algorithm. @@ -202,12 +189,9 @@ class Adam(Optimizer): :math:`\epsilon` represents `eps`. Note: - The Adam optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. @@ -232,14 +216,14 @@ class Adam(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a - Tensor but the dims of the Tensor is 0, use fixed learning - rate. Other cases are not supported. It should be equal to - or greater than 0. Default: 1e-3. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: 0.9. beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: @@ -272,9 +256,9 @@ class Adam(Optimizer): >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> {'params': no_conv_params, 'lr': 0.01}, >>> {'order_params': net.trainable_params()}] - >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. - >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> optm = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() @@ -284,7 +268,7 @@ class Adam(Optimizer): def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, use_nesterov=False, weight_decay=0.0, loss_scale=1.0): super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale) - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + _check_param_value(beta1, beta2, eps, self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) @@ -329,7 +313,7 @@ class PSAdam(Optimizer): def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, use_nesterov=False, weight_decay=0.0, loss_scale=1.0): super(PSAdam, self).__init__(learning_rate, params, weight_decay, loss_scale) - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + _check_param_value(beta1, beta2, eps, self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) @@ -375,17 +359,38 @@ class AdamWeightDecay(Optimizer): """ Implements Adam algorithm weight decay fix. + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is posigive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. It should be equal to or - greater than 0. Default: 1e-3. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. Should be in range (0.0, 1.0). beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. @@ -404,136 +409,48 @@ class AdamWeightDecay(Optimizer): Examples: >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> #1) All parameters use the same learning rate and weight decay >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] + >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(AdamWeightDecay, self).__init__(learning_rate, params) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) - - self.params = self.parameters - self.moments1 = self.params.clone(prefix="adam_m", init='zeros') - self.moments2 = self.params.clone(prefix="adam_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) - + self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') self.hyper_map = C.HyperMap() def construct(self, gradients): lr = self.get_lr() - optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, - self.decay_flag, self.optim_filter) - if self.use_parallel: - optim_result = self.broadcast_params(optim_result) - return optim_result - - -class AdamWeightDecayDynamicLR(Optimizer): - """ - Adam Weight Decay Dynamic Learning Rate (LR). - - Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. - decay_steps (int): The steps of the decay. It must be int and positive. - warmup_steps (int): The steps of lr warm up. Default: 0. - learning_rate (float): A floating point value for the learning rate. It should be equal to or - greater than 0. Default: 0.001. - end_learning_rate (float): A floating point value for the end learning rate. It should be equal - to or greater than 0. Default: 0.0001. - power (float): The Power of the polynomial. It must be positive. Default: 10.0. - beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. - Should be in range (0.0, 1.0). - beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. - Should be in range (0.0, 1.0). - eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. - Should be greater than 0. - weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. - - Inputs: - - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - - Outputs: - tuple[bool], all elements are True. - - Examples: - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> optim = nn.AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) - """ - def __init__(self, - params, - decay_steps, - warmup_steps=0, - learning_rate=0.001, - end_learning_rate=0.0001, - power=10.0, - beta1=0.9, - beta2=0.999, - eps=1e-6, - weight_decay=0.0, - decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): - super(AdamWeightDecayDynamicLR, self).__init__(0.0, params) if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) - _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name) - validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, self.cls_name) - # turn them to scalar when me support scalar/tensor mix operations - self.global_step = Parameter(initializer(0, [1]), name="global_step") - self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) - self.warmup_flag = False - if warmup_steps > 0: - self.warmup_flag = True - self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32)) - self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32)) - self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32)) - self.power = power - self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) - self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) - self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) - self.params = self.parameters - self.moments1 = self.params.clone(prefix="adam_m", init='zeros') - self.moments2 = self.params.clone(prefix="adam_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) - self.hyper_map = C.HyperMap() - self.min = P.Minimum() - self.pow = P.Pow() - self.greater = P.Greater() - self.one = Tensor(np.array([1.0]).astype(np.float32)) - self.cast = P.Cast() - self.start_learning_rate = Tensor(np.array([learning_rate]).astype(np.float32)) - - def construct(self, gradients): - step = self.min(self.global_step, self.decay_steps) - p = step / self.decay_steps - lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate - if self.warmup_flag: - warmup_percent = self.global_step / self.warmup_steps - warmup_lr = self.start_learning_rate * warmup_percent - is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) - lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr - optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, - self.decay_flag, self.optim_filter) + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), + self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) if self.use_parallel: optim_result = self.broadcast_params(optim_result) - added_global_step = self.global_step + self.one - F.control_depend(lr, added_global_step) - self.global_step = added_global_step - return optim_result diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index e2e20507634..15051c22c04 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -24,9 +24,9 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") -@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices", "Tensor", +@_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "IndexedSlices", "Tensor", "Tensor", "Bool") -def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, +def _tensor_run_opt_with_sparse(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment, ps_parameter): """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" success = True @@ -43,9 +43,9 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, return success -@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", +@_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, ps_parameter): +def _tensor_run_opt(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment, ps_parameter): """Apply ftrl optimizer to the weight parameter.""" success = True if ps_parameter: @@ -83,7 +83,7 @@ def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2, return success -def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None): +def _check_param(initial_accum, lr_power, l1, l2, use_locking, prim_name=None): """Check param.""" validator.check_value_type("initial_accum", initial_accum, [float], prim_name) validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) @@ -99,9 +99,6 @@ def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, validator.check_value_type("use_locking", use_locking, [bool], prim_name) - validator.check_value_type("weight_decay", weight_decay, [float], prim_name) - validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name) - class FTRL(Optimizer): """ @@ -113,15 +110,34 @@ class FTRL(Optimizer): `_ for engineering document. Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on all of the parameters. + + To improve parameter groups performance, the customized order of parameters can be supported. + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. - The sparse feature is under continuous development. The sparse - behavior is currently performed on the CPU. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Using different learning rate by separating parameters is currently not supported. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. + initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. - learning_rate (float): The learning rate value, should be positive. Default: 0.001. + learning_rate (float): The learning rate value, should be zero or positive, dynamic learning rate is currently + not supported. Default: 0.001. lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5. l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0. @@ -139,23 +155,36 @@ class FTRL(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.FTRL(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params}, + >>> {'order_params': net.trainable_params()}] + >>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use weight decay of 0.01. + >>> # The no_conv_params's parameters will use default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> opt = nn.FTRL(net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None) + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, use_locking=False, loss_scale=1.0, weight_decay=0.0): - super(FTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) + super(FTRL, self).__init__(learning_rate, params, weight_decay, loss_scale=loss_scale) + if self.dynamic_lr or self.is_group_lr: + raise ValueError('Dynamic learning rate or group learning rate is currently not supported.') + _check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name) self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.linear = self.parameters.clone(prefix="linear", init='zeros') self.l1 = l1 self.l2 = l2 self.lr_power = lr_power - self.weight_decay = weight_decay - self.decay_tf = tuple((lambda: True)() for x in self.parameters) + if not self.is_group: + self.decay_flags = tuple((lambda: True)() for x in self.parameters) self.hyper_map = C.HyperMap() self.opt = P.ApplyFtrl(use_locking=use_locking) self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) @@ -164,12 +193,11 @@ class FTRL(Optimizer): params = self.parameters moments = self.moments linear = self.linear - lr = self.learning_rate - if self.weight_decay > 0.0: - grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) - + grads = self.decay_weight(grads) grads = self.scale_grad(grads) - success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), + lr = self.get_lr() + + success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self.l1, self.l2, self.lr_power, lr), linear, grads, params, moments, self.ps_parameters) return success @@ -180,7 +208,7 @@ class PSFTRL(Optimizer): super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) if self.is_group: raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") - _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) + _check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name) self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.linear = self.parameters.clone(prefix="linear", init='zeros') self.l1 = l1 diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 143917336d3..d80facfbb18 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -32,10 +32,9 @@ num_one = Tensor(np.ones([1]), mstype.float32) _lamb_opt = C.MultitypeFuncGraph("lamb_opt") -@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", +@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") -def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, - gradient, decay_flag, optim_filter): +def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. @@ -44,7 +43,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay_tensor (Tensor): Weight decay. Should be in range [0.0, 1.0]. + weight_decay (Number): Weight decay. Should be in range [0.0, 1.0]. global_step (Tensor): Global step. param (Tensor): Parameters. m (Tensor): m value of parameters. @@ -87,7 +86,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para w_norm = op_norm(param_fp32) g_norm = op_norm(gradient_fp32) - g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) + g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32) zeros = F.zeros_like(w_norm) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) trust_ratio = op_select( @@ -99,7 +98,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para update = next_mm / (op_sqrt(next_vv) + eps) if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) + update = update + op_mul(weight_decay, param_fp32) update_with_lr = op_mul(op_mul(trust_ratio, lr), update) @@ -116,10 +115,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") -@lamb_opt_graph_kernel.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", +@lamb_opt_graph_kernel.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, - global_step, param, m, v, gradient, decay_flag): +def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag): """ Update parameters. @@ -128,7 +126,7 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay_tensor (Tensor): Weight decay. Should be in range [0.0, 1.0]. + weight_decay (Number): Weight decay. Should be in range [0.0, 1.0]. global_step (Tensor): Global step. param (Tensor): Parameters. m (Tensor): m value of parameters. @@ -157,11 +155,10 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex) i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex) i1 = op_square(gradient_fp32) - add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, - i9, beta2, x1, weight_decay_tensor, eps) + add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, i9, beta2, x1, weight_decay, eps) if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) + update = update + op_mul(weight_decay, param_fp32) w_norm = op_norm(param_fp32) g_norm = op_norm(gradient_fp32) @@ -171,38 +168,18 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0) - next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, - param, zeros, ones, tens) + next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, param, zeros, ones, tens) next_v = F.control_depend(add3, next_param) return next_v -def _check_param_value(decay_steps, warmup_steps, start_learning_rate, - end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): - """Check the type of inputs.""" - validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) - validator.check_number_range("start_learning_rate rate", start_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, - prim_name) - validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) - validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, - prim_name) - validator.check_float_positive('power', power, prim_name) - validator.check_float_legal_value('power', power, prim_name) - validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) - validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name) +def _check_param_value(beta1, beta2, eps, prim_name): validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name) - validator.check_value_type( - "weight_dacay", weight_decay, [float], prim_name) - validator.check_number_range( - "beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range( - "beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) - validator.check_number_range( - "eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) - validator.check_number_range( - "weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) class Lamb(Optimizer): @@ -213,16 +190,37 @@ class Lamb(Optimizer): optimization technique. Refer to the paper `LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76 MINUTES `_. + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. - decay_steps (int): The steps of the lr decay. Should be equal to or greater than 1. - warmup_steps (int): The steps of lr warm up. Should be equal to or greater than 0. Default: 0. - start_learning_rate (float): A floating point value for the learning rate. Should be equal to - or greater than 0. Default: 0.1. - end_learning_rate (float): A floating point value for the end learning rate. Should be equal to - or greater than 0. Default: 0.0001. - power (float): The power of the polynomial. It must be positive. Default: 1.0. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. Should be in range (0.0, 1.0). beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. @@ -241,90 +239,84 @@ class Lamb(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.Lamb(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR() + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params, 'lr': poly_decay_lr}, + >>> {'order_params': net.trainable_params(0.01, 0.0001, 10, 0.5)}] + >>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default + >>> # weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> optim = nn.Lamb(params=net.trainable_params(), decay_steps=10) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ - def __init__(self, - params, - decay_steps, - warmup_steps=0, - start_learning_rate=0.1, - end_learning_rate=0.0001, - power=1.0, - beta1=0.9, - beta2=0.999, - eps=1e-6, - weight_decay=0.0, - decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): - super(Lamb, self).__init__(0.0, params) - if self.is_group: - raise RuntimeError( - f"The {self.cls_name} optimizer cannot support group setting.") - _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, - power, beta1, beta2, eps, weight_decay, self.cls_name) + def __init__(self, params, learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(Lamb, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations - self.global_step = Parameter(initializer(0, [1]), name="global_step") - - self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) - self.warmup_flag = False - if warmup_steps > 0: - self.warmup_flag = True - self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32)) - self.start_learning_rate = Tensor( - np.array([start_learning_rate]).astype(np.float32)) - self.end_learning_rate = Tensor( - np.array([end_learning_rate]).astype(np.float32)) - self.diff_learning_rate = Tensor( - np.array([start_learning_rate - end_learning_rate]).astype(np.float32)) - self.power = power self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.weight_decay_tensor = Tensor( - np.array([weight_decay]).astype(np.float32)) self.params = self.parameters self.moments1 = self.params.clone(prefix="lamb_m", init='zeros') self.moments2 = self.params.clone(prefix="lamb_v", init='zeros') - self.decay_flag = tuple(decay_filter(x) for x in self.params) + if not self.dynamic_lr: + self.global_step = Parameter(initializer(0, [1]), name='global_step') + self.assignadd = P.AssignAdd() self.hyper_map = C.HyperMap() - self.min = P.Minimum() - self.pow = P.Pow() - self.greater = P.Greater() - self.one = Tensor(np.array([1.0]).astype(np.float32)) - self.cast = P.Cast() self.enable_graph_kernel = context.get_context("enable_graph_kernel") def construct(self, gradients): - step = self.min(self.global_step, self.decay_steps) - p = step / self.decay_steps - lr = self.diff_learning_rate * \ - self.pow(self.one - p, self.power) + self.end_learning_rate - if self.warmup_flag: - warmup_percent = self.global_step / self.warmup_steps - warmup_lr = self.start_learning_rate * warmup_percent - is_warmup = self.cast(self.greater( - self.warmup_steps, self.global_step), mstype.float32) - lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr + lr = self.get_lr() if self.enable_graph_kernel: - optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, + self.global_step), + lr, self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags) + else: + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, + self.global_step, lr), + self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags) + else: + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, + self.global_step, lr, self.weight_decay), + self.params, self.moments1, self.moments2, gradients, self.decay_flags) else: - optim_result = self.hyper_map(F.partial(_lamb_opt, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, - self.decay_flag, self.optim_filter) + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step), + lr, self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step, lr), + self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step, lr, self.weight_decay), + self.params, self.moments1, self.moments2, gradients, + self.decay_flags, self.optim_filter) + if self.use_parallel: optim_result = self.broadcast_params(optim_result) - added_global_step = self.global_step + self.one - F.control_depend(lr, added_global_step) - self.global_step = added_global_step + if not self.dynamic_lr: + F.control_depend(lr, self.assignadd(self.global_step, 1)) return optim_result diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 7b05b372eb2..91ca9a4b22a 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -38,14 +38,14 @@ def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_f return gradient + def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name): validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) - if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name: - raise TypeError("LARS can not be used with ", optimizer.cls_name) validator.check_value_type("epsilon", epsilon, [float], prim_name) validator.check_value_type("coefficient", coefficient, [float], prim_name) validator.check_value_type("use_clip", use_clip, [bool], prim_name) + class LARS(Optimizer): """ Implements the LARS algorithm with LARSUpdate Operator. @@ -81,45 +81,71 @@ class LARS(Optimizer): super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) self.opt = optimizer + self.parameters = optimizer.parameters + self.use_clip = use_clip + self.lars_flag = tuple(lars_filter(x) for x in self.parameters) + self.is_group = optimizer.is_group + self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") + self.decay_flags = optimizer.decay_flags + self.reciprocal_scale = optimizer.reciprocal_scale + self.hyper_map = C.HyperMap() self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.cast = P.Cast() - self.parameters = optimizer.parameters - if use_clip is True: - self.learning_rate = optimizer.learning_rate + + if use_clip: + self.is_group_lr = optimizer.is_group_lr self.dynamic_lr = optimizer.dynamic_lr - self.gather = optimizer.gather - self.assignadd = optimizer.assignadd + self.origin_learning_rate = optimizer.learning_rate self.global_step = optimizer.global_step - else: - self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") - self.reciprocal_scale = optimizer.reciprocal_scale - optimizer.reciprocal_scale = 1.0 - self.is_group = optimizer.is_group + if self.is_group_lr and self.dynamic_lr: + raise ValueError('Grouped dynamic learning rate is currently not supported for the inputs optimizer ' \ + 'of lars.') + if self.is_group: self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay)) + optimizer.weight_decay = tuple(map(lambda x: 0.0, optimizer.weight_decay)) else: self.weight_decay = optimizer.weight_decay / optimizer.loss_scale + optimizer.weight_decay = 0.0 + + optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags)) + optimizer.reciprocal_scale = 1.0 optimizer.exec_weight_decay = False - optimizer.weight_decay = 0.0 - self.decay_flags = optimizer.decay_flags - self.lars_flag = tuple(lars_filter(x) for x in self.parameters) - self.hyper_map = C.HyperMap() + + def _get_lr(self): + """Get the learning rate of current step.""" + lr = self.origin_learning_rate + if self.dynamic_lr: + if self.is_group_lr: + lr = () + for learning_rate in self.origin_learning_rate: + current_dynamic_lr = learning_rate(self.global_step) + lr += (current_dynamic_lr,) + else: + lr = self.origin_learning_rate(self.global_step) + + return lr def construct(self, gradients): params = self.parameters - if self.dynamic_lr: - lr = self.gather(self.learning_rate, self.global_step, 0) - F.control_depend(lr, self.assignadd(self.global_step, 1)) + if self.use_clip: + lr = self._get_lr() else: lr = self.learning_rate + if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) + if self.is_group: - grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay, - gradients, params, self.decay_flags, self.lars_flag) + if self.is_group_lr: + gradients = self.hyper_map(F.partial(_lars_opt, self.lars), lr, self.weight_decay, + gradients, params, self.decay_flags, self.lars_flag) + else: + gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay, + gradients, params, self.decay_flags, self.lars_flag) else: - grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay), - gradients, params, self.decay_flags, self.lars_flag) - success = self.opt(grad_t) + gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay), + gradients, params, self.decay_flags, self.lars_flag) + success = self.opt(gradients) return success diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index d784c88e448..b8341babba7 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -84,12 +84,11 @@ class LazyAdam(Optimizer): :math:`\epsilon` represents `eps`. Note: - The LazyAdam optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. The sparse behavior, to be notice, is not equivalent to the @@ -113,13 +112,14 @@ class LazyAdam(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. Default: 1e-3. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: 0.9. beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: @@ -153,9 +153,9 @@ class LazyAdam(Optimizer): >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> {'params': no_conv_params, 'lr': 0.01}, >>> {'order_params': net.trainable_params()}] - >>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. - >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> opt = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 3beafa0775f..61c06591944 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -47,12 +47,9 @@ class Momentum(Optimizer): Refer to the paper on the importance of initialization and momentum in deep learning for more details. Note: - The Momentum optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. @@ -73,14 +70,13 @@ class Momentum(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a - Tensor but the dims of the Tensor is 0, use fixed learning - rate. Other cases are not supported. It should be equal to - or greater than 0.0. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. momentum (float): Hyperparameter of type float, means momentum for the moving average. It should be at least 0.0. weight_decay (int, float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0. diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 54d4f44f871..9379e395aef 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -20,6 +20,7 @@ import numpy as np import mindspore from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell +from mindspore.nn.layer.container import CellList from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor, IndexedSlices @@ -30,6 +31,7 @@ from mindspore import log as logger from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode from mindspore.train.parallel_utils import ParallelMode from mindspore import context +from mindspore.nn.learning_rate_schedule import LearningRateSchedule __all__ = ['Optimizer'] @@ -44,25 +46,22 @@ class Optimizer(Cell): This class defines the API to add Ops to train a model. Never use this class directly, but instead instantiate one of its subclasses. - Some optimizers support separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. + Different parameter groups can set different `learning_rate` and `weight_decay`. When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will + be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. Args: - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. It should be equal to or greater - than 0. If the type of `learning_rate` input is int, it will be - converted to float. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning + rate. When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed. @@ -104,32 +103,17 @@ class Optimizer(Cell): loss_scale = float(loss_scale) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) + self.loss_scale = loss_scale - if isinstance(weight_decay, int): - weight_decay = float(weight_decay) - validator.check_value_type("weight_decay", weight_decay, [float], self.cls_name) - validator.check_number_range("weight_decay", weight_decay, 0.0, 1.0, Rel.INC_BOTH, self.cls_name) + weight_decay = self._preprocess_weight_decay(weight_decay) + self.dynamic_lr = False + self.assignadd = None + self.global_step = None self.is_group = False self.is_group_lr = False self.is_group_params_ordered = False - self.loss_scale = loss_scale - if isinstance(learning_rate, int): - learning_rate = float(learning_rate) - if isinstance(learning_rate, float): - self.dynamic_lr = False - self.gather = None - self.assignadd = None - self.global_step = None - self.scalar_lr = learning_rate - else: - self.dynamic_lr = True - self.gather = P.GatherV2() - self.assignadd = P.AssignAdd() - self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') - self.scalar_lr = None - - learning_rate = self._get_single_lr(learning_rate) + learning_rate = self._preprocess_single_lr(learning_rate) if isinstance(parameters[0], dict): self.is_group = True self.group_params = [] @@ -137,32 +121,40 @@ class Optimizer(Cell): self.group_weight_decay = [] self._init_group_params(parameters, learning_rate, weight_decay) - if self.is_group_lr: - self.learning_rate = ParameterTuple(self.group_lr) - else: - self.learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name="learning_rate") + # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params + if self.dynamic_lr: + self.assignadd = P.AssignAdd() + self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') + if self.is_group_lr: + if self.dynamic_lr: + self.learning_rate = CellList(self.group_lr) + else: + self.learning_rate = tuple(self.group_lr) + else: + self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') if self.is_group: self.parameters = ParameterTuple(self.group_params) self.weight_decay = tuple(self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) + self.exec_weight_decay = any(self.decay_flags) else: self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name self.decay_flags = tuple(decay_filter(x) for x in self.parameters) + self.exec_weight_decay = self.weight_decay > 0 ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) self.reciprocal_scale = 1.0 / loss_scale - self.exec_weight_decay = any(self.decay_flags) self.param_length = len(self.parameters) self.map_ = C.Map() use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer") self.use_parallel = use_parallel if use_parallel: - if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: + if self.cls_name not in ["Lamb", "AdamWeightDecay"]: raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) if _get_parallel_mode() != ParallelMode.DATA_PARALLEL: raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format @@ -193,13 +185,12 @@ class Optimizer(Cell): Returns: tuple[Tensor], The gradients after weight decay. """ - params = self.parameters - if self.is_group: - if self.exec_weight_decay: + if self.exec_weight_decay: + params = self.parameters + if self.is_group: gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags, params, gradients) - else: - if self.weight_decay > 0: + else: gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags, params, gradients) @@ -225,24 +216,53 @@ class Optimizer(Cell): return gradients - def _get_single_lr(self, learning_rate): - """Get learning rate in Tensor type.""" - if isinstance(learning_rate, float): + def _preprocess_weight_decay(self, weight_decay): + """Check weight decay, and convert int to float.""" + if isinstance(weight_decay, (float, int)): + weight_decay = float(weight_decay) + validator.check_number_range("weight_decay", weight_decay, 0.0, 1.0, Rel.INC_BOTH, self.cls_name) + return weight_decay + raise TypeError("Weight decay should be int or float.") + + def _preprocess_single_lr(self, learning_rate): + """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" + if isinstance(learning_rate, (float, int)): + learning_rate = float(learning_rate) validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) - lr = Tensor(learning_rate, mstype.float32) - elif isinstance(learning_rate, Iterable): - lr = Tensor(np.array(list(learning_rate)).astype(np.float32)) - elif isinstance(learning_rate, Tensor): + return learning_rate + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: + return learning_rate + + self.dynamic_lr = True + if isinstance(learning_rate, Iterable): + return Tensor(np.array(list(learning_rate)).astype(np.float32)) + if isinstance(learning_rate, Tensor): if learning_rate.dim() > 1: - raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," + raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1," f"but got {learning_rate.dim()}.") if learning_rate.dim() == 1 and learning_rate.size() < 2: - logger.warning("If want to use the dynamic learning rate, please make sure that the number " - "of elements in the list, tuple or tensor passed is greater than 1.") - lr = learning_rate - else: - raise TypeError("Learning rate should be float, Tensor or Iterable.") - return lr + logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number" + "of elements in the tensor passed is greater than 1.") + return learning_rate + if isinstance(learning_rate, LearningRateSchedule): + return learning_rate + raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.") + + def _build_single_lr(self, learning_rate, name): + """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule.""" + if isinstance(learning_rate, float): + learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name) + if self.is_group_lr and self.dynamic_lr: + learning_rate = _ConvertToCell(learning_rate) + return learning_rate + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: + learning_rate = Parameter(learning_rate, name) + if self.is_group_lr and self.dynamic_lr: + learning_rate = _ConvertToCell(learning_rate) + return learning_rate + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: + return _IteratorLearningRate(learning_rate, name) + return learning_rate def _check_group_params(self, parameters): """Check group params.""" @@ -270,13 +290,12 @@ class Optimizer(Cell): def _parse_group_params(self, parameters, learning_rate): """Parse group params.""" self._check_group_params(parameters) - if self.dynamic_lr: - dynamic_lr_length = learning_rate.size() + if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: + tensor_lr_length = learning_rate.size() else: - dynamic_lr_length = 0 + tensor_lr_length = 0 for group_param in parameters: - lr_length = dynamic_lr_length if 'order_params' in group_param.keys(): if len(group_param.keys()) > 1: raise ValueError("The order params dict in group parameters should " @@ -288,53 +307,38 @@ class Optimizer(Cell): if 'lr' in group_param.keys(): self.is_group_lr = True - self._get_single_lr(group_param['lr']) - if isinstance(group_param['lr'], Iterable): - lr_length = len(group_param['lr']) - self.dynamic_lr = True - elif isinstance(group_param['lr'], Tensor): - lr_length = group_param['lr'].size() - self.dynamic_lr = True + group_lr = self._preprocess_single_lr(group_param['lr']) - if dynamic_lr_length not in (lr_length, 0): - raise ValueError("The dynamic learning rate in group should be the same size.") - - dynamic_lr_length = lr_length - self.dynamic_lr_length = dynamic_lr_length + if isinstance(group_lr, Tensor) and group_lr.dim() == 1: + group_lr_length = group_lr.size() + if tensor_lr_length == 0: + tensor_lr_length = group_lr_length + elif group_lr_length != tensor_lr_length: + raise ValueError("The Tensor type dynamic learning rate in group should be the same size.") def _init_group_params(self, parameters, learning_rate, weight_decay): """Init learning rate or weight decay in group params.""" - origin_dynamic_lr = self.dynamic_lr self._parse_group_params(parameters, learning_rate) - if self.dynamic_lr and not origin_dynamic_lr: - self.gather = P.GatherV2() - self.assignadd = P.AssignAdd() - self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') + default_lr = self._build_single_lr(learning_rate, 'learning_rate') params_store = [] - for group_param in parameters: + for group_num, group_param in enumerate(parameters): if 'order_params' in group_param.keys(): ordered_parameters = group_param['order_params'] continue self.group_params += group_param['params'] + if 'lr' in group_param.keys(): - params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor)) - if self.dynamic_lr and not params_dynamic_lr: - lr = Tensor(np.array([group_param['lr']] * self.dynamic_lr_length).astype(np.float32)) - else: - lr = self._get_single_lr(group_param['lr']) + lr_param_name = 'learning_rate_group_' + str(group_num) + lr = self._preprocess_single_lr(group_param['lr']) + lr = self._build_single_lr(lr, lr_param_name) else: - if self.dynamic_lr and not origin_dynamic_lr: - lr = Tensor(np.array([self.scalar_lr] * self.dynamic_lr_length).astype(np.float32)) - else: - lr = learning_rate + lr = default_lr if 'weight_decay' in group_param.keys(): - validator.check_float_legal_value('weight_decay', group_param['weight_decay'], None) - validator.check_number_range('weight_decay', group_param['weight_decay'], 0.0, 1.0, - Rel.INC_BOTH, self.cls_name) - weight_decay_ = group_param['weight_decay'] * self.loss_scale + cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay']) + weight_decay_ = cur_weight_decay * self.loss_scale else: weight_decay_ = weight_decay * self.loss_scale @@ -348,7 +352,7 @@ class Optimizer(Cell): raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") params_store.append(param.name) - self.group_lr.append(Parameter(lr, name="lr_" + param.name)) + self.group_lr.append(lr) self.group_weight_decay.append(weight_decay_) if self.is_group_params_ordered: @@ -384,19 +388,17 @@ class Optimizer(Cell): Returns: float, the learning rate of current step. """ - if self.is_group_lr: - lr = self.learning_rate - if self.dynamic_lr: + lr = self.learning_rate + if self.dynamic_lr: + if self.is_group_lr: lr = () - for i in range(self.param_length): - current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) + for learning_rate in self.learning_rate: + current_dynamic_lr = learning_rate(self.global_step) lr += (current_dynamic_lr,) - F.control_depend(lr, self.assignadd(self.global_step, 1)) - else: - lr = self.learning_rate - if self.dynamic_lr: - lr = self.gather(self.learning_rate, self.global_step, 0) - F.control_depend(lr, self.assignadd(self.global_step, 1)) + else: + lr = self.learning_rate(self.global_step) + + F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr def get_lr_parameter(self, param): @@ -409,29 +411,31 @@ class Optimizer(Cell): Returns: Parameter, single `Parameter` or `list[Parameter]` according to the input type. """ - if not isinstance(param, (Parameter, list)): + def get_lr_value(learning_rate): + if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)): + return learning_rate.learning_rate + + return learning_rate + + if isinstance(param, Parameter): + param_list = [param] + elif isinstance(param, list): + param_list = param + else: raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") - if isinstance(param, list): - lr = [] - for p in param: - validator.check_value_type("parameter", p, [Parameter], self.cls_name) - if p not in self.parameters: - raise ValueError(f"The parameter {p.name} is not in optimizer.") - if self.is_group_lr: - index = self.parameters.index(p) - lr.append(self.learning_rate[index]) - else: - lr.append(self.learning_rate) - else: - if param not in self.parameters: - raise ValueError(f"The parameter {param.name} is not in optimizer.") + lr = [] + for p in param_list: + validator.check_value_type("parameter", p, [Parameter], self.cls_name) + if p not in self.parameters: + raise ValueError(f"The parameter {p.name} is not in optimizer.") if self.is_group_lr: - index = self.parameters.index(param) - lr = self.learning_rate[index] + index = self.parameters.index(p) + lr.append(get_lr_value(self.learning_rate[index])) else: - lr = self.learning_rate - return lr + lr.append(get_lr_value(self.learning_rate)) + + return lr if isinstance(param, list) else lr[0] def _get_parameter_group_id(self): """ @@ -524,3 +528,33 @@ def tensor_grad_scale_with_sparse(scale, grad): if scale == 1.0: return grad return IndexedSlices(grad.indices(), grad.values() * scale, grad.dense_shape()) + + +class _ConvertToCell(LearningRateSchedule): + """Inner api, convert learning rate of scalar to LearningRateSchedule.""" + def __init__(self, learning_rate): + super(_ConvertToCell, self).__init__() + if not isinstance(learning_rate, Parameter): + raise TypeError('Learning rate must be Parameter.') + self.learning_rate = learning_rate + + def construct(self, global_step): + return self.learning_rate + 1.0 - 1.0 + + +class _IteratorLearningRate(LearningRateSchedule): + """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule.""" + def __init__(self, learning_rate, name): + super(_IteratorLearningRate, self).__init__() + if isinstance(learning_rate, Tensor): + if learning_rate.dim() != 1: + raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1," + f"but got {learning_rate.dim()}.") + else: + raise TypeError("Learning rate should be Tensor.") + + self.learning_rate = Parameter(learning_rate, name) + self.gather = P.GatherV2() + + def construct(self, global_step): + return self.gather(self.learning_rate, global_step, 0) diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index daa41d1ae8a..2ef320fd9c3 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -32,7 +32,7 @@ def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): +def _tensor_run_opt(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): """Apply proximal_ada_grad optimizer to the weight parameter.""" success = True success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient)) @@ -59,15 +59,42 @@ class ProximalAdagrad(Optimizer): `_. Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' should be in one of group parameters. + accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. - learning_rate (float): The learning rate value, must be greater than or equal to zero. Default: 0.001. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 0.001. l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0. l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0. use_locking (bool): If True use locks for update operation. Default: False. @@ -83,21 +110,31 @@ class ProximalAdagrad(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.ProximalAdagrad(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': no_conv_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] + >>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> opt = nn.ProximalAdagrad(net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None) + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0, use_locking=False, loss_scale=1.0, weight_decay=0.0): super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale) - if self.is_group: - raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param_value(accum, l1, l2, use_locking, self.cls_name) self.accum = self.parameters.clone(prefix="accum", init=accum) self.l1 = Tensor(l1, mstype.float32) self.l2 = Tensor(l2, mstype.float32) - self.weight_decay = weight_decay self.hyper_map = C.HyperMap() self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking) @@ -107,7 +144,11 @@ class ProximalAdagrad(Optimizer): accum = self.accum grads = self.decay_weight(grads) grads = self.scale_grad(grads) - lr = self.learning_rate - success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2), - grads, params, accum) + lr = self.get_lr() + if self.is_group_lr: + success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, + grads, params, accum) + else: + success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr), + grads, params, accum) return success diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index 75d4d6b0aca..ac7dff68e7a 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -44,12 +44,9 @@ class RMSProp(Optimizer): Implements Root Mean Squared Propagation (RMSProp) algorithm. Note: - The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. To improve parameter groups performance, the customized order of parameters can be supported. @@ -109,13 +106,14 @@ class RMSProp(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. Default: 0.1. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 0.1. decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or greater than 0. Default: 0.0. diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 410539d9304..f093200906b 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -40,14 +40,11 @@ class SGD(Optimizer): momentum in deep learning `_. Note: - The SGD optimizer supports separating parameter groups. Different parameter groups can set different - `learning_rate` and `weight_decay`. - When separating parameter groups, the weight decay in each group will be applied on the parameters if the - value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be - applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. - To improve parameter groups performance, the customized order of parameters can be supported. + To improve parameter groups performance, the customized order of parameters can be supported. Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, @@ -66,14 +63,14 @@ class SGD(Optimizer): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is - Iterable or a Tensor and the dims of the Tensor is 1, - use dynamic learning rate, then the i-th step will - take the i-th value as the learning rate. - When the learning_rate is float or learning_rate is a Tensor - but the dims of the Tensor is 0, use fixed learning rate. - Other cases are not supported. It should be equal to or - greater than 0. Default: 0.1. + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate. + When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with + dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 0.1. momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0. dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0. weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0. diff --git a/model_zoo/mass/src/utils/lr_scheduler.py b/model_zoo/mass/src/utils/lr_scheduler.py index 44ef397fdd1..16607678e54 100644 --- a/model_zoo/mass/src/utils/lr_scheduler.py +++ b/model_zoo/mass/src/utils/lr_scheduler.py @@ -14,9 +14,10 @@ # ============================================================================ """Learning scheduler.""" from math import ceil - import numpy as np +import mindspore.nn.learning_rate_schedule as lr_schedules + def square_root_schedule(lr, update_num, decay_start_step, warmup_steps=2000, @@ -105,3 +106,35 @@ def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr return lrs + + +class BertLearningRate(lr_schedules.LearningRateSchedule): + """ + Implements of warmup-polydecay learning rate scheduler. + + Args: + learning_rate (float): The initial value of learning rate. + end_learning_rate (float): The end value of learning rate. + warmup_steps (int): The warm up steps of learning rate. + decay_steps (int): A value used to calculate decayed learning rate. + power (float): A value used to calculate decayed learning rate. + + Returns: + Tensor. The learning rate value for the current step. + """ + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr diff --git a/model_zoo/mass/train.py b/model_zoo/mass/train.py index 4d297aa5180..a0c19592659 100644 --- a/model_zoo/mass/train.py +++ b/model_zoo/mass/train.py @@ -37,7 +37,7 @@ from src.transformer.infer_mass import infer from src.utils import LossCallBack from src.utils import one_weight, zero_weight, weight_variable from src.utils import square_root_schedule -from src.utils.lr_scheduler import polynomial_decay_scheduler +from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate parser = argparse.ArgumentParser(description='MASS train entry point.') parser.add_argument("--config", type=str, required=True, help="model config json file path.") @@ -178,10 +178,16 @@ def _build_training_pipeline(config: TransformerConfig, if config.optimizer.lower() == "adam": optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98) elif config.optimizer.lower() == "lamb": - optimizer = Lamb(net_with_loss.trainable_params(), decay_steps=12000, - start_learning_rate=config.lr, end_learning_rate=config.min_lr, - power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01, - eps=1e-6) + lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr, + power=10.0, warmup_steps=config.warmup_steps) + decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + net_with_loss.trainable_params())) + other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(), + net_with_loss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}] + + optimizer = Lamb(group_params, lr, eps=1e-6) elif config.optimizer.lower() == "momentum": optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9) else: diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md index 45928da4e3f..b971efd0f9f 100644 --- a/model_zoo/official/nlp/bert/README.md +++ b/model_zoo/official/nlp/bert/README.md @@ -147,7 +147,7 @@ Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 Parameters for optimizer: - AdamWeightDecayDynamicLR: + AdamWeightDecay: decay_steps steps of the learning rate decay: N learning_rate value of learning rate: Q end_learning_rate value of end learning rate: Q, must be positive diff --git a/model_zoo/official/nlp/bert/run_classifier.py b/model_zoo/official/nlp/bert/run_classifier.py index 73f41a858b0..97b01cceb94 100644 --- a/model_zoo/official/nlp/bert/run_classifier.py +++ b/model_zoo/official/nlp/bert/run_classifier.py @@ -23,12 +23,12 @@ from src.bert_for_finetune import BertFinetuneCell, BertCLS from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.dataset import create_classification_dataset from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation -from src.utils import make_directory, LossCallBack, LoadNewestCkpt +from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate import mindspore.common.dtype as mstype from mindspore import context from mindspore import log as logger from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum +from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum from mindspore.common.tensor import Tensor from mindspore.train.model import Model from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor @@ -42,27 +42,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer - if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), - decay_steps=steps_per_epoch * epoch_num, - learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=optimizer_cfg.AdamWeightDecayDynamicLR.power, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) + if optimizer_cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.AdamWeightDecay.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}] + + optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif optimizer_cfg.optimizer == 'Lamb': - optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, - start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, - end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, - power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - decay_filter=optimizer_cfg.Lamb.decay_filter) + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.Lamb.power) + optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: - raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") + raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) diff --git a/model_zoo/official/nlp/bert/run_ner.py b/model_zoo/official/nlp/bert/run_ner.py index 4b1a1cead75..c9314bf39cb 100644 --- a/model_zoo/official/nlp/bert/run_ner.py +++ b/model_zoo/official/nlp/bert/run_ner.py @@ -23,13 +23,13 @@ import argparse from src.bert_for_finetune import BertFinetuneCell, BertNER from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.dataset import create_ner_dataset -from src.utils import make_directory, LossCallBack, LoadNewestCkpt +from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation import mindspore.common.dtype as mstype from mindspore import context from mindspore import log as logger from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum +from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum from mindspore.common.tensor import Tensor from mindspore.train.model import Model from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor @@ -44,27 +44,30 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer - if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), - decay_steps=steps_per_epoch * epoch_num, - learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=optimizer_cfg.AdamWeightDecayDynamicLR.power, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) + if optimizer_cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.AdamWeightDecay.power) + params = network.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}] + optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif optimizer_cfg.optimizer == 'Lamb': - optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, - start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, - end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, - power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - decay_filter=optimizer_cfg.Lamb.decay_filter) + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.Lamb.power) + optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: - raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") + raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 54769c011bc..b8fab8ad2ec 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -28,12 +28,12 @@ from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR +from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay from mindspore import log as logger from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from src.dataset import create_bert_dataset from src.config import cfg, bert_net_cfg -from src.utils import LossCallBack +from src.utils import LossCallBack, BertLearningRate _current_dir = os.path.dirname(os.path.realpath(__file__)) @@ -109,24 +109,35 @@ def run_pretrain(): netwithloss = BertNetworkWithLoss(bert_net_cfg, True) if cfg.optimizer == 'Lamb': - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, - start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, - power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, - eps=cfg.Lamb.eps) + lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, + end_learning_rate=cfg.Lamb.end_learning_rate, + warmup_steps=cfg.Lamb.warmup_steps, + decay_steps=ds.get_dataset_size() * new_repeat_count, + power=cfg.Lamb.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(cfg.Lamb.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, + {'params': other_params}] + optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) - elif cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), - decay_steps=ds.get_dataset_size() * new_repeat_count, - learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=cfg.AdamWeightDecayDynamicLR.power, - weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=cfg.AdamWeightDecayDynamicLR.eps, - warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps) + elif cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, + end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=cfg.AdamWeightDecay.warmup_steps, + decay_steps=ds.get_dataset_size() * new_repeat_count, + power=cfg.AdamWeightDecay.power) + params = net_with_loss.trainable_params() + decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}] + + optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: - raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". + raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". format(cfg.optimizer)) callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] if args_opt.enable_save_ckpt == "true": diff --git a/model_zoo/official/nlp/bert/run_squad.py b/model_zoo/official/nlp/bert/run_squad.py index 55fa2a0fc32..bc2b75fa322 100644 --- a/model_zoo/official/nlp/bert/run_squad.py +++ b/model_zoo/official/nlp/bert/run_squad.py @@ -25,12 +25,12 @@ from src.dataset import create_squad_dataset from src import tokenization from src.create_squad_data import read_squad_examples, convert_examples_to_features from src.run_squad import write_predictions -from src.utils import make_directory, LossCallBack, LoadNewestCkpt +from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate import mindspore.common.dtype as mstype from mindspore import context from mindspore import log as logger from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum +from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum from mindspore.common.tensor import Tensor from mindspore.train.model import Model from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor @@ -44,27 +44,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer - if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), - decay_steps=steps_per_epoch * epoch_num, - learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=optimizer_cfg.AdamWeightDecayDynamicLR.power, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) + if optimizer_cfg.optimizer == 'AdamWeightDecay': + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.AdamWeightDecay.power) + params = network.trainable_params() + decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) + other_params = list(filter(lambda x: x not in decay_params, params)) + group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, + {'params': other_params, 'weight_decay': 0.0}] + + optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif optimizer_cfg.optimizer == 'Lamb': - optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, - start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, - end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, - power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, - warmup_steps=int(steps_per_epoch * epoch_num * 0.1), - decay_filter=optimizer_cfg.Lamb.decay_filter) + lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_steps=steps_per_epoch * epoch_num, + power=optimizer_cfg.Lamb.power) + optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: - raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") + raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index 812f0c2f180..e553f4b0385 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -24,20 +24,22 @@ cfg = edict({ 'scale_factor': 2, 'scale_window': 1000, 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ + 'AdamWeightDecay': edict({ 'learning_rate': 3e-5, 'end_learning_rate': 1e-10, 'power': 5.0, 'weight_decay': 1e-5, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), 'eps': 1e-6, 'warmup_steps': 10000, }), 'Lamb': edict({ - 'start_learning_rate': 3e-5, + 'learning_rate': 3e-5, 'end_learning_rate': 1e-10, 'power': 10.0, 'warmup_steps': 10000, 'weight_decay': 0.01, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), 'eps': 1e-6, }), 'Momentum': edict({ diff --git a/model_zoo/official/nlp/bert/src/finetune_eval_config.py b/model_zoo/official/nlp/bert/src/finetune_eval_config.py index 4b8e121e095..4a9f05a3fc0 100644 --- a/model_zoo/official/nlp/bert/src/finetune_eval_config.py +++ b/model_zoo/official/nlp/bert/src/finetune_eval_config.py @@ -23,19 +23,20 @@ from .bert_model import BertConfig optimizer_cfg = edict({ 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ + 'AdamWeightDecay': edict({ 'learning_rate': 2e-5, 'end_learning_rate': 1e-7, 'power': 1.0, 'weight_decay': 1e-5, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), 'eps': 1e-6, }), 'Lamb': edict({ - 'start_learning_rate': 2e-5, + 'learning_rate': 2e-5, 'end_learning_rate': 1e-7, 'power': 1.0, 'weight_decay': 0.01, - 'decay_filter': lambda x: False, + 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), }), 'Momentum': edict({ 'learning_rate': 2e-5, diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py index dfb6ffa5fef..775931b23af 100644 --- a/model_zoo/official/nlp/bert/src/utils.py +++ b/model_zoo/official/nlp/bert/src/utils.py @@ -23,6 +23,7 @@ from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.common import dtype as mstype from mindspore.train.callback import Callback +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR class CrossEntropyCalculation(nn.Cell): @@ -123,3 +124,25 @@ def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, pre max_num = int(num) load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) return load_finetune_checkpoint_path + + +class BertLearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for Bert network. + """ + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr diff --git a/tests/mindspore_test_framework/apps/test_lamb_check_loss.py b/tests/mindspore_test_framework/apps/test_lamb_check_loss.py index 44989596208..11e13261e6d 100644 --- a/tests/mindspore_test_framework/apps/test_lamb_check_loss.py +++ b/tests/mindspore_test_framework/apps/test_lamb_check_loss.py @@ -30,7 +30,7 @@ verification_set = [ 'block': { 'model': network, 'loss': SquaredLoss(), - 'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01), + 'opt': Lamb(network.trainable_params(), 0.02, weight_decay=0.01), 'num_epochs': num_epochs, 'loss_upper_bound': 0.3, }, diff --git a/tests/mindspore_test_framework/pipeline/gradient/check_training.py b/tests/mindspore_test_framework/pipeline/gradient/check_training.py index 135b162ec77..61ed61af6de 100644 --- a/tests/mindspore_test_framework/pipeline/gradient/check_training.py +++ b/tests/mindspore_test_framework/pipeline/gradient/check_training.py @@ -31,7 +31,7 @@ Example: 'block': { 'model': network, 'loss': SquaredLoss(), - 'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01), + 'opt': Lamb(network.trainable_params(), lr=0.02, weight_decay=0.01), 'num_epochs': num_epochs, 'loss_upper_bound': 0.3, }, diff --git a/tests/perf_test/bert/test_bert_train.py b/tests/perf_test/bert/test_bert_train.py index 096571adea0..058cf7221ad 100644 --- a/tests/perf_test/bert/test_bert_train.py +++ b/tests/perf_test/bert/test_bert_train.py @@ -22,8 +22,9 @@ import os import mindspore.common.dtype as mstype import mindspore.context as context from mindspore import Tensor -from mindspore.nn.optim import AdamWeightDecayDynamicLR +from mindspore.nn.optim import AdamWeightDecay from mindspore.train.loss_scale_manager import DynamicLossScaleManager +from mindspore.nn import learning_rate_schedule as lr_schedules from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from ...dataset_mock import MindData from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph @@ -98,6 +99,25 @@ def get_config(version='base', batch_size=1): return BertConfig(batch_size=batch_size) +class BertLearningRate(lr_schedules.LearningRateSchedule): + def __init__(self, decay_steps, warmup_steps=0, learning_rate=0.1, end_learning_rate=0.0001, power=1.0): + super(BertLearningRate, self).__init__() + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr + + def test_bert_train(): """ the main function @@ -123,7 +143,8 @@ def test_bert_train(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=False) @@ -147,7 +168,8 @@ def test_bert_withlossscale_train(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True) @@ -173,7 +195,8 @@ def bert_withlossscale_manager_train(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True) @@ -200,7 +223,8 @@ def bert_withlossscale_manager_train_feed(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) + lr = BertLearningRate(10) + optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True) diff --git a/tests/st/networks/models/bert/src/config.py b/tests/st/networks/models/bert/src/config.py index 812f0c2f180..0aef2bc8c94 100644 --- a/tests/st/networks/models/bert/src/config.py +++ b/tests/st/networks/models/bert/src/config.py @@ -24,7 +24,7 @@ cfg = edict({ 'scale_factor': 2, 'scale_window': 1000, 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ + 'AdamWeightDecay': edict({ 'learning_rate': 3e-5, 'end_learning_rate': 1e-10, 'power': 5.0, @@ -33,7 +33,7 @@ cfg = edict({ 'warmup_steps': 10000, }), 'Lamb': edict({ - 'start_learning_rate': 3e-5, + 'learning_rate': 3e-5, 'end_learning_rate': 1e-10, 'power': 10.0, 'warmup_steps': 10000, diff --git a/tests/st/networks/models/bert/src/finetune_config.py b/tests/st/networks/models/bert/src/finetune_config.py index e92842489b9..466676fd2e4 100644 --- a/tests/st/networks/models/bert/src/finetune_config.py +++ b/tests/st/networks/models/bert/src/finetune_config.py @@ -32,7 +32,7 @@ cfg = edict({ 'pre_training_ckpt': '/your/path/pre_training.ckpt', 'use_crf': False, 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ + 'AdamWeightDecay': edict({ 'learning_rate': 2e-5, 'end_learning_rate': 1e-7, 'power': 1.0, @@ -40,7 +40,7 @@ cfg = edict({ 'eps': 1e-6, }), 'Lamb': edict({ - 'start_learning_rate': 2e-5, + 'learning_rate': 2e-5, 'end_learning_rate': 1e-7, 'power': 1.0, 'decay_filter': lambda x: False, diff --git a/tests/st/networks/models/bert/test_bert_graph_kernel.py b/tests/st/networks/models/bert/test_bert_graph_kernel.py index 4c9673e0767..24aec5084e0 100644 --- a/tests/st/networks/models/bert/test_bert_graph_kernel.py +++ b/tests/st/networks/models/bert/test_bert_graph_kernel.py @@ -29,9 +29,11 @@ from mindspore.nn.optim import Lamb from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.model import Model +from mindspore.nn import learning_rate_schedule as lr_schedules from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell from src.bert_model import BertConfig + DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" @@ -111,6 +113,25 @@ def weight_variable(shape): return Tensor(ones) +class BertLearningRate(lr_schedules.LearningRateSchedule): + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr + + class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() @@ -134,9 +155,15 @@ def test_bert_tdt(): ds = me_de_train_dataset() config = get_config(version='large', batch_size=16) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), - start_learning_rate=5e-5, end_learning_rate=1e-9, - power=10.0, warmup_steps=0, weight_decay=0.01) + lr = BertLearningRate(decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), learning_rate=5e-5, + end_learning_rate=1e-9, power=10.0, warmup_steps=0) + decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() + no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower() + decay_params = list(filter(decay_filter, net_with_loss.trainable_params())) + other_params = list(filter(no_decay_filter, net_with_loss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}] + optimizer = Lamb(group_params, lr) scale_window = 3 scale_manager = DynamicLossScaleManager(262144, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, diff --git a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py index c0b2d3231b1..e33def066d1 100644 --- a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py @@ -33,6 +33,7 @@ from mindspore.nn.optim import Lamb from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.model import Model +import mindspore.nn.learning_rate_schedule as lr_schedules _current_dir = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] @@ -125,6 +126,25 @@ def weight_variable(shape): return Tensor(ones) +class BertLearningRate(lr_schedules.LearningRateSchedule): + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr + + class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() @@ -162,9 +182,16 @@ def test_bert_percision(): batch_size = 16 config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*new_repeat_count, - start_learning_rate=5e-5, end_learning_rate=1e-9, - power=10.0, warmup_steps=0, weight_decay=0.01) + lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count, + learning_rate=5e-5, end_learning_rate=1e-9, + power=10.0, warmup_steps=0) + decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() + no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower() + decay_params = list(filter(decay_filter, net_with_loss.trainable_params())) + other_params = list(filter(no_decay_filter, net_with_loss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}] + optimizer = Lamb(group_params, lr) scale_window = 3 scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, @@ -220,9 +247,18 @@ def test_bert_performance(): batch_size = 16 config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*new_repeat_count, - start_learning_rate=5e-5, end_learning_rate=1e-9, - power=10.0, warmup_steps=0, weight_decay=0.01) + + lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count, + learning_rate=5e-5, end_learning_rate=1e-9, + power=10.0, warmup_steps=0) + decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() + no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower() + decay_params = list(filter(decay_filter, net_with_loss.trainable_params())) + other_params = list(filter(no_decay_filter, net_with_loss.trainable_params())) + group_params = [{'params': decay_params, 'weight_decay': 0.01}, + {'params': other_params}] + optimizer = Lamb(group_params, lr) + scale_window = 3 scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index 03a73893c50..bebbc008801 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -20,8 +20,10 @@ import mindspore.nn as nn from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR +from mindspore.nn.optim import Adam, AdamWeightDecay from mindspore.ops import operations as P +import mindspore.nn.learning_rate_schedule as lr_schedules +from mindspore.nn.dynamic_lr import polynomial_decay_lr context.set_context(enable_sparse=True) @@ -112,6 +114,62 @@ def test_sparse_adam_compile(): _executor.compile(train_network, indices, label) +def test_adam_group1(): + """ test_adam_group_lr_and_weight_decay """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + net_with_loss = WithLossCell(net, loss) + all_params = net.trainable_params() + + poly_decay_lr = polynomial_decay_lr(0.01, 0.0001, total_step=10, step_per_epoch=1, decay_epoch=3, power=1.0) + + group_params = [{'params': [all_params[0]], 'lr': poly_decay_lr, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = nn.Adam(group_params, learning_rate=0.1) + + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_adam_group2(): + """ test_adam_group_lr_and_weight_decay """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + net_with_loss = WithLossCell(net, loss) + all_params = net.trainable_params() + + schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0) + group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = nn.Adam(group_params, learning_rate=schedule_lr) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_adamweightdecay_group(): + """ test_adam_group_lr_and_weight_decay """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + net_with_loss = WithLossCell(net, loss) + all_params = net.trainable_params() + + schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0) + group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = nn.AdamWeightDecay(group_params, learning_rate=schedule_lr) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + def test_AdamWeightDecay_beta1(): net = Net() print("**********", net.get_parameters()) @@ -131,20 +189,6 @@ def test_AdamWeightDecay_e(): AdamWeightDecay(net.get_parameters(), eps=-0.1, learning_rate=0.1) -def test_AdamWeightDecayDynamicLR(): - """ test_AdamWeightDecayDynamicLR """ - inputs = Tensor(np.ones([1, 64]).astype(np.float32)) - label = Tensor(np.zeros([1, 10]).astype(np.float32)) - net = Net() - net.set_train() - loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) - - net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepCell(net_with_loss, optimizer) - _executor.compile(train_network, inputs, label) - - def test_adam_mindspore_with_empty_params(): net = nn.Flatten() with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): diff --git a/tests/ut/python/nn/optim/test_lamb.py b/tests/ut/python/nn/optim/test_lamb.py index 4d229f0837d..b2963fc9501 100644 --- a/tests/ut/python/nn/optim/test_lamb.py +++ b/tests/ut/python/nn/optim/test_lamb.py @@ -14,7 +14,6 @@ # ============================================================================ """ test lamb """ import numpy as np -import pytest import mindspore.nn as nn from mindspore import Tensor, Parameter @@ -22,6 +21,27 @@ from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Lamb from mindspore.ops import operations as P +import mindspore.common.dtype as mstype +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR + + +class LambLearningRate(LearningRateSchedule): + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(LambLearningRate, self).__init__() + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + decay_lr = self.decay_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + return lr class Net(nn.Cell): @@ -51,6 +71,21 @@ class NetWithoutWeight(nn.Cell): return x +def test_lamb_compile_dynamic_lr(): + """ test_Lamb_compile """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0) + optimizer = Lamb(net.trainable_params(), warmup_decay_lr) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + def test_lamb_compile(): """ test_Lamb_compile """ inputs = Tensor(np.ones([1, 64]).astype(np.float32)) @@ -58,20 +93,27 @@ def test_lamb_compile(): net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = Lamb(net.trainable_params(), decay_steps=10) + + optimizer = Lamb(net.trainable_params(), 0.02, 0.9) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) -def test_lamb_error(): +def test_lamb_group(): + """ test_Lamb_group_compile """ + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() - with pytest.raises(TypeError): - Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0) + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0) + all_params = net.trainable_params() + group_params = [{'params': [all_params[0]], 'lr': warmup_decay_lr, 'weight_decay': 0.9}, + {'params': [all_params[1]]}] + optimizer = Lamb(group_params, 0.02) - with pytest.raises(TypeError): - Lamb(net.get_parameters(), decay_steps=1.0) - - with pytest.raises(ValueError): - Lamb(net.get_parameters(), decay_steps=0) + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) diff --git a/tests/ut/python/nn/optim/test_optimizer.py b/tests/ut/python/nn/optim/test_optimizer.py index 70b79e97d7b..32d9c5b4fe2 100644 --- a/tests/ut/python/nn/optim/test_optimizer.py +++ b/tests/ut/python/nn/optim/test_optimizer.py @@ -18,7 +18,7 @@ import pytest from mindspore import Tensor from mindspore.common.parameter import Parameter -from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR +from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay class IterableObjc: @@ -81,10 +81,6 @@ class TestNullParam(): with pytest.raises(ValueError): AdamWeightDecay(None) - def test_AdamWeightDecayDynamicLR_init(self): - with pytest.raises(ValueError): - AdamWeightDecayDynamicLR(None, 10) - def test_Sgd_init(self): with pytest.raises(ValueError): SGD(None) @@ -101,10 +97,6 @@ class TestUnsupportParam(): with pytest.raises(TypeError): AdamWeightDecay(9) - def test_AdamWeightDecayDynamicLR_init(self): - with pytest.raises(TypeError): - AdamWeightDecayDynamicLR(0.5, 10) - def test_Sgd_init(self): with pytest.raises(TypeError): paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x") diff --git a/tests/ut/python/nn/optim/test_proximal_ada_grad.py b/tests/ut/python/nn/optim/test_proximal_ada_grad.py index 3077896fed5..d88c55fd700 100644 --- a/tests/ut/python/nn/optim/test_proximal_ada_grad.py +++ b/tests/ut/python/nn/optim/test_proximal_ada_grad.py @@ -37,6 +37,7 @@ class Net(nn.Cell): x = self.biasAdd(self.matmul(x, self.weight), self.bias) return x + class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): diff --git a/tests/ut/python/nn/test_dynamic_lr.py b/tests/ut/python/nn/test_dynamic_lr.py index c53f28d5f7f..44a803a2196 100644 --- a/tests/ut/python/nn/test_dynamic_lr.py +++ b/tests/ut/python/nn/test_dynamic_lr.py @@ -28,7 +28,7 @@ decay_epoch = 2 min_lr = 0.01 max_lr = 0.1 power = 0.5 - +warmup_epoch = 2 class TestInputs: def test_milestone1(self): @@ -234,3 +234,8 @@ def test_polynomial_decay(): lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power, True) assert len(lr2) == total_step + + +def test_warmup(): + lr1 = dr.warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch) + assert len(lr1) == total_step diff --git a/tests/ut/python/nn/test_learning_rate_schedule.py b/tests/ut/python/nn/test_learning_rate_schedule.py new file mode 100644 index 00000000000..74f261a02ed --- /dev/null +++ b/tests/ut/python/nn/test_learning_rate_schedule.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Test Dynamic Learning Rate """ +import pytest + +from mindspore import Tensor, Parameter +from mindspore.nn import learning_rate_schedule as lr_schedules +from mindspore.common.api import _executor +import mindspore.common.dtype as mstype + + +learning_rate = 0.1 +end_learning_rate = 0.01 +decay_rate = 0.9 +decay_steps = 4 +warmup_steps = 2 +min_lr = 0.01 +max_lr = 0.1 +power = 0.5 +global_step = Parameter(Tensor(2, mstype.int32), 'global_step') + + +class TestInit: + def test_learning_rate_type(self): + lr = True + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps) + + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power) + + def test_learning_rate_value(self): + lr = -1.0 + with pytest.raises(ValueError): + lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps) + + with pytest.raises(ValueError): + lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power) + + def test_end_learning_rate_type(self): + lr = True + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power) + + def test_end_learning_rate_value(self): + lr = -1.0 + with pytest.raises(ValueError): + lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power) + + def test_decay_rate_type(self): + rate = 'a' + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps) + + def test_decay_rate_value(self): + rate = -1.0 + with pytest.raises(ValueError): + lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps) + + def test_decay_steps_type(self): + decay_steps_e = 'm' + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e) + + with pytest.raises(TypeError): + lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e) + + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power) + + def test_decay_steps_value(self): + decay_steps_e = -2 + with pytest.raises(ValueError): + lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e) + + with pytest.raises(ValueError): + lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e) + + with pytest.raises(ValueError): + lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power) + + def test_is_stair(self): + is_stair = 1 + with pytest.raises(TypeError): + lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, is_stair) + + def test_min_lr_type(self): + min_lr1 = True + with pytest.raises(TypeError): + lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps) + + def test_min_lr_value(self): + min_lr1 = -1.0 + with pytest.raises(ValueError): + lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps) + + def test_max_lr_type(self): + max_lr1 = 'a' + with pytest.raises(TypeError): + lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps) + + def test_max_lr_value(self): + max_lr1 = -1.0 + with pytest.raises(ValueError): + lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps) + + def test_power(self): + power1 = True + with pytest.raises(TypeError): + lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power1) + + +def test_exponential_decay(): + lr_schedule = lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, True) + _executor.compile(lr_schedule, global_step) + + +def test_enatural_exp_decay(): + lr_schedule = lr_schedules.NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True) + _executor.compile(lr_schedule, global_step) + + +def test_inverse_decay(): + lr_schedule = lr_schedules.InverseDecayLR(learning_rate, decay_rate, decay_steps, True) + _executor.compile(lr_schedule, global_step) + + +def test_cosine_decay(): + lr_schedule = lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps) + _executor.compile(lr_schedule, global_step) + + +def test_polynomial_decay(): + lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + _executor.compile(lr_schedule, global_step) + + +def test_polynomial_decay2(): + lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power, True) + _executor.compile(lr_schedule, global_step) + + +def test_warmup(): + lr_schedule = lr_schedules.WarmUpLR(learning_rate, warmup_steps) + _executor.compile(lr_schedule, global_step) diff --git a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py index ea60f1f09b8..6f77c4a3614 100644 --- a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py +++ b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py @@ -152,7 +152,7 @@ def test_compile_fp16_overflow(): net = NetFP16(16, 16) loss = MSELoss() - optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) + optimizer = Lamb(net.trainable_params(), learning_rate=0.01) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network.set_train() diff --git a/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py index 2f93eb61867..19b82e16e45 100644 --- a/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py @@ -104,9 +104,11 @@ def test_group_dynamic_1(): assert opt.is_group_params_ordered is True for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()): if param in conv_params: - assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) else: - assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) assert param.name == order_param.name @@ -134,9 +136,11 @@ def test_group_dynamic_2(): assert opt.dynamic_lr is True for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert np.all(lr.data.asnumpy() == Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy()) + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy()) else: - assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3).astype(np.float32)).asnumpy()) + assert np.all(lr.learning_rate.data.asnumpy() == \ + Tensor(np.array([default_lr] * 3).astype(np.float32)).asnumpy()) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -157,7 +161,7 @@ def test_group_dynamic_no_same_size(): def test_group_not_float_lr(): net = LeNet5() - conv_lr = 1 + conv_lr = np.array(1) default_lr = 0.3 conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) @@ -169,7 +173,7 @@ def test_group_not_float_lr(): def test_group_not_float_weight_decay(): net = LeNet5() - conv_weight_decay = 1 + conv_weight_decay = np.array(1) conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, @@ -238,11 +242,15 @@ def test_get_lr_parameter_with_group(): assert opt.is_group_lr is True for param in opt.parameters: lr = opt.get_lr_parameter(param) - assert lr.name == 'lr_' + param.name + if 'conv' in param.name: + cur_name = 'learning_rate_group_' + '0' + else: + cur_name = 'learning_rate_group_' + '1' + assert lr.name == cur_name lr_list = opt.get_lr_parameter(conv_params) for lr, param in zip(lr_list, conv_params): - assert lr.name == 'lr_' + param.name + assert lr.name == 'learning_rate_group_' + '0' def test_get_lr_parameter_with_order_group(): @@ -256,7 +264,11 @@ def test_get_lr_parameter_with_order_group(): assert opt.is_group_lr is True for param in opt.parameters: lr = opt.get_lr_parameter(param) - assert lr.name == 'lr_' + param.name + if 'conv' in param.name: + cur_name = 'learning_rate_group_' + '0' + else: + cur_name = 'learning_rate' + assert lr.name == cur_name def test_get_lr_parameter_with_no_group(): @@ -271,7 +283,7 @@ def test_get_lr_parameter_with_no_group(): assert opt.is_group_lr is False for param in opt.parameters: lr = opt.get_lr_parameter(param) - assert lr.name == opt.learning_rate.name + assert lr.name == 'learning_rate' params_error = [1, 2, 3] with pytest.raises(TypeError): @@ -305,7 +317,11 @@ def test_order_params_1(): assert decay_flags is False assert param.name == order_param.name - assert lr.name == 'lr_' + param.name + if 'conv' in param.name: + assert lr.name == 'learning_rate' + elif 'bias' in param.name: + assert lr.name == 'learning_rate_group_' + '1' + def test_order_params_2(): @@ -323,8 +339,9 @@ def test_order_params_2(): assert opt.is_group is True assert opt.is_group_lr is True assert opt.is_group_params_ordered is True + all_lr = opt.get_lr_parameter(fc1_params+conv_params) for weight_decay, decay_flags, lr, param, order_param in zip( - opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, fc1_params+conv_params): + opt.weight_decay, opt.decay_flags, all_lr, opt.parameters, fc1_params+conv_params): if param in conv_params: assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy()) assert weight_decay == conv_weight_decay @@ -339,8 +356,10 @@ def test_order_params_2(): assert decay_flags is False assert param.name == order_param.name - assert lr.name == 'lr_' + param.name - + if 'conv' in param.name: + assert lr.name == 'learning_rate' + elif 'fc1' in param.name: + assert lr.name == 'learning_rate_group_' + '0' def test_get_order_params_with_not_same(): net = LeNet5() diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index ee9291fb98a..f6173b24c40 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -20,7 +20,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb +from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb from mindspore.ops import operations as P from mindspore import context @@ -51,23 +51,8 @@ class Net(nn.Cell): return s -def test_AdamWeightDecayDynamicLR(): - """ test_AdamWeightDecayDynamicLR """ - context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) - inputs = Tensor(np.ones([32, 128]).astype(np.float32)) - label = Tensor(np.zeros([32, 768]).astype(np.float32)) - net = Net() - net.set_train() - loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) - - net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepCell(net_with_loss, optimizer) - _executor.compile(train_network, inputs, label) - - def test_AdamWeightDecay(): - """ test_AdamWeightDecayDynamicLR """ + """ test_AdamWeightDecay """ context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) @@ -89,7 +74,7 @@ def test_lamb_compile(): net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = Lamb(net.trainable_params(), decay_steps=10) + optimizer = Lamb(net.trainable_params(), learning_rate=0.1) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) @@ -102,9 +87,9 @@ def test_edge_case(): net = Net() with pytest.raises(RuntimeError): context.set_auto_parallel_context(parallel_mode="stand_alone") - Lamb(net.trainable_params(), decay_steps=10) + Lamb(net.trainable_params(), learning_rate=0.1) with pytest.raises(RuntimeError): Adam(net.trainable_params(), learning_rate=0.1) with pytest.raises(RuntimeError): context.set_auto_parallel_context(device_num=16) - Lamb(net.trainable_params(), decay_steps=10) + Lamb(net.trainable_params(), learning_rate=0.1)