diff --git a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py index 66b751782e8..fb64d3db2a1 100644 --- a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py @@ -101,12 +101,8 @@ def me_de_train_dataset(sink_mode=False): type_cast_op = C.TypeCast(mstype.int32) new_repeat_count = repeat_count if sink_mode: - repeat_count = 30 sink_size = 100 - ori_dataaet_size = ds.get_dataset_size() - new_size = sink_size * batch_size - ds.set_dataset_size(new_size) - new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size()) + new_repeat_count = 3 ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) @@ -264,7 +260,7 @@ def test_bert_performance(): config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) - lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count, + lr = BertLearningRate(decay_steps=sink_size * new_repeat_count, learning_rate=5e-5, end_learning_rate=1e-9, power=10.0, warmup_steps=0) decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() @@ -302,7 +298,7 @@ def test_bert_performance(): else: logger.info("***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) - time_monitor_callback = TimeMonitor(ds.get_dataset_size()) + time_monitor_callback = TimeMonitor(sink_size) model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback], dataset_sink_mode=True, sink_size=sink_size)