diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 46f77f4ed8b..52ec1cebc3d 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -84,6 +84,7 @@ class GlobalComm: BACKEND = DEFAULT_BACKEND WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP INITED = False + CHECK_ENVS = True def is_hccl_available(): """ diff --git a/mindspore/communication/_hccl_management.py b/mindspore/communication/_hccl_management.py index f9610cac6cb..4fe07110a9a 100644 --- a/mindspore/communication/_hccl_management.py +++ b/mindspore/communication/_hccl_management.py @@ -24,7 +24,7 @@ HCCL_LIB_CTYPES = "" def check_group(group): """ - A function that check if a collection communication group is leagal. + A function that check if a collection communication group is legal. Returns: None @@ -39,7 +39,7 @@ def check_group(group): def check_rank_num(rank_num): """ - A function that check if a collection communication rank number is leagal.If not raise error. + A function that check if a collection communication rank number is legal.If not raise error. Returns: None @@ -53,7 +53,7 @@ def check_rank_num(rank_num): def check_rank_id(rank_id): """ - A function that check if a collection communication rank id is leagal.If not raise error. + A function that check if a collection communication rank id is legal.If not raise error. Returns: None @@ -112,7 +112,7 @@ def create_group(group, rank_num, rank_ids): c_group = c_str(group) ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) if ret != 0: - raise RuntimeError('Create group error.') + raise RuntimeError('Create group error, the error code is ', ret) else: raise TypeError('Rank ids must be a python list.') diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index a1793b1f5fe..de16ce2e99c 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -36,6 +36,28 @@ def _get_group(group): return group +def _check_parallel_envs(): + """ + Check whether parallel environment variables have been exported or not. + + Raises: + RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values. + """ + if not GlobalComm.CHECK_ENVS: + return + import os + rank_id_str = os.getenv("RANK_ID") + if not rank_id_str: + raise RuntimeError("Environment variables RANK_ID has not been exported") + try: + int(rank_id_str) + except ValueError: + print("RANK_ID should be number") + rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH") + rank_table_file_str_old = os.getenv("RANK_TABLE_FILE") + if not rank_table_file_str and not rank_table_file_str_old: + raise RuntimeError("Get hccl rank_table_file failed, " + "please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.") def init(backend_name=None): """ @@ -68,6 +90,7 @@ def init(backend_name=None): if backend_name == "hccl": if device_target != "Ascend": raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target)) + _check_parallel_envs() init_hccl() GlobalComm.BACKEND = Backend("hccl") GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 7f22a1a58df..9285b06b41e 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -30,13 +30,16 @@ from mindspore.ops.operations.comm_ops import Broadcast, AllSwap from mindspore.ops.operations.array_ops import Gather import mindspore + # pylint: disable=W0212 # W0212: protected-access tag = 0 context.set_context(device_target="Ascend") +GlobalComm.CHECK_ENVS = False init("hccl") +GlobalComm.CHECK_ENVS = True class AllReduceNet(nn.Cell): diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py index fce38a32558..d5a3f1a20b5 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py @@ -31,11 +31,13 @@ from mindspore.parallel import set_algo_parameters from mindspore.parallel._utils import _reset_op_id as resset_op_id from mindspore.train.model import Model from mindspore.context import ParallelMode +from mindspore.communication._comm_helper import GlobalComm context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(device_id=0) +GlobalComm.CHECK_ENVS = False init() - +GlobalComm.CHECK_ENVS = True def weight_variable(): return TruncatedNormal(0.02) diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py b/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py index a959bcadf9d..7510b02c265 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py @@ -18,11 +18,15 @@ from mindspore.communication.management import init from mindspore.parallel import set_algo_parameters from mindspore.train.model import Model from mindspore.context import ParallelMode +from mindspore.communication._comm_helper import GlobalComm from .test_auto_parallel_resnet import resnet50 + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(device_id=0) +GlobalComm.CHECK_ENVS = False init() +GlobalComm.CHECK_ENVS = True def test_train_32k_8p(batch_size=32, num_classes=32768): dev_num = 8 diff --git a/tests/ut/python/parallel/test_broadcast_dict.py b/tests/ut/python/parallel/test_broadcast_dict.py index ff02d045ca9..6128ad38932 100644 --- a/tests/ut/python/parallel/test_broadcast_dict.py +++ b/tests/ut/python/parallel/test_broadcast_dict.py @@ -19,7 +19,7 @@ import mindspore.nn as nn from mindspore import Tensor, Parameter from mindspore.communication.management import init from mindspore.ops import operations as P - +from mindspore.communication._comm_helper import GlobalComm class DataParallelNet(nn.Cell): def __init__(self): @@ -49,7 +49,9 @@ def test_param_broadcast(): context.set_context(mode=context.GRAPH_MODE) context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=True) + GlobalComm.CHECK_ENVS = False init() + GlobalComm.CHECK_ENVS = True network = DataParallelNet() network.set_train() @@ -62,7 +64,9 @@ def test_param_not_broadcast(): context.set_context(mode=context.GRAPH_MODE) context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=False) + GlobalComm.CHECK_ENVS = False init() + GlobalComm.CHECK_ENVS = True network = ModelParallelNet() network.set_train() diff --git a/tests/ut/python/parallel/test_gather_v2_primitive.py b/tests/ut/python/parallel/test_gather_v2_primitive.py index e7ccfe1f18e..10f5dd2fdb4 100644 --- a/tests/ut/python/parallel/test_gather_v2_primitive.py +++ b/tests/ut/python/parallel/test_gather_v2_primitive.py @@ -29,6 +29,7 @@ from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.train import Model from mindspore.context import ParallelMode +from mindspore.communication._comm_helper import GlobalComm context.set_context(mode=context.GRAPH_MODE) device_number = 32 @@ -124,7 +125,9 @@ class TrainOneStepCell(Cell): def net_trains(criterion, rank): + GlobalComm.CHECK_ENVS = False init() + GlobalComm.CHECK_ENVS = True lr = 0.1 momentum = 0.9 max_epoch = 20 diff --git a/tests/ut/python/parallel/test_optimizer.py b/tests/ut/python/parallel/test_optimizer.py index 112069f5f19..55a22ddf608 100644 --- a/tests/ut/python/parallel/test_optimizer.py +++ b/tests/ut/python/parallel/test_optimizer.py @@ -24,7 +24,7 @@ from mindspore.nn import Momentum from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops import operations as P from mindspore.context import ParallelMode - +from mindspore.communication._comm_helper import GlobalComm class Net(nn.Cell): def __init__(self, input_channel, out_channel): @@ -47,7 +47,9 @@ def test_dense_gen_graph(): context.set_context(mode=context.GRAPH_MODE) context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8) + GlobalComm.CHECK_ENVS = False init() + GlobalComm.CHECK_ENVS = True network = Net(512, 128) loss_fn = nn.SoftmaxCrossEntropyWithLogits() diff --git a/tests/ut/python/train/test_amp.py b/tests/ut/python/train/test_amp.py index 03102025046..34601ab39c5 100644 --- a/tests/ut/python/train/test_amp.py +++ b/tests/ut/python/train/test_amp.py @@ -21,6 +21,7 @@ from mindspore import Tensor from mindspore import amp from mindspore import nn from mindspore.communication.management import init +from mindspore.communication._comm_helper import GlobalComm from mindspore.context import ParallelMode from mindspore.train import Model from ....dataset_mock import MindData @@ -156,8 +157,8 @@ def test_compile_model_train_O2_parallel(): net = NetNoLoss(16, 16) loss = nn.MSELoss() optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0) - + GlobalComm.CHECK_ENVS = False init() - + GlobalComm.CHECK_ENVS = True model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") model.train(2, dataset, dataset_sink_mode=False) diff --git a/tests/ut/python/train/test_dataset_helper.py b/tests/ut/python/train/test_dataset_helper.py index 916d87caad8..2cad5d6bb57 100644 --- a/tests/ut/python/train/test_dataset_helper.py +++ b/tests/ut/python/train/test_dataset_helper.py @@ -19,9 +19,9 @@ import numpy as np import mindspore.context as context from mindspore.communication.management import init from mindspore.train.dataset_helper import DatasetHelper +from mindspore.communication._comm_helper import GlobalComm from ....dataset_mock import MindData - def get_dataset(batch_size=1): dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), @@ -75,7 +75,9 @@ def test_dataset_iter_normal(): @pytest.mark.skipif('not context.get_context("enable_ge")') def test_dataset_iter_ge(): + GlobalComm.CHECK_ENVS = False init("hccl") + GlobalComm.CHECK_ENVS = True dataset = get_dataset(32) dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) count = 0 @@ -87,7 +89,9 @@ def test_dataset_iter_ge(): @pytest.mark.skipif('context.get_context("enable_ge")') def test_dataset_iter_ms_loop_sink(): + GlobalComm.CHECK_ENVS = False init("hccl") + GlobalComm.CHECK_ENVS = True context.set_context(enable_loop_sink=True) dataset = get_dataset(32) dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) @@ -101,7 +105,9 @@ def test_dataset_iter_ms_loop_sink(): @pytest.mark.skipif('context.get_context("enable_ge")') def test_dataset_iter_ms(): + GlobalComm.CHECK_ENVS = False init("hccl") + GlobalComm.CHECK_ENVS = True context.set_context(enable_loop_sink=False) dataset = get_dataset(32) DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)