optional init hccl for graph mode and pynative mode

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-08-09 10:59:25 +08:00
parent 7ab5ba3a97
commit 68ffbc7c55
5 changed files with 34 additions and 22 deletions

View File

@ -197,7 +197,8 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode)) {
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode) ||
mode == kPynativeMode) {
return workspace_size_list_;
}
workspace_size_list_.emplace_back(

View File

@ -873,7 +873,7 @@ bool AscendKernelRuntime::HcclInit() {
}
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str,
full_path);
full_path, mode == kGraphMode);
free(full_path);
if (!ret) {
MS_LOG(ERROR) << "Hcom init failed.";

View File

@ -137,43 +137,53 @@ bool HcclAdapter::InitHccl() {
return true;
}
bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
MS_LOG(INFO) << "Start init hccl adapter.";
bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file,
bool is_graph_mode) {
MS_LOG(INFO) << "Start init hccl adapter for " << (is_graph_mode ? "graph mode." : "pynative mode.");
std::lock_guard<std::mutex> lock(init_mutex_);
if (init_flag_) {
MS_LOG(INFO) << "Hccl has been inited, skip.";
return true;
}
is_graph_mode_ = is_graph_mode;
InitPlugin();
bool ret = InitKernelInfoStore(device_id, rank_id, rank_file);
if (!ret) {
return false;
}
ret = InitHcclComm(rank_id, rank_file);
if (!ret) {
return false;
}
ret = InitHcclExec();
if (!ret) {
return false;
if (is_graph_mode_) {
bool ret = InitKernelInfoStore(device_id, rank_id, rank_file);
if (!ret) {
return false;
}
ret = InitHcclExec();
if (!ret) {
return false;
}
} else {
bool ret = InitHcclComm(rank_id, rank_file);
if (!ret) {
return false;
}
}
init_flag_ = true;
MS_LOG(INFO) << "Init hccl adapter success.";
return true;
}
bool HcclAdapter::FinalizeHccl() {
MS_LOG(INFO) << "Start destroy hccl adapter.";
std::lock_guard<std::mutex> lock(init_mutex_);
MS_LOG(INFO) << "Start destroy hccl adapter for " << (is_graph_mode_ ? "graph mode." : "pynative mode.");
if (!init_flag_) {
MS_LOG(INFO) << "Hccl has never been inited, skip.";
return true;
}
(void)FinalizeHcclExec();
(void)FinalizeHcclComm();
(void)FinalizeKernelInfoStore();
if (is_graph_mode_) {
(void)FinalizeHcclExec();
(void)FinalizeKernelInfoStore();
} else {
(void)FinalizeHcclComm();
}
FinalizePlugin();
init_flag_ = false;
MS_LOG(INFO) << "Destroy hccl adapter success.";

View File

@ -42,7 +42,7 @@ class HcclAdapter {
static HcclAdapter &GetInstance();
// common
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file, bool is_graph_mode);
bool InitHccl();
bool FinalizeHccl();
@ -121,6 +121,7 @@ class HcclAdapter {
std::shared_ptr<::ge::OpsKernelBuilder> ops_kernel_builder_ = nullptr;
bool init_flag_ = false;
bool is_graph_mode_ = false;
std::mutex init_mutex_;
};
} // namespace mindspore::hccl

View File

@ -23,7 +23,7 @@ HcclAdapter &HcclAdapter::GetInstance() {
return instance;
}
bool HcclAdapter::InitHccl() { return true; }
bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view, bool) { return true; }
bool HcclAdapter::FinalizeHccl() { return true; }
HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t *) const { return HCCL_SUCCESS; }
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; }