Add Device Number Check

This commit is contained in:
huangxinjing 2021-11-01 11:56:24 +08:00
parent 5ed7040d4c
commit 741e3604eb
5 changed files with 16 additions and 7 deletions

View File

@ -125,6 +125,9 @@ class ParallelContext {
}
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
bool hccl_test_available() const { return hccl_test_available_; }
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
void set_sharding_propagation(const bool);
@ -178,6 +181,7 @@ class ParallelContext {
bool enable_all2all_;
std::vector<std::vector<int64_t>> dataset_strategy_;
bool dataset_repeat_dim_right_ = false;
bool hccl_test_available_ = false;
};
} // namespace parallel

View File

@ -2929,15 +2929,15 @@ CommInfo GetCommInfo() {
device_num = UintToInt(world_rank_size);
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
}
#if defined(ENABLE_GPU)
if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
if (world_rank_size != device_num) {
MS_LOG(EXCEPTION) << "The device_num " << device_num
<< " set in the context is not consist with the word group size " << world_rank_size;
}
#if ENABLE_D || ENABLE_GPU
if (ParallelContext::GetInstance()->device_num_is_set() && world_rank_size != device_num &&
!ParallelContext::GetInstance()->hccl_test_available()) {
// hccl_test_available is used when we compile graphs in real ascend card environment, but with hccl_test.
MS_LOG(EXCEPTION) << "The device_num " << device_num << " set in the context is not consist with "
<< world_rank_size << " devices you have "
<< ". Please check your rank_table file(for Ascend) or host file(for GPU).";
}
#endif
uint32_t rank_id = 0;
if (!ParallelContext::GetInstance()->global_rank_is_set()) {
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {

View File

@ -135,6 +135,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
.def("get_device_num", &ParallelContext::device_num, "Get device num.")
.def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")

View File

@ -19,6 +19,7 @@ from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib
_HCCL_AVAILABLE = False
_HCCL_TEST_AVAILABLE = False
_NCCL_AVAILABLE = False
_MPI_AVAILABLE = False
try:
@ -45,6 +46,7 @@ else:
try:
import hccl_test.manage.api as hccl
_HCCL_AVAILABLE = True
_HCCL_TEST_AVAILABLE = True
except ImportError:
_HCCL_AVAILABLE = False

View File

@ -72,6 +72,8 @@ class _AutoParallelContext:
self.check_context_handle()
if device_num < 1 or device_num > 4096:
raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE
self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE)
self._context_handle.set_device_num(device_num)
def get_device_num(self):