From af4923123ca07c759faa4d79e18cc9b5c8e7758d Mon Sep 17 00:00:00 2001 From: shibeiji Date: Wed, 22 Jul 2020 18:06:02 +0800 Subject: [PATCH] script update for bert --- model_zoo/official/nlp/bert/run_pretrain.py | 6 ++---- .../official/nlp/bert/src/bert_for_finetune.py | 6 +++--- .../official/nlp/bert/src/bert_for_pre_training.py | 6 +++--- model_zoo/official/nlp/bert/src/utils.py | 14 ++++++++++---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index bab9fa2f583..749f6f6236e 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -117,8 +117,7 @@ def run_pretrain(): decay_params = list(filter(cfg.Lamb.decay_filter, params)) other_params = list(filter(lambda x: x not in decay_params, params)) group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, - {'params': other_params}, - {'order_params': params}] + {'params': other_params}] optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, @@ -133,8 +132,7 @@ def run_pretrain(): decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: x not in decay_params, params)) group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, - {'params': other_params, 'weight_decay': 0.0}, - {'order_params': params}] + {'params': other_params, 'weight_decay': 0.0}] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: diff --git a/model_zoo/official/nlp/bert/src/bert_for_finetune.py b/model_zoo/official/nlp/bert/src/bert_for_finetune.py index 32ac0823b97..5fbf1d81b9b 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_finetune.py +++ b/model_zoo/official/nlp/bert/src/bert_for_finetune.py @@ -22,7 +22,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.train.parallel_utils import ParallelMode @@ -55,7 +55,7 @@ class BertFinetuneCell(nn.Cell): super(BertFinetuneCell, self).__init__(auto_prefix=False) self.network = network - self.weights = ParameterTuple(network.trainable_params()) + self.weights = optimizer.parameters self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, @@ -158,7 +158,7 @@ class BertSquadCell(nn.Cell): def __init__(self, network, optimizer, scale_update_cell=None): super(BertSquadCell, self).__init__(auto_prefix=False) self.network = network - self.weights = ParameterTuple(network.trainable_params()) + self.weights = optimizer.parameters self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.reducer_flag = False diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 802391ee861..1d12ddaf061 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -21,7 +21,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.train.parallel_utils import ParallelMode @@ -270,7 +270,7 @@ class BertTrainOneStepCell(nn.Cell): def __init__(self, network, optimizer, sens=1.0): super(BertTrainOneStepCell, self).__init__(auto_prefix=False) self.network = network - self.weights = ParameterTuple(network.trainable_params()) + self.weights = optimizer.parameters self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.sens = sens @@ -349,7 +349,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): def __init__(self, network, optimizer, scale_update_cell=None): super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network - self.weights = ParameterTuple(network.trainable_params()) + self.weights = optimizer.parameters self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py index 6e8ea6ed643..71d37039e90 100644 --- a/model_zoo/official/nlp/bert/src/utils.py +++ b/model_zoo/official/nlp/bert/src/utils.py @@ -133,7 +133,10 @@ class BertLearningRate(LearningRateSchedule): """ def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): super(BertLearningRate, self).__init__() - self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) @@ -142,8 +145,11 @@ class BertLearningRate(LearningRateSchedule): self.cast = P.Cast() def construct(self, global_step): - is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) - warmup_lr = self.warmup_lr(global_step) decay_lr = self.decay_lr(global_step) - lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr return lr