Dynamic load ompi and nccl

This commit is contained in:
ZPaC 2021-11-09 09:41:49 +08:00
parent 214d5dec46
commit 87057fdc27
20 changed files with 344 additions and 158 deletions

View File

@ -188,6 +188,11 @@ if(ENABLE_MPI)
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
install(
TARGETS mpi_collective
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif()
if(ENABLE_D)
install(
@ -205,6 +210,11 @@ if(ENABLE_GPU)
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
install(
TARGETS nvidia_collective
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif()
install(
TARGETS gpu_queue

View File

@ -142,6 +142,11 @@ if(ENABLE_MPI)
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
install(
TARGETS mpi_collective
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif()
endif()
@ -152,6 +157,11 @@ if(ENABLE_GPU)
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
install(
TARGETS nvidia_collective
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif()
install(
TARGETS gpu_queue

View File

@ -814,9 +814,9 @@ bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &arg
} catch (const py::error_already_set &ex) {
if (!StaticAnalysisException::Instance().HasException()) {
// print function call stack info before release
std::string exception_info = GetCompileExceptionInfo();
if (!exception_info.empty()) {
MS_LOG(ERROR) << exception_info;
std::string compile_exception_info = GetCompileExceptionInfo();
if (!compile_exception_info.empty()) {
MS_LOG(ERROR) << compile_exception_info;
}
}
ReleaseResource(phase);

View File

@ -7,11 +7,32 @@ endif()
if(ENABLE_GPU)
file(GLOB_RECURSE HARDWARE_GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
list(REMOVE_ITEM HARDWARE_GPU_SRC_LIST "gpu/nvidia_collective_comm_lib.cc" "gpu/nvidia_communication_group.cc")
if(ENABLE_MPI)
set(NVIDIA_COLLECTIVE_SRCS "gpu/nvidia_collective_comm_lib.cc"
"gpu/nvidia_communication_group.cc"
"collective/collective_communication_lib.cc"
"collective/communication_group.cc")
set_property(SOURCE ${NVIDIA_COLLECTIVE_SRCS} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(nvidia_collective SHARED ${NVIDIA_COLLECTIVE_SRCS})
target_link_libraries(nvidia_collective PRIVATE mindspore::nccl)
endif()
endif()
if(ENABLE_CPU)
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
if(ENABLE_MPI)
set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc"
"cpu/mpi_communication_group.cc"
"collective/collective_communication_lib.cc"
"collective/communication_group.cc")
set_property(SOURCE ${MPI_COLLECTIVE_SRCS} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(mpi_collective SHARED ${MPI_COLLECTIVE_SRCS})
target_link_libraries(mpi_collective PRIVATE mindspore::ompi)
endif()
endif()

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/hardware/collective/collective_comm_lib_loader.h"
namespace mindspore {
namespace device {
bool CollectiveCommLibLoader::Initialize() {
std::string err_msg = "";
#ifndef _WIN32
collective_comm_lib_ptr_ = dlopen(comm_lib_name_.c_str(), RTLD_LAZY);
err_msg = GetDlErrorMsg();
#else
collective_comm_lib_ptr_ = LoadLibrary(comm_lib_name_.c_str());
err_msg = std::to_string(GetLastError());
#endif
if (collective_comm_lib_ptr_ == nullptr) {
MS_LOG(EXCEPTION) << "Loading " + comm_lib_name_ + " failed. Error: " + err_msg;
return false;
}
return true;
}
bool CollectiveCommLibLoader::Finalize() {
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
#ifndef _WIN32
if (dlclose(collective_comm_lib_ptr_) != 0) {
MS_LOG(EXCEPTION) << "Closing " + comm_lib_name_ + " handle failed. Error: " + GetDlErrorMsg();
return false;
}
#else
if (!FreeLibrary(reinterpret_cast<HINSTANCE__ *>(collective_comm_lib_ptr_))) {
MS_LOG(EXCEPTION) << "Closing " + comm_lib_name_ + " handle failed. Error: " + std::to_string(GetLastError());
return false;
}
#endif
return true;
}
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_
#ifndef _WIN32
#include "utils/dlopen_macro.h"
#else
#include <windows.h>
#undef ERROR
#undef SM_DEBUG
#undef Yield
#include "utils/log_adapter.h"
#endif
#include <string>
#include <memory>
#include <vector>
namespace mindspore {
namespace device {
class CollectiveCommLibLoader {
public:
explicit CollectiveCommLibLoader(const std::string &comm_lib_name)
: collective_comm_lib_ptr_(nullptr), comm_lib_name_(comm_lib_name) {}
~CollectiveCommLibLoader() = default;
// Dynamically load the communication library.
bool Initialize();
// Finalize the communication library.
bool Finalize();
// Return the handle for this collective communication library. Caller should use this handle to call all methods of
// the collective communication.
void *collective_comm_lib_ptr() const { return collective_comm_lib_ptr_; }
private:
// The library handle 'dlopen' returns.
void *collective_comm_lib_ptr_;
// Name of the communication library.
std::string comm_lib_name_;
};
using CollectiveCommLibLoaderPtr = std::shared_ptr<CollectiveCommLibLoader>;
} // namespace device
} // namespace mindspore
#ifndef _WIN32
// The exported symbols of collective communication shared library is registered here.
ORIGIN_METHOD(InitializeCollectiveLib, bool, uint32_t, uint32_t)
ORIGIN_METHOD(FinalizeCollectiveLib, bool)
ORIGIN_METHOD(CreateCommunicationGroup, bool, const std::string &, const std::vector<uint32_t> &)
ORIGIN_METHOD(DestroyCommunicationGroup, bool, const std::string &)
ORIGIN_METHOD(GetRankId, uint32_t, const std::string &)
ORIGIN_METHOD(GetCommunicationGroupSize, uint32_t, const std::string &)
ORIGIN_METHOD(AssignLocalRank, bool)
ORIGIN_METHOD(local_rank_id, uint32_t)
#endif
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_

View File

@ -24,11 +24,8 @@ bool CollectiveCommunicationLib::Finalize() {
}
for (const auto &group : groups_) {
MS_EXCEPTION_IF_NULL(group.second);
if (!group.second->Finalize()) {
MS_LOG(EXCEPTION) << "Finalizing group failed.";
return false;
}
CHECK_IF_NULL(group.second);
CHECK_RET(group.second->Finalize(), true, "Finalizing group failed.");
}
groups_.clear();
initialized_ = false;
@ -36,33 +33,25 @@ bool CollectiveCommunicationLib::Finalize() {
}
bool CollectiveCommunicationLib::DestroyCommunicationGroup(const std::string &group_name) {
if (groups_.count(group_name) == 0) {
MS_LOG(EXCEPTION) << "The group " << group_name << " is not created.";
return false;
}
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " is not created.");
auto group = groups_[group_name];
MS_EXCEPTION_IF_NULL(group);
group->Finalize();
CHECK_IF_NULL(group);
CHECK_RET(group->Finalize(), true, "Finalizing group failed.");
(void)groups_.erase(group_name);
return true;
}
uint32_t CollectiveCommunicationLib::GetRankId(const std::string &group_name) {
if (groups_.count(group_name) == 0) {
MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist.";
return UINT32_MAX;
}
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
auto group = groups_[group_name];
MS_EXCEPTION_IF_NULL(group);
CHECK_IF_NULL(group);
return group->GetGroupRank(global_rank_id_);
}
uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name) {
if (groups_.count(group_name) == 0) {
MS_LOG(EXCEPTION) << "The group " << group_name << " does not exist.";
return UINT32_MAX;
}
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
auto group = groups_[group_name];
MS_EXCEPTION_IF_NULL(group);
CHECK_IF_NULL(group);
return group->group_size();
}

View File

@ -31,12 +31,7 @@ namespace device {
// MsCollectiveCommLib which uses the host-side communication library developed by MindSpore.
class CollectiveCommunicationLib {
public:
CollectiveCommunicationLib()
: collective_comm_lib_ptr_(nullptr),
initialized_(false),
global_rank_id_(0),
local_rank_id_(0),
global_rank_size_(0) {}
CollectiveCommunicationLib() : initialized_(false), global_rank_id_(0), local_rank_id_(0), global_rank_size_(0) {}
virtual ~CollectiveCommunicationLib() { groups_.clear(); }
// Initialize collecitve communication library.
@ -70,9 +65,6 @@ class CollectiveCommunicationLib {
uint32_t local_rank_id() const;
protected:
// The third party collective communication library. They are dynamically loaded by MindSpore.
const void *collective_comm_lib_ptr_;
// Whether this collective communication library is initialized.
bool initialized_;

View File

@ -20,8 +20,7 @@ namespace mindspore {
namespace device {
CommunicationGroup::CommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
uint32_t global_rank)
: collective_comm_lib_ptr_(nullptr),
initialized_(false),
: initialized_(false),
global_rank_(global_rank),
size_(group_ranks.size()),
name_(name),
@ -36,18 +35,14 @@ CommunicationGroup::CommunicationGroup(const std::string name, const std::vector
}
uint32_t CommunicationGroup::GetGroupRank(uint32_t global_rank) {
if (global_to_group_ranks_.count(global_rank) == 0) {
MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the global rank " << global_rank;
return UINT32_MAX;
}
CHECK_RET((global_to_group_ranks_.count(global_rank) != 0), true,
"Group " + name_ + " doesn't contain the global rank " + std::to_string(global_rank));
return global_to_group_ranks_[global_rank];
}
uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
if (group_to_global_ranks_.count(group_rank) == 0) {
MS_LOG(EXCEPTION) << "Group " << name_ << " doesn't contain the group rank " << group_rank;
return UINT32_MAX;
}
CHECK_RET((group_to_global_ranks_.count(group_rank) != 0), true,
"Group " + name_ + " doesn't contain the group rank " + std::to_string(group_rank));
return group_to_global_ranks_[group_rank];
}

View File

@ -21,8 +21,9 @@
#include <string>
#include <vector>
#include <memory>
#include "mindspore/core/utils/log_adapter.h"
#include "mindspore/core/utils/convert_utils_base.h"
#include <sstream>
#include <algorithm>
#include "pybind11/pybind11.h"
namespace mindspore {
namespace device {
@ -56,9 +57,6 @@ class CommunicationGroup {
uint32_t group_size() const;
protected:
// The third party collective communication libraries. They are dynamically loaded by MindSpore.
const void *collective_comm_lib_ptr_;
// Whether this communication group is initialized.
bool initialized_;
@ -81,4 +79,24 @@ class CommunicationGroup {
using CommunicationGroupPtr = std::shared_ptr<CommunicationGroup>;
} // namespace device
} // namespace mindspore
#define CHECK_RET(expression, result, message) \
do { \
auto ret = (expression); \
if (ret != result) { \
std::ostringstream oss; \
oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << ": " << message; \
pybind11::pybind11_fail(oss.str()); \
} \
} while (0)
#define CHECK_IF_NULL(ptr) \
do { \
if ((ptr) == nullptr) { \
std::ostringstream oss; \
oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << ": The pointer[" << #ptr \
<< "] is null."; \
pybind11::pybind11_fail(oss.str()); \
} \
} while (0)
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COMMUNICATION_GROUP_H_

View File

@ -24,47 +24,62 @@ bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) {
return false;
}
// Initialize MPI interface.
int initialized = 0;
CHECK_MPI_RET(MPI_Initialized(&initialized), "Failed to check MPI initialization status.");
CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check MPI initialization status.");
if (initialized == 0) {
CHECK_MPI_RET(MPI_Init(nullptr, nullptr), "Failed to initialize MPI.");
CHECK_RET(MPI_Init(nullptr, nullptr), MPI_SUCCESS, "Failed to initialize MPI.");
}
// Generated MPI global rank id and rank size for the world group MPI_COMM_WORLD.
int rank_id = 0;
int rank_size = 0;
CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id), "Failed to initialize MPI global rank id.");
CHECK_MPI_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_size), "Failed to initialize MPI global rank size.");
global_rank_id_ = IntToUint(rank_id);
global_rank_size_ = IntToUint(rank_size);
MS_LOG(INFO) << "The MPI global rank id of this process is " << global_rank_id_ << ", global rank size is "
<< global_rank_size_;
CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id), MPI_SUCCESS, "Failed to initialize MPI global rank id.");
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size), MPI_SUCCESS, "Failed to initialize MPI global rank size.");
global_rank_id_ = static_cast<uint32_t>(rank_id);
global_rank_size_ = static_cast<uint32_t>(rank_size);
// Create the world group of MPI because every other group is generated from MPI world group.
CHECK_MPI_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), "Failed to get group of MPI_COMM_WORLD.");
CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD.");
initialized_ = true;
return true;
}
bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
if (groups_.count(group_name) != 0) {
MS_LOG(EXCEPTION) << "The MPI group " << group_name << " has already existed.";
return false;
}
CHECK_RET((groups_.count(group_name) == 0), true, "The MPI group " + group_name + " has already existed.");
MPICommunicationGroupPtr group = std::make_shared<MPICommunicationGroup>(group_name, group_ranks, global_rank_id_);
MS_EXCEPTION_IF_NULL(group);
if (!group->Initialize(world_group_)) {
MS_LOG(EXCEPTION) << "Initializing group failed.";
return false;
}
CHECK_IF_NULL(group);
CHECK_RET(group->Initialize(world_group_), true, "Initializing group failed.");
groups_[group_name] = group;
MS_LOG(INFO) << "MPI group of " << group_name << " is created.";
return true;
}
} // namespace cpu
} // namespace device
} // namespace mindspore
// The exported APIs for 'dlsym' to load.
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
return MPICollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
}
bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}
bool DestroyCommunicationGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}
uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); }
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name);
}
bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }
uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }

View File

@ -47,4 +47,18 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
} // namespace cpu
} // namespace device
} // namespace mindspore
#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_

View File

@ -27,8 +27,10 @@ bool MPICommunicationGroup::Finalize() {
if (!initialized_) {
return false;
}
CHECK_MPI_RET(MPI_Comm_free(&group_communicator_), "Freeing MPI group communicator for " + name_ + " failed.");
CHECK_MPI_RET(MPI_Group_free(&group_), "Freeing MPI group for " + name_ + " failed.");
CHECK_RET(MPI_Comm_free(&group_communicator_), MPI_SUCCESS,
"Freeing MPI group communicator for " + name_ + " failed.");
CHECK_RET(MPI_Group_free(&group_), MPI_SUCCESS, "Freeing MPI group for " + name_ + " failed.");
initialized_ = false;
return true;
}
@ -38,15 +40,12 @@ bool MPICommunicationGroup::Initialize(const MPI_Group &world_group) {
return false;
}
std::vector<int> ranks(group_ranks_.begin(), group_ranks_.end());
CHECK_MPI_RET(MPI_Group_incl(world_group, ranks.size(), ranks.data(), &group_),
"Creating MPI group for " + name_ + " failed.");
CHECK_RET(MPI_Group_incl(world_group, ranks.size(), ranks.data(), &group_), MPI_SUCCESS,
"Creating MPI group for " + name_ + " failed.");
CHECK_RET(MPI_Comm_create(MPI_COMM_WORLD, group_, &group_communicator_), MPI_SUCCESS,
"Creating MPI group communicator for " + name_ + " failed.");
CHECK_MPI_RET(MPI_Comm_create(MPI_COMM_WORLD, group_, &group_communicator_),
"Creating MPI group communicator for " + name_ + " failed.");
if (group_communicator_ == MPI_COMM_NULL) {
MS_LOG(EXCEPTION) << "The MPI communicator for group " << name_ << " failed.";
return false;
}
CHECK_RET(group_communicator_ != MPI_COMM_NULL, true, "The MPI communicator for group " + name_ + " failed.");
initialized_ = true;
return true;
}

View File

@ -22,6 +22,7 @@
#include <vector>
#include <memory>
#include "runtime/hardware/collective/communication_group.h"
#include "utils/dlopen_macro.h"
namespace mindspore {
namespace device {
@ -44,16 +45,6 @@ class MPICommunicationGroup : public CommunicationGroup {
MPI_Comm group_communicator_;
};
using MPICommunicationGroupPtr = std::shared_ptr<MPICommunicationGroup>;
#define CHECK_MPI_RET(expression, message) \
do { \
{ \
auto ret = (expression); \
if (ret != MPI_SUCCESS) { \
MS_LOG(EXCEPTION) << (message); \
} \
} \
} while (false)
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include "runtime/device/device_address.h"
#include "runtime/device/bucket.h"
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/collective/collective_comm_lib_loader.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/common_backend_optimization.h"
@ -49,7 +50,7 @@ struct DeviceContextKey {
class DeviceContext {
public:
explicit DeviceContext(const DeviceContextKey &device_context_key)
: device_context_key_(device_context_key), collective_comm_lib_(nullptr) {}
: device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {}
virtual ~DeviceContext() = default;
// Initialize the device context.
@ -150,7 +151,7 @@ class DeviceContext {
virtual bool InitCollectiveCommLib() { return true; }
// Return collective communication object for caller to access
CollectiveCommunicationLibPtr collective_comm_lib() const { return collective_comm_lib_; }
void *collective_comm_lib() const { return collective_comm_lib_ptr_; }
// TODO(jiaorui): will be delete
// Dump all graphs.
@ -158,10 +159,11 @@ class DeviceContext {
protected:
DeviceContextKey device_context_key_;
CollectiveCommunicationLibPtr collective_comm_lib_;
// The dynamic loaded handle for collective communication library.
void *collective_comm_lib_ptr_;
};
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_

View File

@ -30,9 +30,6 @@
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/gpu/gpu_common.h"
#include "runtime/hardware/gpu/optimizer.h"
#ifdef ENABLE_MPI
#include "runtime/hardware/gpu/nvidia_collective_comm_lib.h"
#endif
#include "common/trans.h"
#include "utils/context/graph_kernel_flags.h"
#include "runtime/device/gpu/gpu_bucket.h"
@ -530,8 +527,21 @@ std::shared_ptr<Bucket> GPUDeviceContext::CreateBucket(uint32_t bucket_id, uint3
bool GPUDeviceContext::InitCollectiveCommLib() {
#ifdef ENABLE_MPI
collective_comm_lib_ = &NvidiaCollectiveCommLib::GetInstance();
collective_comm_lib_->Initialize();
std::string nvidia_comm_lib_name = "libnvidia_collective.so";
auto loader = std::make_shared<CollectiveCommLibLoader>(nvidia_comm_lib_name);
MS_EXCEPTION_IF_NULL(loader);
if (!loader->Initialize()) {
MS_LOG(EXCEPTION) << "Loading NCCL collective library failed.";
return false;
}
collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
auto init_collecitve_lib_func = DlsymFuncObj(InitializeCollectiveLib, collective_comm_lib_ptr_);
if (!init_collecitve_lib_func(0, 0)) {
MS_LOG(EXCEPTION) << "Initializing for NCCL library failed.";
return false;
}
return true;
#else
return false;

View File

@ -19,10 +19,6 @@
namespace mindspore {
namespace device {
namespace gpu {
NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() {
collective_comm_lib_ptr_ = CollectiveInitializer::instance().collective_handle();
}
bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
if (initialized_) {
return false;
@ -30,26 +26,48 @@ bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_r
global_rank_id_ = global_rank;
global_rank_size_ = global_rank_size;
MS_LOG(INFO) << "The global rank id of this process is " << global_rank_id_
<< ", global rank size of nccl_world_group is " << global_rank_size_;
initialized_ = true;
return true;
}
bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
if (groups_.count(group_name) != 0) {
MS_LOG(EXCEPTION) << "The NCCL group " << group_name << " has already existed.";
return false;
}
CHECK_RET((groups_.count(group_name) == 0), true, "The NCCL group " + group_name + " has already existed.");
NvidiaCommunicationGroupPtr group =
std::make_shared<NvidiaCommunicationGroup>(group_name, group_ranks, global_rank_id_);
MS_EXCEPTION_IF_NULL(group);
CHECK_IF_NULL(group);
groups_[group_name] = group;
MS_LOG(INFO) << "NCCL group of " << group_name << " is created. But it's not initialized yet.";
return true;
}
} // namespace gpu
} // namespace device
} // namespace mindspore
// The exported APIs for 'dlsym' to load.
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
}
bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); }
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}
bool DestroyCommunicationGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}
uint32_t GetRankId(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name);
}
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name);
}
bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }
uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }

View File

@ -40,10 +40,24 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
private:
NvidiaCollectiveCommLib();
NvidiaCollectiveCommLib() = default;
~NvidiaCollectiveCommLib() override = default;
};
} // namespace gpu
} // namespace device
} // namespace mindspore
#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_

View File

@ -21,9 +21,7 @@ namespace device {
namespace gpu {
NvidiaCommunicationGroup::NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
uint32_t global_rank)
: CommunicationGroup(name, group_ranks, global_rank) {
collective_comm_lib_ptr_ = CollectiveInitializer::instance().collective_handle();
}
: CommunicationGroup(name, group_ranks, global_rank) {}
bool NvidiaCommunicationGroup::Initialize(void *root_info) {
if (initialized_) {
@ -32,20 +30,11 @@ bool NvidiaCommunicationGroup::Initialize(void *root_info) {
// The unique id is broadcasted by the root rank.
unique_id_ = *(static_cast<ncclUniqueId *>(root_info));
uint32_t group_rank = GetGroupRank(global_rank_);
// Initialize the NCCL communicator while the group created. Pay attention that 'ncclCommInitRank' should be called
// after GPU device id is set.
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
auto comm_init_rank =
reinterpret_cast<NCCLCommInitRank>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "NCCLCommInitRank"));
MS_EXCEPTION_IF_NULL(comm_init_rank);
MS_LOG(INFO) << "Start initializing NCCL communicator for group " << name_;
uint32_t group_rank = GetGroupRank(global_rank_);
CHECK_NCCL_RET((*comm_init_rank)(&comm_, SizeToInt(size_), unique_id_, UintToInt(group_rank)),
"Initializing NCCL communicator failed.");
MS_LOG(INFO) << "NCCL communicator for group " << name_ << " is successfully initialized.";
CHECK_RET(ncclCommInitRank(&comm_, static_cast<int>(size_), unique_id_, static_cast<int>(group_rank)), ncclSuccess,
"Initializing NCCL communicator failed.");
initialized_ = true;
return true;
}
@ -55,26 +44,14 @@ bool NvidiaCommunicationGroup::Finalize() {
return false;
}
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
auto comm_abort =
reinterpret_cast<NCCLCommAbort>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "NCCLCommAbort"));
MS_EXCEPTION_IF_NULL(comm_abort);
auto comm_destroy =
reinterpret_cast<NCCLCommDestroy>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "NCCLCommDestroy"));
MS_EXCEPTION_IF_NULL(comm_destroy);
CHECK_NCCL_RET((*comm_abort)(comm_), "Failed to abort NCCL communicator.");
CHECK_NCCL_RET((*comm_destroy)(comm_), "Failed to destroy NCCL communicator.");
CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator.");
CHECK_RET(ncclCommDestroy(comm_), ncclSuccess, "Failed to destroy NCCL communicator.");
initialized_ = false;
return true;
}
void *NvidiaCommunicationGroup::GenerateRootInfo() {
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
auto nccl_id_funcptr =
reinterpret_cast<NcclUniqueId>(dlsym(const_cast<void *>(collective_comm_lib_ptr_), "nccl_unique_id"));
MS_EXCEPTION_IF_NULL(nccl_id_funcptr);
unique_id_ = (*nccl_id_funcptr)();
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
return &unique_id_;
}
} // namespace gpu

View File

@ -22,16 +22,11 @@
#include <vector>
#include <memory>
#include "runtime/hardware/collective/communication_group.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "utils/dlopen_macro.h"
namespace mindspore {
namespace device {
namespace gpu {
using NcclUniqueId = ncclUniqueId (*)();
using NCCLCommInitRank = ncclResult_t (*)(ncclComm_t *, int, ncclUniqueId, int);
using NCCLCommAbort = ncclResult_t (*)(ncclComm_t);
using NCCLCommDestroy = ncclResult_t (*)(ncclComm_t);
class NvidiaCommunicationGroup : public CommunicationGroup {
public:
explicit NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
@ -52,17 +47,6 @@ class NvidiaCommunicationGroup : public CommunicationGroup {
ncclComm_t comm_;
};
using NvidiaCommunicationGroupPtr = std::shared_ptr<NvidiaCommunicationGroup>;
using CollectiveInitializer = device::gpu::CollectiveInitializer;
#define CHECK_NCCL_RET(expression, message) \
do { \
{ \
auto ret = (expression); \
if (ret != ncclSuccess) { \
MS_LOG(EXCEPTION) << (message); \
} \
} \
} while (false)
} // namespace gpu
} // namespace device
} // namespace mindspore