!3662 add multiple node support for resnet50

Merge pull request !3662 from gengdongjie/master
This commit is contained in:
mindspore-ci-bot 2020-07-30 16:06:38 +08:00 committed by Gitee
commit 3527f0c16f
2 changed files with 23 additions and 7 deletions

View File

@ -79,10 +79,13 @@ export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
export RANK_TABLE_FILE=$PATH1 export RANK_TABLE_FILE=$PATH1
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++)) for((i=0; i<${DEVICE_NUM}; i++))
do do
export DEVICE_ID=$i export DEVICE_ID=$i
export RANK_ID=$i export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i rm -rf ./train_parallel$i
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i cp ../*.py ./train_parallel$i

View File

@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
dataset dataset
""" """
if target == "Ascend": if target == "Ascend":
device_num = int(os.getenv("DEVICE_NUM")) device_num, rank_id = _get_rank_info()
rank_id = int(os.getenv("RANK_ID"))
else: else:
init("nccl") init("nccl")
rank_id = get_rank() rank_id = get_rank()
@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
dataset dataset
""" """
if target == "Ascend": if target == "Ascend":
device_num = int(os.getenv("DEVICE_NUM")) device_num, rank_id = _get_rank_info()
rank_id = int(os.getenv("RANK_ID"))
else: else:
init("nccl") init("nccl")
rank_id = get_rank() rank_id = get_rank()
@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
Returns: Returns:
dataset dataset
""" """
device_num = int(os.getenv("RANK_SIZE")) device_num, rank_id = _get_rank_info()
rank_id = int(os.getenv("RANK_ID"))
if device_num == 1: if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
ds = ds.repeat(repeat_num) ds = ds.repeat(repeat_num)
return ds return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id