forked from mindspore-Ecosystem/mindspore
Add Device Number Check
This commit is contained in:
parent
5ed7040d4c
commit
741e3604eb
|
@ -125,6 +125,9 @@ class ParallelContext {
|
|||
}
|
||||
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
||||
|
||||
void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
|
||||
bool hccl_test_available() const { return hccl_test_available_; }
|
||||
|
||||
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
|
||||
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
|
||||
void set_sharding_propagation(const bool);
|
||||
|
@ -178,6 +181,7 @@ class ParallelContext {
|
|||
bool enable_all2all_;
|
||||
std::vector<std::vector<int64_t>> dataset_strategy_;
|
||||
bool dataset_repeat_dim_right_ = false;
|
||||
bool hccl_test_available_ = false;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -2929,15 +2929,15 @@ CommInfo GetCommInfo() {
|
|||
device_num = UintToInt(world_rank_size);
|
||||
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
|
||||
}
|
||||
#if defined(ENABLE_GPU)
|
||||
if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
|
||||
if (world_rank_size != device_num) {
|
||||
MS_LOG(EXCEPTION) << "The device_num " << device_num
|
||||
<< " set in the context is not consist with the word group size " << world_rank_size;
|
||||
}
|
||||
#if ENABLE_D || ENABLE_GPU
|
||||
if (ParallelContext::GetInstance()->device_num_is_set() && world_rank_size != device_num &&
|
||||
!ParallelContext::GetInstance()->hccl_test_available()) {
|
||||
// hccl_test_available is used when we compile graphs in real ascend card environment, but with hccl_test.
|
||||
MS_LOG(EXCEPTION) << "The device_num " << device_num << " set in the context is not consist with "
|
||||
<< world_rank_size << " devices you have "
|
||||
<< ". Please check your rank_table file(for Ascend) or host file(for GPU).";
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t rank_id = 0;
|
||||
if (!ParallelContext::GetInstance()->global_rank_is_set()) {
|
||||
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
|
||||
|
|
|
@ -135,6 +135,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
|
||||
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
|
||||
.def("get_device_num", &ParallelContext::device_num, "Get device num.")
|
||||
.def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
|
||||
.def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
|
||||
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
||||
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
||||
|
|
|
@ -19,6 +19,7 @@ from mindspore import log as logger
|
|||
from ._hccl_management import load_lib as hccl_load_lib
|
||||
|
||||
_HCCL_AVAILABLE = False
|
||||
_HCCL_TEST_AVAILABLE = False
|
||||
_NCCL_AVAILABLE = False
|
||||
_MPI_AVAILABLE = False
|
||||
try:
|
||||
|
@ -45,6 +46,7 @@ else:
|
|||
try:
|
||||
import hccl_test.manage.api as hccl
|
||||
_HCCL_AVAILABLE = True
|
||||
_HCCL_TEST_AVAILABLE = True
|
||||
except ImportError:
|
||||
_HCCL_AVAILABLE = False
|
||||
|
||||
|
|
|
@ -72,6 +72,8 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
if device_num < 1 or device_num > 4096:
|
||||
raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
|
||||
from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE
|
||||
self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE)
|
||||
self._context_handle.set_device_num(device_num)
|
||||
|
||||
def get_device_num(self):
|
||||
|
|
Loading…
Reference in New Issue