forked from mindspore-Ecosystem/mindspore
!3486 Fix the performance test case of bert
Merge pull request !3486 from chenhaozhe/fix_bert_performance_test_case
This commit is contained in:
commit
7be664fa85
|
@ -101,12 +101,8 @@ def me_de_train_dataset(sink_mode=False):
|
||||||
type_cast_op = C.TypeCast(mstype.int32)
|
type_cast_op = C.TypeCast(mstype.int32)
|
||||||
new_repeat_count = repeat_count
|
new_repeat_count = repeat_count
|
||||||
if sink_mode:
|
if sink_mode:
|
||||||
repeat_count = 30
|
|
||||||
sink_size = 100
|
sink_size = 100
|
||||||
ori_dataaet_size = ds.get_dataset_size()
|
new_repeat_count = 3
|
||||||
new_size = sink_size * batch_size
|
|
||||||
ds.set_dataset_size(new_size)
|
|
||||||
new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size())
|
|
||||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||||
|
@ -264,7 +260,7 @@ def test_bert_performance():
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version, batch_size=batch_size)
|
||||||
netwithloss = BertNetworkWithLoss(config, True)
|
netwithloss = BertNetworkWithLoss(config, True)
|
||||||
|
|
||||||
lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count,
|
lr = BertLearningRate(decay_steps=sink_size * new_repeat_count,
|
||||||
learning_rate=5e-5, end_learning_rate=1e-9,
|
learning_rate=5e-5, end_learning_rate=1e-9,
|
||||||
power=10.0, warmup_steps=0)
|
power=10.0, warmup_steps=0)
|
||||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
|
decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
|
||||||
|
@ -302,7 +298,7 @@ def test_bert_performance():
|
||||||
else:
|
else:
|
||||||
logger.info("***************** BERT param name is 3 {}".format(name))
|
logger.info("***************** BERT param name is 3 {}".format(name))
|
||||||
param.default_input = weight_variable(value.asnumpy().shape)
|
param.default_input = weight_variable(value.asnumpy().shape)
|
||||||
time_monitor_callback = TimeMonitor(ds.get_dataset_size())
|
time_monitor_callback = TimeMonitor(sink_size)
|
||||||
model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback],
|
model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback],
|
||||||
dataset_sink_mode=True, sink_size=sink_size)
|
dataset_sink_mode=True, sink_size=sink_size)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue