!3662 add multiple node support for resnet50
Merge pull request !3662 from gengdongjie/master
This commit is contained in:
commit
3527f0c16f
|
@ -79,10 +79,13 @@ export RANK_SIZE=8
|
|||
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
|
|
|
@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num = int(os.getenv("DEVICE_NUM"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
|
@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num = int(os.getenv("DEVICE_NUM"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
|
@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
Returns:
|
||||
dataset
|
||||
"""
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
|
@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
||||
|
|
Loading…
Reference in New Issue