From 741e3604eb058a030987df76751864ff5b2812ef Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Mon, 1 Nov 2021 11:56:24 +0800 Subject: [PATCH] Add Device Number Check --- mindspore/ccsrc/frontend/parallel/context.h | 4 ++++ mindspore/ccsrc/frontend/parallel/step_parallel.cc | 14 +++++++------- mindspore/ccsrc/pipeline/jit/init.cc | 1 + mindspore/communication/_comm_helper.py | 2 ++ mindspore/parallel/_auto_parallel_context.py | 2 ++ 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index d447c4081fd..01ce0abfc79 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -125,6 +125,9 @@ class ParallelContext { } bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } + void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; } + bool hccl_test_available() const { return hccl_test_available_; } + bool set_communi_parallel_mode(const std::string &communi_parallel_mode); std::string communi_parallel_mode() const { return communi_parallel_mode_; } void set_sharding_propagation(const bool); @@ -178,6 +181,7 @@ class ParallelContext { bool enable_all2all_; std::vector> dataset_strategy_; bool dataset_repeat_dim_right_ = false; + bool hccl_test_available_ = false; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 8a9d3da7046..fcdb96cde0e 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2929,15 +2929,15 @@ CommInfo GetCommInfo() { device_num = UintToInt(world_rank_size); MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num; } -#if defined(ENABLE_GPU) - if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) { - if (world_rank_size != device_num) { - MS_LOG(EXCEPTION) << "The device_num " << device_num - << " set in the context is not consist with the word group size " << world_rank_size; - } +#if ENABLE_D || ENABLE_GPU + if (ParallelContext::GetInstance()->device_num_is_set() && world_rank_size != device_num && + !ParallelContext::GetInstance()->hccl_test_available()) { + // hccl_test_available is used when we compile graphs in real ascend card environment, but with hccl_test. + MS_LOG(EXCEPTION) << "The device_num " << device_num << " set in the context is not consist with " + << world_rank_size << " devices you have " + << ". Please check your rank_table file(for Ascend) or host file(for GPU)."; } #endif - uint32_t rank_id = 0; if (!ParallelContext::GetInstance()->global_rank_is_set()) { if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 72d69649b7e..81bfafe5014 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -135,6 +135,7 @@ PYBIND11_MODULE(_c_expression, m) { (void)py::class_>(m, "AutoParallelContext") .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") .def("get_device_num", &ParallelContext::device_num, "Get device num.") + .def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.") .def("set_device_num", &ParallelContext::set_device_num, "Set device num.") .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.") .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 9e5d8513319..b9dc8b452c8 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -19,6 +19,7 @@ from mindspore import log as logger from ._hccl_management import load_lib as hccl_load_lib _HCCL_AVAILABLE = False +_HCCL_TEST_AVAILABLE = False _NCCL_AVAILABLE = False _MPI_AVAILABLE = False try: @@ -45,6 +46,7 @@ else: try: import hccl_test.manage.api as hccl _HCCL_AVAILABLE = True + _HCCL_TEST_AVAILABLE = True except ImportError: _HCCL_AVAILABLE = False diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index d15b39b3018..ebacec38adb 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -72,6 +72,8 @@ class _AutoParallelContext: self.check_context_handle() if device_num < 1 or device_num > 4096: raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num)) + from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE + self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE) self._context_handle.set_device_num(device_num) def get_device_num(self):