From a76d58b52d47d303ee12b2c4b68e7337fe474b7e Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Mon, 10 May 2021 21:48:41 +0800 Subject: [PATCH] hccl decouple Signed-off-by: zhoufeng --- cmake/package.cmake | 7 + mindspore/ccsrc/CMakeLists.txt | 17 +- .../kernel_compiler/hccl/hccl_context.cc | 80 ---- .../kernel_compiler/hccl/hccl_context.h | 47 --- .../kernel_compiler/hccl/hccl_kernel.cc | 14 +- .../hccl/hcom_all_broadcast.cc | 7 +- .../kernel_compiler/hccl/hcom_all_reduce.cc | 8 +- mindspore/ccsrc/cxx_api/CMakeLists.txt | 5 - .../runtime/device/ascend/ascend_bucket.cc | 6 +- .../device/ascend/ascend_kernel_runtime.cc | 30 +- .../device/ascend/ascend_kernel_runtime.h | 1 - .../ascend/executor/hccl_dynamic_kernel.cc | 84 +---- .../ascend/executor/hccl_dynamic_kernel.h | 21 -- .../ccsrc/runtime/hccl_adapter/CMakeLists.txt | 3 +- .../runtime/hccl_adapter/hccl_adapter.cc | 350 ++++++++++++++---- .../ccsrc/runtime/hccl_adapter/hccl_adapter.h | 84 ++++- .../runtime/hccl_adapter/hcom_graph_adaptor.h | 32 -- .../hccl_adapter/plugin/CMakeLists.txt | 41 ++ .../hccl_adapter/plugin/hccl_plugin.cc | 90 +++++ .../runtime/hccl_adapter/plugin/hccl_plugin.h | 63 ++++ mindspore/ccsrc/utils/comm_manager.cc | 11 +- .../stub/dynamic_shape/dynamic_shape_stub.cc | 3 - tests/ut/cpp/stub/ge/ge_task_launch_stub.cc | 27 +- 23 files changed, 620 insertions(+), 411 deletions(-) delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.cc delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.h delete mode 100644 mindspore/ccsrc/runtime/hccl_adapter/hcom_graph_adaptor.h create mode 100644 mindspore/ccsrc/runtime/hccl_adapter/plugin/CMakeLists.txt create mode 100644 mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.cc create mode 100644 mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.h diff --git a/cmake/package.cmake b/cmake/package.cmake index eb9291a5de2..99f6f7bd197 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -248,6 +248,13 @@ if(NOT ENABLE_GE) DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) + + install( + TARGETS hccl_plugin + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) + install( FILES ${CMAKE_BINARY_DIR}/graphengine/metadef/graph/libgraph.so diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 45c83797998..251af24452d 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -293,8 +293,6 @@ if(ENABLE_D) endif() MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}") - find_library(HCCL hccl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(CCE_LIB cce ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(RUNTIME_LIB runtime ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH}) @@ -302,14 +300,7 @@ if(ENABLE_D) ${ASCEND_DRIVER_BACK_PATH}) find_library(PROFILING msprofiler_fwkacl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(ACL ascendcl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(REGISTER register ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH}) - # hccl_adpter - find_library(HCCL_ADPTER hcom_graph_adaptor ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(HCCL_RA ra ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(HCCL_BUILDER hcom_opskernel_builder ${ASCEND_RUNTIME_PATH}/plugin/opskernel - ${ASCEND_TOOLKIT_RUNTIME_PATH}/plugin/opskernel) add_library(ms_profile SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/device/ascend/profiling/profiling_callback_register.cc) @@ -317,9 +308,7 @@ if(ENABLE_D) target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init) target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive mindspore::protobuf -Wl,--end-group) - target_link_libraries(mindspore ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} - ${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER} - ${HCCL_RA} ${PLATFORM} ${ACL}) + target_link_libraries(mindspore ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} -Wl,--no-as-needed ${OPTILING} ${ACL}) target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group) elseif(CMAKE_SYSTEM_NAME MATCHES "Windows") target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece @@ -353,10 +342,6 @@ if(ENABLE_D) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64/plugin/opskernel) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.cc deleted file mode 100644 index aa4070db0ed..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/hccl/hccl_context.h" -#include "utils/log_adapter.h" -#include "hccl/hccl.h" - -constexpr auto kHcclConfigFile = "MINDSPORE_HCCL_CONFIG_PATH"; -constexpr auto kHcclConfigFileOld = "RANK_TABLE_FILE"; - -namespace mindspore { -namespace kernel { -int GetRankId() { - auto rank_id_env = std::getenv("RANK_ID"); - if (rank_id_env == nullptr) { - MS_LOG(EXCEPTION) << "No RANK_ID, please export RANK_ID"; - } - try { - return std::stoi(rank_id_env); - } catch (std::invalid_argument &e) { - MS_LOG(EXCEPTION) << "Invalid rankd id env:" << rank_id_env; - } -} - -bool HcclContext::InitHccl() { - if (hccl_comm_ != nullptr) { - return true; - } - auto config_file = std::getenv(kHcclConfigFile); - if (config_file == nullptr) { - config_file = std::getenv(kHcclConfigFileOld); - if (config_file == nullptr) { - MS_LOG(ERROR) << "Get hccl rank table file failed. Please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"; - return false; - } - } - - rank_id_ = GetRankId(); - if (rank_id_ < 0 || rank_id_ > 7) { - MS_LOG(ERROR) << "rank_id needs to be between 0-7"; - return false; - } - - auto hccl_result = HcclCommInitClusterInfo(config_file, rank_id_, &hccl_comm_); - if (hccl_result != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcclCommInitClusterInfo failed, ret:" << hccl_result; - return false; - } - MS_LOG(INFO) << "HcclCommInitClusterInfo success"; - return true; -} - -bool HcclContext::Finalize() { - if (hccl_comm_ == nullptr) { - return true; - } - auto hccl_result = HcclCommDestroy(hccl_comm_); - if (hccl_result != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcclComm destroy failed, ret:" << hccl_result; - return false; - } - MS_LOG(INFO) << "HcclComm destroy success"; - hccl_comm_ = nullptr; - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.h deleted file mode 100644 index 9ca54e7e5ab..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_ -#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_ - -#include -#include "hccl/hccl_types.h" -#include "utils/ms_utils.h" - -namespace mindspore { -namespace kernel { -class HcclContext { - public: - static HcclContext &GetInstance() { - static HcclContext instance; - return instance; - } - - bool InitHccl(); - bool Finalize(); - HcclComm hccl_comm() { return hccl_comm_; } - - private: - HcclContext() = default; - ~HcclContext() = default; - DISABLE_COPY_AND_ASSIGN(HcclContext); - HcclComm hccl_comm_{nullptr}; - int rank_id_{0}; - uint32_t device_id_{0}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 89abbeb043f..d40dbcb400e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -188,7 +188,8 @@ const std::vector &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0])); + workspace_size_list_.emplace_back( + hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0])); return workspace_size_list_; } @@ -218,7 +219,7 @@ std::vector HcclKernel::GenTask(const std::vector &inpu std::vector private_def; HcclDataType data_type = hccl_data_type_list_[0]; std::vector task_info; - bool ret = hccl::GenTask(anf_node, data_type, &task_info); + bool ret = hccl::HcclAdapter::GetInstance().GenTask(anf_node, data_type, &task_info); if (!ret) { MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed."; } @@ -245,10 +246,11 @@ std::vector HcclKernel::GenTask(const std::vector &inpu workspace_addr = workspace.at(0)->addr; } - results.emplace_back(std::make_shared( - kernel_name_, stream_id, hccl::GetHcclType(anf_node), input_data_addr, output_data_addr, workspace_addr, - task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, - op_type_, data_type, group_, NeedDump())); + results.emplace_back( + std::make_shared(kernel_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr, + output_data_addr, workspace_addr, task.workspace_size, task.stream_num, + private_def, hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(), + hccl_count_, root_id_, op_type_, data_type, group_, NeedDump())); } return results; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc index ec5073194c0..a2624cf6463 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc @@ -17,8 +17,7 @@ #include "backend/kernel_compiler/hccl/hcom_all_broadcast.h" #include #include "utils/ms_context.h" -#include "backend/kernel_compiler/hccl/hccl_context.h" -#include "external/hccl/hccl.h" +#include "runtime/hccl_adapter/hccl_adapter.h" namespace mindspore { namespace kernel { @@ -31,8 +30,8 @@ bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, } MS_EXCEPTION_IF_NULL(inputs[0]); MS_EXCEPTION_IF_NULL(stream_ptr); - auto hccl_result = HcclBroadcast(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, - HcclContext::GetInstance().hccl_comm(), stream_ptr); + auto hccl_result = hccl::HcclAdapter::GetInstance().HcclBroadcast(inputs[0]->addr, hccl_count_, + hccl_data_type_list_[0], root_id_, stream_ptr); if (hccl_result != HCCL_SUCCESS) { MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << hccl_result; return false; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc index a2fe9c98783..4480b556bfa 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -16,9 +16,7 @@ #include "backend/kernel_compiler/hccl/hcom_all_reduce.h" #include -#include "utils/ms_context.h" -#include "backend/kernel_compiler/hccl/hccl_context.h" -#include "external/hccl/hccl.h" +#include "runtime/hccl_adapter/hccl_adapter.h" namespace mindspore { namespace kernel { @@ -32,8 +30,8 @@ bool HcomAllReduceKernel::Launch(const std::vector &inputs, const st MS_EXCEPTION_IF_NULL(inputs[0]); MS_EXCEPTION_IF_NULL(outputs[0]); MS_EXCEPTION_IF_NULL(stream_ptr); - auto hccl_result = HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, - HcclContext::GetInstance().hccl_comm(), stream_ptr); + auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_, + hccl_data_type_list_[0], op_type_, stream_ptr); if (hccl_result != HCCL_SUCCESS) { MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result; return false; diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index dc5dc309925..5de27cf39ae 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -126,10 +126,6 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64/plugin/opskernel) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH @@ -141,7 +137,6 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/acllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/acllib/lib64) - set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) set(MINDSPORE_RPATH diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc index 2a8bb0f5fc8..f2060816f9d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc @@ -22,12 +22,12 @@ #include "external/hccl/hccl.h" #include "runtime/device/ascend/ascend_memory_pool.h" #include "backend/kernel_compiler/hccl/hcom_util.h" -#include "backend/kernel_compiler/hccl/hccl_context.h" #include "runtime/device/memory_manager.h" #include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/ascend/ascend_event.h" #include "runtime/device/ascend/ascend_launch_mul.h" #include "runtime/device/ascend/ascend_launch_atomic_clean.h" +#include "runtime/hccl_adapter/hccl_adapter.h" #include "utils/profile.h" #define CHECK_ASCEND_RT_WITH_EXCEPTION(expression, message) \ @@ -147,8 +147,8 @@ void AscendBucket::LaunchAllReduce() { auto hccl_count = total_size_ / type_size; HcclReduceOp op_type = HcclReduceOp::HCCL_REDUCE_SUM; - auto hccl_result = HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count, iter->second, op_type, - kernel::HcclContext::GetInstance().hccl_comm(), stream_); + auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count, + iter->second, op_type, stream_); if (hccl_result != HCCL_SUCCESS) { MS_LOG(EXCEPTION) << "HcclAllReduce faled, ret:" << hccl_result; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index d1d1fb98108..690638762f8 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -51,7 +51,6 @@ #include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h" #include "runtime/hccl_adapter/hccl_adapter.h" #include "runtime/device/ascend/profiling/profiling_callback_register.h" -#include "backend/kernel_compiler/hccl/hccl_context.h" #ifdef ENABLE_TDTQUE #include "minddata/dataset/engine/tdt/tdt_handle.h" using mindspore::dataset::TdtHandle; @@ -260,7 +259,6 @@ void AscendKernelRuntime::ReleaseDeviceRes() { MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret; } - (void)DestroySingleOpHccl(); (void)DestroyHccl(); (void)ResetDevice(device_id); (void)ProfilingManager::GetInstance().StopProfiling(); @@ -340,11 +338,7 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { GenKernelEvents(graph); return true; } - // Do HcomExecutorInitialize - if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) { - MS_LOG(ERROR) << "Init Hccl Executor Failed"; - return false; - } + if (!GenTask(graph)) { return false; } @@ -837,25 +831,13 @@ bool AscendKernelRuntime::HcclInit() { return false; } MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; - bool ret = hccl::InitHccl(context_ptr->get_param(MS_CTX_DEVICE_ID), rank_id_str, full_path); + bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param(MS_CTX_DEVICE_ID), rank_id_str, + full_path); free(full_path); if (!ret) { MS_LOG(ERROR) << "Hcom init failed."; return false; } - auto task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); - if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode || !task_sink) { - MS_LOG(INFO) << "Hccl comm init."; - return kernel::HcclContext::GetInstance().InitHccl(); - } - return true; -} - -bool AscendKernelRuntime::DestroySingleOpHccl() { - if (!kernel::HcclContext::GetInstance().Finalize()) { - MS_LOG(ERROR) << "Hccl finalize failed"; - return false; - } return true; } @@ -869,11 +851,7 @@ bool AscendKernelRuntime::DestroyHccl() { MS_LOG(INFO) << "Hccl is not enable, no need to close."; return true; } - // Dynamic Shape Hccl Finalize - if (!HcclExecutorManager::GetInstance().Finalize()) { - MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed"; - } - bool res = hccl::FinalizeHccl(); + bool res = hccl::HcclAdapter::GetInstance().FinalizeHccl(); if (!res) { MS_LOG(ERROR) << "Hccl destroy failed"; return false; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 7d18f7efa54..3f3b09c142f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -76,7 +76,6 @@ class AscendKernelRuntime : public KernelRuntime { static bool HcclInit(); static bool NeedDestroyHccl(); static bool DestroyHccl(); - static bool DestroySingleOpHccl(); void SetCurrentContext(); void ClearGraphModelMap(); diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc index c145af90590..c960552914a 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc @@ -23,6 +23,7 @@ #include "utils/log_adapter.h" #include "runtime/device/kernel_runtime.h" #include "backend/kernel_compiler/hccl/hcom_util.h" +#include "runtime/hccl_adapter/hccl_adapter.h" namespace { // Find so in RPATH or LD_LIBRARY_PATH (/usr/local/Ascend/fwkacllib/lib64/) @@ -90,23 +91,12 @@ void HcclDynamicKernel::StaticShapeExecute() { void HcclDynamicKernel::Execute() { MS_LOG(INFO) << "Start Execute"; - - auto EnqueueHcomOperation = - (HcclResult(*)(ge::HcomOpertion, std::function))HcclExecutorManager::GetInstance() - .GetHcomOpertion(); - if (EnqueueHcomOperation == nullptr) { - MS_LOG(ERROR) << "Failed to get EnqueueHcomOperation function"; - HcclExecutorManager::GetInstance().CloseHandle(); - MS_LOG(EXCEPTION) << "Hccl dynamic kernel execute failed"; - return; - } - - ge::HcomOpertion op_info; + ::HcomOperation op_info; op_info.hcclType = hccl_type_; op_info.inputPtr = input_ptr_; op_info.outputPtr = output_ptr_; - op_info.dataType = data_type_; - op_info.opType = op_type_; + op_info.dataType = static_cast(data_type_); + op_info.opType = static_cast(op_type_); op_info.root = root_; op_info.count = count_; @@ -119,7 +109,7 @@ void HcclDynamicKernel::Execute() { MS_LOG(INFO) << "hccl callback success."; }; - auto hccl_ret = EnqueueHcomOperation(op_info, callback); + auto hccl_ret = hccl::HcclAdapter::GetInstance().HcclExecEnqueueOp(op_info, callback); if (hccl_ret != HCCL_SUCCESS) { MS_LOG(EXCEPTION) << "Call EnqueueHcomOperation failed"; } @@ -130,70 +120,6 @@ void HcclDynamicKernel::Execute() { } void HcclDynamicKernel::PostExecute() {} - -bool HcclExecutorManager::Initialize() { - if (initialized_) { - return true; - } - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - if (!context->get_param(MS_CTX_ENABLE_HCCL)) { - return true; - } - initialized_ = true; - MS_LOG(INFO) << "Start Initialize Hccl DynamicKernel"; - handle_ = dlopen(kHcomGraphAdaptorPath, RTLD_NOW | RTLD_GLOBAL); - if (handle_ == nullptr) { - MS_LOG(ERROR) << "dlopen failed, path:" << kHcomGraphAdaptorPath; - return false; - } - - auto HcomExecutorInitialize = (HcclResult(*)())dlsym(handle_, "HcomExecInitialize"); - if (HcomExecutorInitialize == nullptr) { - MS_LOG(ERROR) << "dlsym HcomExecutorInitialize failed"; - return false; - } - - HcclResult hccl_ret = HcomExecutorInitialize(); - if (hccl_ret == HCCL_E_PTR) { - MS_LOG(WARNING) << "Hccl comm is null, hcom executor initialize is not required"; - } else if (hccl_ret == HCCL_SUCCESS) { - MS_LOG(INFO) << "Hcom DynamicKernel Initialize success"; - } else { - MS_LOG(ERROR) << "Hcom DynamicKernel Initialize failed"; - return false; - } - return true; -} - -bool HcclExecutorManager::Finalize() { - if (!initialized_) { - return true; - } - auto HcomExecutorFinalize = (HcclResult(*)())dlsym(handle_, "HcomExecFinalize"); - if (HcomExecutorFinalize == nullptr) { - MS_LOG(ERROR) << "Fail to dlsym HcomExecutorFinalize"; - return false; - } - HcclResult hccl_ret = HcomExecutorFinalize(); - if (hccl_ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Hcom DynamicKernel Finalize failed"; - return false; - } - if (dlclose(handle_) != 0) { - MS_LOG(ERROR) << "Failed to close hcom handle"; - return false; - } - MS_LOG(INFO) << "Hccl DynamicKernel Finalize success"; - return true; -} - -void *HcclExecutorManager::GetHcomOpertion() { return dlsym(handle_, "HcomExecEnqueueOperation"); } -void HcclExecutorManager::CloseHandle() { - if (dlclose(handle_) != 0) { - MS_LOG(WARNING) << "Failed to close hcom handle"; - } -} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h index cebbdb0472e..af9ae12ddc5 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h +++ b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h @@ -56,27 +56,6 @@ class HcclDynamicKernel : public DynamicKernel { void StaticShapeExecute(); }; - -class HcclExecutorManager { - public: - static HcclExecutorManager &GetInstance() { - static HcclExecutorManager instance; - return instance; - } - - bool Initialize(); - bool Finalize(); - void *GetHcomOpertion(); - void CloseHandle(); - - private: - HcclExecutorManager() = default; - ~HcclExecutorManager() = default; - DISABLE_COPY_AND_ASSIGN(HcclExecutorManager); - - void *handle_{nullptr}; - bool initialized_{false}; -}; } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/hccl_adapter/CMakeLists.txt b/mindspore/ccsrc/runtime/hccl_adapter/CMakeLists.txt index 8f7f421ed43..e42585b745c 100644 --- a/mindspore/ccsrc/runtime/hccl_adapter/CMakeLists.txt +++ b/mindspore/ccsrc/runtime/hccl_adapter/CMakeLists.txt @@ -1,8 +1,9 @@ -file(GLOB_RECURSE HCCL_ADAPTER_SRC_LIST ./*.cc) +file(GLOB HCCL_ADAPTER_SRC_LIST ./*.cc) set_property(SOURCE ${HCCL_ADAPTER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_HCCL_ADPT) if(ENABLE_D) add_library(_mindspore_runtime_hccl_adapter_obj OBJECT ${HCCL_ADAPTER_SRC_LIST}) target_include_directories(_mindspore_runtime_hccl_adapter_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge) add_dependencies(_mindspore_runtime_hccl_adapter_obj graph) + add_subdirectory(plugin) endif() \ No newline at end of file diff --git a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc index a1783dea859..9dbf742ed17 100644 --- a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc +++ b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc @@ -14,29 +14,44 @@ * limitations under the License. */ #include "runtime/hccl_adapter/hccl_adapter.h" +#include #include #include #define google ascend_private -#include "register/ops_kernel_builder_registry.h" #include "common/opskernel/ops_kernel_info_store.h" +#include "common/opskernel/ops_kernel_builder.h" #include "external/ge/ge_api_types.h" #undef google +#include "hccl/hccl.h" +#include "hccl/hcom.h" #include "utils/log_adapter.h" #include "utils/ms_utils.h" #include "runtime/hccl_adapter/converter.h" -#include "runtime/hccl_adapter/hcom_graph_adaptor.h" +static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so"; static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl"; static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE"; static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO"; -// following global var, thread safety is not guaranteed -static std::shared_ptr ops_kernel_info_store = nullptr; -static ge::OpsKernelBuilderPtr ops_kernel_builder = nullptr; -namespace mindspore::hccl { +inline static std::string GetDlErrorMsg() { + const char *result = dlerror(); + return (result == nullptr) ? "Unknown" : result; +} + +template +static T DlsymWithCast(void *handle, const char *symbol_name) { + T symbol = reinterpret_cast(dlsym(handle, symbol_name)); + if (symbol == nullptr) { + MS_LOG(EXCEPTION) << "Dlsym symbol " << symbol_name << " failed, result = " << GetDlErrorMsg(); + } + return symbol; +} + +#define DlsymFuncObj(handle, func_name) DlsymWithCast(handle, k##func_name##Name); + static std::map GenHcclOptions(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) { - auto env_deploy_mode = common::GetEnv(kHcclDeployModeEnv); + auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv); if (env_deploy_mode.empty()) { MS_LOG(WARNING) << kHcclDeployModeEnv << " is not set in ENV. Now set to default value 0"; env_deploy_mode = "0"; @@ -53,7 +68,7 @@ static std::map GenHcclOptions(uint32_t device_id, std {ge::OPTION_EXEC_HCCL_FLAG, "1"}, {ge::OPTION_EXEC_DEPLOY_MODE, env_deploy_mode}}; - auto env_hccl_algo = common::GetEnv(kHcclAlgoEnv); + auto env_hccl_algo = mindspore::common::GetEnv(kHcclAlgoEnv); if (!env_hccl_algo.empty()) { std::string ge_hccl_algo = "HCCL_algorithm"; default_options_map.emplace(ge_hccl_algo, env_hccl_algo); @@ -62,74 +77,111 @@ static std::map GenHcclOptions(uint32_t device_id, std return default_options_map; } -bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) { +namespace mindspore::hccl { +HcclAdapter &HcclAdapter::GetInstance() { + static HcclAdapter instance; + return instance; +} + +void HcclAdapter::InitPlugin() { + if (plugin_handle_ != nullptr) { + return; + } + + plugin_handle_ = dlopen(kHcclPluginFileName, RTLD_NOW | RTLD_GLOBAL); + if (plugin_handle_ == nullptr) { + MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg(); + } + + init_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, InitHcomGraphAdapter); + finalize_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, FinalizeHcomGraphAdapter); + get_hccl_kernel_info_store_ = DlsymFuncObj(plugin_handle_, GetHcclKernelInfoStore); + get_all_kernel_builder_ = DlsymFuncObj(plugin_handle_, GetAllKernelBuilder); + init_hccl_comm_ = DlsymFuncObj(plugin_handle_, InitHcclComm); + finalize_hccl_comm_ = DlsymFuncObj(plugin_handle_, FinalizeHcclComm); + launch_hccl_broadcast_ = DlsymFuncObj(plugin_handle_, LaunchHcclBroadcast); + launch_hccl_all_reduce_ = DlsymFuncObj(plugin_handle_, LaunchHcclAllReduce); + hccl_create_group_ = DlsymFuncObj(plugin_handle_, HcclCreateGroup); + hccl_destroy_group_ = DlsymFuncObj(plugin_handle_, HcclDestroyGroup); + hccl_get_rank_id_ = DlsymFuncObj(plugin_handle_, HcclGetRankId); + hccl_get_rank_size_ = DlsymFuncObj(plugin_handle_, HcclGetRankSize); + hccl_exec_initialize_ = DlsymFuncObj(plugin_handle_, HcclExecInitialize); + hccl_exec_finalize_ = DlsymFuncObj(plugin_handle_, HcclExecFinalize); + hccl_exec_enqueue_op_ = DlsymFuncObj(plugin_handle_, HcclExecEnqueueOp); +} + +void HcclAdapter::FinalizePlugin() { + if (plugin_handle_ == nullptr) { + return; + } + + init_hcom_graph_adapter_ = nullptr; + finalize_hcom_graph_adapter_ = nullptr; + get_hccl_kernel_info_store_ = nullptr; + get_all_kernel_builder_ = nullptr; + init_hccl_comm_ = nullptr; + finalize_hccl_comm_ = nullptr; + launch_hccl_broadcast_ = nullptr; + launch_hccl_all_reduce_ = nullptr; + hccl_create_group_ = nullptr; + hccl_destroy_group_ = nullptr; + hccl_get_rank_id_ = nullptr; + hccl_get_rank_size_ = nullptr; + hccl_exec_initialize_ = nullptr; + hccl_exec_finalize_ = nullptr; + hccl_exec_enqueue_op_ = nullptr; + (void)dlclose(plugin_handle_); + plugin_handle_ = nullptr; +} + +bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) { MS_LOG(INFO) << "Start init hccl adapter."; - // get ops_kernel_builder - std::map all_builders = ge::OpsKernelBuilderRegistry::GetInstance().GetAll(); - if (all_builders.size() != 1) { - MS_LOG(EXCEPTION) << "Builders size should be 1 (hccl builder), but is " << all_builders.size(); + std::lock_guard lock(init_mutex_); + if (init_flag_) { + MS_LOG(INFO) << "Hccl has been inited, skip."; + return true; } - MS_LOG(INFO) << "Get builder " << all_builders.begin()->first; - ops_kernel_builder = all_builders.begin()->second; - MS_EXCEPTION_IF_NULL(ops_kernel_builder); - // init ops_kernel_builder - auto options = GenHcclOptions(device_id, rank_id, rank_file); - auto ret = ops_kernel_builder->Initialize(options); - if (ret != ge::SUCCESS) { - MS_LOG(EXCEPTION) << "Init builder failed, ret = " << ret; + InitPlugin(); + bool ret = InitKernelInfoStore(device_id, rank_id, rank_file); + if (!ret) { + return false; + } + ret = InitHcclComm(rank_id, rank_file); + if (!ret) { + return false; } - // get ops_kernel_info_store - ret = ::Initialize(options); - if (ret != ge::SUCCESS) { - MS_LOG(EXCEPTION) << "Init plugin so failed, ret = " << ret; + ret = InitHcclExec(); + if (!ret) { + return false; } - std::map> all_ops_kernel_info_stores; - ::GetOpsKernelInfoStores(all_ops_kernel_info_stores); - for (auto &[name, ptr] : all_ops_kernel_info_stores) { - if (name == kHcclOpsKernelInfoStore) { - ops_kernel_info_store = ptr; - break; - } - } - MS_EXCEPTION_IF_NULL(ops_kernel_info_store); - ret = ops_kernel_info_store->Initialize(options); - if (ret != ge::SUCCESS) { - MS_LOG(EXCEPTION) << "Init info store failed, ret = " << ret; - } + init_flag_ = true; MS_LOG(INFO) << "Init hccl adapter success."; return true; } -bool FinalizeHccl() { +bool HcclAdapter::FinalizeHccl() { MS_LOG(INFO) << "Start destroy hccl adapter."; - if (ops_kernel_info_store != nullptr) { - auto ret = ops_kernel_info_store->Finalize(); - if (ret != ge::SUCCESS) { - MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret; - return false; - } + std::lock_guard lock(init_mutex_); + if (!init_flag_) { + MS_LOG(INFO) << "Hccl has never been inited, skip."; + return true; } - if (ops_kernel_builder != nullptr) { - auto ret = ops_kernel_builder->Finalize(); - if (ret != ge::SUCCESS) { - MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret; - return false; - } - } - - ::Finalize(); - ops_kernel_info_store.reset(); - ops_kernel_builder.reset(); + (void)FinalizeHcclExec(); + (void)FinalizeHcclComm(); + (void)FinalizeKernelInfoStore(); + FinalizePlugin(); + init_flag_ = false; MS_LOG(INFO) << "Destroy hccl adapter success."; return true; } -bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector *task_info_lists) { - MS_EXCEPTION_IF_NULL(ops_kernel_builder); +bool HcclAdapter::GenTask(const AnfNodePtr &node, HcclDataType datatype, + std::vector *task_info_lists) const { + MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(task_info_lists); MS_LOG(INFO) << "Start generate task for hccl node " << node->DebugString(); auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype); @@ -138,7 +190,8 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vectorCalcOpRunningParam(*ge_node); + MS_EXCEPTION_IF_NULL(ops_kernel_builder_); + ge::Status ret = ops_kernel_builder_->CalcOpRunningParam(*ge_node); if (ret != ge::SUCCESS) { MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret; return false; @@ -146,7 +199,7 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector domi_tasks; - ret = ops_kernel_builder->GenerateTask(*ge_node, unused_ctx, domi_tasks); + ret = ops_kernel_builder_->GenerateTask(*ge_node, unused_ctx, domi_tasks); if (ret != ge::SUCCESS) { MS_LOG(ERROR) << "OpsKernelBuilder GenerateTask failed, ret = " << ret; return false; @@ -160,8 +213,8 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vectorDebugString() << " ,dtype is " << datatype; auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype); MS_EXCEPTION_IF_NULL(ge_node); @@ -169,7 +222,7 @@ int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) { MS_EXCEPTION_IF_NULL(op); MS_LOG(INFO) << "Start to call CalcOpRunningParam"; - ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node); + ge::Status ret = ops_kernel_builder_->CalcOpRunningParam(*ge_node); if (ret != ge::SUCCESS) { MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret; return false; @@ -185,12 +238,179 @@ int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) { return workspace_size; } -void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); } +void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return ops_kernel_info_store_.get(); } -std::string GetHcclType(const AnfNodePtr &node) { +std::string HcclAdapter::GetHcclType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); return GetGeNodeName(cnode); } + +HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, + aclrtStream stream) const { + MS_EXCEPTION_IF_NULL(launch_hccl_broadcast_); + return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream); +} + +HcclResult HcclAdapter::HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + HcclReduceOp op, aclrtStream stream) const { + MS_EXCEPTION_IF_NULL(launch_hccl_all_reduce_); + return launch_hccl_all_reduce_(sendBuf, recvBuf, count, dataType, op, hccl_comm_, stream); +} + +bool HcclAdapter::InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) { + MS_LOG(INFO) << "Start init hccl kernel info store."; + MS_EXCEPTION_IF_NULL(init_hcom_graph_adapter_); + MS_EXCEPTION_IF_NULL(get_hccl_kernel_info_store_); + // get ops_kernel_builder + std::map> all_builders; + get_all_kernel_builder_(&all_builders); + if (all_builders.size() != 1) { + MS_LOG(EXCEPTION) << "Builders size should be 1 (hccl builder), but is " << all_builders.size(); + } + + MS_LOG(INFO) << "Get builder " << all_builders.begin()->first; + ops_kernel_builder_ = all_builders.begin()->second; + MS_EXCEPTION_IF_NULL(ops_kernel_builder_); + // init ops_kernel_builder + auto options = GenHcclOptions(device_id, rank_id, rank_file); + auto ret = ops_kernel_builder_->Initialize(options); + if (ret != ge::SUCCESS) { + MS_LOG(EXCEPTION) << "Init hccl kernel builder failed, ret = " << ret; + } + + // get ops_kernel_info_store + ret = init_hcom_graph_adapter_(options); + if (ret != ge::SUCCESS) { + MS_LOG(EXCEPTION) << "Init hccl graph adapter failed, ret = " << ret; + } + + get_hccl_kernel_info_store_(&ops_kernel_info_store_); + MS_EXCEPTION_IF_NULL(ops_kernel_info_store_); + ret = ops_kernel_info_store_->Initialize(options); + if (ret != ge::SUCCESS) { + MS_LOG(EXCEPTION) << "Init info store failed, ret = " << ret; + } + MS_LOG(INFO) << "Init hccl kernel info store success."; + return true; +} + +bool HcclAdapter::FinalizeKernelInfoStore() { + MS_LOG(INFO) << "Start destroy hccl kernel info store."; + if (ops_kernel_info_store_ != nullptr) { + auto ret = ops_kernel_info_store_->Finalize(); + if (ret != ge::SUCCESS) { + MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret; + return false; + } + } + + if (ops_kernel_builder_ != nullptr) { + auto ret = ops_kernel_builder_->Finalize(); + if (ret != ge::SUCCESS) { + MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret; + return false; + } + } + + MS_EXCEPTION_IF_NULL(finalize_hcom_graph_adapter_); + finalize_hcom_graph_adapter_(); + ops_kernel_info_store_.reset(); + ops_kernel_builder_.reset(); + MS_LOG(INFO) << "Destroy hccl kernel info store success."; + return true; +} + +bool HcclAdapter::InitHcclComm(std::string_view rank_id, std::string_view rank_file) { + MS_LOG(INFO) << "Start init hccl comm."; + int rank_id_i = -1; + try { + rank_id_i = std::stoi(rank_id.data()); + } catch (std::invalid_argument &) { + MS_LOG(EXCEPTION) << "Invalid rank id env:" << rank_id; + } + if (rank_id_i < 0 || rank_id_i > 7) { + MS_LOG(ERROR) << "rank_id needs to be between 0-7"; + return false; + } + MS_EXCEPTION_IF_NULL(init_hccl_comm_); + auto hccl_result = init_hccl_comm_(rank_file.data(), rank_id_i, &hccl_comm_); + if (hccl_result != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcclCommInitClusterInfo failed, ret:" << hccl_result; + return false; + } + MS_LOG(INFO) << "InitHcclComm success"; + return true; +} + +bool HcclAdapter::FinalizeHcclComm() { + MS_LOG(INFO) << "Start finalize hccl comm."; + if (hccl_comm_ == nullptr) { + return true; + } + + MS_EXCEPTION_IF_NULL(finalize_hccl_comm_); + auto hccl_result = finalize_hccl_comm_(hccl_comm_); + if (hccl_result != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcclComm destroy failed, ret:" << hccl_result; + return false; + } + hccl_comm_ = nullptr; + MS_LOG(INFO) << "HcclComm destroy success"; + return true; +} + +HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const { + MS_EXCEPTION_IF_NULL(hccl_create_group_); + return hccl_create_group_(group.c_str(), rank_num, rank_ids); +} + +HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const { + MS_EXCEPTION_IF_NULL(hccl_destroy_group_); + return hccl_destroy_group_(group.c_str()); +} + +HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const { + MS_EXCEPTION_IF_NULL(hccl_get_rank_id_); + return hccl_get_rank_id_(group.c_str(), rank_id); +} + +HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const { + MS_EXCEPTION_IF_NULL(hccl_get_rank_size_); + return hccl_get_rank_size_(group.c_str(), rank_size); +} + +bool HcclAdapter::InitHcclExec() { + MS_LOG(INFO) << "Start init hccl exec."; + MS_EXCEPTION_IF_NULL(hccl_exec_initialize_); + HcclResult hccl_ret = hccl_exec_initialize_(); + if (hccl_ret == HCCL_E_PTR) { + MS_LOG(WARNING) << "Hccl comm is null, hcom executor initialize is not required"; + } else if (hccl_ret == HCCL_SUCCESS) { + MS_LOG(INFO) << "Hcom DynamicKernel Initialize success"; + } else { + MS_LOG(ERROR) << "Hcom DynamicKernel Initialize failed"; + return false; + } + MS_LOG(INFO) << "InitHcclExec success"; + return true; +} + +bool HcclAdapter::FinalizeHcclExec() { + MS_LOG(INFO) << "Start finalize hccl exec."; + MS_EXCEPTION_IF_NULL(hccl_exec_finalize_); + HcclResult hccl_ret = hccl_exec_finalize_(); + if (hccl_ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Hcom DynamicKernel Finalize failed"; + return false; + } + MS_LOG(INFO) << "HcclExec destroy success"; + return true; +} + +HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const { + MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_); + return hccl_exec_enqueue_op_(op_info, callback); +} } // namespace mindspore::hccl diff --git a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h index 6488a0ffb6f..592e61c11e4 100644 --- a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h +++ b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h @@ -20,8 +20,15 @@ #include #include #include +#include #include "mindspore/core/ir/anf.h" #include "hccl/hccl_types.h" +#include "runtime/hccl_adapter/plugin/hccl_plugin.h" + +namespace ge { +class OpsKernelInfoStore; +class OpsKernelBuilder; +} // namespace ge namespace mindspore::hccl { struct HcclTaskInfo { @@ -30,11 +37,76 @@ struct HcclTaskInfo { int64_t stream_num; }; -bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); -bool FinalizeHccl(); -bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector *task_info_lists); -int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype); -void *GetHcclOpsKernelInfoStore(); -std::string GetHcclType(const AnfNodePtr &node); +class HcclAdapter { + public: + static HcclAdapter &GetInstance(); + + // common + bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); + bool FinalizeHccl(); + + HcclResult HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const; + HcclResult HcclDestroyGroup(const std::string &group) const; + HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const; + HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const; + + // for ge node + bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector *task_info_lists) const; + int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const; + void *GetHcclOpsKernelInfoStore() const; + static std::string GetHcclType(const AnfNodePtr &node); + + // for single op + HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, aclrtStream stream) const; + HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + aclrtStream stream) const; + + // for enqueue op + HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const; + + private: + HcclAdapter() = default; + ~HcclAdapter() = default; + void InitPlugin(); + void FinalizePlugin(); + + bool InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); + bool FinalizeKernelInfoStore(); + + bool InitHcclComm(std::string_view rank_id, std::string_view rank_file); + bool FinalizeHcclComm(); + + bool InitHcclExec(); + bool FinalizeHcclExec(); + + void *plugin_handle_ = nullptr; + + InitHcomGraphAdapterFunObj init_hcom_graph_adapter_ = nullptr; + FinalizeHcomGraphAdapterFunObj finalize_hcom_graph_adapter_ = nullptr; + GetHcclKernelInfoStoreFunObj get_hccl_kernel_info_store_ = nullptr; + GetAllKernelBuilderFunObj get_all_kernel_builder_ = nullptr; + + InitHcclCommFunObj init_hccl_comm_ = nullptr; + FinalizeHcclCommFunObj finalize_hccl_comm_ = nullptr; + LaunchHcclBroadcastFunObj launch_hccl_broadcast_ = nullptr; + LaunchHcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr; + + HcclCreateGroupFunObj hccl_create_group_ = nullptr; + HcclDestroyGroupFunObj hccl_destroy_group_ = nullptr; + HcclGetRankIdFunObj hccl_get_rank_id_ = nullptr; + HcclGetRankSizeFunObj hccl_get_rank_size_ = nullptr; + + HcclExecInitializeFunObj hccl_exec_initialize_ = nullptr; + HcclExecFinalizeFunObj hccl_exec_finalize_ = nullptr; + HcclExecEnqueueOpFunObj hccl_exec_enqueue_op_ = nullptr; + + HcclComm hccl_comm_ = nullptr; + + std::shared_ptr<::ge::OpsKernelInfoStore> ops_kernel_info_store_ = nullptr; + std::shared_ptr<::ge::OpsKernelBuilder> ops_kernel_builder_ = nullptr; + + bool init_flag_ = false; + std::mutex init_mutex_; +}; } // namespace mindspore::hccl #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H diff --git a/mindspore/ccsrc/runtime/hccl_adapter/hcom_graph_adaptor.h b/mindspore/ccsrc/runtime/hccl_adapter/hcom_graph_adaptor.h deleted file mode 100644 index 08fadfebde3..00000000000 --- a/mindspore/ccsrc/runtime/hccl_adapter/hcom_graph_adaptor.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H -#define MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H - -#include -#include -#include -#include "mindspore/core/ir/anf.h" -#include "common/opskernel/ops_kernel_info_store.h" - -extern "C" { -ge::Status Initialize(const std::map &); -ge::Status Finalize(); -void GetOpsKernelInfoStores(std::map> &); -} - -#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H diff --git a/mindspore/ccsrc/runtime/hccl_adapter/plugin/CMakeLists.txt b/mindspore/ccsrc/runtime/hccl_adapter/plugin/CMakeLists.txt new file mode 100644 index 00000000000..bb4e203620d --- /dev/null +++ b/mindspore/ccsrc/runtime/hccl_adapter/plugin/CMakeLists.txt @@ -0,0 +1,41 @@ +add_library(hccl_plugin SHARED hccl_plugin.cc) +target_include_directories(hccl_plugin PRIVATE ${CMAKE_BINARY_DIR}/proto/ge) +add_dependencies(hccl_plugin graph) + +set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64) +set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64) +set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64) +set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64/plugin/opskernel) +set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel) +set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel) +set_target_properties(hccl_plugin PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) + +if(DEFINED ENV{D_LINK_PATH}) + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") + MESSAGE("system processor matches aarch64") + set(D_LIB_PATH $ENV{D_LINK_PATH}/aarch64) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") + MESSAGE("system processor matches x86_64") + set(D_LIB_PATH $ENV{D_LINK_PATH}/x86_64) + else() + MESSAGE("system ${CMAKE_HOST_SYSTEM_PROCESSOR} not support") + endif() +else() + MESSAGE("use system default lib") + if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH}) + else() + set(ASCEND_PATH /usr/local/Ascend) + endif() + set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64) + set(ASCEND_PLUGIN_PATH ${ASCEND_RUNTIME_PATH}/plugin/opskernel) + set(ASCEND_TOOLKIT_RUNTIME_PATH ${ASCEND_PATH}/ascend-toolkit/latest/fwkacllib/lib64) + set(ASCEND_TOOLKIT_PLUGIN_PATH ${ASCEND_TOOLKIT_RUNTIME_PATH}/plugin/opskernel) +endif() + +find_library(HCCL hccl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(REGISTER register ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(HCCL_ADPTER hcom_graph_adaptor ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(HCCL_RA ra ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +find_library(HCCL_BUILDER hcom_opskernel_builder ${ASCEND_PLUGIN_PATH} ${ASCEND_TOOLKIT_PLUGIN_PATH}) +target_link_libraries(hccl_plugin -Wl,--no-as-needed ${HCCL} ${HCCL_ADPTER} ${REGISTER} ${HCCL_BUILDER} ${HCCL_RA}) diff --git a/mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.cc b/mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.cc new file mode 100644 index 00000000000..81e9b029a76 --- /dev/null +++ b/mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/hccl_adapter/plugin/hccl_plugin.h" +#define google ascend_private +#include "register/ops_kernel_builder_registry.h" +#include "common/opskernel/ops_kernel_info_store.h" +#undef google +#include "hccl/hcom.h" + +static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl"; + +extern "C" { +ge::Status Initialize(const std::map &); +ge::Status Finalize(); +void GetOpsKernelInfoStores(std::map> &); + +ge::Status PluginInitHcomGraphAdapter(const std::map &options) { + return ::Initialize(options); +} + +ge::Status PluginFinalizeHcomGraphAdapter() { return ::Finalize(); } + +void PluginGetHcclKernelInfoStore(std::shared_ptr *hccl_kernel_info_store) { + if (hccl_kernel_info_store == nullptr) { + return; + } + + std::map> all_ops_kernel_info_stores; + ::GetOpsKernelInfoStores(all_ops_kernel_info_stores); + for (auto &[name, ptr] : all_ops_kernel_info_stores) { + if (name == kHcclOpsKernelInfoStore) { + *hccl_kernel_info_store = ptr; + return; + } + } + + *hccl_kernel_info_store = nullptr; +} + +void PluginGetAllKernelBuilder(std::map *all_ops_kernel_builder) { + if (all_ops_kernel_builder == nullptr) { + return; + } + + *all_ops_kernel_builder = ge::OpsKernelBuilderRegistry::GetInstance().GetAll(); +} + +HcclResult PluginLaunchHcclBroadcast(void *buf, uint64_t count, HcclDataType data_type, uint32_t root, HcclComm comm, + aclrtStream stream) { + return HcclBroadcast(buf, count, data_type, root, comm, stream); +} + +HcclResult PluginLaunchHcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType data_type, + HcclReduceOp op, HcclComm comm, aclrtStream stream) { + return HcclAllReduce(send_buf, recv_buf, count, data_type, op, comm, stream); +} + +HcclResult PluginInitHcclComm(const char *cluster_info, uint32_t rank, HcclComm *comm) { + return HcclCommInitClusterInfo(cluster_info, rank, comm); +} + +HcclResult PluginFinalizeHcclComm(HcclComm comm) { return HcclCommDestroy(comm); } + +HcclResult PluginHcclCreateGroup(const char *group, uint32_t rank_num, uint32_t *rank_ids) { + return HcomCreateGroup(group, rank_num, rank_ids); +} + +HcclResult PluginHcclDestroyGroup(const char *group) { return HcomDestroyGroup(group); } +HcclResult PluginHcclGetRankId(const char *group, uint32_t *rank_id) { return HcomGetRankId(group, rank_id); } +HcclResult PluginHcclGetRankSize(const char *group, uint32_t *rank_size) { return HcomGetRankSize(group, rank_size); } + +HcclResult PluginHcclExecInitialize() { return HcomExecInitialize(); } +HcclResult PluginHcclExecFinalize() { return HcomExecFinalize(); } +HcclResult PluginHcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) { + return HcomExecEnqueueOperation(op_info, callback); +} +}; // extern C diff --git a/mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.h b/mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.h new file mode 100644 index 00000000000..bb40dee12b0 --- /dev/null +++ b/mindspore/ccsrc/runtime/hccl_adapter/plugin/hccl_plugin.h @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H +#define MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H + +#include +#include +#include +#include +#include "external/ge/ge_api_types.h" +#include "hccl/hccl.h" + +namespace ge { +class OpsKernelBuilder; +class OpsKernelInfoStore; +} // namespace ge + +extern "C" { +struct HcomOperation; +} // extern C + +using OptionsType = std::map; +using OpsKernelBuilderMap = std::map>; +using HExecCallBack = std::function; + +#define PLUGIN_METHOD(name, return_type, params...) \ + extern "C" { \ + __attribute__((visibility("default"))) return_type Plugin##name(params); \ + } \ + constexpr const char *k##name##Name = "Plugin" #name; \ + using name##FunObj = std::function; \ + using name##FunPtr = return_type (*)(params); + +PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &); +PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status); +PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr *); +PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *); +PLUGIN_METHOD(LaunchHcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); +PLUGIN_METHOD(LaunchHcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, + aclrtStream); +PLUGIN_METHOD(InitHcclComm, HcclResult, const char *, uint32_t, HcclComm *); +PLUGIN_METHOD(FinalizeHcclComm, HcclResult, HcclComm); +PLUGIN_METHOD(HcclCreateGroup, HcclResult, const char *, uint32_t, uint32_t *); +PLUGIN_METHOD(HcclDestroyGroup, HcclResult, const char *); +PLUGIN_METHOD(HcclGetRankId, HcclResult, const char *, uint32_t *); +PLUGIN_METHOD(HcclGetRankSize, HcclResult, const char *, uint32_t *); +PLUGIN_METHOD(HcclExecInitialize, HcclResult); +PLUGIN_METHOD(HcclExecFinalize, HcclResult); +PLUGIN_METHOD(HcclExecEnqueueOp, HcclResult, const ::HcomOperation &, HExecCallBack); +#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H diff --git a/mindspore/ccsrc/utils/comm_manager.cc b/mindspore/ccsrc/utils/comm_manager.cc index c92d33a9ab1..6b943ffc4f7 100644 --- a/mindspore/ccsrc/utils/comm_manager.cc +++ b/mindspore/ccsrc/utils/comm_manager.cc @@ -18,7 +18,7 @@ #include "utils/convert_utils.h" #ifndef NO_DLIB -#include "hccl/hcom.h" +#include "runtime/hccl_adapter/hccl_adapter.h" #endif #if defined(ENABLE_GPU) @@ -67,26 +67,27 @@ bool CommManager::CreateGroupSync(const string &group, const vector(rank_id_list).data())); + hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size), + vector(rank_id_list).data())); return true; } bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { HCCL_GROUP_CHECK_EMPTY(group); - HCCL_RUN_CHECK(string("get rank_id"), group, HcomGetRankId(group.c_str(), rank_id)); + HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id)); return true; } bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const { HCCL_GROUP_CHECK_EMPTY(group); - HCCL_RUN_CHECK(string("get rank size"), group, HcomGetRankSize(group.c_str(), rank_size)); + HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size)); return true; } bool CommManager::DestroyGroup(const string &group) const { HCCL_GROUP_CHECK_EMPTY(group); HCCL_GROUP_CHECK_IS_WORLD(group); - HCCL_RUN_CHECK(string("destroy communicate group"), group, HcomDestroyGroup(group.c_str())); + HCCL_RUN_CHECK(string("destroy communicate group"), group, hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group)); return true; } #elif defined(ENABLE_GPU) diff --git a/tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc b/tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc index 2f0a0031f00..c10726d8928 100644 --- a/tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc +++ b/tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc @@ -41,9 +41,6 @@ void AiCoreDynamicKernel::UpdateArgs() {} void AiCoreDynamicKernel::Initialize() {} void AiCoreDynamicKernel::PostExecute() {} -bool HcclExecutorManager::Initialize() { return true; } -bool HcclExecutorManager::Finalize() { return true; } - void OpTilingCalculater::Init() {} void OpTilingCalculater::CalculateTiling(const NotNull &cnode, const optiling::OpCompileInfo &op_compile_info, const std::map &depend_tensor_map, diff --git a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc index 18bd0929c0f..d875febfdfc 100644 --- a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc +++ b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc @@ -18,11 +18,26 @@ namespace mindspore { namespace hccl { -bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } -bool FinalizeHccl() { return true; } -bool GenTask(const AnfNodePtr &, HcclDataType, std::vector *) { return true; } -int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; } -void *GetHcclOpsKernelInfoStore() { return nullptr; } -std::string GetHcclType(const AnfNodePtr &) { return ""; } +HcclAdapter &HcclAdapter::GetInstance() { + static HcclAdapter instance; + return instance; +} +bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view) { return true; } +bool HcclAdapter::FinalizeHccl() { return true; } +HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t *) const { return HCCL_SUCCESS; } +HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; } +HcclResult HcclAdapter::HcclGetRankId(const std::string &, uint32_t *) const { return HCCL_SUCCESS; } +HcclResult HcclAdapter::HcclGetRankSize(const std::string &, uint32_t *) const { return HCCL_SUCCESS; } +bool HcclAdapter::GenTask(const AnfNodePtr &, HcclDataType, std::vector *) const { return true; } +int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) const { return 0; } +void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return nullptr; } +std::string HcclAdapter::GetHcclType(const AnfNodePtr &) { return ""; } +HcclResult HcclAdapter::HcclBroadcast(void *, uint64_t, HcclDataType, uint32_t, aclrtStream) const { + return HCCL_SUCCESS; +} +HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const { + return HCCL_SUCCESS; +} +HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &, HExecCallBack) const { return HCCL_SUCCESS; } } // namespace hccl } // namespace mindspore