fix error if hccl/nccl not inited

This commit is contained in:
huanghui 2021-09-03 10:51:42 +08:00
parent 032c8e28db
commit 8be552fc13
2 changed files with 14 additions and 0 deletions

View File

@ -47,6 +47,7 @@ class HcclAdapter {
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file, bool is_graph_mode);
bool InitHccl();
bool FinalizeHccl();
const bool Inited() const { return init_flag_; }
HcclResult HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const;
HcclResult HcclDestroyGroup(const std::string &group) const;

View File

@ -227,6 +227,19 @@ uint32_t GetRank() {
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
if (parallel_context->parallel_mode() != parallel::STAND_ALONE) {
#ifndef NO_DLIB
// Check HCCL inited.
if (!hccl::HcclAdapter::GetInstance().Inited()) {
MS_LOG(DEBUG) << "HCCL not inited, return rank_id = 0";
return rank_id;
}
#elif defined(ENABLE_GPU)
// Check NCCL inited.
if (!CollectiveInitializer::instance().collective_inited()) {
MS_LOG(DEBUG) << "NCLL not inited, return rank_id = 0";
return rank_id;
}
#endif
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
MS_LOG(EXCEPTION) << "Get rank id failed.";
}