forked from mindspore-Ecosystem/mindspore
!6508 add split allreduce testcase for bert_thor ut
Merge pull request !6508 from wangshuangling/master
This commit is contained in:
commit
59bde6921f
|
@ -107,6 +107,32 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
|
||||||
logger.info("repeat count: {}".format(ds.get_repeat_count()))
|
logger.info("repeat count: {}".format(ds.get_repeat_count()))
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def _set_bert_all_reduce_split():
|
||||||
|
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
if bert_net_cfg.num_hidden_layers == 12:
|
||||||
|
if bert_net_cfg.use_relative_positions:
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||||
|
"hccl_world_groupsum1")
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||||
|
"hccl_world_groupsum3")
|
||||||
|
else:
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||||
|
"hccl_world_groupsum1")
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||||
|
"hccl_world_groupsum3")
|
||||||
|
elif bert_net_cfg.num_hidden_layers == 24:
|
||||||
|
if bert_net_cfg.use_relative_positions:
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||||
|
"hccl_world_groupsum1")
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||||
|
"hccl_world_groupsum3")
|
||||||
|
else:
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum1")
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum3")
|
||||||
|
|
||||||
|
|
||||||
def train_process_bert_thor(q, device_id, epoch_size, device_num):
|
def train_process_bert_thor(q, device_id, epoch_size, device_num):
|
||||||
os.system("mkdir " + str(device_id))
|
os.system("mkdir " + str(device_id))
|
||||||
os.chdir(str(device_id))
|
os.chdir(str(device_id))
|
||||||
|
@ -120,10 +146,11 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
|
||||||
D.init()
|
D.init()
|
||||||
rank = device_id % device_num
|
rank = device_id % device_num
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
|
_set_bert_all_reduce_split()
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||||
device_num=device_num)
|
device_num=device_num)
|
||||||
|
|
||||||
bert_net_cfg.num_hidden_layers = 2
|
bert_net_cfg.num_hidden_layers = 4
|
||||||
ds = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH, schema_dir=None)
|
ds = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH, schema_dir=None)
|
||||||
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||||
|
|
||||||
|
@ -200,8 +227,8 @@ def test_bert_thor_mlperf_8p():
|
||||||
os.system("rm -rf " + str(i))
|
os.system("rm -rf " + str(i))
|
||||||
|
|
||||||
print("End training...")
|
print("End training...")
|
||||||
assert mean_cost < 51
|
assert mean_cost < 64.2
|
||||||
assert mean_loss < 8.5
|
assert mean_loss < 7.9
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_bert_thor_mlperf_8p()
|
test_bert_thor_mlperf_8p()
|
||||||
|
|
Loading…
Reference in New Issue