add checks for init

This commit is contained in:
lichenever 2020-09-14 11:12:24 +08:00
parent 07389e408a
commit aaf7045cdf
3 changed files with 6 additions and 2 deletions

View File

@ -62,8 +62,8 @@ def init(backend_name=None):
""" """
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return return
device_target = context.get_context("device_target")
if backend_name is None: if backend_name is None:
device_target = context.get_context("device_target")
if device_target == "Ascend": if device_target == "Ascend":
backend_name = "hccl" backend_name = "hccl"
elif device_target == "GPU": elif device_target == "GPU":
@ -74,6 +74,8 @@ def init(backend_name=None):
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
if backend_name == "hccl": if backend_name == "hccl":
if device_target != "Ascend":
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
init_hccl() init_hccl()
GlobalComm.BACKEND = Backend("hccl") GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP

View File

@ -380,7 +380,8 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving. data parallel training in the benefit of time and memory saving.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.

View File

@ -33,6 +33,7 @@ from mindspore.ops.operations.comm_ops import Broadcast
tag = 0 tag = 0
context.set_context(device_target="Ascend")
init("hccl") init("hccl")