forked from OSSInnovation/mindspore
modify dataset.py and add autp parallel split
This commit is contained in:
parent
0f22140331
commit
298ff4adc1
|
@ -52,7 +52,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
ds = ds.repeat(new_repeat_count)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds, new_repeat_count
|
||||
|
|
|
@ -81,6 +81,11 @@ def run_pretrain():
|
|||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
|
||||
D.init()
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue