!26720 Add implement of collective manager API

Merge pull request !26720 from ZPaC/dir-of-distributed
This commit is contained in:
i-robot 2021-11-25 13:28:14 +00:00 committed by Gitee
commit 6c20f55a51
13 changed files with 292 additions and 57 deletions

View File

@ -28,7 +28,9 @@ CollectiveManager::CollectiveManager()
host_ctx_(nullptr),
device_ctx_(nullptr),
host_comm_lib_(nullptr),
host_comm_lib_instance_(nullptr),
device_comm_lib_(nullptr),
device_comm_lib_instance_(nullptr),
global_rank_id_(0),
local_rank_id_(0),
global_rank_size_(0),
@ -61,13 +63,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
return false;
}
MS_EXCEPTION_IF_NULL(host_comm_lib_);
// Step 2: Create global communication group on host side.
if (!CreateHostGlobalCommGroup(global_group_name)) {
// Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not
// be necessary.
// Step 2: Assign local rank id(device id) for this process.
if (!AssignLocalRank(global_group_name)) {
MS_LOG(ERROR) << "Failed to assign local rank id.";
return false;
}
// Step 3: Initialize device side collective communication.
if (!InitDeviceCommLib(backend)) {
MS_LOG(ERROR) << "Failed to initialize device communication library.";
return false;
}
// Step 4: Create global communication group.
if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to initialize host communication library.";
return false;
}
// Step 3: Assign local rank id(device id) for this process.
MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << ".";
return true;
@ -75,25 +89,81 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
MS_EXCEPTION_IF_NULL(host_comm_lib_);
MS_EXCEPTION_IF_NULL(device_comm_lib_);
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
// Step 1: Create communication group on host side.
// Step 2: Generate device information of the root node.
// Step 3: Broadcast the device root information to all nodes.
// Step 4: Create communication group on device side.
if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side.";
return false;
}
// Step 2: Create communication group on device side.
if (!device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on device side.";
return false;
}
// Step 3: Generate device information of the root node.
CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
MS_EXCEPTION_IF_NULL(group);
size_t root_info_size = 0;
void *root_info = group->GenerateRootInfo(&root_info_size);
MS_EXCEPTION_IF_NULL(root_info);
// Step 4: Broadcast the device root information to all nodes on host side.
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt, 0, group_name,
nullptr)) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
return false;
}
// Step 5: Initialize communication group on the device side.
if (!group->Initialize(root_info)) {
MS_LOG(ERROR) << "Initialize group on the device side failed.";
return false;
}
return true;
}
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) { return true; }
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->DestroyCommunicationGroup(group_name)) {
MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the host side.";
return false;
}
uint32_t CollectiveManager::GetRankId(const std::string &group_name) { return 0; }
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->DestroyCommunicationGroup(group_name)) {
MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the device side.";
return false;
}
return true;
}
uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) { return 0; }
uint32_t CollectiveManager::GetRankId(const std::string &group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
return host_comm_lib_instance_->GetRankId(group_name);
}
uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
return host_comm_lib_instance_->GetGroupSize(group_name);
}
bool CollectiveManager::Finalize() {
if (finalized_) {
return true;
}
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize host communication library.";
}
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize device communication library.";
}
return true;
}
@ -105,19 +175,34 @@ bool CollectiveManager::InitHostCommlib() {
MS_LOG(ERROR) << "Failed to load communication library on the host side.";
return false;
}
return true;
}
bool CollectiveManager::CreateHostGlobalCommGroup(const std::string &global_group_name) {
host_comm_lib_ = host_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_);
if (global_group_ranks_.empty()) {
MS_LOG(ERROR) << "The global group rank list is empty.";
auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_);
host_comm_lib_instance_ = instance_func();
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
// MindSpore communication. For other communication libraries, global rank id and size is generated by itself, e.g.,
// OpenMPI, and parameters 'global_rank_id_', 'global_rank_size_' will not be used.
MS_LOG(INFO) << "Start initializing communication library on host side...";
if (!host_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) {
MS_LOG(ERROR) << "Failed to initialize communication library on host side.";
return false;
}
// Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
global_rank_id_ = host_comm_lib_instance_->global_rank_id();
global_rank_size_ = host_comm_lib_instance_->global_rank_size();
for (uint32_t i = 0; i < global_rank_size_; i++) {
global_group_ranks_.push_back(i);
}
MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
<< ", global rank size: " << global_rank_size_;
return true;
}
bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t device_id) {
bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
std::string device_name;
if (backend == "nccl") {
device_name = "GPU";
@ -128,13 +213,68 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t d
return false;
}
device::DeviceContextKey device_key = {device_name, device_id};
device::DeviceContextKey device_key = {device_name, local_rank_id_};
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
MS_EXCEPTION_IF_NULL(device_ctx_);
if (!device_ctx_->LoadCollectiveCommLib()) {
MS_LOG(ERROR) << "Failed to load communication library on the device side.";
return false;
}
device_comm_lib_ = device_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(device_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_);
device_comm_lib_instance_ = instance_func();
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
MS_LOG(INFO) << "Start initializing communication library on device side...";
if (!device_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) {
MS_LOG(ERROR) << "Failed to initialize communication library on device side.";
return false;
}
MS_LOG(INFO) << "Communication library on device side is successfully initialized.";
return true;
}
bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
char host_name[MAX_HOSTNAME_LEN] = {0};
#ifndef _WIN32
if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
MS_LOG(ERROR) << "Failed to get host name.";
return false;
}
#endif
MS_LOG(INFO) << "Host name for rank " << global_rank_id_ << " is " << host_name;
// Generate host name hash for every process. The host names of different physical machine should not be the same so
// that local rank id won't repeat.
size_t host_hash = std::hash<std::string>()(host_name);
const uint32_t kGlobalRankSize = global_rank_size_;
size_t all_host_hashs[kGlobalRankSize];
if (global_rank_id_ >= global_rank_size_) {
MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
<< global_rank_size_;
return false;
}
all_host_hashs[global_rank_id_] = host_hash;
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// AllGather host names across the global communication group.
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt,
global_group_name, nullptr)) {
MS_LOG(ERROR) << "AllGather for host names failed.";
return false;
}
// Accumulate rank id.
for (uint32_t rank = 0; rank < global_rank_size_; rank++) {
if (rank == global_rank_id_) {
break;
}
if (all_host_hashs[rank] == all_host_hashs[global_rank_id_]) {
local_rank_id_++;
}
}
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_;
return true;
}
} // namespace collective

View File

@ -22,6 +22,7 @@
#include <vector>
#include <atomic>
#include "utils/ms_utils.h"
#include "distributed/constants.h"
#include "runtime/hardware/device_context_manager.h"
namespace mindspore {
@ -30,6 +31,8 @@ namespace collective {
using DeviceContext = device::DeviceContext;
using DeviceContextKey = device::DeviceContextKey;
using DeviceContextManager = device::DeviceContextManager;
using CollectiveCommunicationLib = device::CollectiveCommunicationLib;
using CommunicationGroupPtr = device::CommunicationGroupPtr;
// The collective communication API.
// MindSpore uses OpenMPI on CPU, NCCL on GPU, HCCL on Ascend, to achieve distributed training.
@ -43,6 +46,9 @@ class CollectiveManager {
// Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL.
bool Initialize(const std::string &backend, const std::string &global_group_name);
// Finalize the collective communication.
bool Finalize();
// Create communication group.
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);
@ -55,8 +61,10 @@ class CollectiveManager {
// Get the size of the specified group.
uint32_t GetGroupSize(const std::string &group_name);
// Finalize the collective communication.
bool Finalize();
// In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
// framework, they're generated by cluster::ClusterContext.
uint32_t set_global_rank_id();
uint32_t set_global_rank_size();
private:
CollectiveManager();
@ -64,14 +72,11 @@ class CollectiveManager {
// Initialize communication library on host side.
bool InitHostCommlib();
// Create world communication group on the host side.
bool CreateHostGlobalCommGroup(const std::string &global_group_name);
// Initialize communication library on device side.
bool InitDeviceCommLib(const std::string &backend, uint32_t device_id);
bool InitDeviceCommLib(const std::string &backend);
// Create world communication group on the device side.
bool CreateDeviceGlobalCommGroup(const std::string &global_group_name);
// Assign the local rank id for this process.
bool AssignLocalRank(const std::string &global_group_name);
std::atomic_bool inited_;
std::atomic_bool finalized_;
@ -81,9 +86,15 @@ class CollectiveManager {
DeviceContext *host_ctx_;
DeviceContext *device_ctx_;
// The dynamically loaded handle for collective communication library by 'dlopen'.
// Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
// framework.
void *host_comm_lib_;
CollectiveCommunicationLib *host_comm_lib_instance_;
// Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
// When only CPU backend is used, device communication library should not be initialized.
void *device_comm_lib_;
CollectiveCommunicationLib *device_comm_lib_instance_;
// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
uint32_t global_rank_id_;

View File

@ -34,6 +34,7 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
const std::set<std::string> kValidRoleName = {kEnvRoleOfServer, kEnvRoleOfWorker, kEnvRoleOfScheduler};
constexpr char kLocalHost[] = "127.0.0.1";
constexpr int MAX_HOSTNAME_LEN = 1024;
const uint16_t kDefaultSchedPort = 6667;
const uint16_t kMaxPort = 65535;
} // namespace distributed

View File

@ -21,6 +21,7 @@
#include <memory>
#include <vector>
#include "utils/dlopen_macro.h"
#include "runtime/hardware/collective/collective_communication_lib.h"
namespace mindspore {
namespace device {
@ -51,17 +52,6 @@ using CollectiveCommLibLoaderPtr = std::shared_ptr<CollectiveCommLibLoader>;
} // namespace device
} // namespace mindspore
#ifndef _WIN32
// The exported symbols of collective communication shared library is registered here.
ORIGIN_METHOD(InitializeCollectiveLib, bool, uint32_t, uint32_t)
ORIGIN_METHOD(FinalizeCollectiveLib, bool)
ORIGIN_METHOD(CreateCommunicationGroup, bool, const std::string &, const std::vector<uint32_t> &)
ORIGIN_METHOD(DestroyCommunicationGroup, bool, const std::string &)
ORIGIN_METHOD(GetRankId, uint32_t, const std::string &)
ORIGIN_METHOD(GetCommunicationGroupSize, uint32_t, const std::string &)
ORIGIN_METHOD(AssignLocalRank, bool)
ORIGIN_METHOD(global_rank_id, uint32_t)
ORIGIN_METHOD(local_rank_id, uint32_t)
ORIGIN_METHOD(global_rank_size, uint32_t)
#endif
ORIGIN_METHOD(communication_lib_instance, mindspore::device::CollectiveCommunicationLib *)
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_

View File

@ -55,6 +55,11 @@ uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name)
return group->group_size();
}
CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &group_name) {
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
return groups_[group_name];
}
uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }
uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }

View File

@ -21,6 +21,7 @@
#include <memory>
#include <vector>
#include <string>
#include "ir/dtype/type_id.h"
#include "runtime/hardware/collective/communication_group.h"
namespace mindspore {
@ -61,6 +62,21 @@ class CollectiveCommunicationLib {
// Assign the local rank id for this process. Normally used by collective communication library on the host side.
virtual bool AssignLocalRank() { return true; }
// Return communication group pointer.
virtual CommunicationGroupPtr GetGroup(const std::string &group_name);
// Primitive of AllGather operation.
virtual bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) {
return true;
}
// Primitive of Broadcast operation.
virtual bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return true;
}
// Returns global rank id of this process.
uint32_t global_rank_id() const;

View File

@ -44,10 +44,9 @@ class CommunicationGroup {
// Finalize the communication group. For example, destroy the group, etc.
virtual bool Finalize() = 0;
// Return the root rank's information. Only root rank of one group could call this method.Normally this is used for
// collective libraries on the device side. For NCCL group, it returns 'ncclUniqueId'. For HCCL group, it returns
// 'HcclRootInfo'.
virtual void *GenerateRootInfo() { return nullptr; }
// Return the root rank's information and its size. Normally this is used for collective libraries on the device side.
// For NCCL group, it returns a pointer to 'ncclUniqueId'. For HCCL group, it returns a pointer to 'HcclRootInfo'.
virtual void *GenerateRootInfo(size_t *root_info_size) { return nullptr; }
// Get group or global rank for the given rank.
uint32_t GetGroupRank(uint32_t global_rank);

View File

@ -55,11 +55,12 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam
return true;
}
} // namespace cpu
} // namespace device
} // namespace mindspore
// The exported APIs for 'dlsym' to load.
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); }
bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); }
bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
@ -80,8 +81,24 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) {
bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }
CommunicationGroupPtr GetGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroup(group_name);
}
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}
uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); }
uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }
uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); }
} // namespace device
} // namespace mindspore

View File

@ -38,6 +38,16 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
// Override creating method. Reuse destroying method in base class CollectiveCommunicationLib.
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) override {
return true;
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream) override {
return true;
}
private:
MPICollectiveCommLib() = default;
~MPICollectiveCommLib() override = default;
@ -45,12 +55,11 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
MPI_Group world_group_;
};
} // namespace cpu
} // namespace device
} // namespace mindspore
#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
@ -60,7 +69,15 @@ extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_

View File

@ -41,11 +41,11 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_
return true;
}
} // namespace gpu
} // namespace device
} // namespace mindspore
// The exported APIs for 'dlsym' to load.
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); }
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
}
@ -70,4 +70,22 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) {
bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }
CommunicationGroupPtr GetGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name);
}
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name,
stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}
uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
} // namespace device
} // namespace mindspore

View File

@ -39,17 +39,26 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) override {
return true;
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream) override {
return true;
}
private:
NvidiaCollectiveCommLib() = default;
~NvidiaCollectiveCommLib() override = default;
};
} // namespace gpu
} // namespace device
} // namespace mindspore
#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
@ -59,5 +68,13 @@ extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_

View File

@ -21,7 +21,7 @@ namespace device {
namespace gpu {
NvidiaCommunicationGroup::NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
uint32_t global_rank)
: CommunicationGroup(name, group_ranks, global_rank) {}
: CommunicationGroup(name, group_ranks, global_rank), unique_id_({}), comm_(nullptr) {}
bool NvidiaCommunicationGroup::Initialize(void *root_info) {
if (initialized_) {
@ -50,8 +50,12 @@ bool NvidiaCommunicationGroup::Finalize() {
return true;
}
void *NvidiaCommunicationGroup::GenerateRootInfo() {
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
void *NvidiaCommunicationGroup::GenerateRootInfo(size_t *root_info_size) {
*root_info_size = sizeof(unique_id_);
uint32_t group_rank = GetGroupRank(global_rank_);
if (group_rank == 0) {
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
}
return &unique_id_;
}
} // namespace gpu

View File

@ -37,7 +37,7 @@ class NvidiaCommunicationGroup : public CommunicationGroup {
bool Initialize(void *root_info) override;
bool Finalize() override;
void *GenerateRootInfo() override;
void *GenerateRootInfo(size_t *root_info_size) override;
private:
// The NCCL unique id for this group. Used to initialize this group's communicator.