add checks for init
This commit is contained in:
parent
07389e408a
commit
aaf7045cdf
|
@ -62,8 +62,8 @@ def init(backend_name=None):
|
||||||
"""
|
"""
|
||||||
if _is_role_pserver() or _is_role_sched():
|
if _is_role_pserver() or _is_role_sched():
|
||||||
return
|
return
|
||||||
|
device_target = context.get_context("device_target")
|
||||||
if backend_name is None:
|
if backend_name is None:
|
||||||
device_target = context.get_context("device_target")
|
|
||||||
if device_target == "Ascend":
|
if device_target == "Ascend":
|
||||||
backend_name = "hccl"
|
backend_name = "hccl"
|
||||||
elif device_target == "GPU":
|
elif device_target == "GPU":
|
||||||
|
@ -74,6 +74,8 @@ def init(backend_name=None):
|
||||||
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
|
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
|
||||||
|
|
||||||
if backend_name == "hccl":
|
if backend_name == "hccl":
|
||||||
|
if device_target != "Ascend":
|
||||||
|
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
|
||||||
init_hccl()
|
init_hccl()
|
||||||
GlobalComm.BACKEND = Backend("hccl")
|
GlobalComm.BACKEND = Backend("hccl")
|
||||||
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||||
|
|
|
@ -380,7 +380,8 @@ def set_auto_parallel_context(**kwargs):
|
||||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||||
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in
|
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in
|
||||||
data parallel training in the benefit of time and memory saving.
|
data parallel training in the benefit of time and memory saving.
|
||||||
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
|
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
||||||
|
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not attribute in auto parallel context.
|
ValueError: If input key is not attribute in auto parallel context.
|
||||||
|
|
|
@ -33,6 +33,7 @@ from mindspore.ops.operations.comm_ops import Broadcast
|
||||||
|
|
||||||
tag = 0
|
tag = 0
|
||||||
|
|
||||||
|
context.set_context(device_target="Ascend")
|
||||||
init("hccl")
|
init("hccl")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue