forked from mindspore-Ecosystem/mindspore
Dynamic load ompi and nccl
This commit is contained in:
parent
214d5dec46
commit
87057fdc27
|
@ -188,6 +188,11 @@ if(ENABLE_MPI)
|
|||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS mpi_collective
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif()
|
||||
if(ENABLE_D)
|
||||
install(
|
||||
|
@ -205,6 +210,11 @@ if(ENABLE_GPU)
|
|||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS nvidia_collective
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif()
|
||||
install(
|
||||
TARGETS gpu_queue
|
||||
|
|
|
@ -142,6 +142,11 @@ if(ENABLE_MPI)
|
|||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS mpi_collective
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
@ -152,6 +157,11 @@ if(ENABLE_GPU)
|
|||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS nvidia_collective
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif()
|
||||
install(
|
||||
TARGETS gpu_queue
|
||||
|
|
|
@ -814,9 +814,9 @@ bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &arg
|
|||
} catch (const py::error_already_set &ex) {
|
||||
if (!StaticAnalysisException::Instance().HasException()) {
|
||||
// print function call stack info before release
|
||||
std::string exception_info = GetCompileExceptionInfo();
|
||||
if (!exception_info.empty()) {
|
||||
MS_LOG(ERROR) << exception_info;
|
||||
std::string compile_exception_info = GetCompileExceptionInfo();
|
||||
if (!compile_exception_info.empty()) {
|
||||
MS_LOG(ERROR) << compile_exception_info;
|
||||
}
|
||||
}
|
||||
ReleaseResource(phase);
|
||||
|
|
|
@ -7,11 +7,32 @@ endif()
|
|||
|
||||
if(ENABLE_GPU)
|
||||
file(GLOB_RECURSE HARDWARE_GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
|
||||
list(REMOVE_ITEM HARDWARE_GPU_SRC_LIST "gpu/nvidia_collective_comm_lib.cc" "gpu/nvidia_communication_group.cc")
|
||||
if(ENABLE_MPI)
|
||||
set(NVIDIA_COLLECTIVE_SRCS "gpu/nvidia_collective_comm_lib.cc"
|
||||
"gpu/nvidia_communication_group.cc"
|
||||
"collective/collective_communication_lib.cc"
|
||||
"collective/communication_group.cc")
|
||||
set_property(SOURCE ${NVIDIA_COLLECTIVE_SRCS} PROPERTY COMPILE_DEFINITIONS
|
||||
SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(nvidia_collective SHARED ${NVIDIA_COLLECTIVE_SRCS})
|
||||
target_link_libraries(nvidia_collective PRIVATE mindspore::nccl)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ENABLE_CPU)
|
||||
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
|
||||
if(ENABLE_MPI)
|
||||
set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc"
|
||||
"cpu/mpi_communication_group.cc"
|
||||
"collective/collective_communication_lib.cc"
|
||||
"collective/communication_group.cc")
|
||||
set_property(SOURCE ${MPI_COLLECTIVE_SRCS} PROPERTY COMPILE_DEFINITIONS
|
||||
SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(mpi_collective SHARED ${MPI_COLLECTIVE_SRCS})
|
||||
target_link_libraries(mpi_collective PRIVATE mindspore::ompi)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/hardware/collective/collective_comm_lib_loader.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
bool CollectiveCommLibLoader::Initialize() {
|
||||
std::string err_msg = "";
|
||||
#ifndef _WIN32
|
||||
collective_comm_lib_ptr_ = dlopen(comm_lib_name_.c_str(), RTLD_LAZY);
|
||||
err_msg = GetDlErrorMsg();
|
||||
#else
|
||||
collective_comm_lib_ptr_ = LoadLibrary(comm_lib_name_.c_str());
|
||||
err_msg = std::to_string(GetLastError());
|
||||
#endif
|
||||
if (collective_comm_lib_ptr_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Loading " + comm_lib_name_ + " failed. Error: " + err_msg;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CollectiveCommLibLoader::Finalize() {
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
|
||||
|
||||
#ifndef _WIN32
|
||||
if (dlclose(collective_comm_lib_ptr_) != 0) {
|
||||
MS_LOG(EXCEPTION) << "Closing " + comm_lib_name_ + " handle failed. Error: " + GetDlErrorMsg();
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
if (!FreeLibrary(reinterpret_cast<HINSTANCE__ *>(collective_comm_lib_ptr_))) {
|
||||
MS_LOG(EXCEPTION) << "Closing " + comm_lib_name_ + " handle failed. Error: " + std::to_string(GetLastError());
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_
|
||||
|
||||
#ifndef _WIN32
|
||||
#include "utils/dlopen_macro.h"
|
||||
#else
|
||||
#include <windows.h>
|
||||
#undef ERROR
|
||||
#undef SM_DEBUG
|
||||
#undef Yield
|
||||
#include "utils/log_adapter.h"
|
||||
#endif
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
class CollectiveCommLibLoader {
|
||||
public:
|
||||
explicit CollectiveCommLibLoader(const std::string &comm_lib_name)
|
||||
: collective_comm_lib_ptr_(nullptr), comm_lib_name_(comm_lib_name) {}
|
||||
~CollectiveCommLibLoader() = default;
|
||||
|
||||
// Dynamically load the communication library.
|
||||
bool Initialize();
|
||||
|
||||
// Finalize the communication library.
|
||||
bool Finalize();
|
||||
|
||||
// Return the handle for this collective communication library. Caller should use this handle to call all methods of
|
||||
// the collective communication.
|
||||
void *collective_comm_lib_ptr() const { return collective_comm_lib_ptr_; }
|
||||
|
||||
private:
|
||||
// The library handle 'dlopen' returns.
|
||||
void *collective_comm_lib_ptr_;
|
||||
|
||||
// Name of the communication library.
|
||||
std::string comm_lib_name_;
|
||||
};
|
||||
using CollectiveCommLibLoaderPtr = std::shared_ptr<CollectiveCommLibLoader>;
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#ifndef _WIN32
|
||||
// The exported symbols of collective communication shared library is registered here.
|
||||
ORIGIN_METHOD(InitializeCollectiveLib, bool, uint32_t, uint32_t)
|
||||
ORIGIN_METHOD(FinalizeCollectiveLib, bool)
|
||||
ORIGIN_METHOD(CreateCommunicationGroup, bool, const std::string &, const std::vector<uint32_t> &)
|
||||
ORIGIN_METHOD(DestroyCommunicationGroup, bool, const std::string &)
|
||||
ORIGIN_METHOD(GetRankId, uint32_t, const std::string &)
|
||||
ORIGIN_METHOD(GetCommunicationGroupSize, uint32_t, const std::string &)
|
||||
ORIGIN_METHOD(AssignLocalRank, bool)
|
||||
ORIGIN_METHOD(local_rank_id, uint32_t)
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_
|
|
@ -24,11 +24,8 @@ bool CollectiveCommunicationLib::Finalize() {
|
|||
}
|
||||
|
||||
for (const auto &group : groups_) {
|
||||
MS_EXCEPTION_IF_NULL(group.second);
|
||||
if (!group.second->Finalize()) {
|
||||
MS_LOG(EXCEPTION) << "Finalizing group failed.";
|
||||
return false;
|
||||
}
|
||||
CHECK_IF_NULL(group.second);
|
||||
CHECK_RET(group.second->Finalize(), true, "Finalizing group failed.");
|
||||
}
|
||||
groups_.clear();
|
||||
initialized_ = false;
|
||||
|
@ -36,33 +33,25 @@ bool CollectiveCommunicationLib::Finalize() {
|
|||
}
|
||||
|
||||
bool CollectiveCommunicationLib::DestroyCommunicationGroup(const std::string &group_name) {
|
||||
if (groups_.count(group_name) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The group " << group_name << " is not created.";
|
||||
return false;
|
||||
}
|
||||
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " is not created.");
|
||||
auto group = groups_[group_name];
|
||||
MS_EXCEPTION_IF_NULL(group);
|
||||
group->Finalize();
|
||||
CHECK_IF_NULL(group);
|
||||
CHECK_RET(group->Finalize(), true, "Finalizing group failed.");
|
||||
(void)groups_.erase(group_name);
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t CollectiveCommunicationLib::GetRankId(const std::string &group_name) {
|
||||
if (groups_.count(group_name) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist.";
|
||||
return UINT32_MAX;
|
||||
}
|
||||
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
|
||||
auto group = groups_[group_name];
|
||||
MS_EXCEPTION_IF_NULL(group);
|
||||
CHECK_IF_NULL(group);
|
||||
return group->GetGroupRank(global_rank_id_);
|
||||
}
|
||||
|
||||
uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name) {
|
||||
if (groups_.count(group_name) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist.";
|
||||
return UINT32_MAX;
|
||||
}
|
||||
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
|
||||
auto group = groups_[group_name];
|
||||
MS_EXCEPTION_IF_NULL(group);
|
||||
CHECK_IF_NULL(group);
|
||||
return group->group_size();
|
||||
}
|
||||
|
||||
|
|
|
@ -31,12 +31,7 @@ namespace device {
|
|||
// MsCollectiveCommLib which uses the host-side communication library developed by MindSpore.
|
||||
class CollectiveCommunicationLib {
|
||||
public:
|
||||
CollectiveCommunicationLib()
|
||||
: collective_comm_lib_ptr_(nullptr),
|
||||
initialized_(false),
|
||||
global_rank_id_(0),
|
||||
local_rank_id_(0),
|
||||
global_rank_size_(0) {}
|
||||
CollectiveCommunicationLib() : initialized_(false), global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {}
|
||||
virtual ~CollectiveCommunicationLib() { groups_.clear(); }
|
||||
|
||||
// Initialize collecitve communication library.
|
||||
|
@ -70,9 +65,6 @@ class CollectiveCommunicationLib {
|
|||
uint32_t local_rank_id() const;
|
||||
|
||||
protected:
|
||||
// The third party collective communication library. They are dynamically loaded by MindSpore.
|
||||
const void *collective_comm_lib_ptr_;
|
||||
|
||||
// Whether this collective communication library is initialized.
|
||||
bool initialized_;
|
||||
|
||||
|
|
|
@ -20,8 +20,7 @@ namespace mindspore {
|
|||
namespace device {
|
||||
CommunicationGroup::CommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
|
||||
uint32_t global_rank)
|
||||
: collective_comm_lib_ptr_(nullptr),
|
||||
initialized_(false),
|
||||
: initialized_(false),
|
||||
global_rank_(global_rank),
|
||||
size_(group_ranks.size()),
|
||||
name_(name),
|
||||
|
@ -36,18 +35,14 @@ CommunicationGroup::CommunicationGroup(const std::string name, const std::vector
|
|||
}
|
||||
|
||||
uint32_t CommunicationGroup::GetGroupRank(uint32_t global_rank) {
|
||||
if (global_to_group_ranks_.count(global_rank) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the global rank " << global_rank;
|
||||
return UINT32_MAX;
|
||||
}
|
||||
CHECK_RET((global_to_group_ranks_.count(global_rank) != 0), true,
|
||||
"Group " + name_ + " doesn't contain the global rank " + std::to_string(global_rank));
|
||||
return global_to_group_ranks_[global_rank];
|
||||
}
|
||||
|
||||
uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
|
||||
if (group_to_global_ranks_.count(group_rank) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the group rank " << group_rank;
|
||||
return UINT32_MAX;
|
||||
}
|
||||
CHECK_RET((group_to_global_ranks_.count(group_rank) != 0), true,
|
||||
"Group " + name_ + " doesn't contain the group rank " + std::to_string(group_rank));
|
||||
return group_to_global_ranks_[group_rank];
|
||||
}
|
||||
|
||||
|
|
|
@ -21,8 +21,9 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "mindspore/core/utils/convert_utils_base.h"
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -56,9 +57,6 @@ class CommunicationGroup {
|
|||
uint32_t group_size() const;
|
||||
|
||||
protected:
|
||||
// The third party collective communication libraries. They are dynamically loaded by MindSpore.
|
||||
const void *collective_comm_lib_ptr_;
|
||||
|
||||
// Whether this communication group is initialized.
|
||||
bool initialized_;
|
||||
|
||||
|
@ -81,4 +79,24 @@ class CommunicationGroup {
|
|||
using CommunicationGroupPtr = std::shared_ptr<CommunicationGroup>;
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#define CHECK_RET(expression, result, message) \
|
||||
do { \
|
||||
auto ret = (expression); \
|
||||
if (ret != result) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << ": " << message; \
|
||||
pybind11::pybind11_fail(oss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_IF_NULL(ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
std::ostringstream oss; \
|
||||
oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << ": The pointer[" << #ptr \
|
||||
<< "] is null."; \
|
||||
pybind11::pybind11_fail(oss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COMMUNICATION_GROUP_H_
|
||||
|
|
|
@ -24,47 +24,62 @@ bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Initialize MPI interface.
|
||||
int initialized = 0;
|
||||
CHECK_MPI_RET(MPI_Initialized(&initialized), "Failed to check MPI initialization status.");
|
||||
CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check MPI initialization status.");
|
||||
if (initialized == 0) {
|
||||
CHECK_MPI_RET(MPI_Init(nullptr, nullptr), "Failed to initialize MPI.");
|
||||
CHECK_RET(MPI_Init(nullptr, nullptr), MPI_SUCCESS, "Failed to initialize MPI.");
|
||||
}
|
||||
|
||||
// Generated MPI global rank id and rank size for the world group MPI_COMM_WORLD.
|
||||
int rank_id = 0;
|
||||
int rank_size = 0;
|
||||
CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id), "Failed to initialize MPI global rank id.");
|
||||
CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_size), "Failed to initialize MPI global rank size.");
|
||||
global_rank_id_ = IntToUint(rank_id);
|
||||
global_rank_size_ = IntToUint(rank_size);
|
||||
MS_LOG(INFO) << "The MPI global rank id of this process is " << global_rank_id_ << ", global rank size is "
|
||||
<< global_rank_size_;
|
||||
CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id), MPI_SUCCESS, "Failed to initialize MPI global rank id.");
|
||||
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size), MPI_SUCCESS, "Failed to initialize MPI global rank size.");
|
||||
global_rank_id_ = static_cast<uint32_t>(rank_id);
|
||||
global_rank_size_ = static_cast<uint32_t>(rank_size);
|
||||
|
||||
// Create the world group of MPI because every other group is generated from MPI world group.
|
||||
CHECK_MPI_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), "Failed to get group of MPI_COMM_WORLD.");
|
||||
CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD.");
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks) {
|
||||
if (groups_.count(group_name) != 0) {
|
||||
MS_LOG(EXCEPTION) << "The MPI group " << group_name << " has already existed.";
|
||||
return false;
|
||||
}
|
||||
CHECK_RET((groups_.count(group_name) == 0), true, "The MPI group " + group_name + " has already existed.");
|
||||
|
||||
MPICommunicationGroupPtr group = std::make_shared<MPICommunicationGroup>(group_name, group_ranks, global_rank_id_);
|
||||
MS_EXCEPTION_IF_NULL(group);
|
||||
if (!group->Initialize(world_group_)) {
|
||||
MS_LOG(EXCEPTION) << "Initializing group failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
CHECK_IF_NULL(group);
|
||||
CHECK_RET(group->Initialize(world_group_), true, "Initializing group failed.");
|
||||
groups_[group_name] = group;
|
||||
MS_LOG(INFO) << "MPI group of " << group_name << " is created.";
|
||||
return true;
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
// The exported APIs for 'dlsym' to load.
|
||||
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;
|
||||
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
|
||||
return MPICollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
|
||||
}
|
||||
|
||||
bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
|
||||
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
|
||||
return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
|
||||
}
|
||||
|
||||
bool DestroyCommunicationGroup(const std::string &group_name) {
|
||||
return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
|
||||
}
|
||||
|
||||
uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); }
|
||||
|
||||
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
|
||||
return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name);
|
||||
}
|
||||
|
||||
bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }
|
||||
|
||||
uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }
|
||||
|
|
|
@ -47,4 +47,18 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
|
|||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#ifndef EXPORT_MPI_WRAPPER
|
||||
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
|
||||
uint32_t global_rank_size = UINT32_MAX);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
|
||||
extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_
|
||||
|
|
|
@ -27,8 +27,10 @@ bool MPICommunicationGroup::Finalize() {
|
|||
if (!initialized_) {
|
||||
return false;
|
||||
}
|
||||
CHECK_MPI_RET(MPI_Comm_free(&group_communicator_), "Freeing MPI group communicator for " + name_ + " failed.");
|
||||
CHECK_MPI_RET(MPI_Group_free(&group_), "Freeing MPI group for " + name_ + " failed.");
|
||||
|
||||
CHECK_RET(MPI_Comm_free(&group_communicator_), MPI_SUCCESS,
|
||||
"Freeing MPI group communicator for " + name_ + " failed.");
|
||||
CHECK_RET(MPI_Group_free(&group_), MPI_SUCCESS, "Freeing MPI group for " + name_ + " failed.");
|
||||
initialized_ = false;
|
||||
return true;
|
||||
}
|
||||
|
@ -38,15 +40,12 @@ bool MPICommunicationGroup::Initialize(const MPI_Group &world_group) {
|
|||
return false;
|
||||
}
|
||||
std::vector<int> ranks(group_ranks_.begin(), group_ranks_.end());
|
||||
CHECK_MPI_RET(MPI_Group_incl(world_group, ranks.size(), ranks.data(), &group_),
|
||||
"Creating MPI group for " + name_ + " failed.");
|
||||
CHECK_RET(MPI_Group_incl(world_group, ranks.size(), ranks.data(), &group_), MPI_SUCCESS,
|
||||
"Creating MPI group for " + name_ + " failed.");
|
||||
CHECK_RET(MPI_Comm_create(MPI_COMM_WORLD, group_, &group_communicator_), MPI_SUCCESS,
|
||||
"Creating MPI group communicator for " + name_ + " failed.");
|
||||
|
||||
CHECK_MPI_RET(MPI_Comm_create(MPI_COMM_WORLD, group_, &group_communicator_),
|
||||
"Creating MPI group communicator for " + name_ + " failed.");
|
||||
if (group_communicator_ == MPI_COMM_NULL) {
|
||||
MS_LOG(EXCEPTION) << "The MPI communicator for group " << name_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
CHECK_RET(group_communicator_ != MPI_COMM_NULL, true, "The MPI communicator for group " + name_ + " failed.");
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "runtime/hardware/collective/communication_group.h"
|
||||
#include "utils/dlopen_macro.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -44,16 +45,6 @@ class MPICommunicationGroup : public CommunicationGroup {
|
|||
MPI_Comm group_communicator_;
|
||||
};
|
||||
using MPICommunicationGroupPtr = std::shared_ptr<MPICommunicationGroup>;
|
||||
|
||||
#define CHECK_MPI_RET(expression, message) \
|
||||
do { \
|
||||
{ \
|
||||
auto ret = (expression); \
|
||||
if (ret != MPI_SUCCESS) { \
|
||||
MS_LOG(EXCEPTION) << (message); \
|
||||
} \
|
||||
} \
|
||||
} while (false)
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "runtime/device/device_address.h"
|
||||
#include "runtime/device/bucket.h"
|
||||
#include "runtime/hardware/collective/collective_communication_lib.h"
|
||||
#include "runtime/hardware/collective/collective_comm_lib_loader.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/common_backend_optimization.h"
|
||||
|
@ -49,7 +50,7 @@ struct DeviceContextKey {
|
|||
class DeviceContext {
|
||||
public:
|
||||
explicit DeviceContext(const DeviceContextKey &device_context_key)
|
||||
: device_context_key_(device_context_key), collective_comm_lib_(nullptr) {}
|
||||
: device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {}
|
||||
virtual ~DeviceContext() = default;
|
||||
|
||||
// Initialize the device context.
|
||||
|
@ -150,7 +151,7 @@ class DeviceContext {
|
|||
virtual bool InitCollectiveCommLib() { return true; }
|
||||
|
||||
// Return collective communication object for caller to access
|
||||
CollectiveCommunicationLibPtr collective_comm_lib() const { return collective_comm_lib_; }
|
||||
void *collective_comm_lib() const { return collective_comm_lib_ptr_; }
|
||||
|
||||
// TODO(jiaorui): will be delete
|
||||
// Dump all graphs.
|
||||
|
@ -158,10 +159,11 @@ class DeviceContext {
|
|||
|
||||
protected:
|
||||
DeviceContextKey device_context_key_;
|
||||
CollectiveCommunicationLibPtr collective_comm_lib_;
|
||||
|
||||
// The dynamic loaded handle for collective communication library.
|
||||
void *collective_comm_lib_ptr_;
|
||||
};
|
||||
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_
|
||||
|
|
|
@ -30,9 +30,6 @@
|
|||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
#include "runtime/hardware/gpu/optimizer.h"
|
||||
#ifdef ENABLE_MPI
|
||||
#include "runtime/hardware/gpu/nvidia_collective_comm_lib.h"
|
||||
#endif
|
||||
#include "common/trans.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "runtime/device/gpu/gpu_bucket.h"
|
||||
|
@ -530,8 +527,21 @@ std::shared_ptr<Bucket> GPUDeviceContext::CreateBucket(uint32_t bucket_id, uint3
|
|||
|
||||
bool GPUDeviceContext::InitCollectiveCommLib() {
|
||||
#ifdef ENABLE_MPI
|
||||
collective_comm_lib_ = &NvidiaCollectiveCommLib::GetInstance();
|
||||
collective_comm_lib_->Initialize();
|
||||
std::string nvidia_comm_lib_name = "libnvidia_collective.so";
|
||||
auto loader = std::make_shared<CollectiveCommLibLoader>(nvidia_comm_lib_name);
|
||||
MS_EXCEPTION_IF_NULL(loader);
|
||||
if (!loader->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Loading NCCL collective library failed.";
|
||||
return false;
|
||||
}
|
||||
collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr();
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
|
||||
|
||||
auto init_collecitve_lib_func = DlsymFuncObj(InitializeCollectiveLib, collective_comm_lib_ptr_);
|
||||
if (!init_collecitve_lib_func(0, 0)) {
|
||||
MS_LOG(EXCEPTION) << "Initializing for NCCL library failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
|
|
@ -19,10 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() {
|
||||
collective_comm_lib_ptr_ = CollectiveInitializer::instance().collective_handle();
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
|
||||
if (initialized_) {
|
||||
return false;
|
||||
|
@ -30,26 +26,48 @@ bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_r
|
|||
|
||||
global_rank_id_ = global_rank;
|
||||
global_rank_size_ = global_rank_size;
|
||||
MS_LOG(INFO) << "The global rank id of this process is " << global_rank_id_
|
||||
<< ", global rank size of nccl_world_group is " << global_rank_size_;
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks) {
|
||||
if (groups_.count(group_name) != 0) {
|
||||
MS_LOG(EXCEPTION) << "The NCCL group " << group_name << " has already existed.";
|
||||
return false;
|
||||
}
|
||||
CHECK_RET((groups_.count(group_name) == 0), true, "The NCCL group " + group_name + " has already existed.");
|
||||
|
||||
NvidiaCommunicationGroupPtr group =
|
||||
std::make_shared<NvidiaCommunicationGroup>(group_name, group_ranks, global_rank_id_);
|
||||
MS_EXCEPTION_IF_NULL(group);
|
||||
CHECK_IF_NULL(group);
|
||||
groups_[group_name] = group;
|
||||
MS_LOG(INFO) << "NCCL group of " << group_name << " is created. But it's not initialized yet.";
|
||||
return true;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
// The exported APIs for 'dlsym' to load.
|
||||
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
|
||||
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
|
||||
}
|
||||
|
||||
bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); }
|
||||
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
|
||||
}
|
||||
|
||||
bool DestroyCommunicationGroup(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
|
||||
}
|
||||
|
||||
uint32_t GetRankId(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name);
|
||||
}
|
||||
|
||||
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name);
|
||||
}
|
||||
|
||||
bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }
|
||||
|
||||
uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
|
||||
|
|
|
@ -40,10 +40,24 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
|
|||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
|
||||
|
||||
private:
|
||||
NvidiaCollectiveCommLib();
|
||||
NvidiaCollectiveCommLib() = default;
|
||||
~NvidiaCollectiveCommLib() override = default;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#ifndef EXPORT_NCCL_WRAPPER
|
||||
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
|
||||
uint32_t global_rank_size = UINT32_MAX);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_
|
||||
|
|
|
@ -21,9 +21,7 @@ namespace device {
|
|||
namespace gpu {
|
||||
NvidiaCommunicationGroup::NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
|
||||
uint32_t global_rank)
|
||||
: CommunicationGroup(name, group_ranks, global_rank) {
|
||||
collective_comm_lib_ptr_ = CollectiveInitializer::instance().collective_handle();
|
||||
}
|
||||
: CommunicationGroup(name, group_ranks, global_rank) {}
|
||||
|
||||
bool NvidiaCommunicationGroup::Initialize(void *root_info) {
|
||||
if (initialized_) {
|
||||
|
@ -32,20 +30,11 @@ bool NvidiaCommunicationGroup::Initialize(void *root_info) {
|
|||
|
||||
// The unique id is broadcasted by the root rank.
|
||||
unique_id_ = *(static_cast<ncclUniqueId *>(root_info));
|
||||
|
||||
uint32_t group_rank = GetGroupRank(global_rank_);
|
||||
// Initialize the NCCL communicator while the group created. Pay attention that 'ncclCommInitRank' should be called
|
||||
// after GPU device id is set.
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
|
||||
auto comm_init_rank =
|
||||
reinterpret_cast<NCCLCommInitRank>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "NCCLCommInitRank"));
|
||||
MS_EXCEPTION_IF_NULL(comm_init_rank);
|
||||
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for group " << name_;
|
||||
uint32_t group_rank = GetGroupRank(global_rank_);
|
||||
CHECK_NCCL_RET((*comm_init_rank)(&comm_, SizeToInt(size_), unique_id_, UintToInt(group_rank)),
|
||||
"Initializing NCCL communicator failed.");
|
||||
MS_LOG(INFO) << "NCCL communicator for group " << name_ << " is successfully initialized.";
|
||||
|
||||
CHECK_RET(ncclCommInitRank(&comm_, static_cast<int>(size_), unique_id_, static_cast<int>(group_rank)), ncclSuccess,
|
||||
"Initializing NCCL communicator failed.");
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
@ -55,26 +44,14 @@ bool NvidiaCommunicationGroup::Finalize() {
|
|||
return false;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
|
||||
auto comm_abort =
|
||||
reinterpret_cast<NCCLCommAbort>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "NCCLCommAbort"));
|
||||
MS_EXCEPTION_IF_NULL(comm_abort);
|
||||
auto comm_destroy =
|
||||
reinterpret_cast<NCCLCommDestroy>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "NCCLCommDestroy"));
|
||||
MS_EXCEPTION_IF_NULL(comm_destroy);
|
||||
|
||||
CHECK_NCCL_RET((*comm_abort)(comm_), "Failed to abort NCCL communicator.");
|
||||
CHECK_NCCL_RET((*comm_destroy)(comm_), "Failed to destroy NCCL communicator.");
|
||||
CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator.");
|
||||
CHECK_RET(ncclCommDestroy(comm_), ncclSuccess, "Failed to destroy NCCL communicator.");
|
||||
initialized_ = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void *NvidiaCommunicationGroup::GenerateRootInfo() {
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
|
||||
auto nccl_id_funcptr =
|
||||
reinterpret_cast<NcclUniqueId>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "nccl_unique_id"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_id_funcptr);
|
||||
unique_id_ = (*nccl_id_funcptr)();
|
||||
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
|
||||
return &unique_id_;
|
||||
}
|
||||
} // namespace gpu
|
||||
|
|
|
@ -22,16 +22,11 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "runtime/hardware/collective/communication_group.h"
|
||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#include "utils/dlopen_macro.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
using NcclUniqueId = ncclUniqueId (*)();
|
||||
using NCCLCommInitRank = ncclResult_t (*)(ncclComm_t *, int, ncclUniqueId, int);
|
||||
using NCCLCommAbort = ncclResult_t (*)(ncclComm_t);
|
||||
using NCCLCommDestroy = ncclResult_t (*)(ncclComm_t);
|
||||
|
||||
class NvidiaCommunicationGroup : public CommunicationGroup {
|
||||
public:
|
||||
explicit NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
|
||||
|
@ -52,17 +47,6 @@ class NvidiaCommunicationGroup : public CommunicationGroup {
|
|||
ncclComm_t comm_;
|
||||
};
|
||||
using NvidiaCommunicationGroupPtr = std::shared_ptr<NvidiaCommunicationGroup>;
|
||||
using CollectiveInitializer = device::gpu::CollectiveInitializer;
|
||||
|
||||
#define CHECK_NCCL_RET(expression, message) \
|
||||
do { \
|
||||
{ \
|
||||
auto ret = (expression); \
|
||||
if (ret != ncclSuccess) { \
|
||||
MS_LOG(EXCEPTION) << (message); \
|
||||
} \
|
||||
} \
|
||||
} while (false)
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue