diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index 53da7291d8c..5d13bb5e7b1 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -62,8 +62,8 @@ def init(backend_name=None): """ if _is_role_pserver() or _is_role_sched(): return + device_target = context.get_context("device_target") if backend_name is None: - device_target = context.get_context("device_target") if device_target == "Ascend": backend_name = "hccl" elif device_target == "GPU": @@ -74,6 +74,8 @@ def init(backend_name=None): raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) if backend_name == "hccl": + if device_target != "Ascend": + raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target)) init_hccl() GlobalComm.BACKEND = Backend("hccl") GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP diff --git a/mindspore/context.py b/mindspore/context.py index 0d97bcb0edd..b6ceca1183f 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -380,7 +380,8 @@ def set_auto_parallel_context(**kwargs): full_batch (bool): Whether to load the whole batch on each device. Default: False. enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in data parallel training in the benefit of time and memory saving. - all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. + all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM + and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 7688adb41a5..7df6149e7b6 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -33,6 +33,7 @@ from mindspore.ops.operations.comm_ops import Broadcast tag = 0 +context.set_context(device_target="Ascend") init("hccl")