From 8f8eee4b5e63b5128d915eeb92d2a55f0fc195b7 Mon Sep 17 00:00:00 2001 From: mwang Date: Thu, 4 Feb 2021 11:32:35 +0800 Subject: [PATCH] bert thor supports lr configuration in config.py --- mindspore/train/train_thor/convert_utils.py | 2 ++ model_zoo/official/cv/resnet/src/config.py | 2 +- model_zoo/official/cv/resnet/train.py | 2 +- model_zoo/official/nlp/bert/run_pretrain.py | 5 ++-- model_zoo/official/nlp/bert/src/config.py | 8 ++++++ model_zoo/official/nlp/bert/src/utils.py | 27 +++++++-------------- 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/mindspore/train/train_thor/convert_utils.py b/mindspore/train/train_thor/convert_utils.py index ac7c93a3201..4714dcc673d 100644 --- a/mindspore/train/train_thor/convert_utils.py +++ b/mindspore/train/train_thor/convert_utils.py @@ -40,6 +40,8 @@ class ConvertNetUntils(): if subcell.activation_flag: act_class = subcell.activation.__class__.__name__ act_name = act_class.lower() + if act_name == "fastgelu": + act_name = "fast_gelu" if subcell.out_channels == 1001: new_subcell = nn.Dense_Thor(in_channels=subcell.in_channels, out_channels=subcell.out_channels, diff --git a/model_zoo/official/cv/resnet/src/config.py b/model_zoo/official/cv/resnet/src/config.py index 3ebcf54251e..e64427a3fd5 100755 --- a/model_zoo/official/cv/resnet/src/config.py +++ b/model_zoo/official/cv/resnet/src/config.py @@ -18,7 +18,7 @@ network config setting, will be used in train.py and eval.py from easydict import EasyDict as ed # config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional. cfg = ed({ - 'optimizer': 'Thor', + 'optimizer': 'Momentum', }) # config for resent50, cifar10 diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 58806d5e477..f56b5d29f0d 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -18,7 +18,7 @@ import argparse import ast from mindspore import context from mindspore import Tensor -from mindspore.nn.optim.momentum import Momentum, THOR +from mindspore.nn.optim import Momentum, THOR from mindspore.train.model import Model from mindspore.context import ParallelMode from mindspore.train.train_thor import ConvertModelUtils diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 7b146b46da2..0338fe196ac 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -100,8 +100,9 @@ def _get_optimizer(args_opt, network): optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) elif cfg.optimizer == "Thor": from src.utils import get_bert_thor_lr, get_bert_thor_damping - lr = get_bert_thor_lr() - damping = get_bert_thor_damping() + lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps) + damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power, + cfg.Thor.damping_total_steps) split_indices = None if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index 14ed15750a2..b1d757f1f1f 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -49,6 +49,14 @@ cfg = edict({ 'momentum': 0.9, }), 'Thor': edict({ + 'lr_max': 0.0034, + 'lr_min': 3.244e-5, + 'lr_power': 1.0, + 'lr_total_steps': 30000, + 'damping_max': 5e-2, + 'damping_min': 1e-6, + 'damping_power': 1.0, + 'damping_total_steps': 30000, 'momentum': 0.9, 'weight_decay': 5e-4, 'loss_scale': 1.0, diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py index 20147a8b22c..75a1223cf72 100644 --- a/model_zoo/official/nlp/bert/src/utils.py +++ b/model_zoo/official/nlp/bert/src/utils.py @@ -22,7 +22,6 @@ import math import collections import numpy as np import mindspore.nn as nn -from mindspore import context from mindspore import log as logger from mindspore.ops import operations as P from mindspore.common.tensor import Tensor @@ -107,10 +106,11 @@ class LossCallBack(Callback): percent = 1 epoch_num -= 1 print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" - .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs))) + .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)), + flush=True) else: print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) + str(cb_params.net_outputs)), flush=True) def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): """ @@ -220,22 +220,13 @@ def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps return learning_rate -def get_bert_thor_lr(): - if context.get_context("device_target") == "Ascend": - learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=3.244018779068399e-05, - lr_max=0.0034022148941459055, warmup_steps=0, total_steps=30000, poly_power=1) - else: - learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=1.7e-3, warmup_steps=0, - total_steps=30000, poly_power=1) - +def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000): + learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0, + total_steps=lr_total_steps, poly_power=lr_power) return Tensor(learning_rate) -def get_bert_thor_damping(): - if context.get_context("device_target") == "Ascend": - damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=5e-2, warmup_steps=0, total_steps=30000, - poly_power=1) - else: - damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.5e-2, warmup_steps=0, - total_steps=30000, poly_power=1) +def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000): + damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0, + total_steps=damping_total_steps, poly_power=damping_power) return Tensor(damping)