!3662 add multiple node support for resnet50

Merge pull request !3662 from gengdongjie/master
This commit is contained in:
mindspore-ci-bot 2020-07-30 16:06:38 +08:00 committed by Gitee
commit 3527f0c16f
2 changed files with 23 additions and 7 deletions

View File

@ -79,10 +79,13 @@ export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
export RANK_TABLE_FILE=$PATH1
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i

View File

@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
dataset
"""
if target == "Ascend":
device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID"))
device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
dataset
"""
if target == "Ascend":
device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID"))
device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
Returns:
dataset
"""
device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
device_num, rank_id = _get_rank_info()
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
ds = ds.repeat(repeat_num)
return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id