forked from mindspore-Ecosystem/mindspore
!3729 modify readme and timemoniter steps
Merge pull request !3729 from wanghua/c75_tinybert
This commit is contained in:
commit
6875095b3a
|
@ -1,6 +1,6 @@
|
||||||
# TinyBERT Example
|
# TinyBERT Example
|
||||||
## Description
|
## Description
|
||||||
This example implements general distill and task distill of [BERT-base](https://github.com/google-research/bert)(the base version of BERT model).
|
[TinyBERT](https://github.com/huawei-noah/Pretrained-Model/tree/master/TinyBERT) is 7.5x smalller and 9.4x faster on inference than [BERT-base](https://github.com/google-research/bert)(the base version of BERT model) and achieves competitive performances in the tasks of natural language understanding. It performs a novel transformer distillation at both the pre-training and task-specific learning stages.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||||
|
|
|
@ -87,8 +87,10 @@ def run_general_distill():
|
||||||
|
|
||||||
if args_opt.enable_data_sink == "true":
|
if args_opt.enable_data_sink == "true":
|
||||||
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||||
|
time_monitor_steps = args_opt.data_sink_steps
|
||||||
else:
|
else:
|
||||||
repeat_count = args_opt.epoch_size
|
repeat_count = args_opt.epoch_size
|
||||||
|
time_monitor_steps = dataset_size
|
||||||
|
|
||||||
lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate,
|
lr_schedule = BertLearningRate(learning_rate=common_cfg.AdamWeightDecay.learning_rate,
|
||||||
end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
|
end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
|
||||||
|
@ -104,10 +106,10 @@ def run_general_distill():
|
||||||
|
|
||||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps)
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps)
|
||||||
|
|
||||||
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||||
args_opt.save_ckpt_step,
|
args_opt.save_ckpt_step,
|
||||||
args_opt.max_ckpt_num,
|
args_opt.max_ckpt_num,
|
||||||
save_ckpt_dir)]
|
save_ckpt_dir)]
|
||||||
|
|
||||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
|
||||||
scale_factor=common_cfg.scale_factor,
|
scale_factor=common_cfg.scale_factor,
|
||||||
|
|
|
@ -92,8 +92,10 @@ def run_predistill():
|
||||||
dataset_size = dataset.get_dataset_size()
|
dataset_size = dataset.get_dataset_size()
|
||||||
if args_opt.enable_data_sink == 'true':
|
if args_opt.enable_data_sink == 'true':
|
||||||
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||||
|
time_monitor_steps = args_opt.data_sink_steps
|
||||||
else:
|
else:
|
||||||
repeat_count = args_opt.td_phase1_epoch_size
|
repeat_count = args_opt.td_phase1_epoch_size
|
||||||
|
time_monitor_steps = dataset_size
|
||||||
|
|
||||||
optimizer_cfg = cfg.optimizer_cfg
|
optimizer_cfg = cfg.optimizer_cfg
|
||||||
|
|
||||||
|
@ -110,10 +112,10 @@ def run_predistill():
|
||||||
{'order_params': params}]
|
{'order_params': params}]
|
||||||
|
|
||||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||||
callback = [TimeMonitor(dataset_size), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
|
||||||
args_opt.save_ckpt_step,
|
args_opt.save_ckpt_step,
|
||||||
args_opt.max_ckpt_num,
|
args_opt.max_ckpt_num,
|
||||||
td_phase1_save_ckpt_dir)]
|
td_phase1_save_ckpt_dir)]
|
||||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||||
scale_factor=cfg.scale_factor,
|
scale_factor=cfg.scale_factor,
|
||||||
scale_window=cfg.scale_window)
|
scale_window=cfg.scale_window)
|
||||||
|
@ -147,8 +149,10 @@ def run_task_distill(ckpt_file):
|
||||||
dataset_size = train_dataset.get_dataset_size()
|
dataset_size = train_dataset.get_dataset_size()
|
||||||
if args_opt.enable_data_sink == 'true':
|
if args_opt.enable_data_sink == 'true':
|
||||||
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
||||||
|
time_monitor_steps = args_opt.data_sink_steps
|
||||||
else:
|
else:
|
||||||
repeat_count = args_opt.td_phase2_epoch_size
|
repeat_count = args_opt.td_phase2_epoch_size
|
||||||
|
time_monitor_steps = dataset_size
|
||||||
|
|
||||||
optimizer_cfg = cfg.optimizer_cfg
|
optimizer_cfg = cfg.optimizer_cfg
|
||||||
|
|
||||||
|
@ -170,14 +174,14 @@ def run_task_distill(ckpt_file):
|
||||||
device_num, rank, args_opt.do_shuffle,
|
device_num, rank, args_opt.do_shuffle,
|
||||||
args_opt.eval_data_dir, args_opt.schema_dir)
|
args_opt.eval_data_dir, args_opt.schema_dir)
|
||||||
if args_opt.do_eval.lower() == "true":
|
if args_opt.do_eval.lower() == "true":
|
||||||
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
||||||
ModelSaveCkpt(netwithloss.bert,
|
ModelSaveCkpt(netwithloss.bert,
|
||||||
args_opt.save_ckpt_step,
|
args_opt.save_ckpt_step,
|
||||||
args_opt.max_ckpt_num,
|
args_opt.max_ckpt_num,
|
||||||
td_phase2_save_ckpt_dir),
|
td_phase2_save_ckpt_dir),
|
||||||
EvalCallBack(netwithloss.bert, eval_dataset)]
|
EvalCallBack(netwithloss.bert, eval_dataset)]
|
||||||
else:
|
else:
|
||||||
callback = [TimeMonitor(dataset_size), LossCallBack(),
|
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
||||||
ModelSaveCkpt(netwithloss.bert,
|
ModelSaveCkpt(netwithloss.bert,
|
||||||
args_opt.save_ckpt_step,
|
args_opt.save_ckpt_step,
|
||||||
args_opt.max_ckpt_num,
|
args_opt.max_ckpt_num,
|
||||||
|
|
Loading…
Reference in New Issue