forked from mindspore-Ecosystem/mindspore
!3729 modify readme and timemoniter steps
Merge pull request !3729 from wanghua/c75_tinybert
This commit is contained in:
commit
6875095b3a
|
@ -1,6 +1,6 @@
|
|||
# TinyBERT Example
|
||||
## Description
|
||||
This example implements general distill and task distill of [BERT-base](https://github.com/google-research/bert)(the base version of BERT model).
|
||||
[TinyBERT](https://github.com/huawei-noah/Pretrained-Model/tree/master/TinyBERT) is 7.5x smalller and 9.4x faster on inference than [BERT-base](https://github.com/google-research/bert)(the base version of BERT model) and achieves competitive performances in the tasks of natural language understanding. It performs a novel transformer distillation at both the pre-training and task-specific learning stages.
|
||||
|
||||
## Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
|
|
@ -87,8 +87,10 @@ def run_general_distill():
|
|||
|
||||
if args_opt.enable_data_sink == "true":
|
||||
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
time_monitor_steps = args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.epoch_size
|
||||
time_monitor_steps = dataset_size
|
||||
|
||||
lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
|
||||
|
@ -104,10 +106,10 @@ def run_general_distill():
|
|||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps)
|
||||
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
save_ckpt_dir)]
|
||||
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
save_ckpt_dir)]
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
|
||||
scale_factor=common_cfg.scale_factor,
|
||||
|
|
|
@ -92,8 +92,10 @@ def run_predistill():
|
|||
dataset_size = dataset.get_dataset_size()
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
time_monitor_steps = args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.td_phase1_epoch_size
|
||||
time_monitor_steps = dataset_size
|
||||
|
||||
optimizer_cfg = cfg.optimizer_cfg
|
||||
|
||||
|
@ -110,10 +112,10 @@ def run_predistill():
|
|||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase1_save_ckpt_dir)]
|
||||
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase1_save_ckpt_dir)]
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
|
@ -147,8 +149,10 @@ def run_task_distill(ckpt_file):
|
|||
dataset_size = train_dataset.get_dataset_size()
|
||||
if args_opt.enable_data_sink == 'true':
|
||||
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||
time_monitor_steps = args_opt.data_sink_steps
|
||||
else:
|
||||
repeat_count = args_opt.td_phase2_epoch_size
|
||||
time_monitor_steps = dataset_size
|
||||
|
||||
optimizer_cfg = cfg.optimizer_cfg
|
||||
|
||||
|
@ -170,14 +174,14 @@ def run_task_distill(ckpt_file):
|
|||
device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.eval_data_dir, args_opt.schema_dir)
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
||||
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
||||
ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
td_phase2_save_ckpt_dir),
|
||||
EvalCallBack(netwithloss.bert, eval_dataset)]
|
||||
else:
|
||||
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
||||
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
||||
ModelSaveCkpt(netwithloss.bert,
|
||||
args_opt.save_ckpt_step,
|
||||
args_opt.max_ckpt_num,
|
||||
|
|
Loading…
Reference in New Issue