!23045 Add Device Check For GPU
Merge pull request !23045 from huangxinjing/add_device_check
This commit is contained in:
commit
f4f587e619
|
@ -3030,13 +3030,22 @@ CommInfo GetCommInfo() {
|
||||||
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
|
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
|
||||||
}
|
}
|
||||||
uint32_t world_rank_size = 0;
|
uint32_t world_rank_size = 0;
|
||||||
|
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Get rank size failed";
|
||||||
|
}
|
||||||
|
|
||||||
if (!ParallelContext::GetInstance()->device_num_is_set()) {
|
if (!ParallelContext::GetInstance()->device_num_is_set()) {
|
||||||
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
|
|
||||||
MS_LOG(EXCEPTION) << "Get rank size failed";
|
|
||||||
}
|
|
||||||
device_num = UintToInt(world_rank_size);
|
device_num = UintToInt(world_rank_size);
|
||||||
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
|
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
|
||||||
}
|
}
|
||||||
|
#if defined(ENABLE_GPU)
|
||||||
|
if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
|
||||||
|
if (world_rank_size != device_num) {
|
||||||
|
MS_LOG(EXCEPTION) << "The device_num " << device_num
|
||||||
|
<< " set in the context is not consist with the word group size " << world_rank_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
uint32_t rank_id = 0;
|
uint32_t rank_id = 0;
|
||||||
if (!ParallelContext::GetInstance()->global_rank_is_set()) {
|
if (!ParallelContext::GetInstance()->global_rank_is_set()) {
|
||||||
|
|
Loading…
Reference in New Issue