fix the linear ratio of vgg16, deepfm and wide_deep

This commit is contained in:
root 2021-01-12 14:24:49 +08:00
parent a44d8386d8
commit bb5a354ea6
3 changed files with 5 additions and 3 deletions

View File

@ -140,7 +140,7 @@ if __name__ == '__main__':
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
gradients_mean=True, all_reduce_fusion_config=[3, 10, 12, 15])
else:
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)

View File

@ -57,7 +57,9 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
all_reduce_fusion_config=[9, 11])
init()
rank_id = int(os.environ.get('RANK_ID'))
elif args_opt.device_target == "GPU":

View File

@ -125,6 +125,6 @@ if __name__ == "__main__":
init()
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size())
device_num=get_group_size(), all_reduce_fusion_config=[6, 12])
train_and_eval(wide_deep_config)