diff --git a/model_zoo/official/cv/vgg16/train.py b/model_zoo/official/cv/vgg16/train.py index 57cc9da59cc..bae5d0adee9 100644 --- a/model_zoo/official/cv/vgg16/train.py +++ b/model_zoo/official/cv/vgg16/train.py @@ -140,7 +140,7 @@ if __name__ == '__main__': device_num = args.group_size context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True) + gradients_mean=True, all_reduce_fusion_config=[3, 10, 12, 15]) else: if args.device_target == "Ascend": context.set_context(device_id=args.device_id) diff --git a/model_zoo/official/recommend/deepfm/train.py b/model_zoo/official/recommend/deepfm/train.py index 6edd0f4c6ae..a55b1289de1 100644 --- a/model_zoo/official/recommend/deepfm/train.py +++ b/model_zoo/official/recommend/deepfm/train.py @@ -57,7 +57,9 @@ if __name__ == '__main__': device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, + all_reduce_fusion_config=[9, 11]) init() rank_id = int(os.environ.get('RANK_ID')) elif args_opt.device_target == "GPU": diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py index 1f8ad32f066..cb81e5e5782 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py @@ -125,6 +125,6 @@ if __name__ == "__main__": init() context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, - device_num=get_group_size()) + device_num=get_group_size(), all_reduce_fusion_config=[6, 12]) train_and_eval(wide_deep_config)