forked from mindspore-Ecosystem/mindspore
!16467 modify defination of hccl origin function in hccl plugin
From: @zhoufeng54 Reviewed-by: @jjfeing,@xu-yfei Signed-off-by: @xu-yfei
This commit is contained in:
commit
84837b0bc9
|
@ -29,7 +29,6 @@
|
||||||
#include "runtime/hccl_adapter/converter.h"
|
#include "runtime/hccl_adapter/converter.h"
|
||||||
|
|
||||||
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
|
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
|
||||||
static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
|
|
||||||
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
|
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
|
||||||
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
|
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
|
||||||
|
|
||||||
|
@ -47,7 +46,7 @@ static T DlsymWithCast(void *handle, const char *symbol_name) {
|
||||||
return symbol;
|
return symbol;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DlsymFuncObj(handle, func_name) DlsymWithCast<func_name##FunPtr>(handle, k##func_name##Name);
|
#define DlsymFuncObj(func_name) DlsymWithCast<func_name##FunPtr>(plugin_handle_, k##func_name##Name);
|
||||||
|
|
||||||
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
|
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
|
||||||
std::string_view rank_file) {
|
std::string_view rank_file) {
|
||||||
|
@ -93,21 +92,21 @@ void HcclAdapter::InitPlugin() {
|
||||||
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
|
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
|
||||||
}
|
}
|
||||||
|
|
||||||
init_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, InitHcomGraphAdapter);
|
init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter);
|
||||||
finalize_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, FinalizeHcomGraphAdapter);
|
finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter);
|
||||||
get_hccl_kernel_info_store_ = DlsymFuncObj(plugin_handle_, GetHcclKernelInfoStore);
|
get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore);
|
||||||
get_all_kernel_builder_ = DlsymFuncObj(plugin_handle_, GetAllKernelBuilder);
|
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder);
|
||||||
init_hccl_comm_ = DlsymFuncObj(plugin_handle_, InitHcclComm);
|
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo);
|
||||||
finalize_hccl_comm_ = DlsymFuncObj(plugin_handle_, FinalizeHcclComm);
|
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy);
|
||||||
launch_hccl_broadcast_ = DlsymFuncObj(plugin_handle_, LaunchHcclBroadcast);
|
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast);
|
||||||
launch_hccl_all_reduce_ = DlsymFuncObj(plugin_handle_, LaunchHcclAllReduce);
|
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce);
|
||||||
hccl_create_group_ = DlsymFuncObj(plugin_handle_, HcclCreateGroup);
|
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup);
|
||||||
hccl_destroy_group_ = DlsymFuncObj(plugin_handle_, HcclDestroyGroup);
|
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup);
|
||||||
hccl_get_rank_id_ = DlsymFuncObj(plugin_handle_, HcclGetRankId);
|
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId);
|
||||||
hccl_get_rank_size_ = DlsymFuncObj(plugin_handle_, HcclGetRankSize);
|
hccl_get_rank_size_ = DlsymFuncObj(HcomGetRankSize);
|
||||||
hccl_exec_initialize_ = DlsymFuncObj(plugin_handle_, HcclExecInitialize);
|
hccl_exec_initialize_ = DlsymFuncObj(HcomExecInitialize);
|
||||||
hccl_exec_finalize_ = DlsymFuncObj(plugin_handle_, HcclExecFinalize);
|
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize);
|
||||||
hccl_exec_enqueue_op_ = DlsymFuncObj(plugin_handle_, HcclExecEnqueueOp);
|
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HcclAdapter::FinalizePlugin() {
|
void HcclAdapter::FinalizePlugin() {
|
||||||
|
@ -409,7 +408,7 @@ bool HcclAdapter::FinalizeHcclExec() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const {
|
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
|
||||||
MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_);
|
MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_);
|
||||||
return hccl_exec_enqueue_op_(op_info, callback);
|
return hccl_exec_enqueue_op_(op_info, callback);
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ class HcclAdapter {
|
||||||
aclrtStream stream) const;
|
aclrtStream stream) const;
|
||||||
|
|
||||||
// for enqueue op
|
// for enqueue op
|
||||||
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const;
|
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
HcclAdapter() = default;
|
HcclAdapter() = default;
|
||||||
|
@ -86,19 +86,19 @@ class HcclAdapter {
|
||||||
GetHcclKernelInfoStoreFunObj get_hccl_kernel_info_store_ = nullptr;
|
GetHcclKernelInfoStoreFunObj get_hccl_kernel_info_store_ = nullptr;
|
||||||
GetAllKernelBuilderFunObj get_all_kernel_builder_ = nullptr;
|
GetAllKernelBuilderFunObj get_all_kernel_builder_ = nullptr;
|
||||||
|
|
||||||
InitHcclCommFunObj init_hccl_comm_ = nullptr;
|
HcclCommInitClusterInfoFunObj init_hccl_comm_ = nullptr;
|
||||||
FinalizeHcclCommFunObj finalize_hccl_comm_ = nullptr;
|
HcclCommDestroyFunObj finalize_hccl_comm_ = nullptr;
|
||||||
LaunchHcclBroadcastFunObj launch_hccl_broadcast_ = nullptr;
|
HcclBroadcastFunObj launch_hccl_broadcast_ = nullptr;
|
||||||
LaunchHcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr;
|
HcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr;
|
||||||
|
|
||||||
HcclCreateGroupFunObj hccl_create_group_ = nullptr;
|
HcomCreateGroupFunObj hccl_create_group_ = nullptr;
|
||||||
HcclDestroyGroupFunObj hccl_destroy_group_ = nullptr;
|
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;
|
||||||
HcclGetRankIdFunObj hccl_get_rank_id_ = nullptr;
|
HcomGetRankIdFunObj hccl_get_rank_id_ = nullptr;
|
||||||
HcclGetRankSizeFunObj hccl_get_rank_size_ = nullptr;
|
HcomGetRankSizeFunObj hccl_get_rank_size_ = nullptr;
|
||||||
|
|
||||||
HcclExecInitializeFunObj hccl_exec_initialize_ = nullptr;
|
HcomExecInitializeFunObj hccl_exec_initialize_ = nullptr;
|
||||||
HcclExecFinalizeFunObj hccl_exec_finalize_ = nullptr;
|
HcomExecFinalizeFunObj hccl_exec_finalize_ = nullptr;
|
||||||
HcclExecEnqueueOpFunObj hccl_exec_enqueue_op_ = nullptr;
|
HcomExecEnqueueOperationFunObj hccl_exec_enqueue_op_ = nullptr;
|
||||||
|
|
||||||
HcclComm hccl_comm_ = nullptr;
|
HcclComm hccl_comm_ = nullptr;
|
||||||
|
|
||||||
|
|
|
@ -57,34 +57,4 @@ void PluginGetAllKernelBuilder(std::map<std::string, ge::OpsKernelBuilderPtr> *a
|
||||||
|
|
||||||
*all_ops_kernel_builder = ge::OpsKernelBuilderRegistry::GetInstance().GetAll();
|
*all_ops_kernel_builder = ge::OpsKernelBuilderRegistry::GetInstance().GetAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
HcclResult PluginLaunchHcclBroadcast(void *buf, uint64_t count, HcclDataType data_type, uint32_t root, HcclComm comm,
|
|
||||||
aclrtStream stream) {
|
|
||||||
return HcclBroadcast(buf, count, data_type, root, comm, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
HcclResult PluginLaunchHcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType data_type,
|
|
||||||
HcclReduceOp op, HcclComm comm, aclrtStream stream) {
|
|
||||||
return HcclAllReduce(send_buf, recv_buf, count, data_type, op, comm, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
HcclResult PluginInitHcclComm(const char *cluster_info, uint32_t rank, HcclComm *comm) {
|
|
||||||
return HcclCommInitClusterInfo(cluster_info, rank, comm);
|
|
||||||
}
|
|
||||||
|
|
||||||
HcclResult PluginFinalizeHcclComm(HcclComm comm) { return HcclCommDestroy(comm); }
|
|
||||||
|
|
||||||
HcclResult PluginHcclCreateGroup(const char *group, uint32_t rank_num, uint32_t *rank_ids) {
|
|
||||||
return HcomCreateGroup(group, rank_num, rank_ids);
|
|
||||||
}
|
|
||||||
|
|
||||||
HcclResult PluginHcclDestroyGroup(const char *group) { return HcomDestroyGroup(group); }
|
|
||||||
HcclResult PluginHcclGetRankId(const char *group, uint32_t *rank_id) { return HcomGetRankId(group, rank_id); }
|
|
||||||
HcclResult PluginHcclGetRankSize(const char *group, uint32_t *rank_size) { return HcomGetRankSize(group, rank_size); }
|
|
||||||
|
|
||||||
HcclResult PluginHcclExecInitialize() { return HcomExecInitialize(); }
|
|
||||||
HcclResult PluginHcclExecFinalize() { return HcomExecFinalize(); }
|
|
||||||
HcclResult PluginHcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) {
|
|
||||||
return HcomExecEnqueueOperation(op_info, callback);
|
|
||||||
}
|
|
||||||
}; // extern C
|
}; // extern C
|
||||||
|
|
|
@ -34,7 +34,7 @@ struct HcomOperation;
|
||||||
|
|
||||||
using OptionsType = std::map<std::string, std::string>;
|
using OptionsType = std::map<std::string, std::string>;
|
||||||
using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>;
|
using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>;
|
||||||
using HExecCallBack = std::function<void(HcclResult status)>;
|
using HExecCallBack = std::function<void(HcclResult)>;
|
||||||
|
|
||||||
#define PLUGIN_METHOD(name, return_type, params...) \
|
#define PLUGIN_METHOD(name, return_type, params...) \
|
||||||
extern "C" { \
|
extern "C" { \
|
||||||
|
@ -44,20 +44,28 @@ using HExecCallBack = std::function<void(HcclResult status)>;
|
||||||
using name##FunObj = std::function<return_type(params)>; \
|
using name##FunObj = std::function<return_type(params)>; \
|
||||||
using name##FunPtr = return_type (*)(params);
|
using name##FunPtr = return_type (*)(params);
|
||||||
|
|
||||||
|
#define ORIGIN_METHOD(name, return_type, params...) \
|
||||||
|
extern "C" { \
|
||||||
|
return_type name(params); \
|
||||||
|
} \
|
||||||
|
constexpr const char *k##name##Name = #name; \
|
||||||
|
using name##FunObj = std::function<return_type(params)>; \
|
||||||
|
using name##FunPtr = return_type (*)(params);
|
||||||
|
|
||||||
PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &);
|
PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &);
|
||||||
PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status);
|
PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status);
|
||||||
PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *);
|
PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *);
|
||||||
PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *);
|
PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *);
|
||||||
PLUGIN_METHOD(LaunchHcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
|
|
||||||
PLUGIN_METHOD(LaunchHcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm,
|
ORIGIN_METHOD(HcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
|
||||||
aclrtStream);
|
ORIGIN_METHOD(HcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream);
|
||||||
PLUGIN_METHOD(InitHcclComm, HcclResult, const char *, uint32_t, HcclComm *);
|
ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
|
||||||
PLUGIN_METHOD(FinalizeHcclComm, HcclResult, HcclComm);
|
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
|
||||||
PLUGIN_METHOD(HcclCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
|
ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
|
||||||
PLUGIN_METHOD(HcclDestroyGroup, HcclResult, const char *);
|
ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *);
|
||||||
PLUGIN_METHOD(HcclGetRankId, HcclResult, const char *, uint32_t *);
|
ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *);
|
||||||
PLUGIN_METHOD(HcclGetRankSize, HcclResult, const char *, uint32_t *);
|
ORIGIN_METHOD(HcomGetRankSize, HcclResult, const char *, uint32_t *);
|
||||||
PLUGIN_METHOD(HcclExecInitialize, HcclResult);
|
ORIGIN_METHOD(HcomExecInitialize, HcclResult);
|
||||||
PLUGIN_METHOD(HcclExecFinalize, HcclResult);
|
ORIGIN_METHOD(HcomExecFinalize, HcclResult);
|
||||||
PLUGIN_METHOD(HcclExecEnqueueOp, HcclResult, const ::HcomOperation &, HExecCallBack);
|
ORIGIN_METHOD(HcomExecEnqueueOperation, HcclResult, ::HcomOperation, HExecCallBack);
|
||||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H
|
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H
|
||||||
|
|
|
@ -38,6 +38,8 @@ HcclResult HcclAdapter::HcclBroadcast(void *, uint64_t, HcclDataType, uint32_t,
|
||||||
HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const {
|
HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const {
|
||||||
return HCCL_SUCCESS;
|
return HCCL_SUCCESS;
|
||||||
}
|
}
|
||||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &, HExecCallBack) const { return HCCL_SUCCESS; }
|
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
|
||||||
|
return HCCL_SUCCESS;
|
||||||
|
}
|
||||||
} // namespace hccl
|
} // namespace hccl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue