!2061 set allreduce fusion split indices for NEZHA bert
Merge pull request !2061 from shibeiji/master
This commit is contained in:
commit
d251ba9e11
|
@ -84,9 +84,15 @@ def run_pretrain():
|
|||
device_num=device_num)
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217])
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421])
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
|
||||
D.init()
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue