forked from mindspore-Ecosystem/mindspore
Add Device Number Check
This commit is contained in:
parent
5ed7040d4c
commit
741e3604eb
|
@ -125,6 +125,9 @@ class ParallelContext {
|
||||||
}
|
}
|
||||||
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
||||||
|
|
||||||
|
void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
|
||||||
|
bool hccl_test_available() const { return hccl_test_available_; }
|
||||||
|
|
||||||
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
|
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
|
||||||
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
|
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
|
||||||
void set_sharding_propagation(const bool);
|
void set_sharding_propagation(const bool);
|
||||||
|
@ -178,6 +181,7 @@ class ParallelContext {
|
||||||
bool enable_all2all_;
|
bool enable_all2all_;
|
||||||
std::vector<std::vector<int64_t>> dataset_strategy_;
|
std::vector<std::vector<int64_t>> dataset_strategy_;
|
||||||
bool dataset_repeat_dim_right_ = false;
|
bool dataset_repeat_dim_right_ = false;
|
||||||
|
bool hccl_test_available_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -2929,15 +2929,15 @@ CommInfo GetCommInfo() {
|
||||||
device_num = UintToInt(world_rank_size);
|
device_num = UintToInt(world_rank_size);
|
||||||
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
|
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
|
||||||
}
|
}
|
||||||
#if defined(ENABLE_GPU)
|
#if ENABLE_D || ENABLE_GPU
|
||||||
if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
|
if (ParallelContext::GetInstance()->device_num_is_set() && world_rank_size != device_num &&
|
||||||
if (world_rank_size != device_num) {
|
!ParallelContext::GetInstance()->hccl_test_available()) {
|
||||||
MS_LOG(EXCEPTION) << "The device_num " << device_num
|
// hccl_test_available is used when we compile graphs in real ascend card environment, but with hccl_test.
|
||||||
<< " set in the context is not consist with the word group size " << world_rank_size;
|
MS_LOG(EXCEPTION) << "The device_num " << device_num << " set in the context is not consist with "
|
||||||
}
|
<< world_rank_size << " devices you have "
|
||||||
|
<< ". Please check your rank_table file(for Ascend) or host file(for GPU).";
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
uint32_t rank_id = 0;
|
uint32_t rank_id = 0;
|
||||||
if (!ParallelContext::GetInstance()->global_rank_is_set()) {
|
if (!ParallelContext::GetInstance()->global_rank_is_set()) {
|
||||||
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
|
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
|
||||||
|
|
|
@ -135,6 +135,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
|
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
|
||||||
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
|
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
|
||||||
.def("get_device_num", &ParallelContext::device_num, "Get device num.")
|
.def("get_device_num", &ParallelContext::device_num, "Get device num.")
|
||||||
|
.def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
|
||||||
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
|
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
|
||||||
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
||||||
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
||||||
|
|
|
@ -19,6 +19,7 @@ from mindspore import log as logger
|
||||||
from ._hccl_management import load_lib as hccl_load_lib
|
from ._hccl_management import load_lib as hccl_load_lib
|
||||||
|
|
||||||
_HCCL_AVAILABLE = False
|
_HCCL_AVAILABLE = False
|
||||||
|
_HCCL_TEST_AVAILABLE = False
|
||||||
_NCCL_AVAILABLE = False
|
_NCCL_AVAILABLE = False
|
||||||
_MPI_AVAILABLE = False
|
_MPI_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
|
@ -45,6 +46,7 @@ else:
|
||||||
try:
|
try:
|
||||||
import hccl_test.manage.api as hccl
|
import hccl_test.manage.api as hccl
|
||||||
_HCCL_AVAILABLE = True
|
_HCCL_AVAILABLE = True
|
||||||
|
_HCCL_TEST_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_HCCL_AVAILABLE = False
|
_HCCL_AVAILABLE = False
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,8 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
if device_num < 1 or device_num > 4096:
|
if device_num < 1 or device_num > 4096:
|
||||||
raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
|
raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
|
||||||
|
from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE
|
||||||
|
self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE)
|
||||||
self._context_handle.set_device_num(device_num)
|
self._context_handle.set_device_num(device_num)
|
||||||
|
|
||||||
def get_device_num(self):
|
def get_device_num(self):
|
||||||
|
|
Loading…
Reference in New Issue