forked from mindspore-Ecosystem/mindspore
optional init hccl for graph mode and pynative mode
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
7ab5ba3a97
commit
68ffbc7c55
|
@ -197,7 +197,8 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode)) {
|
||||
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode) ||
|
||||
mode == kPynativeMode) {
|
||||
return workspace_size_list_;
|
||||
}
|
||||
workspace_size_list_.emplace_back(
|
||||
|
|
|
@ -873,7 +873,7 @@ bool AscendKernelRuntime::HcclInit() {
|
|||
}
|
||||
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
|
||||
bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str,
|
||||
full_path);
|
||||
full_path, mode == kGraphMode);
|
||||
free(full_path);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Hcom init failed.";
|
||||
|
|
|
@ -137,43 +137,53 @@ bool HcclAdapter::InitHccl() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
|
||||
MS_LOG(INFO) << "Start init hccl adapter.";
|
||||
bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file,
|
||||
bool is_graph_mode) {
|
||||
MS_LOG(INFO) << "Start init hccl adapter for " << (is_graph_mode ? "graph mode." : "pynative mode.");
|
||||
std::lock_guard<std::mutex> lock(init_mutex_);
|
||||
if (init_flag_) {
|
||||
MS_LOG(INFO) << "Hccl has been inited, skip.";
|
||||
return true;
|
||||
}
|
||||
|
||||
is_graph_mode_ = is_graph_mode;
|
||||
InitPlugin();
|
||||
bool ret = InitKernelInfoStore(device_id, rank_id, rank_file);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
ret = InitHcclComm(rank_id, rank_file);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
ret = InitHcclExec();
|
||||
if (!ret) {
|
||||
return false;
|
||||
if (is_graph_mode_) {
|
||||
bool ret = InitKernelInfoStore(device_id, rank_id, rank_file);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ret = InitHcclExec();
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
bool ret = InitHcclComm(rank_id, rank_file);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
init_flag_ = true;
|
||||
MS_LOG(INFO) << "Init hccl adapter success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HcclAdapter::FinalizeHccl() {
|
||||
MS_LOG(INFO) << "Start destroy hccl adapter.";
|
||||
std::lock_guard<std::mutex> lock(init_mutex_);
|
||||
MS_LOG(INFO) << "Start destroy hccl adapter for " << (is_graph_mode_ ? "graph mode." : "pynative mode.");
|
||||
if (!init_flag_) {
|
||||
MS_LOG(INFO) << "Hccl has never been inited, skip.";
|
||||
return true;
|
||||
}
|
||||
|
||||
(void)FinalizeHcclExec();
|
||||
(void)FinalizeHcclComm();
|
||||
(void)FinalizeKernelInfoStore();
|
||||
if (is_graph_mode_) {
|
||||
(void)FinalizeHcclExec();
|
||||
(void)FinalizeKernelInfoStore();
|
||||
} else {
|
||||
(void)FinalizeHcclComm();
|
||||
}
|
||||
|
||||
FinalizePlugin();
|
||||
init_flag_ = false;
|
||||
MS_LOG(INFO) << "Destroy hccl adapter success.";
|
||||
|
|
|
@ -42,7 +42,7 @@ class HcclAdapter {
|
|||
static HcclAdapter &GetInstance();
|
||||
|
||||
// common
|
||||
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file, bool is_graph_mode);
|
||||
bool InitHccl();
|
||||
bool FinalizeHccl();
|
||||
|
||||
|
@ -121,6 +121,7 @@ class HcclAdapter {
|
|||
std::shared_ptr<::ge::OpsKernelBuilder> ops_kernel_builder_ = nullptr;
|
||||
|
||||
bool init_flag_ = false;
|
||||
bool is_graph_mode_ = false;
|
||||
std::mutex init_mutex_;
|
||||
};
|
||||
} // namespace mindspore::hccl
|
||||
|
|
|
@ -23,7 +23,7 @@ HcclAdapter &HcclAdapter::GetInstance() {
|
|||
return instance;
|
||||
}
|
||||
bool HcclAdapter::InitHccl() { return true; }
|
||||
bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
|
||||
bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view, bool) { return true; }
|
||||
bool HcclAdapter::FinalizeHccl() { return true; }
|
||||
HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t *) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; }
|
||||
|
|
Loading…
Reference in New Issue