forked from mindspore-Ecosystem/mindspore
!16467 modify defination of hccl origin function in hccl plugin
From: @zhoufeng54 Reviewed-by: @jjfeing,@xu-yfei Signed-off-by: @xu-yfei
This commit is contained in:
commit
84837b0bc9
|
@ -29,7 +29,6 @@
|
|||
#include "runtime/hccl_adapter/converter.h"
|
||||
|
||||
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
|
||||
static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
|
||||
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
|
||||
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
|
||||
|
||||
|
@ -47,7 +46,7 @@ static T DlsymWithCast(void *handle, const char *symbol_name) {
|
|||
return symbol;
|
||||
}
|
||||
|
||||
#define DlsymFuncObj(handle, func_name) DlsymWithCast<func_name##FunPtr>(handle, k##func_name##Name);
|
||||
#define DlsymFuncObj(func_name) DlsymWithCast<func_name##FunPtr>(plugin_handle_, k##func_name##Name);
|
||||
|
||||
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
|
||||
std::string_view rank_file) {
|
||||
|
@ -93,21 +92,21 @@ void HcclAdapter::InitPlugin() {
|
|||
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
|
||||
}
|
||||
|
||||
init_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, InitHcomGraphAdapter);
|
||||
finalize_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, FinalizeHcomGraphAdapter);
|
||||
get_hccl_kernel_info_store_ = DlsymFuncObj(plugin_handle_, GetHcclKernelInfoStore);
|
||||
get_all_kernel_builder_ = DlsymFuncObj(plugin_handle_, GetAllKernelBuilder);
|
||||
init_hccl_comm_ = DlsymFuncObj(plugin_handle_, InitHcclComm);
|
||||
finalize_hccl_comm_ = DlsymFuncObj(plugin_handle_, FinalizeHcclComm);
|
||||
launch_hccl_broadcast_ = DlsymFuncObj(plugin_handle_, LaunchHcclBroadcast);
|
||||
launch_hccl_all_reduce_ = DlsymFuncObj(plugin_handle_, LaunchHcclAllReduce);
|
||||
hccl_create_group_ = DlsymFuncObj(plugin_handle_, HcclCreateGroup);
|
||||
hccl_destroy_group_ = DlsymFuncObj(plugin_handle_, HcclDestroyGroup);
|
||||
hccl_get_rank_id_ = DlsymFuncObj(plugin_handle_, HcclGetRankId);
|
||||
hccl_get_rank_size_ = DlsymFuncObj(plugin_handle_, HcclGetRankSize);
|
||||
hccl_exec_initialize_ = DlsymFuncObj(plugin_handle_, HcclExecInitialize);
|
||||
hccl_exec_finalize_ = DlsymFuncObj(plugin_handle_, HcclExecFinalize);
|
||||
hccl_exec_enqueue_op_ = DlsymFuncObj(plugin_handle_, HcclExecEnqueueOp);
|
||||
init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter);
|
||||
finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter);
|
||||
get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore);
|
||||
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder);
|
||||
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo);
|
||||
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy);
|
||||
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast);
|
||||
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce);
|
||||
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup);
|
||||
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup);
|
||||
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId);
|
||||
hccl_get_rank_size_ = DlsymFuncObj(HcomGetRankSize);
|
||||
hccl_exec_initialize_ = DlsymFuncObj(HcomExecInitialize);
|
||||
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize);
|
||||
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation);
|
||||
}
|
||||
|
||||
void HcclAdapter::FinalizePlugin() {
|
||||
|
@ -409,7 +408,7 @@ bool HcclAdapter::FinalizeHcclExec() {
|
|||
return true;
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const {
|
||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
|
||||
MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_);
|
||||
return hccl_exec_enqueue_op_(op_info, callback);
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ class HcclAdapter {
|
|||
aclrtStream stream) const;
|
||||
|
||||
// for enqueue op
|
||||
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const;
|
||||
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const;
|
||||
|
||||
private:
|
||||
HcclAdapter() = default;
|
||||
|
@ -86,19 +86,19 @@ class HcclAdapter {
|
|||
GetHcclKernelInfoStoreFunObj get_hccl_kernel_info_store_ = nullptr;
|
||||
GetAllKernelBuilderFunObj get_all_kernel_builder_ = nullptr;
|
||||
|
||||
InitHcclCommFunObj init_hccl_comm_ = nullptr;
|
||||
FinalizeHcclCommFunObj finalize_hccl_comm_ = nullptr;
|
||||
LaunchHcclBroadcastFunObj launch_hccl_broadcast_ = nullptr;
|
||||
LaunchHcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr;
|
||||
HcclCommInitClusterInfoFunObj init_hccl_comm_ = nullptr;
|
||||
HcclCommDestroyFunObj finalize_hccl_comm_ = nullptr;
|
||||
HcclBroadcastFunObj launch_hccl_broadcast_ = nullptr;
|
||||
HcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr;
|
||||
|
||||
HcclCreateGroupFunObj hccl_create_group_ = nullptr;
|
||||
HcclDestroyGroupFunObj hccl_destroy_group_ = nullptr;
|
||||
HcclGetRankIdFunObj hccl_get_rank_id_ = nullptr;
|
||||
HcclGetRankSizeFunObj hccl_get_rank_size_ = nullptr;
|
||||
HcomCreateGroupFunObj hccl_create_group_ = nullptr;
|
||||
HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;
|
||||
HcomGetRankIdFunObj hccl_get_rank_id_ = nullptr;
|
||||
HcomGetRankSizeFunObj hccl_get_rank_size_ = nullptr;
|
||||
|
||||
HcclExecInitializeFunObj hccl_exec_initialize_ = nullptr;
|
||||
HcclExecFinalizeFunObj hccl_exec_finalize_ = nullptr;
|
||||
HcclExecEnqueueOpFunObj hccl_exec_enqueue_op_ = nullptr;
|
||||
HcomExecInitializeFunObj hccl_exec_initialize_ = nullptr;
|
||||
HcomExecFinalizeFunObj hccl_exec_finalize_ = nullptr;
|
||||
HcomExecEnqueueOperationFunObj hccl_exec_enqueue_op_ = nullptr;
|
||||
|
||||
HcclComm hccl_comm_ = nullptr;
|
||||
|
||||
|
|
|
@ -57,34 +57,4 @@ void PluginGetAllKernelBuilder(std::map<std::string, ge::OpsKernelBuilderPtr> *a
|
|||
|
||||
*all_ops_kernel_builder = ge::OpsKernelBuilderRegistry::GetInstance().GetAll();
|
||||
}
|
||||
|
||||
HcclResult PluginLaunchHcclBroadcast(void *buf, uint64_t count, HcclDataType data_type, uint32_t root, HcclComm comm,
|
||||
aclrtStream stream) {
|
||||
return HcclBroadcast(buf, count, data_type, root, comm, stream);
|
||||
}
|
||||
|
||||
HcclResult PluginLaunchHcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType data_type,
|
||||
HcclReduceOp op, HcclComm comm, aclrtStream stream) {
|
||||
return HcclAllReduce(send_buf, recv_buf, count, data_type, op, comm, stream);
|
||||
}
|
||||
|
||||
HcclResult PluginInitHcclComm(const char *cluster_info, uint32_t rank, HcclComm *comm) {
|
||||
return HcclCommInitClusterInfo(cluster_info, rank, comm);
|
||||
}
|
||||
|
||||
HcclResult PluginFinalizeHcclComm(HcclComm comm) { return HcclCommDestroy(comm); }
|
||||
|
||||
HcclResult PluginHcclCreateGroup(const char *group, uint32_t rank_num, uint32_t *rank_ids) {
|
||||
return HcomCreateGroup(group, rank_num, rank_ids);
|
||||
}
|
||||
|
||||
HcclResult PluginHcclDestroyGroup(const char *group) { return HcomDestroyGroup(group); }
|
||||
HcclResult PluginHcclGetRankId(const char *group, uint32_t *rank_id) { return HcomGetRankId(group, rank_id); }
|
||||
HcclResult PluginHcclGetRankSize(const char *group, uint32_t *rank_size) { return HcomGetRankSize(group, rank_size); }
|
||||
|
||||
HcclResult PluginHcclExecInitialize() { return HcomExecInitialize(); }
|
||||
HcclResult PluginHcclExecFinalize() { return HcomExecFinalize(); }
|
||||
HcclResult PluginHcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) {
|
||||
return HcomExecEnqueueOperation(op_info, callback);
|
||||
}
|
||||
}; // extern C
|
||||
|
|
|
@ -34,7 +34,7 @@ struct HcomOperation;
|
|||
|
||||
using OptionsType = std::map<std::string, std::string>;
|
||||
using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>;
|
||||
using HExecCallBack = std::function<void(HcclResult status)>;
|
||||
using HExecCallBack = std::function<void(HcclResult)>;
|
||||
|
||||
#define PLUGIN_METHOD(name, return_type, params...) \
|
||||
extern "C" { \
|
||||
|
@ -44,20 +44,28 @@ using HExecCallBack = std::function<void(HcclResult status)>;
|
|||
using name##FunObj = std::function<return_type(params)>; \
|
||||
using name##FunPtr = return_type (*)(params);
|
||||
|
||||
#define ORIGIN_METHOD(name, return_type, params...) \
|
||||
extern "C" { \
|
||||
return_type name(params); \
|
||||
} \
|
||||
constexpr const char *k##name##Name = #name; \
|
||||
using name##FunObj = std::function<return_type(params)>; \
|
||||
using name##FunPtr = return_type (*)(params);
|
||||
|
||||
PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &);
|
||||
PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status);
|
||||
PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *);
|
||||
PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *);
|
||||
PLUGIN_METHOD(LaunchHcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
|
||||
PLUGIN_METHOD(LaunchHcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm,
|
||||
aclrtStream);
|
||||
PLUGIN_METHOD(InitHcclComm, HcclResult, const char *, uint32_t, HcclComm *);
|
||||
PLUGIN_METHOD(FinalizeHcclComm, HcclResult, HcclComm);
|
||||
PLUGIN_METHOD(HcclCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
|
||||
PLUGIN_METHOD(HcclDestroyGroup, HcclResult, const char *);
|
||||
PLUGIN_METHOD(HcclGetRankId, HcclResult, const char *, uint32_t *);
|
||||
PLUGIN_METHOD(HcclGetRankSize, HcclResult, const char *, uint32_t *);
|
||||
PLUGIN_METHOD(HcclExecInitialize, HcclResult);
|
||||
PLUGIN_METHOD(HcclExecFinalize, HcclResult);
|
||||
PLUGIN_METHOD(HcclExecEnqueueOp, HcclResult, const ::HcomOperation &, HExecCallBack);
|
||||
|
||||
ORIGIN_METHOD(HcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
|
||||
ORIGIN_METHOD(HcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream);
|
||||
ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
|
||||
ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
|
||||
ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
|
||||
ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *);
|
||||
ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *);
|
||||
ORIGIN_METHOD(HcomGetRankSize, HcclResult, const char *, uint32_t *);
|
||||
ORIGIN_METHOD(HcomExecInitialize, HcclResult);
|
||||
ORIGIN_METHOD(HcomExecFinalize, HcclResult);
|
||||
ORIGIN_METHOD(HcomExecEnqueueOperation, HcclResult, ::HcomOperation, HExecCallBack);
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H
|
||||
|
|
|
@ -38,6 +38,8 @@ HcclResult HcclAdapter::HcclBroadcast(void *, uint64_t, HcclDataType, uint32_t,
|
|||
HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const {
|
||||
return HCCL_SUCCESS;
|
||||
}
|
||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &, HExecCallBack) const { return HCCL_SUCCESS; }
|
||||
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
|
||||
return HCCL_SUCCESS;
|
||||
}
|
||||
} // namespace hccl
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue