forked from mindspore-Ecosystem/mindspore
fix the linear ratio of vgg16, deepfm and wide_deep
This commit is contained in:
parent
a44d8386d8
commit
bb5a354ea6
|
@ -140,7 +140,7 @@ if __name__ == '__main__':
|
|||
device_num = args.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
gradients_mean=True, all_reduce_fusion_config=[3, 10, 12, 15])
|
||||
else:
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
|
|
@ -57,7 +57,9 @@ if __name__ == '__main__':
|
|||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
all_reduce_fusion_config=[9, 11])
|
||||
init()
|
||||
rank_id = int(os.environ.get('RANK_ID'))
|
||||
elif args_opt.device_target == "GPU":
|
||||
|
|
|
@ -125,6 +125,6 @@ if __name__ == "__main__":
|
|||
init()
|
||||
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=get_group_size())
|
||||
device_num=get_group_size(), all_reduce_fusion_config=[6, 12])
|
||||
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
Loading…
Reference in New Issue