support distribution inference on ascend

This commit is contained in:
zhangxuetong 2023-02-07 20:03:14 +08:00
parent 87096fa4d7
commit c146585ce1
2 changed files with 38 additions and 0 deletions

View File

@ -102,6 +102,7 @@ void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_
}
SetDisableReuseMemoryFlag(ge_options);
SetHcclOptions(ge_options);
auto env_job_id = common::GetEnv("JOB_ID");
if (!env_job_id.empty()) {
@ -155,6 +156,42 @@ void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_
}
}
void GeDeviceContext::SetHcclOptions(std::map<std::string, std::string> *ge_options) {
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto env_rank_id = common::GetEnv("RANK_ID");
auto env_device_id = common::GetEnv("ASCEND_DEVICE_ID");
auto env_cluster_info = common::GetEnv("HELP_CLUSTER");
if (!(env_table_file.empty() || env_rank_id.empty()) || !(env_cluster_info.empty() || env_rank_id.empty())) {
MS_LOG(INFO) << "Initialize Ge for distribute parameter";
if (!env_table_file.empty()) {
MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
(*ge_options)["ge.exec.rankTableFile"] = env_table_file;
}
auto env_hccl_flag = common::GetEnv("HCCL_FLAG");
if (!env_hccl_flag.empty()) {
(*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag;
}
(*ge_options)["ge.exec.isUseHcom"] = "1";
(*ge_options)["ge.exec.deviceId"] = env_device_id;
(*ge_options)["ge.exec.rankId"] = env_rank_id;
(*ge_options)["ge.exec.podName"] = env_rank_id;
(*ge_options)["ge.graphRunMode"] = "1";
} else {
// device id is still needed for non-distribute case
(*ge_options)["ge.exec.deviceId"] = env_device_id;
MS_LOG(INFO) << "No hccl mode. "
<< "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV.";
}
auto env_deploy_mode = common::GetEnv("DEPLOY_MODE");
if (!env_deploy_mode.empty()) {
(*ge_options)["ge.exec.deployMode"] = env_deploy_mode;
} else {
(*ge_options)["ge.exec.deployMode"] = "0";
MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0";
}
}
bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context) {
MS_EXCEPTION_IF_NULL(inst_context);
if (inst_context->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {

View File

@ -39,6 +39,7 @@ class GeDeviceContext {
void InitGe(const std::shared_ptr<MsContext> &inst_context);
bool FinalizeGe(const std::shared_ptr<MsContext> &inst_context);
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
void SetHcclOptions(std::map<std::string, std::string> *ge_options);
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const;
};
} // namespace mindspore