ascend support mpi

This commit is contained in:
baihuawei 2021-07-30 16:22:48 +08:00
parent 2ae65ae387
commit def6af8158
12 changed files with 616 additions and 47 deletions

View File

@ -17,6 +17,8 @@ endif()
if(ENABLE_D)
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc")
list(REMOVE_ITEM D_SRC_LIST "ascend/distribute/mpi_collective_group.cc"
"ascend/distribute/collective_group_wrapper.cc" "ascend/distribute/mpi_pycc.cc")
endif()
if(ENABLE_TDTQUE)
file(GLOB_RECURSE TDT_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
@ -43,6 +45,13 @@ if(ENABLE_MPI)
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
endif()
if(ENABLE_D)
set_property(SOURCE "ascend/distribute/mpi_pycc.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
pybind11_add_module(_ascend_mpi "ascend/distribute/mpi_pycc.cc")
target_link_libraries(_ascend_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
endif()
endif()
# gpu
@ -84,4 +93,17 @@ if(ENABLE_D)
set_property(SOURCE ${GE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE)
target_include_directories(_mindspore_runtime_device_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge)
add_dependencies(_mindspore_runtime_device_obj graph)
if(ENABLE_MPI)
set(ASCEND_PATH /usr/local/Ascend)
set(ASCEND_TOOLKIT_RUNTIME_PATH ${ASCEND_PATH}/ascend-toolkit/latest/fwkacllib/lib64)
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
find_library(HCCL hccl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
file(GLOB_RECURSE ASCEND_COLLECTIVE_LIST "ascend/distribute/mpi_collective_group.cc"
"ascend/distribute/collective_group_wrapper.cc")
set_property(SOURCE ${ASCEND_COLLECTIVE_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(ascend_collective SHARED ${ASCEND_COLLECTIVE_LIST})
target_link_libraries(ascend_collective PRIVATE ${HCCL} mindspore::ompi)
target_link_libraries(_ascend_mpi PRIVATE ascend_collective)
endif()
endif()

View File

@ -0,0 +1,94 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/distribute/ascend_collective.h"
#include "utils/log_adapter.h"
static constexpr const char *kAscendCollectiveFileName = "libascend_collective.so";
namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
HcclCollectiveGroup &HcclCollectiveGroup::instance() {
static HcclCollectiveGroup instance;
return instance;
}
void HcclCollectiveGroup::FinalizeCollective() {
MS_LOG(INFO) << "Finalize Collective";
if (collective_handle_ != nullptr) {
MS_EXCEPTION_IF_NULL(finalize_mpi_);
finalize_mpi_();
if (dlclose(collective_handle_) != 0) {
MS_LOG(EXCEPTION) << "Closing libascend_collective.so handle failed.";
}
}
}
bool HcclCollectiveGroup::InitCollective() {
MS_LOG(INFO) << "InitCollective";
if (inited_) {
return true;
}
collective_handle_ = dlopen(kAscendCollectiveFileName, RTLD_NOW);
if (collective_handle_ == nullptr) {
MS_LOG(EXCEPTION)
<< "Loading libascend_collective.so failed. Many reasons could cause this:\n1.libascend_collective.so is not "
"installed.\n2.hccl is not "
"installed or found.\n3.mpi is not installed or found";
}
init_mpi_ = DlsymFuncObj(InitMPI, collective_handle_);
finalize_mpi_ = DlsymFuncObj(FinalizeMPI, collective_handle_);
get_group_comm_ = DlsymFuncObj(GetGroupComm, collective_handle_);
get_group_size_ = DlsymFuncObj(GetGroupSize, collective_handle_);
get_rank_id_by_group_ = DlsymFuncObj(GetRankIdByGroup, collective_handle_);
get_device_id_ = DlsymFuncObj(GetDeviceId, collective_handle_);
create_comm_for_group_ = DlsymFuncObj(CreateCommForGroup, collective_handle_);
destroy_hccl_comm_ = DlsymFuncObj(DestroyHcclComm, collective_handle_);
MS_EXCEPTION_IF_NULL(init_mpi_);
init_mpi_();
inited_ = true;
MS_LOG(INFO) << "InitCollective success";
return true;
}
HcclComm HcclCollectiveGroup::GetGroupComm(const std::string &name) {
MS_EXCEPTION_IF_NULL(get_group_comm_);
return get_group_comm_(name);
}
int HcclCollectiveGroup::GetRankSize(const std::string &name) const {
MS_EXCEPTION_IF_NULL(get_group_size_);
return get_group_size_(name);
}
int HcclCollectiveGroup::GetRankId(const std::string &name) const {
MS_EXCEPTION_IF_NULL(get_rank_id_by_group_);
return get_rank_id_by_group_(name);
}
int HcclCollectiveGroup::GetDeviceId() const {
MS_EXCEPTION_IF_NULL(get_device_id_);
return get_device_id_();
}
void HcclCollectiveGroup::CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks) {
MS_EXCEPTION_IF_NULL(create_comm_for_group_);
create_comm_for_group_(name, ranks);
}
void HcclCollectiveGroup::DestroyCommGroup() {
MS_EXCEPTION_IF_NULL(destroy_hccl_comm_);
destroy_hccl_comm_();
}
} // namespace collective
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H
#include <dlfcn.h>
#include <vector>
#include <string>
#include <map>
#include "hccl/hccl_types.h"
#include "utils/utils.h"
#include "utils/dlopen_macro.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
ORIGIN_METHOD(InitMPI, void);
ORIGIN_METHOD(FinalizeMPI, void);
ORIGIN_METHOD(GetGroupComm, HcclComm, const std::string &);
ORIGIN_METHOD(GetGroupSize, int, const std::string &);
ORIGIN_METHOD(GetRankIdByGroup, int, const std::string &);
ORIGIN_METHOD(GetDeviceId, int);
ORIGIN_METHOD(CreateCommForGroup, bool, const std::string &, const std::vector<unsigned int> &);
ORIGIN_METHOD(DestroyHcclComm, void);
class HcclCollectiveGroup {
public:
HcclCollectiveGroup(HcclCollectiveGroup const &) = delete;
HcclCollectiveGroup &operator=(const HcclCollectiveGroup &) = delete;
static HcclCollectiveGroup &instance();
bool InitCollective();
void FinalizeCollective();
HcclComm GetGroupComm(const std::string &name);
int GetDeviceId() const;
int GetRankId(const std::string &name = kHcclWorldGroup) const;
int GetRankSize(const std::string &name = kHcclWorldGroup) const;
void CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks);
void DestroyCommGroup();
const void *collective_handle() const { return collective_handle_; }
private:
HcclCollectiveGroup() = default;
~HcclCollectiveGroup() = default;
bool inited_ = false;
void *collective_handle_ = nullptr;
InitMPIFunObj init_mpi_ = nullptr;
FinalizeMPIFunObj finalize_mpi_ = nullptr;
GetGroupCommFunObj get_group_comm_ = nullptr;
GetGroupSizeFunObj get_group_size_ = nullptr;
GetRankIdByGroupFunObj get_rank_id_by_group_ = nullptr;
GetDeviceIdFunObj get_device_id_ = nullptr;
CreateCommForGroupFunObj create_comm_for_group_ = nullptr;
DestroyHcclCommFunObj destroy_hccl_comm_ = nullptr;
};
} // namespace collective
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/distribute/collective_group_wrapper.h"
void InitMPI() { MPICollective::instance().Init(); }
void FinalizeMPI() { MPICollective::instance().FinalizeMPI(); }
int GetRankIdByGroup(const std::string &name) { return MPICollective::instance().GetRankIdByGroup(name); }
int GetGroupSize(const std::string &name) { return MPICollective::instance().GetGroupSize(name); }
int GetDeviceId() { return MPICollective::instance().GetDeviceId(); }
HcclComm GetGroupComm(const std::string &name) { return MPICollective::instance().GetGroupComm(name); }
bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks) {
return MPICollective::instance().CreateCommGroup(name, ranks);
}
void DestroyHcclComm() { MPICollective::instance().DestroyHcclComm(); }

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H
#include <vector>
#include <string>
#include "runtime/device/ascend/distribute/mpi_collective_group.h"
#ifndef EXPORT_WRAPPER
#define EXPORT_WRAPPER __attribute__((visibility("default")))
#endif
using MPICollective = mindspore::device::ascend::collective::MPICollective;
extern "C" EXPORT_WRAPPER void InitMPI();
extern "C" EXPORT_WRAPPER void FinalizeMPI();
extern "C" EXPORT_WRAPPER int GetRankIdByGroup(const std::string &name);
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &name);
extern "C" EXPORT_WRAPPER int GetDeviceId();
extern "C" EXPORT_WRAPPER HcclComm GetGroupComm(const std::string &name);
extern "C" EXPORT_WRAPPER bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks);
extern "C" EXPORT_WRAPPER void DestroyHcclComm();
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H

View File

@ -0,0 +1,132 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "hccl/hccl.h"
#include "runtime/rt.h"
#include "runtime/device/ascend/distribute/mpi_collective_group.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
MPICollective::MPICollective() : mpi_inited_(false), rank_id_(0), local_rank_id_(0), rank_size_(0) {}
void MPICollective::FinalizeMPI() {
group_info_.clear();
group_comm_.clear();
int finalized;
MPI_Finalized(&finalized);
if (finalized == 0) {
MPI_Finalize();
}
}
void MPICollective::DestroyHcclComm() {
for (auto &it : group_comm_) {
CHECK_RET(HcclCommDestroy(it.second), HCCL_SUCCESS, "HcclCommDestroy failed");
}
}
MPICollective &MPICollective::instance() {
static MPICollective instance;
return instance;
}
int MPICollective::GetRankIdByGroup(const std::string &name) {
CHECK_RET(group_info_.count(name), 1, "Failed to get MPI group rank by group name " + name);
return group_info_[name].first;
}
int MPICollective::GetGroupSize(const std::string &name) {
CHECK_RET(group_info_.count(name), 1, "Failed to get MPI group size by group name " + name);
return group_info_[name].second;
}
HcclComm MPICollective::GetGroupComm(const std::string &name) {
CHECK_RET(group_comm_.count(name), 1, "Failed to get MPI group comm by group name " + name);
return group_comm_[name];
}
int MPICollective::GetDeviceId() { return local_rank_id_; }
bool MPICollective::Init() {
int init_flag = 0;
CHECK_RET(MPI_Initialized(&init_flag), MPI_SUCCESS, "Check mpi initialized fail!");
if (init_flag == 0) {
CHECK_RET(MPI_Init(nullptr, nullptr), MPI_SUCCESS, "Failed to init mpi!");
}
CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_), MPI_SUCCESS, "comm_group_world_ init fail!");
CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id!");
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size!");
AssignLocalRankID();
group_info_["hccl_world_group"] = {rank_id_, rank_size_};
mpi_inited_ = true;
return true;
}
bool MPICollective::CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks) {
CHECK_RET(mpi_inited_, true, "HcclCollectiveGroup has not been inited.");
CHECK_RET(ranks.empty(), false, "Ranks is empty.");
std::vector<int> group_ranks(ranks.begin(), ranks.end());
CHECK_RET(group_comm_.count(name), 0, "Group comm has already been created.");
CHECK_RET(rtSetDevice(local_rank_id_), RT_ERROR_NONE, "Call rtSetDevice error.");
HcclRootInfo rootInfo;
if (static_cast<size_t>(rank_id_) == ranks[0]) {
CHECK_RET(HcclGetRootInfo(&rootInfo), HCCL_SUCCESS, "HcclGetRootInfo failed.");
}
MPI_Group mpi_group = MPI_GROUP_NULL;
CHECK_RET(MPI_Group_incl(comm_group_world_, group_ranks.size(), group_ranks.data(), &mpi_group), MPI_SUCCESS,
"Create mpi group failed!");
MPI_Comm mpi_group_comm;
CHECK_RET(MPI_Comm_create(MPI_COMM_WORLD, mpi_group, &mpi_group_comm), MPI_SUCCESS, "Create mpi comm fail!");
CHECK_RET(MPI_Bcast(&rootInfo, sizeof(rootInfo), MPI_BYTE, 0, mpi_group_comm), MPI_SUCCESS,
"Mpi reduce_scatter failed!");
HcclComm group_hcomm = nullptr;
int group_rank[1];
int global_rank[1] = {rank_id_};
CHECK_RET(MPI_Group_translate_ranks(comm_group_world_, 1, global_rank, mpi_group, group_rank), MPI_SUCCESS,
"Failed to translate global rank to group rank.");
if (group_rank[0] == MPI_UNDEFINED) {
return false;
}
CHECK_RET(HcclCommInitRootInfo(ranks.size(), &rootInfo, static_cast<uint32_t>(group_rank[0]), &group_hcomm),
HCCL_SUCCESS, "HcclCommInitRootInfo failed.");
group_comm_[name] = group_hcomm;
group_info_[name] = {group_rank[0], static_cast<int>(ranks.size())};
return true;
}
void MPICollective::AssignLocalRankID() {
char host_name[MAX_HOSTNAME_LEN] = {0};
CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), MPI_SUCCESS, "Getting host name failed!");
size_t host_hash = std::hash<std::string>()(host_name);
const int kRankSize = rank_size_;
size_t all_host_hashs[kRankSize];
all_host_hashs[rank_id_] = host_hash;
CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD),
MPI_SUCCESS, "MPI_Allgather host hash failed.");
for (int global_rank = 0; global_rank < kRankSize; global_rank++) {
if (global_rank == rank_id_) {
break;
}
if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) {
local_rank_id_++;
}
}
return;
}
} // namespace collective
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H
#include <mpi.h>
#include <unistd.h>
#include <map>
#include <string>
#include <vector>
#include <utility>
#include <sstream>
#include "hccl/hccl_types.h"
#include "pybind11/pybind11.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
constexpr int MAX_HOSTNAME_LEN = 1024;
class MPICollective {
public:
MPICollective(MPICollective const &) = delete;
MPICollective &operator=(const MPICollective &) = delete;
static MPICollective &instance();
void AssignLocalRankID();
bool Init();
void FinalizeMPI();
int GetRankIdByGroup(const std::string &name);
int GetGroupSize(const std::string &name);
HcclComm GetGroupComm(const std::string &name);
int GetDeviceId();
bool CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks);
void DestroyHcclComm();
private:
MPICollective();
~MPICollective() = default;
bool mpi_inited_;
int rank_id_;
int local_rank_id_;
int rank_size_;
MPI_Group comm_group_world_;
std::map<std::string, std::pair<int, int>> group_info_;
std::map<std::string, HcclComm> group_comm_;
};
#define CHECK_RET(expression, result, message) \
{ \
auto ret = (expression); \
if (ret != result) { \
std::ostringstream oss; \
oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ \
<< " | Ascend collective Error: " << message << " | Error Number " << ret; \
pybind11::pybind11_fail(oss.str()); \
} \
}
} // namespace collective
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_INIT_H

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/distribute/mpi_pycc.h"
#include <pybind11/operators.h>
namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
MpiPycc &MpiPycc::instance() {
static MpiPycc instance;
return instance;
}
int MpiPycc::GetDeviceID() { return GetDeviceId(); }
int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); }
int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); }
// cppcheck-suppress syntaxError
PYBIND11_MODULE(_ascend_mpi, mpi_initializer) {
mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id");
mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id");
mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size");
}
} // namespace collective
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H
#include <string>
#include "runtime/device/ascend/distribute/collective_group_wrapper.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace collective {
class MpiPycc {
public:
MpiPycc(MpiPycc const &) = delete;
MpiPycc &operator=(const MpiPycc &) = delete;
static MpiPycc &instance();
static int GetDeviceID();
static int GetRankId(const std::string &group);
static int GetRankSize(const std::string &group);
private:
MpiPycc() = default;
~MpiPycc() = default;
};
} // namespace collective
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H

View File

@ -32,22 +32,6 @@ static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
inline static std::string GetDlErrorMsg() {
const char *result = dlerror();
return (result == nullptr) ? "Unknown" : result;
}
template <class T>
static T DlsymWithCast(void *handle, const char *symbol_name) {
T symbol = reinterpret_cast<T>(dlsym(handle, symbol_name));
if (symbol == nullptr) {
MS_LOG(EXCEPTION) << "Dlsym symbol " << symbol_name << " failed, result = " << GetDlErrorMsg();
}
return symbol;
}
#define DlsymFuncObj(func_name) DlsymWithCast<func_name##FunPtr>(plugin_handle_, k##func_name##Name);
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
std::string_view rank_file) {
auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv);
@ -92,21 +76,21 @@ void HcclAdapter::InitPlugin() {
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
}
init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter);
finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter);
get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore);
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder);
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo);
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy);
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast);
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce);
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup);
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup);
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId);
hccl_get_rank_size_ = DlsymFuncObj(HcomGetRankSize);
hccl_exec_initialize_ = DlsymFuncObj(HcomExecInitialize);
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize);
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation);
init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter, plugin_handle_);
finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter, plugin_handle_);
get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore, plugin_handle_);
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder, plugin_handle_);
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo, plugin_handle_);
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy, plugin_handle_);
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_);
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_);
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_);
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup, plugin_handle_);
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId, plugin_handle_);
hccl_get_rank_size_ = DlsymFuncObj(HcomGetRankSize, plugin_handle_);
hccl_exec_initialize_ = DlsymFuncObj(HcomExecInitialize, plugin_handle_);
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize, plugin_handle_);
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation, plugin_handle_);
}
void HcclAdapter::FinalizePlugin() {

View File

@ -22,6 +22,7 @@
#include <functional>
#include "external/ge/ge_api_types.h"
#include "hccl/hccl.h"
#include "utils/dlopen_macro.h"
constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
@ -38,22 +39,6 @@ using OptionsType = std::map<std::string, std::string>;
using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>;
using HExecCallBack = std::function<void(HcclResult)>;
#define PLUGIN_METHOD(name, return_type, params...) \
extern "C" { \
__attribute__((visibility("default"))) return_type Plugin##name(params); \
} \
constexpr const char *k##name##Name = "Plugin" #name; \
using name##FunObj = std::function<return_type(params)>; \
using name##FunPtr = return_type (*)(params);
#define ORIGIN_METHOD(name, return_type, params...) \
extern "C" { \
return_type name(params); \
} \
constexpr const char *k##name##Name = #name; \
using name##FunObj = std::function<return_type(params)>; \
using name##FunPtr = return_type (*)(params);
PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &);
PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status);
PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *);

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H
#define MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H
#include <dlfcn.h>
#include <string>
#include <functional>
#include "utils/log_adapter.h"
#define PLUGIN_METHOD(name, return_type, params...) \
extern "C" { \
__attribute__((visibility("default"))) return_type Plugin##name(params); \
} \
constexpr const char *k##name##Name = "Plugin" #name; \
using name##FunObj = std::function<return_type(params)>; \
using name##FunPtr = return_type (*)(params);
#define ORIGIN_METHOD(name, return_type, params...) \
extern "C" { \
return_type name(params); \
} \
constexpr const char *k##name##Name = #name; \
using name##FunObj = std::function<return_type(params)>; \
using name##FunPtr = return_type (*)(params);
inline static std::string GetDlErrorMsg() {
const char *result = dlerror();
return (result == nullptr) ? "Unknown" : result;
}
template <class T>
static T DlsymWithCast(void *handle, const char *symbol_name) {
T symbol = reinterpret_cast<T>(dlsym(handle, symbol_name));
if (symbol == nullptr) {
MS_LOG(EXCEPTION) << "Dlsym symbol " << symbol_name << " failed, result = " << GetDlErrorMsg();
}
return symbol;
}
#define DlsymFuncObj(func_name, plugin_handle) DlsymWithCast<func_name##FunPtr>(plugin_handle, k##func_name##Name);
#endif // MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H