From 8012dbde544fca42f76f5f62d29ae0e235cfbd81 Mon Sep 17 00:00:00 2001 From: wangmin Date: Fri, 18 Sep 2020 20:33:36 +0800 Subject: [PATCH] add split allreduce testcase for bert_thor --- .../bert_performance/test_bert_thor_mlperf.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_thor_mlperf.py b/tests/st/networks/models/bert/bert_performance/test_bert_thor_mlperf.py index c5c3bbb2356..e10962aaf62 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_thor_mlperf.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_thor_mlperf.py @@ -107,6 +107,32 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, logger.info("repeat count: {}".format(ds.get_repeat_count())) return ds + +def _set_bert_all_reduce_split(): + """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24.""" + from mindspore.parallel._auto_parallel_context import auto_parallel_context + if bert_net_cfg.num_hidden_layers == 12: + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217], + "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217], + "hccl_world_groupsum3") + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205], + "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205], + "hccl_world_groupsum3") + elif bert_net_cfg.num_hidden_layers == 24: + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421], + "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421], + "hccl_world_groupsum3") + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum3") + + def train_process_bert_thor(q, device_id, epoch_size, device_num): os.system("mkdir " + str(device_id)) os.chdir(str(device_id)) @@ -120,10 +146,11 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num): D.init() rank = device_id % device_num context.reset_auto_parallel_context() + _set_bert_all_reduce_split() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) - bert_net_cfg.num_hidden_layers = 2 + bert_net_cfg.num_hidden_layers = 4 ds = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH, schema_dir=None) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) @@ -200,8 +227,8 @@ def test_bert_thor_mlperf_8p(): os.system("rm -rf " + str(i)) print("End training...") - assert mean_cost < 51 - assert mean_loss < 8.5 + assert mean_cost < 64.2 + assert mean_loss < 7.9 if __name__ == '__main__': test_bert_thor_mlperf_8p()