!3791 modelzoo: repair vgg distribute training problem
Merge pull request !3791 from ms_yan/vgg_8p_D
This commit is contained in:
commit
b55e5e2ce2
|
@ -47,6 +47,6 @@ do
|
|||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i &> log &
|
||||
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log &
|
||||
cd ..
|
||||
done
|
|
@ -191,12 +191,13 @@ if __name__ == '__main__':
|
|||
if args.is_distributed:
|
||||
if args.device_target == "Ascend":
|
||||
init()
|
||||
context.set_context(device_id=args.device_id)
|
||||
elif args.device_target == "GPU":
|
||||
init("nccl")
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
device_num = args.group_size
|
||||
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
device_num = args.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
|
|
Loading…
Reference in New Issue