From 4cf23db838a9e44c0b5cab8badd82e61f3d11218 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Fri, 31 Jul 2020 12:27:43 +0800 Subject: [PATCH] repair distribute training for vgg in D --- .../official/cv/vgg16/scripts/run_distribute_train.sh | 2 +- model_zoo/official/cv/vgg16/train.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh b/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh index ff47070abf0..1a9e022fd2f 100755 --- a/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh @@ -47,6 +47,6 @@ do cd ./train_parallel$i || exit echo "start training for rank $RANK_ID, device $DEVICE_ID" env > env.log - python train.py --data_path=$2 --device_target="Ascend" --device_id=$i &> log & + python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log & cd .. done \ No newline at end of file diff --git a/model_zoo/official/cv/vgg16/train.py b/model_zoo/official/cv/vgg16/train.py index 3b1d85e890d..2ddf89e9776 100644 --- a/model_zoo/official/cv/vgg16/train.py +++ b/model_zoo/official/cv/vgg16/train.py @@ -191,12 +191,13 @@ if __name__ == '__main__': if args.is_distributed: if args.device_target == "Ascend": init() + context.set_context(device_id=args.device_id) elif args.device_target == "GPU": init("nccl") - args.rank = get_rank() - args.group_size = get_group_size() - device_num = args.group_size + args.rank = get_rank() + args.group_size = get_group_size() + device_num = args.group_size context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)