!23045 Add Device Check For GPU

Merge pull request !23045 from huangxinjing/add_device_check
This commit is contained in:
i-robot 2021-09-09 01:27:51 +00:00 committed by Gitee
commit f4f587e619
1 changed files with 12 additions and 3 deletions

View File

@ -3030,13 +3030,22 @@ CommInfo GetCommInfo() {
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
}
uint32_t world_rank_size = 0;
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
MS_LOG(EXCEPTION) << "Get rank size failed";
}
if (!ParallelContext::GetInstance()->device_num_is_set()) {
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
MS_LOG(EXCEPTION) << "Get rank size failed";
}
device_num = UintToInt(world_rank_size);
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
}
#if defined(ENABLE_GPU)
if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
if (world_rank_size != device_num) {
MS_LOG(EXCEPTION) << "The device_num " << device_num
<< " set in the context is not consist with the word group size " << world_rank_size;
}
}
#endif
uint32_t rank_id = 0;
if (!ParallelContext::GetInstance()->global_rank_is_set()) {