forked from mindspore-Ecosystem/mindspore
!16260 hccl decouple
From: @zhoufeng54 Reviewed-by: @kisnwang,@xu-yfei Signed-off-by: @xu-yfei
This commit is contained in:
commit
5f79f7d229
|
@ -248,6 +248,13 @@ if(NOT ENABLE_GE)
|
|||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS hccl_plugin
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_BINARY_DIR}/graphengine/metadef/graph/libgraph.so
|
||||
|
|
|
@ -293,8 +293,6 @@ if(ENABLE_D)
|
|||
endif()
|
||||
|
||||
MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}")
|
||||
find_library(HCCL hccl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(CCE_LIB cce ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(RUNTIME_LIB runtime ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}
|
||||
${ASCEND_DRIVER_BACK_PATH})
|
||||
|
@ -302,14 +300,7 @@ if(ENABLE_D)
|
|||
${ASCEND_DRIVER_BACK_PATH})
|
||||
find_library(PROFILING msprofiler_fwkacl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(ACL ascendcl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(REGISTER register ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH})
|
||||
# hccl_adpter
|
||||
find_library(HCCL_ADPTER hcom_graph_adaptor ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(HCCL_RA ra ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(HCCL_BUILDER hcom_opskernel_builder ${ASCEND_RUNTIME_PATH}/plugin/opskernel
|
||||
${ASCEND_TOOLKIT_RUNTIME_PATH}/plugin/opskernel)
|
||||
|
||||
add_library(ms_profile SHARED
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/device/ascend/profiling/profiling_callback_register.cc)
|
||||
|
@ -317,9 +308,7 @@ if(ENABLE_D)
|
|||
target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init)
|
||||
target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive
|
||||
mindspore::protobuf -Wl,--end-group)
|
||||
target_link_libraries(mindspore ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}
|
||||
${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER}
|
||||
${HCCL_RA} ${PLATFORM} ${ACL})
|
||||
target_link_libraries(mindspore ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} -Wl,--no-as-needed ${OPTILING} ${ACL})
|
||||
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)
|
||||
elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece
|
||||
|
@ -353,10 +342,6 @@ if(ENABLE_D)
|
|||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||
set(MINDSPORE_RPATH
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/hccl/hccl_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "hccl/hccl.h"
|
||||
|
||||
constexpr auto kHcclConfigFile = "MINDSPORE_HCCL_CONFIG_PATH";
|
||||
constexpr auto kHcclConfigFileOld = "RANK_TABLE_FILE";
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
int GetRankId() {
|
||||
auto rank_id_env = std::getenv("RANK_ID");
|
||||
if (rank_id_env == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "No RANK_ID, please export RANK_ID";
|
||||
}
|
||||
try {
|
||||
return std::stoi(rank_id_env);
|
||||
} catch (std::invalid_argument &e) {
|
||||
MS_LOG(EXCEPTION) << "Invalid rankd id env:" << rank_id_env;
|
||||
}
|
||||
}
|
||||
|
||||
bool HcclContext::InitHccl() {
|
||||
if (hccl_comm_ != nullptr) {
|
||||
return true;
|
||||
}
|
||||
auto config_file = std::getenv(kHcclConfigFile);
|
||||
if (config_file == nullptr) {
|
||||
config_file = std::getenv(kHcclConfigFileOld);
|
||||
if (config_file == nullptr) {
|
||||
MS_LOG(ERROR) << "Get hccl rank table file failed. Please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
rank_id_ = GetRankId();
|
||||
if (rank_id_ < 0 || rank_id_ > 7) {
|
||||
MS_LOG(ERROR) << "rank_id needs to be between 0-7";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto hccl_result = HcclCommInitClusterInfo(config_file, rank_id_, &hccl_comm_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclCommInitClusterInfo failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "HcclCommInitClusterInfo success";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HcclContext::Finalize() {
|
||||
if (hccl_comm_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
auto hccl_result = HcclCommDestroy(hccl_comm_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclComm destroy failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "HcclComm destroy success";
|
||||
hccl_comm_ = nullptr;
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_
|
||||
|
||||
#include <string>
|
||||
#include "hccl/hccl_types.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class HcclContext {
|
||||
public:
|
||||
static HcclContext &GetInstance() {
|
||||
static HcclContext instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool InitHccl();
|
||||
bool Finalize();
|
||||
HcclComm hccl_comm() { return hccl_comm_; }
|
||||
|
||||
private:
|
||||
HcclContext() = default;
|
||||
~HcclContext() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(HcclContext);
|
||||
HcclComm hccl_comm_{nullptr};
|
||||
int rank_id_{0};
|
||||
uint32_t device_id_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_
|
|
@ -188,7 +188,8 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
|
|||
return workspace_size_list_;
|
||||
}
|
||||
|
||||
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0]));
|
||||
workspace_size_list_.emplace_back(
|
||||
hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0]));
|
||||
return workspace_size_list_;
|
||||
}
|
||||
|
||||
|
@ -218,7 +219,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
|||
std::vector<uint8_t> private_def;
|
||||
HcclDataType data_type = hccl_data_type_list_[0];
|
||||
std::vector<hccl::HcclTaskInfo> task_info;
|
||||
bool ret = hccl::GenTask(anf_node, data_type, &task_info);
|
||||
bool ret = hccl::HcclAdapter::GetInstance().GenTask(anf_node, data_type, &task_info);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed.";
|
||||
}
|
||||
|
@ -245,10 +246,11 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
|||
workspace_addr = workspace.at(0)->addr;
|
||||
}
|
||||
|
||||
results.emplace_back(std::make_shared<HcclTaskInfo>(
|
||||
kernel_name_, stream_id, hccl::GetHcclType(anf_node), input_data_addr, output_data_addr, workspace_addr,
|
||||
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
|
||||
op_type_, data_type, group_, NeedDump()));
|
||||
results.emplace_back(
|
||||
std::make_shared<HcclTaskInfo>(kernel_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr,
|
||||
output_data_addr, workspace_addr, task.workspace_size, task.stream_num,
|
||||
private_def, hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(),
|
||||
hccl_count_, root_id_, op_type_, data_type, group_, NeedDump()));
|
||||
}
|
||||
|
||||
return results;
|
||||
|
|
|
@ -17,8 +17,7 @@
|
|||
#include "backend/kernel_compiler/hccl/hcom_all_broadcast.h"
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/hccl/hccl_context.h"
|
||||
#include "external/hccl/hccl.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -31,8 +30,8 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
auto hccl_result = HcclBroadcast(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_,
|
||||
HcclContext::GetInstance().hccl_comm(), stream_ptr);
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclBroadcast(inputs[0]->addr, hccl_count_,
|
||||
hccl_data_type_list_[0], root_id_, stream_ptr);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << hccl_result;
|
||||
return false;
|
||||
|
|
|
@ -16,9 +16,7 @@
|
|||
|
||||
#include "backend/kernel_compiler/hccl/hcom_all_reduce.h"
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/hccl/hccl_context.h"
|
||||
#include "external/hccl/hccl.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -32,8 +30,8 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
auto hccl_result = HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_,
|
||||
HcclContext::GetInstance().hccl_comm(), stream_ptr);
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_,
|
||||
hccl_data_type_list_[0], op_type_, stream_ptr);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result;
|
||||
return false;
|
||||
|
|
|
@ -126,10 +126,6 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||
set(MINDSPORE_RPATH
|
||||
|
@ -141,7 +137,6 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/acllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/acllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||
set(MINDSPORE_RPATH
|
||||
|
|
|
@ -22,12 +22,12 @@
|
|||
#include "external/hccl/hccl.h"
|
||||
#include "runtime/device/ascend/ascend_memory_pool.h"
|
||||
#include "backend/kernel_compiler/hccl/hcom_util.h"
|
||||
#include "backend/kernel_compiler/hccl/hccl_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "runtime/device/ascend/ascend_event.h"
|
||||
#include "runtime/device/ascend/ascend_launch_mul.h"
|
||||
#include "runtime/device/ascend/ascend_launch_atomic_clean.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#include "utils/profile.h"
|
||||
|
||||
#define CHECK_ASCEND_RT_WITH_EXCEPTION(expression, message) \
|
||||
|
@ -147,8 +147,8 @@ void AscendBucket::LaunchAllReduce() {
|
|||
auto hccl_count = total_size_ / type_size;
|
||||
|
||||
HcclReduceOp op_type = HcclReduceOp::HCCL_REDUCE_SUM;
|
||||
auto hccl_result = HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count, iter->second, op_type,
|
||||
kernel::HcclContext::GetInstance().hccl_comm(), stream_);
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count,
|
||||
iter->second, op_type, stream_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "HcclAllReduce faled, ret:" << hccl_result;
|
||||
}
|
||||
|
|
|
@ -51,7 +51,6 @@
|
|||
#include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#include "runtime/device/ascend/profiling/profiling_callback_register.h"
|
||||
#include "backend/kernel_compiler/hccl/hccl_context.h"
|
||||
#ifdef ENABLE_TDTQUE
|
||||
#include "minddata/dataset/engine/tdt/tdt_handle.h"
|
||||
using mindspore::dataset::TdtHandle;
|
||||
|
@ -260,7 +259,6 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
|
|||
MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret;
|
||||
}
|
||||
|
||||
(void)DestroySingleOpHccl();
|
||||
(void)DestroyHccl();
|
||||
(void)ResetDevice(device_id);
|
||||
(void)ProfilingManager::GetInstance().StopProfiling();
|
||||
|
@ -340,11 +338,7 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
|||
GenKernelEvents(graph);
|
||||
return true;
|
||||
}
|
||||
// Do HcomExecutorInitialize
|
||||
if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) {
|
||||
MS_LOG(ERROR) << "Init Hccl Executor Failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!GenTask(graph)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -837,25 +831,13 @@ bool AscendKernelRuntime::HcclInit() {
|
|||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
|
||||
bool ret = hccl::InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str, full_path);
|
||||
bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str,
|
||||
full_path);
|
||||
free(full_path);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Hcom init failed.";
|
||||
return false;
|
||||
}
|
||||
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || !task_sink) {
|
||||
MS_LOG(INFO) << "Hccl comm init.";
|
||||
return kernel::HcclContext::GetInstance().InitHccl();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::DestroySingleOpHccl() {
|
||||
if (!kernel::HcclContext::GetInstance().Finalize()) {
|
||||
MS_LOG(ERROR) << "Hccl finalize failed";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -869,11 +851,7 @@ bool AscendKernelRuntime::DestroyHccl() {
|
|||
MS_LOG(INFO) << "Hccl is not enable, no need to close.";
|
||||
return true;
|
||||
}
|
||||
// Dynamic Shape Hccl Finalize
|
||||
if (!HcclExecutorManager::GetInstance().Finalize()) {
|
||||
MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed";
|
||||
}
|
||||
bool res = hccl::FinalizeHccl();
|
||||
bool res = hccl::HcclAdapter::GetInstance().FinalizeHccl();
|
||||
if (!res) {
|
||||
MS_LOG(ERROR) << "Hccl destroy failed";
|
||||
return false;
|
||||
|
|
|
@ -76,7 +76,6 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
static bool HcclInit();
|
||||
static bool NeedDestroyHccl();
|
||||
static bool DestroyHccl();
|
||||
static bool DestroySingleOpHccl();
|
||||
void SetCurrentContext();
|
||||
|
||||
void ClearGraphModelMap();
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "backend/kernel_compiler/hccl/hcom_util.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
|
||||
namespace {
|
||||
// Find so in RPATH or LD_LIBRARY_PATH (/usr/local/Ascend/fwkacllib/lib64/)
|
||||
|
@ -90,23 +91,12 @@ void HcclDynamicKernel::StaticShapeExecute() {
|
|||
|
||||
void HcclDynamicKernel::Execute() {
|
||||
MS_LOG(INFO) << "Start Execute";
|
||||
|
||||
auto EnqueueHcomOperation =
|
||||
(HcclResult(*)(ge::HcomOpertion, std::function<void(HcclResult status)>))HcclExecutorManager::GetInstance()
|
||||
.GetHcomOpertion();
|
||||
if (EnqueueHcomOperation == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EnqueueHcomOperation function";
|
||||
HcclExecutorManager::GetInstance().CloseHandle();
|
||||
MS_LOG(EXCEPTION) << "Hccl dynamic kernel execute failed";
|
||||
return;
|
||||
}
|
||||
|
||||
ge::HcomOpertion op_info;
|
||||
::HcomOperation op_info;
|
||||
op_info.hcclType = hccl_type_;
|
||||
op_info.inputPtr = input_ptr_;
|
||||
op_info.outputPtr = output_ptr_;
|
||||
op_info.dataType = data_type_;
|
||||
op_info.opType = op_type_;
|
||||
op_info.dataType = static_cast<HcclDataType>(data_type_);
|
||||
op_info.opType = static_cast<HcclReduceOp>(op_type_);
|
||||
op_info.root = root_;
|
||||
op_info.count = count_;
|
||||
|
||||
|
@ -119,7 +109,7 @@ void HcclDynamicKernel::Execute() {
|
|||
MS_LOG(INFO) << "hccl callback success.";
|
||||
};
|
||||
|
||||
auto hccl_ret = EnqueueHcomOperation(op_info, callback);
|
||||
auto hccl_ret = hccl::HcclAdapter::GetInstance().HcclExecEnqueueOp(op_info, callback);
|
||||
if (hccl_ret != HCCL_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Call EnqueueHcomOperation failed";
|
||||
}
|
||||
|
@ -130,70 +120,6 @@ void HcclDynamicKernel::Execute() {
|
|||
}
|
||||
|
||||
void HcclDynamicKernel::PostExecute() {}
|
||||
|
||||
bool HcclExecutorManager::Initialize() {
|
||||
if (initialized_) {
|
||||
return true;
|
||||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (!context->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
||||
return true;
|
||||
}
|
||||
initialized_ = true;
|
||||
MS_LOG(INFO) << "Start Initialize Hccl DynamicKernel";
|
||||
handle_ = dlopen(kHcomGraphAdaptorPath, RTLD_NOW | RTLD_GLOBAL);
|
||||
if (handle_ == nullptr) {
|
||||
MS_LOG(ERROR) << "dlopen failed, path:" << kHcomGraphAdaptorPath;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto HcomExecutorInitialize = (HcclResult(*)())dlsym(handle_, "HcomExecInitialize");
|
||||
if (HcomExecutorInitialize == nullptr) {
|
||||
MS_LOG(ERROR) << "dlsym HcomExecutorInitialize failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
HcclResult hccl_ret = HcomExecutorInitialize();
|
||||
if (hccl_ret == HCCL_E_PTR) {
|
||||
MS_LOG(WARNING) << "Hccl comm is null, hcom executor initialize is not required";
|
||||
} else if (hccl_ret == HCCL_SUCCESS) {
|
||||
MS_LOG(INFO) << "Hcom DynamicKernel Initialize success";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Hcom DynamicKernel Initialize failed";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HcclExecutorManager::Finalize() {
|
||||
if (!initialized_) {
|
||||
return true;
|
||||
}
|
||||
auto HcomExecutorFinalize = (HcclResult(*)())dlsym(handle_, "HcomExecFinalize");
|
||||
if (HcomExecutorFinalize == nullptr) {
|
||||
MS_LOG(ERROR) << "Fail to dlsym HcomExecutorFinalize";
|
||||
return false;
|
||||
}
|
||||
HcclResult hccl_ret = HcomExecutorFinalize();
|
||||
if (hccl_ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Hcom DynamicKernel Finalize failed";
|
||||
return false;
|
||||
}
|
||||
if (dlclose(handle_) != 0) {
|
||||
MS_LOG(ERROR) << "Failed to close hcom handle";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Hccl DynamicKernel Finalize success";
|
||||
return true;
|
||||
}
|
||||
|
||||
void *HcclExecutorManager::GetHcomOpertion() { return dlsym(handle_, "HcomExecEnqueueOperation"); }
|
||||
void HcclExecutorManager::CloseHandle() {
|
||||
if (dlclose(handle_) != 0) {
|
||||
MS_LOG(WARNING) << "Failed to close hcom handle";
|
||||
}
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,27 +56,6 @@ class HcclDynamicKernel : public DynamicKernel {
|
|||
|
||||
void StaticShapeExecute();
|
||||
};
|
||||
|
||||
class HcclExecutorManager {
|
||||
public:
|
||||
static HcclExecutorManager &GetInstance() {
|
||||
static HcclExecutorManager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool Initialize();
|
||||
bool Finalize();
|
||||
void *GetHcomOpertion();
|
||||
void CloseHandle();
|
||||
|
||||
private:
|
||||
HcclExecutorManager() = default;
|
||||
~HcclExecutorManager() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(HcclExecutorManager);
|
||||
|
||||
void *handle_{nullptr};
|
||||
bool initialized_{false};
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
file(GLOB_RECURSE HCCL_ADAPTER_SRC_LIST ./*.cc)
|
||||
file(GLOB HCCL_ADAPTER_SRC_LIST ./*.cc)
|
||||
set_property(SOURCE ${HCCL_ADAPTER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS
|
||||
SUBMODULE_ID=mindspore::SubModuleId::SM_HCCL_ADPT)
|
||||
if(ENABLE_D)
|
||||
add_library(_mindspore_runtime_hccl_adapter_obj OBJECT ${HCCL_ADAPTER_SRC_LIST})
|
||||
target_include_directories(_mindspore_runtime_hccl_adapter_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge)
|
||||
add_dependencies(_mindspore_runtime_hccl_adapter_obj graph)
|
||||
add_subdirectory(plugin)
|
||||
endif()
|
|
@ -14,29 +14,44 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#include <dlfcn.h>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#define google ascend_private
|
||||
#include "register/ops_kernel_builder_registry.h"
|
||||
#include "common/opskernel/ops_kernel_info_store.h"
|
||||
#include "common/opskernel/ops_kernel_builder.h"
|
||||
#include "external/ge/ge_api_types.h"
|
||||
#undef google
|
||||
#include "hccl/hccl.h"
|
||||
#include "hccl/hcom.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/hccl_adapter/converter.h"
|
||||
#include "runtime/hccl_adapter/hcom_graph_adaptor.h"
|
||||
|
||||
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
|
||||
static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
|
||||
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
|
||||
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
|
||||
// following global var, thread safety is not guaranteed
|
||||
static std::shared_ptr<ge::OpsKernelInfoStore> ops_kernel_info_store = nullptr;
|
||||
static ge::OpsKernelBuilderPtr ops_kernel_builder = nullptr;
|
||||
|
||||
namespace mindspore::hccl {
|
||||
inline static std::string GetDlErrorMsg() {
|
||||
const char *result = dlerror();
|
||||
return (result == nullptr) ? "Unknown" : result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static T DlsymWithCast(void *handle, const char *symbol_name) {
|
||||
T symbol = reinterpret_cast<T>(dlsym(handle, symbol_name));
|
||||
if (symbol == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Dlsym symbol " << symbol_name << " failed, result = " << GetDlErrorMsg();
|
||||
}
|
||||
return symbol;
|
||||
}
|
||||
|
||||
#define DlsymFuncObj(handle, func_name) DlsymWithCast<func_name##FunPtr>(handle, k##func_name##Name);
|
||||
|
||||
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
|
||||
std::string_view rank_file) {
|
||||
auto env_deploy_mode = common::GetEnv(kHcclDeployModeEnv);
|
||||
auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv);
|
||||
if (env_deploy_mode.empty()) {
|
||||
MS_LOG(WARNING) << kHcclDeployModeEnv << " is not set in ENV. Now set to default value 0";
|
||||
env_deploy_mode = "0";
|
||||
|
@ -53,7 +68,7 @@ static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std
|
|||
{ge::OPTION_EXEC_HCCL_FLAG, "1"},
|
||||
{ge::OPTION_EXEC_DEPLOY_MODE, env_deploy_mode}};
|
||||
|
||||
auto env_hccl_algo = common::GetEnv(kHcclAlgoEnv);
|
||||
auto env_hccl_algo = mindspore::common::GetEnv(kHcclAlgoEnv);
|
||||
if (!env_hccl_algo.empty()) {
|
||||
std::string ge_hccl_algo = "HCCL_algorithm";
|
||||
default_options_map.emplace(ge_hccl_algo, env_hccl_algo);
|
||||
|
@ -62,74 +77,111 @@ static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std
|
|||
return default_options_map;
|
||||
}
|
||||
|
||||
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
|
||||
namespace mindspore::hccl {
|
||||
HcclAdapter &HcclAdapter::GetInstance() {
|
||||
static HcclAdapter instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void HcclAdapter::InitPlugin() {
|
||||
if (plugin_handle_ != nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
plugin_handle_ = dlopen(kHcclPluginFileName, RTLD_NOW | RTLD_GLOBAL);
|
||||
if (plugin_handle_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
|
||||
}
|
||||
|
||||
init_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, InitHcomGraphAdapter);
|
||||
finalize_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, FinalizeHcomGraphAdapter);
|
||||
get_hccl_kernel_info_store_ = DlsymFuncObj(plugin_handle_, GetHcclKernelInfoStore);
|
||||
get_all_kernel_builder_ = DlsymFuncObj(plugin_handle_, GetAllKernelBuilder);
|
||||
init_hccl_comm_ = DlsymFuncObj(plugin_handle_, InitHcclComm);
|
||||
finalize_hccl_comm_ = DlsymFuncObj(plugin_handle_, FinalizeHcclComm);
|
||||
launch_hccl_broadcast_ = DlsymFuncObj(plugin_handle_, LaunchHcclBroadcast);
|
||||
launch_hccl_all_reduce_ = DlsymFuncObj(plugin_handle_, LaunchHcclAllReduce);
|
||||
hccl_create_group_ = DlsymFuncObj(plugin_handle_, HcclCreateGroup);
|
||||
hccl_destroy_group_ = DlsymFuncObj(plugin_handle_, HcclDestroyGroup);
|
||||
hccl_get_rank_id_ = DlsymFuncObj(plugin_handle_, HcclGetRankId);
|
||||
hccl_get_rank_size_ = DlsymFuncObj(plugin_handle_, HcclGetRankSize);
|
||||
hccl_exec_initialize_ = DlsymFuncObj(plugin_handle_, HcclExecInitialize);
|
||||
hccl_exec_finalize_ = DlsymFuncObj(plugin_handle_, HcclExecFinalize);
|
||||
hccl_exec_enqueue_op_ = DlsymFuncObj(plugin_handle_, HcclExecEnqueueOp);
|
||||
}
|
||||
|
||||
void HcclAdapter::FinalizePlugin() {
|
||||
if (plugin_handle_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
init_hcom_graph_adapter_ = nullptr;
|
||||
finalize_hcom_graph_adapter_ = nullptr;
|
||||
get_hccl_kernel_info_store_ = nullptr;
|
||||
get_all_kernel_builder_ = nullptr;
|
||||
init_hccl_comm_ = nullptr;
|
||||
finalize_hccl_comm_ = nullptr;
|
||||
launch_hccl_broadcast_ = nullptr;
|
||||
launch_hccl_all_reduce_ = nullptr;
|
||||
hccl_create_group_ = nullptr;
|
||||
hccl_destroy_group_ = nullptr;
|
||||
hccl_get_rank_id_ = nullptr;
|
||||
hccl_get_rank_size_ = nullptr;
|
||||
hccl_exec_initialize_ = nullptr;
|
||||
hccl_exec_finalize_ = nullptr;
|
||||
hccl_exec_enqueue_op_ = nullptr;
|
||||
(void)dlclose(plugin_handle_);
|
||||
plugin_handle_ = nullptr;
|
||||
}
|
||||
|
||||
bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
|
||||
MS_LOG(INFO) << "Start init hccl adapter.";
|
||||
// get ops_kernel_builder
|
||||
std::map<std::string, ge::OpsKernelBuilderPtr> all_builders = ge::OpsKernelBuilderRegistry::GetInstance().GetAll();
|
||||
if (all_builders.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Builders size should be 1 (hccl builder), but is " << all_builders.size();
|
||||
std::lock_guard<std::mutex> lock(init_mutex_);
|
||||
if (init_flag_) {
|
||||
MS_LOG(INFO) << "Hccl has been inited, skip.";
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Get builder " << all_builders.begin()->first;
|
||||
ops_kernel_builder = all_builders.begin()->second;
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
|
||||
// init ops_kernel_builder
|
||||
auto options = GenHcclOptions(device_id, rank_id, rank_file);
|
||||
auto ret = ops_kernel_builder->Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init builder failed, ret = " << ret;
|
||||
InitPlugin();
|
||||
bool ret = InitKernelInfoStore(device_id, rank_id, rank_file);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
ret = InitHcclComm(rank_id, rank_file);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// get ops_kernel_info_store
|
||||
ret = ::Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init plugin so failed, ret = " << ret;
|
||||
ret = InitHcclExec();
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::map<std::string, std::shared_ptr<ge::OpsKernelInfoStore>> all_ops_kernel_info_stores;
|
||||
::GetOpsKernelInfoStores(all_ops_kernel_info_stores);
|
||||
for (auto &[name, ptr] : all_ops_kernel_info_stores) {
|
||||
if (name == kHcclOpsKernelInfoStore) {
|
||||
ops_kernel_info_store = ptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_info_store);
|
||||
ret = ops_kernel_info_store->Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init info store failed, ret = " << ret;
|
||||
}
|
||||
init_flag_ = true;
|
||||
MS_LOG(INFO) << "Init hccl adapter success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FinalizeHccl() {
|
||||
bool HcclAdapter::FinalizeHccl() {
|
||||
MS_LOG(INFO) << "Start destroy hccl adapter.";
|
||||
if (ops_kernel_info_store != nullptr) {
|
||||
auto ret = ops_kernel_info_store->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(init_mutex_);
|
||||
if (!init_flag_) {
|
||||
MS_LOG(INFO) << "Hccl has never been inited, skip.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops_kernel_builder != nullptr) {
|
||||
auto ret = ops_kernel_builder->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
::Finalize();
|
||||
ops_kernel_info_store.reset();
|
||||
ops_kernel_builder.reset();
|
||||
(void)FinalizeHcclExec();
|
||||
(void)FinalizeHcclComm();
|
||||
(void)FinalizeKernelInfoStore();
|
||||
FinalizePlugin();
|
||||
init_flag_ = false;
|
||||
MS_LOG(INFO) << "Destroy hccl adapter success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists) {
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
|
||||
bool HcclAdapter::GenTask(const AnfNodePtr &node, HcclDataType datatype,
|
||||
std::vector<HcclTaskInfo> *task_info_lists) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(task_info_lists);
|
||||
MS_LOG(INFO) << "Start generate task for hccl node " << node->DebugString();
|
||||
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
|
||||
|
@ -138,7 +190,8 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
|
|||
MS_EXCEPTION_IF_NULL(op);
|
||||
|
||||
MS_LOG(INFO) << "Start to call CalcOpRunningParam";
|
||||
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder_);
|
||||
ge::Status ret = ops_kernel_builder_->CalcOpRunningParam(*ge_node);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
|
||||
return false;
|
||||
|
@ -146,7 +199,7 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
|
|||
MS_LOG(INFO) << "Start to call GenerateTask";
|
||||
ge::RunContext unused_ctx;
|
||||
std::vector<domi::TaskDef> domi_tasks;
|
||||
ret = ops_kernel_builder->GenerateTask(*ge_node, unused_ctx, domi_tasks);
|
||||
ret = ops_kernel_builder_->GenerateTask(*ge_node, unused_ctx, domi_tasks);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "OpsKernelBuilder GenerateTask failed, ret = " << ret;
|
||||
return false;
|
||||
|
@ -160,8 +213,8 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
|
|||
return true;
|
||||
}
|
||||
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
|
||||
int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const {
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder_);
|
||||
MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype;
|
||||
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
|
||||
MS_EXCEPTION_IF_NULL(ge_node);
|
||||
|
@ -169,7 +222,7 @@ int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
|
|||
MS_EXCEPTION_IF_NULL(op);
|
||||
|
||||
MS_LOG(INFO) << "Start to call CalcOpRunningParam";
|
||||
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
|
||||
ge::Status ret = ops_kernel_builder_->CalcOpRunningParam(*ge_node);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
|
||||
return false;
|
||||
|
@ -185,12 +238,179 @@ int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
|
|||
return workspace_size;
|
||||
}
|
||||
|
||||
void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); }
|
||||
void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return ops_kernel_info_store_.get(); }
|
||||
|
||||
std::string GetHcclType(const AnfNodePtr &node) {
|
||||
std::string HcclAdapter::GetHcclType(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return GetGeNodeName(cnode);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root,
|
||||
aclrtStream stream) const {
|
||||
MS_EXCEPTION_IF_NULL(launch_hccl_broadcast_);
|
||||
return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType,
|
||||
HcclReduceOp op, aclrtStream stream) const {
|
||||
MS_EXCEPTION_IF_NULL(launch_hccl_all_reduce_);
|
||||
return launch_hccl_all_reduce_(sendBuf, recvBuf, count, dataType, op, hccl_comm_, stream);
|
||||
}
|
||||
|
||||
bool HcclAdapter::InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
|
||||
MS_LOG(INFO) << "Start init hccl kernel info store.";
|
||||
MS_EXCEPTION_IF_NULL(init_hcom_graph_adapter_);
|
||||
MS_EXCEPTION_IF_NULL(get_hccl_kernel_info_store_);
|
||||
// get ops_kernel_builder
|
||||
std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>> all_builders;
|
||||
get_all_kernel_builder_(&all_builders);
|
||||
if (all_builders.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Builders size should be 1 (hccl builder), but is " << all_builders.size();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Get builder " << all_builders.begin()->first;
|
||||
ops_kernel_builder_ = all_builders.begin()->second;
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder_);
|
||||
// init ops_kernel_builder
|
||||
auto options = GenHcclOptions(device_id, rank_id, rank_file);
|
||||
auto ret = ops_kernel_builder_->Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init hccl kernel builder failed, ret = " << ret;
|
||||
}
|
||||
|
||||
// get ops_kernel_info_store
|
||||
ret = init_hcom_graph_adapter_(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init hccl graph adapter failed, ret = " << ret;
|
||||
}
|
||||
|
||||
get_hccl_kernel_info_store_(&ops_kernel_info_store_);
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_info_store_);
|
||||
ret = ops_kernel_info_store_->Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init info store failed, ret = " << ret;
|
||||
}
|
||||
MS_LOG(INFO) << "Init hccl kernel info store success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HcclAdapter::FinalizeKernelInfoStore() {
|
||||
MS_LOG(INFO) << "Start destroy hccl kernel info store.";
|
||||
if (ops_kernel_info_store_ != nullptr) {
|
||||
auto ret = ops_kernel_info_store_->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops_kernel_builder_ != nullptr) {
|
||||
auto ret = ops_kernel_builder_->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(finalize_hcom_graph_adapter_);
|
||||
finalize_hcom_graph_adapter_();
|
||||
ops_kernel_info_store_.reset();
|
||||
ops_kernel_builder_.reset();
|
||||
MS_LOG(INFO) << "Destroy hccl kernel info store success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HcclAdapter::InitHcclComm(std::string_view rank_id, std::string_view rank_file) {
|
||||
MS_LOG(INFO) << "Start init hccl comm.";
|
||||
int rank_id_i = -1;
|
||||
try {
|
||||
rank_id_i = std::stoi(rank_id.data());
|
||||
} catch (std::invalid_argument &) {
|
||||
MS_LOG(EXCEPTION) << "Invalid rank id env:" << rank_id;
|
||||
}
|
||||
if (rank_id_i < 0 || rank_id_i > 7) {
|
||||
MS_LOG(ERROR) << "rank_id needs to be between 0-7";
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(init_hccl_comm_);
|
||||
auto hccl_result = init_hccl_comm_(rank_file.data(), rank_id_i, &hccl_comm_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclCommInitClusterInfo failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "InitHcclComm success";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HcclAdapter::FinalizeHcclComm() {
|
||||
MS_LOG(INFO) << "Start finalize hccl comm.";
|
||||
if (hccl_comm_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(finalize_hccl_comm_);
|
||||
auto hccl_result = finalize_hccl_comm_(hccl_comm_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclComm destroy failed, ret:" << hccl_result;
|
||||
return false;
|
||||
}
|
||||
hccl_comm_ = nullptr;
|
||||
MS_LOG(INFO) << "HcclComm destroy success";
|
||||
return true;
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const {
|
||||
MS_EXCEPTION_IF_NULL(hccl_create_group_);
|
||||
return hccl_create_group_(group.c_str(), rank_num, rank_ids);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const {
|
||||
MS_EXCEPTION_IF_NULL(hccl_destroy_group_);
|
||||
return hccl_destroy_group_(group.c_str());
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const {
|
||||
MS_EXCEPTION_IF_NULL(hccl_get_rank_id_);
|
||||
return hccl_get_rank_id_(group.c_str(), rank_id);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const {
|
||||
MS_EXCEPTION_IF_NULL(hccl_get_rank_size_);
|
||||
return hccl_get_rank_size_(group.c_str(), rank_size);
|
||||
}
|
||||
|
||||
bool HcclAdapter::InitHcclExec() {
|
||||
MS_LOG(INFO) << "Start init hccl exec.";
|
||||
MS_EXCEPTION_IF_NULL(hccl_exec_initialize_);
|
||||
HcclResult hccl_ret = hccl_exec_initialize_();
|
||||
if (hccl_ret == HCCL_E_PTR) {
|
||||
MS_LOG(WARNING) << "Hccl comm is null, hcom executor initialize is not required";
|
||||
} else if (hccl_ret == HCCL_SUCCESS) {
|
||||
MS_LOG(INFO) << "Hcom DynamicKernel Initialize success";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Hcom DynamicKernel Initialize failed";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "InitHcclExec success";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HcclAdapter::FinalizeHcclExec() {
|
||||
MS_LOG(INFO) << "Start finalize hccl exec.";
|
||||
MS_EXCEPTION_IF_NULL(hccl_exec_finalize_);
|
||||
HcclResult hccl_ret = hccl_exec_finalize_();
|
||||
if (hccl_ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Hcom DynamicKernel Finalize failed";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "HcclExec destroy success";
|
||||
return true;
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const {
|
||||
MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_);
|
||||
return hccl_exec_enqueue_op_(op_info, callback);
|
||||
}
|
||||
} // namespace mindspore::hccl
|
||||
|
|
|
@ -20,8 +20,15 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include "mindspore/core/ir/anf.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
#include "runtime/hccl_adapter/plugin/hccl_plugin.h"
|
||||
|
||||
namespace ge {
|
||||
class OpsKernelInfoStore;
|
||||
class OpsKernelBuilder;
|
||||
} // namespace ge
|
||||
|
||||
namespace mindspore::hccl {
|
||||
struct HcclTaskInfo {
|
||||
|
@ -30,11 +37,76 @@ struct HcclTaskInfo {
|
|||
int64_t stream_num;
|
||||
};
|
||||
|
||||
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
bool FinalizeHccl();
|
||||
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype);
|
||||
void *GetHcclOpsKernelInfoStore();
|
||||
std::string GetHcclType(const AnfNodePtr &node);
|
||||
class HcclAdapter {
|
||||
public:
|
||||
static HcclAdapter &GetInstance();
|
||||
|
||||
// common
|
||||
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
bool FinalizeHccl();
|
||||
|
||||
HcclResult HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const;
|
||||
HcclResult HcclDestroyGroup(const std::string &group) const;
|
||||
HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const;
|
||||
HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const;
|
||||
|
||||
// for ge node
|
||||
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists) const;
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) const;
|
||||
void *GetHcclOpsKernelInfoStore() const;
|
||||
static std::string GetHcclType(const AnfNodePtr &node);
|
||||
|
||||
// for single op
|
||||
HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, aclrtStream stream) const;
|
||||
HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
|
||||
aclrtStream stream) const;
|
||||
|
||||
// for enqueue op
|
||||
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const;
|
||||
|
||||
private:
|
||||
HcclAdapter() = default;
|
||||
~HcclAdapter() = default;
|
||||
void InitPlugin();
|
||||
void FinalizePlugin();
|
||||
|
||||
bool InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
bool FinalizeKernelInfoStore();
|
||||
|
||||
bool InitHcclComm(std::string_view rank_id, std::string_view rank_file);
|
||||
bool FinalizeHcclComm();
|
||||
|
||||
bool InitHcclExec();
|
||||
bool FinalizeHcclExec();
|
||||
|
||||
void *plugin_handle_ = nullptr;
|
||||
|
||||
InitHcomGraphAdapterFunObj init_hcom_graph_adapter_ = nullptr;
|
||||
FinalizeHcomGraphAdapterFunObj finalize_hcom_graph_adapter_ = nullptr;
|
||||
GetHcclKernelInfoStoreFunObj get_hccl_kernel_info_store_ = nullptr;
|
||||
GetAllKernelBuilderFunObj get_all_kernel_builder_ = nullptr;
|
||||
|
||||
InitHcclCommFunObj init_hccl_comm_ = nullptr;
|
||||
FinalizeHcclCommFunObj finalize_hccl_comm_ = nullptr;
|
||||
LaunchHcclBroadcastFunObj launch_hccl_broadcast_ = nullptr;
|
||||
LaunchHcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr;
|
||||
|
||||
HcclCreateGroupFunObj hccl_create_group_ = nullptr;
|
||||
HcclDestroyGroupFunObj hccl_destroy_group_ = nullptr;
|
||||
HcclGetRankIdFunObj hccl_get_rank_id_ = nullptr;
|
||||
HcclGetRankSizeFunObj hccl_get_rank_size_ = nullptr;
|
||||
|
||||
HcclExecInitializeFunObj hccl_exec_initialize_ = nullptr;
|
||||
HcclExecFinalizeFunObj hccl_exec_finalize_ = nullptr;
|
||||
HcclExecEnqueueOpFunObj hccl_exec_enqueue_op_ = nullptr;
|
||||
|
||||
HcclComm hccl_comm_ = nullptr;
|
||||
|
||||
std::shared_ptr<::ge::OpsKernelInfoStore> ops_kernel_info_store_ = nullptr;
|
||||
std::shared_ptr<::ge::OpsKernelBuilder> ops_kernel_builder_ = nullptr;
|
||||
|
||||
bool init_flag_ = false;
|
||||
std::mutex init_mutex_;
|
||||
};
|
||||
} // namespace mindspore::hccl
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H
|
||||
#define MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ir/anf.h"
|
||||
#include "common/opskernel/ops_kernel_info_store.h"
|
||||
|
||||
extern "C" {
|
||||
ge::Status Initialize(const std::map<std::string, std::string> &);
|
||||
ge::Status Finalize();
|
||||
void GetOpsKernelInfoStores(std::map<std::string, std::shared_ptr<ge::OpsKernelInfoStore>> &);
|
||||
}
|
||||
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H
|
|
@ -0,0 +1,41 @@
|
|||
add_library(hccl_plugin SHARED hccl_plugin.cc)
|
||||
target_include_directories(hccl_plugin PRIVATE ${CMAKE_BINARY_DIR}/proto/ge)
|
||||
add_dependencies(hccl_plugin graph)
|
||||
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel)
|
||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel)
|
||||
set_target_properties(hccl_plugin PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
||||
|
||||
if(DEFINED ENV{D_LINK_PATH})
|
||||
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64")
|
||||
MESSAGE("system processor matches aarch64")
|
||||
set(D_LIB_PATH $ENV{D_LINK_PATH}/aarch64)
|
||||
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
MESSAGE("system processor matches x86_64")
|
||||
set(D_LIB_PATH $ENV{D_LINK_PATH}/x86_64)
|
||||
else()
|
||||
MESSAGE("system ${CMAKE_HOST_SYSTEM_PROCESSOR} not support")
|
||||
endif()
|
||||
else()
|
||||
MESSAGE("use system default lib")
|
||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
|
||||
else()
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
endif()
|
||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||
set(ASCEND_PLUGIN_PATH ${ASCEND_RUNTIME_PATH}/plugin/opskernel)
|
||||
set(ASCEND_TOOLKIT_RUNTIME_PATH ${ASCEND_PATH}/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
set(ASCEND_TOOLKIT_PLUGIN_PATH ${ASCEND_TOOLKIT_RUNTIME_PATH}/plugin/opskernel)
|
||||
endif()
|
||||
|
||||
find_library(HCCL hccl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(REGISTER register ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(HCCL_ADPTER hcom_graph_adaptor ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(HCCL_RA ra ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(HCCL_BUILDER hcom_opskernel_builder ${ASCEND_PLUGIN_PATH} ${ASCEND_TOOLKIT_PLUGIN_PATH})
|
||||
target_link_libraries(hccl_plugin -Wl,--no-as-needed ${HCCL} ${HCCL_ADPTER} ${REGISTER} ${HCCL_BUILDER} ${HCCL_RA})
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/hccl_adapter/plugin/hccl_plugin.h"
|
||||
#define google ascend_private
|
||||
#include "register/ops_kernel_builder_registry.h"
|
||||
#include "common/opskernel/ops_kernel_info_store.h"
|
||||
#undef google
|
||||
#include "hccl/hcom.h"
|
||||
|
||||
static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
|
||||
|
||||
extern "C" {
|
||||
ge::Status Initialize(const std::map<std::string, std::string> &);
|
||||
ge::Status Finalize();
|
||||
void GetOpsKernelInfoStores(std::map<std::string, std::shared_ptr<ge::OpsKernelInfoStore>> &);
|
||||
|
||||
ge::Status PluginInitHcomGraphAdapter(const std::map<std::string, std::string> &options) {
|
||||
return ::Initialize(options);
|
||||
}
|
||||
|
||||
ge::Status PluginFinalizeHcomGraphAdapter() { return ::Finalize(); }
|
||||
|
||||
void PluginGetHcclKernelInfoStore(std::shared_ptr<ge::OpsKernelInfoStore> *hccl_kernel_info_store) {
|
||||
if (hccl_kernel_info_store == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, std::shared_ptr<ge::OpsKernelInfoStore>> all_ops_kernel_info_stores;
|
||||
::GetOpsKernelInfoStores(all_ops_kernel_info_stores);
|
||||
for (auto &[name, ptr] : all_ops_kernel_info_stores) {
|
||||
if (name == kHcclOpsKernelInfoStore) {
|
||||
*hccl_kernel_info_store = ptr;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
*hccl_kernel_info_store = nullptr;
|
||||
}
|
||||
|
||||
void PluginGetAllKernelBuilder(std::map<std::string, ge::OpsKernelBuilderPtr> *all_ops_kernel_builder) {
|
||||
if (all_ops_kernel_builder == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
*all_ops_kernel_builder = ge::OpsKernelBuilderRegistry::GetInstance().GetAll();
|
||||
}
|
||||
|
||||
HcclResult PluginLaunchHcclBroadcast(void *buf, uint64_t count, HcclDataType data_type, uint32_t root, HcclComm comm,
|
||||
aclrtStream stream) {
|
||||
return HcclBroadcast(buf, count, data_type, root, comm, stream);
|
||||
}
|
||||
|
||||
HcclResult PluginLaunchHcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType data_type,
|
||||
HcclReduceOp op, HcclComm comm, aclrtStream stream) {
|
||||
return HcclAllReduce(send_buf, recv_buf, count, data_type, op, comm, stream);
|
||||
}
|
||||
|
||||
HcclResult PluginInitHcclComm(const char *cluster_info, uint32_t rank, HcclComm *comm) {
|
||||
return HcclCommInitClusterInfo(cluster_info, rank, comm);
|
||||
}
|
||||
|
||||
HcclResult PluginFinalizeHcclComm(HcclComm comm) { return HcclCommDestroy(comm); }
|
||||
|
||||
HcclResult PluginHcclCreateGroup(const char *group, uint32_t rank_num, uint32_t *rank_ids) {
|
||||
return HcomCreateGroup(group, rank_num, rank_ids);
|
||||
}
|
||||
|
||||
HcclResult PluginHcclDestroyGroup(const char *group) { return HcomDestroyGroup(group); }
|
||||
HcclResult PluginHcclGetRankId(const char *group, uint32_t *rank_id) { return HcomGetRankId(group, rank_id); }
|
||||
HcclResult PluginHcclGetRankSize(const char *group, uint32_t *rank_size) { return HcomGetRankSize(group, rank_size); }
|
||||
|
||||
HcclResult PluginHcclExecInitialize() { return HcomExecInitialize(); }
|
||||
HcclResult PluginHcclExecFinalize() { return HcomExecFinalize(); }
|
||||
HcclResult PluginHcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) {
|
||||
return HcomExecEnqueueOperation(op_info, callback);
|
||||
}
|
||||
}; // extern C
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H
|
||||
#define MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include "external/ge/ge_api_types.h"
|
||||
#include "hccl/hccl.h"
|
||||
|
||||
namespace ge {
|
||||
class OpsKernelBuilder;
|
||||
class OpsKernelInfoStore;
|
||||
} // namespace ge
|
||||
|
||||
extern "C" {
|
||||
struct HcomOperation;
|
||||
} // extern C
|
||||
|
||||
using OptionsType = std::map<std::string, std::string>;
|
||||
using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>;
|
||||
using HExecCallBack = std::function<void(HcclResult status)>;
|
||||
|
||||
#define PLUGIN_METHOD(name, return_type, params...) \
|
||||
extern "C" { \
|
||||
__attribute__((visibility("default"))) return_type Plugin##name(params); \
|
||||
} \
|
||||
constexpr const char *k##name##Name = "Plugin" #name; \
|
||||
using name##FunObj = std::function<return_type(params)>; \
|
||||
using name##FunPtr = return_type (*)(params);
|
||||
|
||||
PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &);
|
||||
PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status);
|
||||
PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *);
|
||||
PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *);
|
||||
PLUGIN_METHOD(LaunchHcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
|
||||
PLUGIN_METHOD(LaunchHcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm,
|
||||
aclrtStream);
|
||||
PLUGIN_METHOD(InitHcclComm, HcclResult, const char *, uint32_t, HcclComm *);
|
||||
PLUGIN_METHOD(FinalizeHcclComm, HcclResult, HcclComm);
|
||||
PLUGIN_METHOD(HcclCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
|
||||
PLUGIN_METHOD(HcclDestroyGroup, HcclResult, const char *);
|
||||
PLUGIN_METHOD(HcclGetRankId, HcclResult, const char *, uint32_t *);
|
||||
PLUGIN_METHOD(HcclGetRankSize, HcclResult, const char *, uint32_t *);
|
||||
PLUGIN_METHOD(HcclExecInitialize, HcclResult);
|
||||
PLUGIN_METHOD(HcclExecFinalize, HcclResult);
|
||||
PLUGIN_METHOD(HcclExecEnqueueOp, HcclResult, const ::HcomOperation &, HExecCallBack);
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H
|
|
@ -18,7 +18,7 @@
|
|||
#include "utils/convert_utils.h"
|
||||
|
||||
#ifndef NO_DLIB
|
||||
#include "hccl/hcom.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_GPU)
|
||||
|
@ -67,26 +67,27 @@ bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int
|
|||
HCCL_GROUP_CHECK_EMPTY(group);
|
||||
HCCL_GROUP_CHECK_IS_WORLD(group);
|
||||
HCCL_RUN_CHECK(string("create communicate group"), group,
|
||||
HcomCreateGroup(group.c_str(), UlongToUint(rank_size), vector<unsigned int>(rank_id_list).data()));
|
||||
hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size),
|
||||
vector<unsigned int>(rank_id_list).data()));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
|
||||
HCCL_GROUP_CHECK_EMPTY(group);
|
||||
HCCL_RUN_CHECK(string("get rank_id"), group, HcomGetRankId(group.c_str(), rank_id));
|
||||
HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
|
||||
HCCL_GROUP_CHECK_EMPTY(group);
|
||||
HCCL_RUN_CHECK(string("get rank size"), group, HcomGetRankSize(group.c_str(), rank_size));
|
||||
HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommManager::DestroyGroup(const string &group) const {
|
||||
HCCL_GROUP_CHECK_EMPTY(group);
|
||||
HCCL_GROUP_CHECK_IS_WORLD(group);
|
||||
HCCL_RUN_CHECK(string("destroy communicate group"), group, HcomDestroyGroup(group.c_str()));
|
||||
HCCL_RUN_CHECK(string("destroy communicate group"), group, hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group));
|
||||
return true;
|
||||
}
|
||||
#elif defined(ENABLE_GPU)
|
||||
|
|
|
@ -41,9 +41,6 @@ void AiCoreDynamicKernel::UpdateArgs() {}
|
|||
void AiCoreDynamicKernel::Initialize() {}
|
||||
void AiCoreDynamicKernel::PostExecute() {}
|
||||
|
||||
bool HcclExecutorManager::Initialize() { return true; }
|
||||
bool HcclExecutorManager::Finalize() { return true; }
|
||||
|
||||
void OpTilingCalculater::Init() {}
|
||||
void OpTilingCalculater::CalculateTiling(const NotNull<CNodePtr> &cnode, const optiling::OpCompileInfo &op_compile_info,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
|
||||
|
|
|
@ -18,11 +18,26 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace hccl {
|
||||
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
|
||||
bool FinalizeHccl() { return true; }
|
||||
bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; }
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; }
|
||||
void *GetHcclOpsKernelInfoStore() { return nullptr; }
|
||||
std::string GetHcclType(const AnfNodePtr &) { return ""; }
|
||||
HcclAdapter &HcclAdapter::GetInstance() {
|
||||
static HcclAdapter instance;
|
||||
return instance;
|
||||
}
|
||||
bool HcclAdapter::InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
|
||||
bool HcclAdapter::FinalizeHccl() { return true; }
|
||||
HcclResult HcclAdapter::HcclCreateGroup(const std::string &, uint32_t, uint32_t *) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclDestroyGroup(const std::string &) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclGetRankId(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclGetRankSize(const std::string &, uint32_t *) const { return HCCL_SUCCESS; }
|
||||
bool HcclAdapter::GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) const { return true; }
|
||||
int64_t HcclAdapter::CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) const { return 0; }
|
||||
void *HcclAdapter::GetHcclOpsKernelInfoStore() const { return nullptr; }
|
||||
std::string HcclAdapter::GetHcclType(const AnfNodePtr &) { return ""; }
|
||||
HcclResult HcclAdapter::HcclBroadcast(void *, uint64_t, HcclDataType, uint32_t, aclrtStream) const {
|
||||
return HCCL_SUCCESS;
|
||||
}
|
||||
HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const {
|
||||
return HCCL_SUCCESS;
|
||||
}
|
||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &, HExecCallBack) const { return HCCL_SUCCESS; }
|
||||
} // namespace hccl
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue