forked from mindspore-Ecosystem/mindspore
support multi server muli process
This commit is contained in:
parent
bf699955b1
commit
d9ecfb1858
|
@ -33,10 +33,12 @@ MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1)
|
||||||
export MINDSPORE_HCCL_CONFIG_PATH
|
export MINDSPORE_HCCL_CONFIG_PATH
|
||||||
echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}"
|
echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}"
|
||||||
|
|
||||||
|
export SERVER_ID=0
|
||||||
|
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||||
for((i=0; i<${DEVICE_NUM}; i++))
|
for((i=0; i<${DEVICE_NUM}; i++))
|
||||||
do
|
do
|
||||||
export DEVICE_ID=$i
|
export DEVICE_ID=$i
|
||||||
export RANK_ID=$i
|
export RANK_ID=$((rank_start + i))
|
||||||
rm -rf ./train_parallel$i
|
rm -rf ./train_parallel$i
|
||||||
mkdir ./train_parallel$i
|
mkdir ./train_parallel$i
|
||||||
cp -r ./src ./train_parallel$i
|
cp -r ./src ./train_parallel$i
|
||||||
|
|
|
@ -31,8 +31,7 @@ def create_dataset(data_home, repeat_num=1, training=True):
|
||||||
if not training:
|
if not training:
|
||||||
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
||||||
|
|
||||||
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None
|
rank_size, rank_id = _get_rank_info()
|
||||||
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None
|
|
||||||
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
|
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
|
||||||
|
|
||||||
resize_height = cfg.image_height
|
resize_height = cfg.image_height
|
||||||
|
@ -65,3 +64,19 @@ def create_dataset(data_home, repeat_num=1, training=True):
|
||||||
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
|
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
|
||||||
|
|
||||||
return data_set
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rank_info():
|
||||||
|
"""
|
||||||
|
get rank size and rank id
|
||||||
|
"""
|
||||||
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
|
||||||
|
if rank_size > 1:
|
||||||
|
from mindspore.communication.management import get_rank, get_group_size
|
||||||
|
rank_size = get_group_size()
|
||||||
|
rank_id = get_rank()
|
||||||
|
else:
|
||||||
|
rank_size = rank_id = None
|
||||||
|
|
||||||
|
return rank_size, rank_id
|
||||||
|
|
Loading…
Reference in New Issue