forked from mindspore-Ecosystem/mindspore
!4733 Synchronize some bugfix of bert to branch r0.5
Merge pull request !4733 from chenhaozhe/r0.5
This commit is contained in:
commit
034453e48f
|
@ -382,7 +382,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
beta2=0.999,
|
||||
eps=1e-6,
|
||||
weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
|
||||
super(AdamWeightDecayDynamicLR, self).__init__(0.0, params)
|
||||
if self.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
|
|
|
@ -18,6 +18,7 @@ python run_pretrain.py
|
|||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import numpy
|
||||
import mindspore.communication.management as D
|
||||
|
@ -44,15 +45,16 @@ class LossCallBack(Callback):
|
|||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
def __init__(self, data_epoch_size=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
if not isinstance(data_epoch_size, int) or data_epoch_size < 0:
|
||||
raise ValueError("data_epoch_size must be int and >= 0")
|
||||
self._data_epoch_size = data_epoch_size
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
percent, epoch = math.modf(cb_params.cur_epoch_num / self._data_epoch_size)
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(epoch, "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
|
||||
|
||||
def run_pretrain():
|
||||
"""pre-train bert_clue"""
|
||||
|
@ -120,6 +122,7 @@ def run_pretrain():
|
|||
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.enable_data_sink, args_opt.data_sink_steps,
|
||||
args_opt.data_dir, args_opt.schema_dir)
|
||||
data_epoch_size = new_repeat_count // args_opt.epoch_size # Epoch nums in one dataset.
|
||||
if args_opt.train_steps > 0:
|
||||
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
|
@ -144,7 +147,7 @@ def run_pretrain():
|
|||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
|
||||
format(cfg.optimizer))
|
||||
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
|
||||
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack(data_epoch_size)]
|
||||
if args_opt.enable_save_ckpt == "true":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
|
|
|
@ -54,7 +54,7 @@ do
|
|||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
env > env.log
|
||||
taskset -c $cmdopt python ../run_pretrain.py \
|
||||
taskset -c $cmdopt nohup python ../run_pretrain.py \
|
||||
--distribute="true" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
|
|
|
@ -29,7 +29,7 @@ mkdir -p ms_log
|
|||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python run_pretrain.py \
|
||||
nohup python run_pretrain.py \
|
||||
--distribute="false" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
|
|
|
@ -36,8 +36,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
shuffle=de.Shuffle.FILES if do_shuffle == "true" else False,
|
||||
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print('origin dataset size: ', ori_dataset_size)
|
||||
new_size = ori_dataset_size
|
||||
|
|
Loading…
Reference in New Issue