!16467 modify defination of hccl origin function in hccl plugin

From: @zhoufeng54
Reviewed-by: @jjfeing,@xu-yfei
Signed-off-by: @xu-yfei
This commit is contained in:
mindspore-ci-bot 2021-05-18 17:44:45 +08:00 committed by Gitee
commit 84837b0bc9
5 changed files with 53 additions and 74 deletions

View File

@ -29,7 +29,6 @@
#include "runtime/hccl_adapter/converter.h" #include "runtime/hccl_adapter/converter.h"
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so"; static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE"; static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO"; static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
@ -47,7 +46,7 @@ static T DlsymWithCast(void *handle, const char *symbol_name) {
return symbol; return symbol;
} }
#define DlsymFuncObj(handle, func_name) DlsymWithCast<func_name##FunPtr>(handle, k##func_name##Name); #define DlsymFuncObj(func_name) DlsymWithCast<func_name##FunPtr>(plugin_handle_, k##func_name##Name);
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id, static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
std::string_view rank_file) { std::string_view rank_file) {
@ -93,21 +92,21 @@ void HcclAdapter::InitPlugin() {
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg(); MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
} }
init_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, InitHcomGraphAdapter); init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter);
finalize_hcom_graph_adapter_ = DlsymFuncObj(plugin_handle_, FinalizeHcomGraphAdapter); finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter);
get_hccl_kernel_info_store_ = DlsymFuncObj(plugin_handle_, GetHcclKernelInfoStore); get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore);
get_all_kernel_builder_ = DlsymFuncObj(plugin_handle_, GetAllKernelBuilder); get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder);
init_hccl_comm_ = DlsymFuncObj(plugin_handle_, InitHcclComm); init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo);
finalize_hccl_comm_ = DlsymFuncObj(plugin_handle_, FinalizeHcclComm); finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy);
launch_hccl_broadcast_ = DlsymFuncObj(plugin_handle_, LaunchHcclBroadcast); launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast);
launch_hccl_all_reduce_ = DlsymFuncObj(plugin_handle_, LaunchHcclAllReduce); launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce);
hccl_create_group_ = DlsymFuncObj(plugin_handle_, HcclCreateGroup); hccl_create_group_ = DlsymFuncObj(HcomCreateGroup);
hccl_destroy_group_ = DlsymFuncObj(plugin_handle_, HcclDestroyGroup); hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup);
hccl_get_rank_id_ = DlsymFuncObj(plugin_handle_, HcclGetRankId); hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId);
hccl_get_rank_size_ = DlsymFuncObj(plugin_handle_, HcclGetRankSize); hccl_get_rank_size_ = DlsymFuncObj(HcomGetRankSize);
hccl_exec_initialize_ = DlsymFuncObj(plugin_handle_, HcclExecInitialize); hccl_exec_initialize_ = DlsymFuncObj(HcomExecInitialize);
hccl_exec_finalize_ = DlsymFuncObj(plugin_handle_, HcclExecFinalize); hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize);
hccl_exec_enqueue_op_ = DlsymFuncObj(plugin_handle_, HcclExecEnqueueOp); hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation);
} }
void HcclAdapter::FinalizePlugin() { void HcclAdapter::FinalizePlugin() {
@ -409,7 +408,7 @@ bool HcclAdapter::FinalizeHcclExec() {
return true; return true;
} }
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const { HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_); MS_EXCEPTION_IF_NULL(hccl_exec_enqueue_op_);
return hccl_exec_enqueue_op_(op_info, callback); return hccl_exec_enqueue_op_(op_info, callback);
} }

View File

@ -62,7 +62,7 @@ class HcclAdapter {
aclrtStream stream) const; aclrtStream stream) const;
// for enqueue op // for enqueue op
HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) const; HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const;
private: private:
HcclAdapter() = default; HcclAdapter() = default;
@ -86,19 +86,19 @@ class HcclAdapter {
GetHcclKernelInfoStoreFunObj get_hccl_kernel_info_store_ = nullptr; GetHcclKernelInfoStoreFunObj get_hccl_kernel_info_store_ = nullptr;
GetAllKernelBuilderFunObj get_all_kernel_builder_ = nullptr; GetAllKernelBuilderFunObj get_all_kernel_builder_ = nullptr;
InitHcclCommFunObj init_hccl_comm_ = nullptr; HcclCommInitClusterInfoFunObj init_hccl_comm_ = nullptr;
FinalizeHcclCommFunObj finalize_hccl_comm_ = nullptr; HcclCommDestroyFunObj finalize_hccl_comm_ = nullptr;
LaunchHcclBroadcastFunObj launch_hccl_broadcast_ = nullptr; HcclBroadcastFunObj launch_hccl_broadcast_ = nullptr;
LaunchHcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr; HcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr;
HcclCreateGroupFunObj hccl_create_group_ = nullptr; HcomCreateGroupFunObj hccl_create_group_ = nullptr;
HcclDestroyGroupFunObj hccl_destroy_group_ = nullptr; HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr;
HcclGetRankIdFunObj hccl_get_rank_id_ = nullptr; HcomGetRankIdFunObj hccl_get_rank_id_ = nullptr;
HcclGetRankSizeFunObj hccl_get_rank_size_ = nullptr; HcomGetRankSizeFunObj hccl_get_rank_size_ = nullptr;
HcclExecInitializeFunObj hccl_exec_initialize_ = nullptr; HcomExecInitializeFunObj hccl_exec_initialize_ = nullptr;
HcclExecFinalizeFunObj hccl_exec_finalize_ = nullptr; HcomExecFinalizeFunObj hccl_exec_finalize_ = nullptr;
HcclExecEnqueueOpFunObj hccl_exec_enqueue_op_ = nullptr; HcomExecEnqueueOperationFunObj hccl_exec_enqueue_op_ = nullptr;
HcclComm hccl_comm_ = nullptr; HcclComm hccl_comm_ = nullptr;

View File

@ -57,34 +57,4 @@ void PluginGetAllKernelBuilder(std::map<std::string, ge::OpsKernelBuilderPtr> *a
*all_ops_kernel_builder = ge::OpsKernelBuilderRegistry::GetInstance().GetAll(); *all_ops_kernel_builder = ge::OpsKernelBuilderRegistry::GetInstance().GetAll();
} }
HcclResult PluginLaunchHcclBroadcast(void *buf, uint64_t count, HcclDataType data_type, uint32_t root, HcclComm comm,
aclrtStream stream) {
return HcclBroadcast(buf, count, data_type, root, comm, stream);
}
HcclResult PluginLaunchHcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType data_type,
HcclReduceOp op, HcclComm comm, aclrtStream stream) {
return HcclAllReduce(send_buf, recv_buf, count, data_type, op, comm, stream);
}
HcclResult PluginInitHcclComm(const char *cluster_info, uint32_t rank, HcclComm *comm) {
return HcclCommInitClusterInfo(cluster_info, rank, comm);
}
HcclResult PluginFinalizeHcclComm(HcclComm comm) { return HcclCommDestroy(comm); }
HcclResult PluginHcclCreateGroup(const char *group, uint32_t rank_num, uint32_t *rank_ids) {
return HcomCreateGroup(group, rank_num, rank_ids);
}
HcclResult PluginHcclDestroyGroup(const char *group) { return HcomDestroyGroup(group); }
HcclResult PluginHcclGetRankId(const char *group, uint32_t *rank_id) { return HcomGetRankId(group, rank_id); }
HcclResult PluginHcclGetRankSize(const char *group, uint32_t *rank_size) { return HcomGetRankSize(group, rank_size); }
HcclResult PluginHcclExecInitialize() { return HcomExecInitialize(); }
HcclResult PluginHcclExecFinalize() { return HcomExecFinalize(); }
HcclResult PluginHcclExecEnqueueOp(const ::HcomOperation &op_info, HExecCallBack callback) {
return HcomExecEnqueueOperation(op_info, callback);
}
}; // extern C }; // extern C

View File

@ -34,7 +34,7 @@ struct HcomOperation;
using OptionsType = std::map<std::string, std::string>; using OptionsType = std::map<std::string, std::string>;
using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>; using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>;
using HExecCallBack = std::function<void(HcclResult status)>; using HExecCallBack = std::function<void(HcclResult)>;
#define PLUGIN_METHOD(name, return_type, params...) \ #define PLUGIN_METHOD(name, return_type, params...) \
extern "C" { \ extern "C" { \
@ -44,20 +44,28 @@ using HExecCallBack = std::function<void(HcclResult status)>;
using name##FunObj = std::function<return_type(params)>; \ using name##FunObj = std::function<return_type(params)>; \
using name##FunPtr = return_type (*)(params); using name##FunPtr = return_type (*)(params);
#define ORIGIN_METHOD(name, return_type, params...) \
extern "C" { \
return_type name(params); \
} \
constexpr const char *k##name##Name = #name; \
using name##FunObj = std::function<return_type(params)>; \
using name##FunPtr = return_type (*)(params);
PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &); PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &);
PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status); PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status);
PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *); PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *);
PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *); PLUGIN_METHOD(GetAllKernelBuilder, void, OpsKernelBuilderMap *);
PLUGIN_METHOD(LaunchHcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
PLUGIN_METHOD(LaunchHcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, ORIGIN_METHOD(HcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream);
aclrtStream); ORIGIN_METHOD(HcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream);
PLUGIN_METHOD(InitHcclComm, HcclResult, const char *, uint32_t, HcclComm *); ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *);
PLUGIN_METHOD(FinalizeHcclComm, HcclResult, HcclComm); ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm);
PLUGIN_METHOD(HcclCreateGroup, HcclResult, const char *, uint32_t, uint32_t *); ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *);
PLUGIN_METHOD(HcclDestroyGroup, HcclResult, const char *); ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *);
PLUGIN_METHOD(HcclGetRankId, HcclResult, const char *, uint32_t *); ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *);
PLUGIN_METHOD(HcclGetRankSize, HcclResult, const char *, uint32_t *); ORIGIN_METHOD(HcomGetRankSize, HcclResult, const char *, uint32_t *);
PLUGIN_METHOD(HcclExecInitialize, HcclResult); ORIGIN_METHOD(HcomExecInitialize, HcclResult);
PLUGIN_METHOD(HcclExecFinalize, HcclResult); ORIGIN_METHOD(HcomExecFinalize, HcclResult);
PLUGIN_METHOD(HcclExecEnqueueOp, HcclResult, const ::HcomOperation &, HExecCallBack); ORIGIN_METHOD(HcomExecEnqueueOperation, HcclResult, ::HcomOperation, HExecCallBack);
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_PLUGIN_HCCL_PLUGIN_H

View File

@ -38,6 +38,8 @@ HcclResult HcclAdapter::HcclBroadcast(void *, uint64_t, HcclDataType, uint32_t,
HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const { HcclResult HcclAdapter::HcclAllReduce(void *, void *, uint64_t, HcclDataType, HcclReduceOp, aclrtStream) const {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &, HExecCallBack) const { return HCCL_SUCCESS; } HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
return HCCL_SUCCESS;
}
} // namespace hccl } // namespace hccl
} // namespace mindspore } // namespace mindspore