forked from mindspore-Ecosystem/mindspore
!10656 [MD] Fix a rank id runtime error
From: @xiefangqi Reviewed-by: @liucunwei,@pandoublefeng Signed-off-by: @liucunwei
This commit is contained in:
commit
2675d13804
|
@ -41,7 +41,7 @@ def _init_device_info():
|
|||
"""
|
||||
from mindspore import context
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.parallel._utils import _get_global_rank, _get_device_num
|
||||
from mindspore.parallel._utils import _get_global_rank
|
||||
if context.get_context("device_target") == "GPU":
|
||||
rank_id = _get_global_rank()
|
||||
parallel_mode = auto_parallel_context().get_parallel_mode()
|
||||
|
@ -53,11 +53,15 @@ def _init_device_info():
|
|||
rank_id = cuda_id
|
||||
_config.set_rank_id(rank_id)
|
||||
elif context.get_context("device_target") == "Ascend":
|
||||
rank_id = _get_global_rank()
|
||||
device_num = _get_device_num()
|
||||
# Ascend only support multi-process scenario
|
||||
if device_num > 1:
|
||||
_config.set_rank_id(rank_id)
|
||||
# Ascend is a special scenario, we'd better get rank info from env
|
||||
env_rank_size = os.getenv("RANK_SIZE", None)
|
||||
env_rank_id = os.getenv("RANK_ID", None)
|
||||
if env_rank_size and env_rank_id:
|
||||
# Ascend only support multi-process scenario
|
||||
rank_size = int(env_rank_size.strip())
|
||||
rank_id = int(env_rank_id.strip())
|
||||
if rank_size > 1:
|
||||
_config.set_rank_id(rank_id)
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
|
|
Loading…
Reference in New Issue