!23045 Add Device Check For GPU
Merge pull request !23045 from huangxinjing/add_device_check
This commit is contained in:
commit
f4f587e619
|
@ -3030,13 +3030,22 @@ CommInfo GetCommInfo() {
|
|||
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
|
||||
}
|
||||
uint32_t world_rank_size = 0;
|
||||
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
|
||||
MS_LOG(EXCEPTION) << "Get rank size failed";
|
||||
}
|
||||
|
||||
if (!ParallelContext::GetInstance()->device_num_is_set()) {
|
||||
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
|
||||
MS_LOG(EXCEPTION) << "Get rank size failed";
|
||||
}
|
||||
device_num = UintToInt(world_rank_size);
|
||||
MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
|
||||
}
|
||||
#if defined(ENABLE_GPU)
|
||||
if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
|
||||
if (world_rank_size != device_num) {
|
||||
MS_LOG(EXCEPTION) << "The device_num " << device_num
|
||||
<< " set in the context is not consist with the word group size " << world_rank_size;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t rank_id = 0;
|
||||
if (!ParallelContext::GetInstance()->global_rank_is_set()) {
|
||||
|
|
Loading…
Reference in New Issue