Add Device Number Check

This commit is contained in:
huangxinjing 2021-11-01 11:56:24 +08:00
parent 5ed7040d4c
commit 741e3604eb
5 changed files with 16 additions and 7 deletions

View File

@ -125,6 +125,9 @@ class ParallelContext {
} }
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
bool hccl_test_available() const { return hccl_test_available_; }
bool set_communi_parallel_mode(const std::string &communi_parallel_mode); bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
std::string communi_parallel_mode() const { return communi_parallel_mode_; } std::string communi_parallel_mode() const { return communi_parallel_mode_; }
void set_sharding_propagation(const bool); void set_sharding_propagation(const bool);
@ -178,6 +181,7 @@ class ParallelContext {
bool enable_all2all_; bool enable_all2all_;
std::vector<std::vector<int64_t>> dataset_strategy_; std::vector<std::vector<int64_t>> dataset_strategy_;
bool dataset_repeat_dim_right_ = false; bool dataset_repeat_dim_right_ = false;
bool hccl_test_available_ = false;
}; };
} // namespace parallel } // namespace parallel

View File

@ -2929,15 +2929,15 @@ CommInfo GetCommInfo() {
device_num = UintToInt(world_rank_size); device_num = UintToInt(world_rank_size);
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num; MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
} }
#if defined(ENABLE_GPU) #if ENABLE_D || ENABLE_GPU
if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) { if (ParallelContext::GetInstance()->device_num_is_set() && world_rank_size != device_num &&
if (world_rank_size != device_num) { !ParallelContext::GetInstance()->hccl_test_available()) {
MS_LOG(EXCEPTION) << "The device_num " << device_num // hccl_test_available is used when we compile graphs in real ascend card environment, but with hccl_test.
<< " set in the context is not consist with the word group size " << world_rank_size; MS_LOG(EXCEPTION) << "The device_num " << device_num << " set in the context is not consist with "
} << world_rank_size << " devices you have "
<< ". Please check your rank_table file(for Ascend) or host file(for GPU).";
} }
#endif #endif
uint32_t rank_id = 0; uint32_t rank_id = 0;
if (!ParallelContext::GetInstance()->global_rank_is_set()) { if (!ParallelContext::GetInstance()->global_rank_is_set()) {
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {

View File

@ -135,6 +135,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext") (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
.def("get_device_num", &ParallelContext::device_num, "Get device num.") .def("get_device_num", &ParallelContext::device_num, "Get device num.")
.def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.") .def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.") .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")

View File

@ -19,6 +19,7 @@ from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib from ._hccl_management import load_lib as hccl_load_lib
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False
_HCCL_TEST_AVAILABLE = False
_NCCL_AVAILABLE = False _NCCL_AVAILABLE = False
_MPI_AVAILABLE = False _MPI_AVAILABLE = False
try: try:
@ -45,6 +46,7 @@ else:
try: try:
import hccl_test.manage.api as hccl import hccl_test.manage.api as hccl
_HCCL_AVAILABLE = True _HCCL_AVAILABLE = True
_HCCL_TEST_AVAILABLE = True
except ImportError: except ImportError:
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False

View File

@ -72,6 +72,8 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
if device_num < 1 or device_num > 4096: if device_num < 1 or device_num > 4096:
raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num)) raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE
self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE)
self._context_handle.set_device_num(device_num) self._context_handle.set_device_num(device_num)
def get_device_num(self): def get_device_num(self):