ascend support mpi
This commit is contained in:
parent
2ae65ae387
commit
def6af8158
|
@ -17,6 +17,8 @@ endif()
|
|||
|
||||
if(ENABLE_D)
|
||||
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc")
|
||||
list(REMOVE_ITEM D_SRC_LIST "ascend/distribute/mpi_collective_group.cc"
|
||||
"ascend/distribute/collective_group_wrapper.cc" "ascend/distribute/mpi_pycc.cc")
|
||||
endif()
|
||||
if(ENABLE_TDTQUE)
|
||||
file(GLOB_RECURSE TDT_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
|
@ -43,6 +45,13 @@ if(ENABLE_MPI)
|
|||
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
|
||||
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
|
||||
endif()
|
||||
|
||||
if(ENABLE_D)
|
||||
set_property(SOURCE "ascend/distribute/mpi_pycc.cc"
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
pybind11_add_module(_ascend_mpi "ascend/distribute/mpi_pycc.cc")
|
||||
target_link_libraries(_ascend_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# gpu
|
||||
|
@ -84,4 +93,17 @@ if(ENABLE_D)
|
|||
set_property(SOURCE ${GE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE)
|
||||
target_include_directories(_mindspore_runtime_device_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge)
|
||||
add_dependencies(_mindspore_runtime_device_obj graph)
|
||||
if(ENABLE_MPI)
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
set(ASCEND_TOOLKIT_RUNTIME_PATH ${ASCEND_PATH}/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||
find_library(HCCL hccl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
file(GLOB_RECURSE ASCEND_COLLECTIVE_LIST "ascend/distribute/mpi_collective_group.cc"
|
||||
"ascend/distribute/collective_group_wrapper.cc")
|
||||
set_property(SOURCE ${ASCEND_COLLECTIVE_LIST}
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(ascend_collective SHARED ${ASCEND_COLLECTIVE_LIST})
|
||||
target_link_libraries(ascend_collective PRIVATE ${HCCL} mindspore::ompi)
|
||||
target_link_libraries(_ascend_mpi PRIVATE ascend_collective)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/distribute/ascend_collective.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
static constexpr const char *kAscendCollectiveFileName = "libascend_collective.so";
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace collective {
|
||||
HcclCollectiveGroup &HcclCollectiveGroup::instance() {
|
||||
static HcclCollectiveGroup instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void HcclCollectiveGroup::FinalizeCollective() {
|
||||
MS_LOG(INFO) << "Finalize Collective";
|
||||
if (collective_handle_ != nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(finalize_mpi_);
|
||||
finalize_mpi_();
|
||||
if (dlclose(collective_handle_) != 0) {
|
||||
MS_LOG(EXCEPTION) << "Closing libascend_collective.so handle failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool HcclCollectiveGroup::InitCollective() {
|
||||
MS_LOG(INFO) << "InitCollective";
|
||||
if (inited_) {
|
||||
return true;
|
||||
}
|
||||
collective_handle_ = dlopen(kAscendCollectiveFileName, RTLD_NOW);
|
||||
if (collective_handle_ == nullptr) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Loading libascend_collective.so failed. Many reasons could cause this:\n1.libascend_collective.so is not "
|
||||
"installed.\n2.hccl is not "
|
||||
"installed or found.\n3.mpi is not installed or found";
|
||||
}
|
||||
init_mpi_ = DlsymFuncObj(InitMPI, collective_handle_);
|
||||
finalize_mpi_ = DlsymFuncObj(FinalizeMPI, collective_handle_);
|
||||
get_group_comm_ = DlsymFuncObj(GetGroupComm, collective_handle_);
|
||||
get_group_size_ = DlsymFuncObj(GetGroupSize, collective_handle_);
|
||||
get_rank_id_by_group_ = DlsymFuncObj(GetRankIdByGroup, collective_handle_);
|
||||
get_device_id_ = DlsymFuncObj(GetDeviceId, collective_handle_);
|
||||
create_comm_for_group_ = DlsymFuncObj(CreateCommForGroup, collective_handle_);
|
||||
destroy_hccl_comm_ = DlsymFuncObj(DestroyHcclComm, collective_handle_);
|
||||
MS_EXCEPTION_IF_NULL(init_mpi_);
|
||||
init_mpi_();
|
||||
inited_ = true;
|
||||
MS_LOG(INFO) << "InitCollective success";
|
||||
return true;
|
||||
}
|
||||
HcclComm HcclCollectiveGroup::GetGroupComm(const std::string &name) {
|
||||
MS_EXCEPTION_IF_NULL(get_group_comm_);
|
||||
return get_group_comm_(name);
|
||||
}
|
||||
int HcclCollectiveGroup::GetRankSize(const std::string &name) const {
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_);
|
||||
return get_group_size_(name);
|
||||
}
|
||||
int HcclCollectiveGroup::GetRankId(const std::string &name) const {
|
||||
MS_EXCEPTION_IF_NULL(get_rank_id_by_group_);
|
||||
return get_rank_id_by_group_(name);
|
||||
}
|
||||
int HcclCollectiveGroup::GetDeviceId() const {
|
||||
MS_EXCEPTION_IF_NULL(get_device_id_);
|
||||
return get_device_id_();
|
||||
}
|
||||
void HcclCollectiveGroup::CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks) {
|
||||
MS_EXCEPTION_IF_NULL(create_comm_for_group_);
|
||||
create_comm_for_group_(name, ranks);
|
||||
}
|
||||
void HcclCollectiveGroup::DestroyCommGroup() {
|
||||
MS_EXCEPTION_IF_NULL(destroy_hccl_comm_);
|
||||
destroy_hccl_comm_();
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "hccl/hccl_types.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/dlopen_macro.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace collective {
|
||||
|
||||
ORIGIN_METHOD(InitMPI, void);
|
||||
ORIGIN_METHOD(FinalizeMPI, void);
|
||||
ORIGIN_METHOD(GetGroupComm, HcclComm, const std::string &);
|
||||
ORIGIN_METHOD(GetGroupSize, int, const std::string &);
|
||||
ORIGIN_METHOD(GetRankIdByGroup, int, const std::string &);
|
||||
ORIGIN_METHOD(GetDeviceId, int);
|
||||
ORIGIN_METHOD(CreateCommForGroup, bool, const std::string &, const std::vector<unsigned int> &);
|
||||
ORIGIN_METHOD(DestroyHcclComm, void);
|
||||
|
||||
class HcclCollectiveGroup {
|
||||
public:
|
||||
HcclCollectiveGroup(HcclCollectiveGroup const &) = delete;
|
||||
HcclCollectiveGroup &operator=(const HcclCollectiveGroup &) = delete;
|
||||
static HcclCollectiveGroup &instance();
|
||||
bool InitCollective();
|
||||
void FinalizeCollective();
|
||||
HcclComm GetGroupComm(const std::string &name);
|
||||
int GetDeviceId() const;
|
||||
int GetRankId(const std::string &name = kHcclWorldGroup) const;
|
||||
int GetRankSize(const std::string &name = kHcclWorldGroup) const;
|
||||
void CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks);
|
||||
void DestroyCommGroup();
|
||||
const void *collective_handle() const { return collective_handle_; }
|
||||
|
||||
private:
|
||||
HcclCollectiveGroup() = default;
|
||||
~HcclCollectiveGroup() = default;
|
||||
bool inited_ = false;
|
||||
void *collective_handle_ = nullptr;
|
||||
InitMPIFunObj init_mpi_ = nullptr;
|
||||
FinalizeMPIFunObj finalize_mpi_ = nullptr;
|
||||
GetGroupCommFunObj get_group_comm_ = nullptr;
|
||||
GetGroupSizeFunObj get_group_size_ = nullptr;
|
||||
GetRankIdByGroupFunObj get_rank_id_by_group_ = nullptr;
|
||||
GetDeviceIdFunObj get_device_id_ = nullptr;
|
||||
CreateCommForGroupFunObj create_comm_for_group_ = nullptr;
|
||||
DestroyHcclCommFunObj destroy_hccl_comm_ = nullptr;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/distribute/collective_group_wrapper.h"
|
||||
|
||||
void InitMPI() { MPICollective::instance().Init(); }
|
||||
void FinalizeMPI() { MPICollective::instance().FinalizeMPI(); }
|
||||
int GetRankIdByGroup(const std::string &name) { return MPICollective::instance().GetRankIdByGroup(name); }
|
||||
int GetGroupSize(const std::string &name) { return MPICollective::instance().GetGroupSize(name); }
|
||||
int GetDeviceId() { return MPICollective::instance().GetDeviceId(); }
|
||||
HcclComm GetGroupComm(const std::string &name) { return MPICollective::instance().GetGroupComm(name); }
|
||||
bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks) {
|
||||
return MPICollective::instance().CreateCommGroup(name, ranks);
|
||||
}
|
||||
void DestroyHcclComm() { MPICollective::instance().DestroyHcclComm(); }
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "runtime/device/ascend/distribute/mpi_collective_group.h"
|
||||
#ifndef EXPORT_WRAPPER
|
||||
#define EXPORT_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
using MPICollective = mindspore::device::ascend::collective::MPICollective;
|
||||
|
||||
extern "C" EXPORT_WRAPPER void InitMPI();
|
||||
extern "C" EXPORT_WRAPPER void FinalizeMPI();
|
||||
extern "C" EXPORT_WRAPPER int GetRankIdByGroup(const std::string &name);
|
||||
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &name);
|
||||
extern "C" EXPORT_WRAPPER int GetDeviceId();
|
||||
extern "C" EXPORT_WRAPPER HcclComm GetGroupComm(const std::string &name);
|
||||
extern "C" EXPORT_WRAPPER bool CreateCommForGroup(const std::string &name, const std::vector<unsigned int> &ranks);
|
||||
extern "C" EXPORT_WRAPPER void DestroyHcclComm();
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "hccl/hccl.h"
|
||||
#include "runtime/rt.h"
|
||||
#include "runtime/device/ascend/distribute/mpi_collective_group.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace collective {
|
||||
MPICollective::MPICollective() : mpi_inited_(false), rank_id_(0), local_rank_id_(0), rank_size_(0) {}
|
||||
void MPICollective::FinalizeMPI() {
|
||||
group_info_.clear();
|
||||
group_comm_.clear();
|
||||
int finalized;
|
||||
MPI_Finalized(&finalized);
|
||||
if (finalized == 0) {
|
||||
MPI_Finalize();
|
||||
}
|
||||
}
|
||||
void MPICollective::DestroyHcclComm() {
|
||||
for (auto &it : group_comm_) {
|
||||
CHECK_RET(HcclCommDestroy(it.second), HCCL_SUCCESS, "HcclCommDestroy failed");
|
||||
}
|
||||
}
|
||||
MPICollective &MPICollective::instance() {
|
||||
static MPICollective instance;
|
||||
return instance;
|
||||
}
|
||||
int MPICollective::GetRankIdByGroup(const std::string &name) {
|
||||
CHECK_RET(group_info_.count(name), 1, "Failed to get MPI group rank by group name " + name);
|
||||
return group_info_[name].first;
|
||||
}
|
||||
int MPICollective::GetGroupSize(const std::string &name) {
|
||||
CHECK_RET(group_info_.count(name), 1, "Failed to get MPI group size by group name " + name);
|
||||
return group_info_[name].second;
|
||||
}
|
||||
HcclComm MPICollective::GetGroupComm(const std::string &name) {
|
||||
CHECK_RET(group_comm_.count(name), 1, "Failed to get MPI group comm by group name " + name);
|
||||
return group_comm_[name];
|
||||
}
|
||||
int MPICollective::GetDeviceId() { return local_rank_id_; }
|
||||
bool MPICollective::Init() {
|
||||
int init_flag = 0;
|
||||
CHECK_RET(MPI_Initialized(&init_flag), MPI_SUCCESS, "Check mpi initialized fail!");
|
||||
if (init_flag == 0) {
|
||||
CHECK_RET(MPI_Init(nullptr, nullptr), MPI_SUCCESS, "Failed to init mpi!");
|
||||
}
|
||||
|
||||
CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_), MPI_SUCCESS, "comm_group_world_ init fail!");
|
||||
|
||||
CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id!");
|
||||
|
||||
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size!");
|
||||
AssignLocalRankID();
|
||||
group_info_["hccl_world_group"] = {rank_id_, rank_size_};
|
||||
mpi_inited_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MPICollective::CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks) {
|
||||
CHECK_RET(mpi_inited_, true, "HcclCollectiveGroup has not been inited.");
|
||||
CHECK_RET(ranks.empty(), false, "Ranks is empty.");
|
||||
std::vector<int> group_ranks(ranks.begin(), ranks.end());
|
||||
CHECK_RET(group_comm_.count(name), 0, "Group comm has already been created.");
|
||||
CHECK_RET(rtSetDevice(local_rank_id_), RT_ERROR_NONE, "Call rtSetDevice error.");
|
||||
HcclRootInfo rootInfo;
|
||||
if (static_cast<size_t>(rank_id_) == ranks[0]) {
|
||||
CHECK_RET(HcclGetRootInfo(&rootInfo), HCCL_SUCCESS, "HcclGetRootInfo failed.");
|
||||
}
|
||||
MPI_Group mpi_group = MPI_GROUP_NULL;
|
||||
CHECK_RET(MPI_Group_incl(comm_group_world_, group_ranks.size(), group_ranks.data(), &mpi_group), MPI_SUCCESS,
|
||||
"Create mpi group failed!");
|
||||
MPI_Comm mpi_group_comm;
|
||||
|
||||
CHECK_RET(MPI_Comm_create(MPI_COMM_WORLD, mpi_group, &mpi_group_comm), MPI_SUCCESS, "Create mpi comm fail!");
|
||||
|
||||
CHECK_RET(MPI_Bcast(&rootInfo, sizeof(rootInfo), MPI_BYTE, 0, mpi_group_comm), MPI_SUCCESS,
|
||||
"Mpi reduce_scatter failed!");
|
||||
|
||||
HcclComm group_hcomm = nullptr;
|
||||
int group_rank[1];
|
||||
int global_rank[1] = {rank_id_};
|
||||
CHECK_RET(MPI_Group_translate_ranks(comm_group_world_, 1, global_rank, mpi_group, group_rank), MPI_SUCCESS,
|
||||
"Failed to translate global rank to group rank.");
|
||||
if (group_rank[0] == MPI_UNDEFINED) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CHECK_RET(HcclCommInitRootInfo(ranks.size(), &rootInfo, static_cast<uint32_t>(group_rank[0]), &group_hcomm),
|
||||
HCCL_SUCCESS, "HcclCommInitRootInfo failed.");
|
||||
group_comm_[name] = group_hcomm;
|
||||
group_info_[name] = {group_rank[0], static_cast<int>(ranks.size())};
|
||||
return true;
|
||||
}
|
||||
void MPICollective::AssignLocalRankID() {
|
||||
char host_name[MAX_HOSTNAME_LEN] = {0};
|
||||
CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), MPI_SUCCESS, "Getting host name failed!");
|
||||
size_t host_hash = std::hash<std::string>()(host_name);
|
||||
|
||||
const int kRankSize = rank_size_;
|
||||
size_t all_host_hashs[kRankSize];
|
||||
all_host_hashs[rank_id_] = host_hash;
|
||||
CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD),
|
||||
MPI_SUCCESS, "MPI_Allgather host hash failed.");
|
||||
for (int global_rank = 0; global_rank < kRankSize; global_rank++) {
|
||||
if (global_rank == rank_id_) {
|
||||
break;
|
||||
}
|
||||
if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) {
|
||||
local_rank_id_++;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H
|
||||
|
||||
#include <mpi.h>
|
||||
#include <unistd.h>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <sstream>
|
||||
#include "hccl/hccl_types.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace collective {
|
||||
constexpr int MAX_HOSTNAME_LEN = 1024;
|
||||
class MPICollective {
|
||||
public:
|
||||
MPICollective(MPICollective const &) = delete;
|
||||
MPICollective &operator=(const MPICollective &) = delete;
|
||||
static MPICollective &instance();
|
||||
void AssignLocalRankID();
|
||||
bool Init();
|
||||
void FinalizeMPI();
|
||||
int GetRankIdByGroup(const std::string &name);
|
||||
int GetGroupSize(const std::string &name);
|
||||
HcclComm GetGroupComm(const std::string &name);
|
||||
int GetDeviceId();
|
||||
bool CreateCommGroup(const std::string &name, const std::vector<unsigned int> &ranks);
|
||||
void DestroyHcclComm();
|
||||
|
||||
private:
|
||||
MPICollective();
|
||||
~MPICollective() = default;
|
||||
bool mpi_inited_;
|
||||
int rank_id_;
|
||||
int local_rank_id_;
|
||||
int rank_size_;
|
||||
MPI_Group comm_group_world_;
|
||||
std::map<std::string, std::pair<int, int>> group_info_;
|
||||
std::map<std::string, HcclComm> group_comm_;
|
||||
};
|
||||
#define CHECK_RET(expression, result, message) \
|
||||
{ \
|
||||
auto ret = (expression); \
|
||||
if (ret != result) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ \
|
||||
<< " | Ascend collective Error: " << message << " | Error Number " << ret; \
|
||||
pybind11::pybind11_fail(oss.str()); \
|
||||
} \
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_INIT_H
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/distribute/mpi_pycc.h"
|
||||
#include <pybind11/operators.h>
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace collective {
|
||||
MpiPycc &MpiPycc::instance() {
|
||||
static MpiPycc instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
int MpiPycc::GetDeviceID() { return GetDeviceId(); }
|
||||
int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); }
|
||||
int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); }
|
||||
|
||||
// cppcheck-suppress syntaxError
|
||||
PYBIND11_MODULE(_ascend_mpi, mpi_initializer) {
|
||||
mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id");
|
||||
mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id");
|
||||
mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size");
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H
|
||||
|
||||
#include <string>
|
||||
#include "runtime/device/ascend/distribute/collective_group_wrapper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace collective {
|
||||
class MpiPycc {
|
||||
public:
|
||||
MpiPycc(MpiPycc const &) = delete;
|
||||
MpiPycc &operator=(const MpiPycc &) = delete;
|
||||
static MpiPycc &instance();
|
||||
static int GetDeviceID();
|
||||
static int GetRankId(const std::string &group);
|
||||
static int GetRankSize(const std::string &group);
|
||||
|
||||
private:
|
||||
MpiPycc() = default;
|
||||
~MpiPycc() = default;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H
|
|
@ -32,22 +32,6 @@ static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
|
|||
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
|
||||
static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
|
||||
|
||||
inline static std::string GetDlErrorMsg() {
|
||||
const char *result = dlerror();
|
||||
return (result == nullptr) ? "Unknown" : result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static T DlsymWithCast(void *handle, const char *symbol_name) {
|
||||
T symbol = reinterpret_cast<T>(dlsym(handle, symbol_name));
|
||||
if (symbol == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Dlsym symbol " << symbol_name << " failed, result = " << GetDlErrorMsg();
|
||||
}
|
||||
return symbol;
|
||||
}
|
||||
|
||||
#define DlsymFuncObj(func_name) DlsymWithCast<func_name##FunPtr>(plugin_handle_, k##func_name##Name);
|
||||
|
||||
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
|
||||
std::string_view rank_file) {
|
||||
auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv);
|
||||
|
@ -92,21 +76,21 @@ void HcclAdapter::InitPlugin() {
|
|||
MS_LOG(EXCEPTION) << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg();
|
||||
}
|
||||
|
||||
init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter);
|
||||
finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter);
|
||||
get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore);
|
||||
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder);
|
||||
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo);
|
||||
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy);
|
||||
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast);
|
||||
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce);
|
||||
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup);
|
||||
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup);
|
||||
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId);
|
||||
hccl_get_rank_size_ = DlsymFuncObj(HcomGetRankSize);
|
||||
hccl_exec_initialize_ = DlsymFuncObj(HcomExecInitialize);
|
||||
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize);
|
||||
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation);
|
||||
init_hcom_graph_adapter_ = DlsymFuncObj(InitHcomGraphAdapter, plugin_handle_);
|
||||
finalize_hcom_graph_adapter_ = DlsymFuncObj(FinalizeHcomGraphAdapter, plugin_handle_);
|
||||
get_hccl_kernel_info_store_ = DlsymFuncObj(GetHcclKernelInfoStore, plugin_handle_);
|
||||
get_all_kernel_builder_ = DlsymFuncObj(GetAllKernelBuilder, plugin_handle_);
|
||||
init_hccl_comm_ = DlsymFuncObj(HcclCommInitClusterInfo, plugin_handle_);
|
||||
finalize_hccl_comm_ = DlsymFuncObj(HcclCommDestroy, plugin_handle_);
|
||||
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_);
|
||||
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_);
|
||||
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_);
|
||||
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup, plugin_handle_);
|
||||
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId, plugin_handle_);
|
||||
hccl_get_rank_size_ = DlsymFuncObj(HcomGetRankSize, plugin_handle_);
|
||||
hccl_exec_initialize_ = DlsymFuncObj(HcomExecInitialize, plugin_handle_);
|
||||
hccl_exec_finalize_ = DlsymFuncObj(HcomExecFinalize, plugin_handle_);
|
||||
hccl_exec_enqueue_op_ = DlsymFuncObj(HcomExecEnqueueOperation, plugin_handle_);
|
||||
}
|
||||
|
||||
void HcclAdapter::FinalizePlugin() {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <functional>
|
||||
#include "external/ge/ge_api_types.h"
|
||||
#include "hccl/hccl.h"
|
||||
#include "utils/dlopen_macro.h"
|
||||
|
||||
constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
|
||||
|
||||
|
@ -38,22 +39,6 @@ using OptionsType = std::map<std::string, std::string>;
|
|||
using OpsKernelBuilderMap = std::map<std::string, std::shared_ptr<ge::OpsKernelBuilder>>;
|
||||
using HExecCallBack = std::function<void(HcclResult)>;
|
||||
|
||||
#define PLUGIN_METHOD(name, return_type, params...) \
|
||||
extern "C" { \
|
||||
__attribute__((visibility("default"))) return_type Plugin##name(params); \
|
||||
} \
|
||||
constexpr const char *k##name##Name = "Plugin" #name; \
|
||||
using name##FunObj = std::function<return_type(params)>; \
|
||||
using name##FunPtr = return_type (*)(params);
|
||||
|
||||
#define ORIGIN_METHOD(name, return_type, params...) \
|
||||
extern "C" { \
|
||||
return_type name(params); \
|
||||
} \
|
||||
constexpr const char *k##name##Name = #name; \
|
||||
using name##FunObj = std::function<return_type(params)>; \
|
||||
using name##FunPtr = return_type (*)(params);
|
||||
|
||||
PLUGIN_METHOD(InitHcomGraphAdapter, ge::Status, const OptionsType &);
|
||||
PLUGIN_METHOD(FinalizeHcomGraphAdapter, ge::Status);
|
||||
PLUGIN_METHOD(GetHcclKernelInfoStore, void, std::shared_ptr<ge::OpsKernelInfoStore> *);
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H
|
||||
#define MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H
|
||||
#include <dlfcn.h>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#define PLUGIN_METHOD(name, return_type, params...) \
|
||||
extern "C" { \
|
||||
__attribute__((visibility("default"))) return_type Plugin##name(params); \
|
||||
} \
|
||||
constexpr const char *k##name##Name = "Plugin" #name; \
|
||||
using name##FunObj = std::function<return_type(params)>; \
|
||||
using name##FunPtr = return_type (*)(params);
|
||||
|
||||
#define ORIGIN_METHOD(name, return_type, params...) \
|
||||
extern "C" { \
|
||||
return_type name(params); \
|
||||
} \
|
||||
constexpr const char *k##name##Name = #name; \
|
||||
using name##FunObj = std::function<return_type(params)>; \
|
||||
using name##FunPtr = return_type (*)(params);
|
||||
|
||||
inline static std::string GetDlErrorMsg() {
|
||||
const char *result = dlerror();
|
||||
return (result == nullptr) ? "Unknown" : result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static T DlsymWithCast(void *handle, const char *symbol_name) {
|
||||
T symbol = reinterpret_cast<T>(dlsym(handle, symbol_name));
|
||||
if (symbol == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Dlsym symbol " << symbol_name << " failed, result = " << GetDlErrorMsg();
|
||||
}
|
||||
return symbol;
|
||||
}
|
||||
|
||||
#define DlsymFuncObj(func_name, plugin_handle) DlsymWithCast<func_name##FunPtr>(plugin_handle, k##func_name##Name);
|
||||
#endif // MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H
|
Loading…
Reference in New Issue