forked from mindspore-Ecosystem/mindspore
!26720 Add implement of collective manager API
Merge pull request !26720 from ZPaC/dir-of-distributed
This commit is contained in:
commit
6c20f55a51
|
@ -28,7 +28,9 @@ CollectiveManager::CollectiveManager()
|
|||
host_ctx_(nullptr),
|
||||
device_ctx_(nullptr),
|
||||
host_comm_lib_(nullptr),
|
||||
host_comm_lib_instance_(nullptr),
|
||||
device_comm_lib_(nullptr),
|
||||
device_comm_lib_instance_(nullptr),
|
||||
global_rank_id_(0),
|
||||
local_rank_id_(0),
|
||||
global_rank_size_(0),
|
||||
|
@ -61,13 +63,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
|
|||
return false;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_);
|
||||
// Step 2: Create global communication group on host side.
|
||||
if (!CreateHostGlobalCommGroup(global_group_name)) {
|
||||
// Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not
|
||||
// be necessary.
|
||||
// Step 2: Assign local rank id(device id) for this process.
|
||||
if (!AssignLocalRank(global_group_name)) {
|
||||
MS_LOG(ERROR) << "Failed to assign local rank id.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 3: Initialize device side collective communication.
|
||||
if (!InitDeviceCommLib(backend)) {
|
||||
MS_LOG(ERROR) << "Failed to initialize device communication library.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 4: Create global communication group.
|
||||
if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) {
|
||||
MS_LOG(ERROR) << "Failed to initialize host communication library.";
|
||||
return false;
|
||||
}
|
||||
// Step 3: Assign local rank id(device id) for this process.
|
||||
|
||||
MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << ".";
|
||||
return true;
|
||||
|
@ -75,25 +89,81 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
|
|||
|
||||
bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks) {
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_);
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_);
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
// Step 1: Create communication group on host side.
|
||||
// Step 2: Generate device information of the root node.
|
||||
// Step 3: Broadcast the device root information to all nodes.
|
||||
// Step 4: Create communication group on device side.
|
||||
if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
|
||||
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 2: Create communication group on device side.
|
||||
if (!device_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
|
||||
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on device side.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 3: Generate device information of the root node.
|
||||
CommunicationGroupPtr group = device_comm_lib_instance_->GetGroup(group_name);
|
||||
MS_EXCEPTION_IF_NULL(group);
|
||||
size_t root_info_size = 0;
|
||||
void *root_info = group->GenerateRootInfo(&root_info_size);
|
||||
MS_EXCEPTION_IF_NULL(root_info);
|
||||
|
||||
// Step 4: Broadcast the device root information to all nodes on host side.
|
||||
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt, 0, group_name,
|
||||
nullptr)) {
|
||||
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 5: Initialize communication group on the device side.
|
||||
if (!group->Initialize(root_info)) {
|
||||
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) { return true; }
|
||||
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
if (!host_comm_lib_instance_->DestroyCommunicationGroup(group_name)) {
|
||||
MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the host side.";
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t CollectiveManager::GetRankId(const std::string &group_name) { return 0; }
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
if (!device_comm_lib_instance_->DestroyCommunicationGroup(group_name)) {
|
||||
MS_LOG(ERROR) << "Failed to destroy communication group of " << group_name << " on the device side.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) { return 0; }
|
||||
uint32_t CollectiveManager::GetRankId(const std::string &group_name) {
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
return host_comm_lib_instance_->GetRankId(group_name);
|
||||
}
|
||||
|
||||
uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
return host_comm_lib_instance_->GetGroupSize(group_name);
|
||||
}
|
||||
|
||||
bool CollectiveManager::Finalize() {
|
||||
if (finalized_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
if (!host_comm_lib_instance_->Finalize()) {
|
||||
MS_LOG(WARNING) << "Failed to finalize host communication library.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
if (!device_comm_lib_instance_->Finalize()) {
|
||||
MS_LOG(WARNING) << "Failed to finalize device communication library.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -105,19 +175,34 @@ bool CollectiveManager::InitHostCommlib() {
|
|||
MS_LOG(ERROR) << "Failed to load communication library on the host side.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CollectiveManager::CreateHostGlobalCommGroup(const std::string &global_group_name) {
|
||||
host_comm_lib_ = host_ctx_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_);
|
||||
if (global_group_ranks_.empty()) {
|
||||
MS_LOG(ERROR) << "The global group rank list is empty.";
|
||||
auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_);
|
||||
host_comm_lib_instance_ = instance_func();
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
|
||||
// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
|
||||
// MindSpore communication. For other communication libraries, global rank id and size is generated by itself, e.g.,
|
||||
// OpenMPI, and parameters 'global_rank_id_', 'global_rank_size_' will not be used.
|
||||
MS_LOG(INFO) << "Start initializing communication library on host side...";
|
||||
if (!host_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) {
|
||||
MS_LOG(ERROR) << "Failed to initialize communication library on host side.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
|
||||
global_rank_id_ = host_comm_lib_instance_->global_rank_id();
|
||||
global_rank_size_ = host_comm_lib_instance_->global_rank_size();
|
||||
for (uint32_t i = 0; i < global_rank_size_; i++) {
|
||||
global_group_ranks_.push_back(i);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
|
||||
<< ", global rank size: " << global_rank_size_;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t device_id) {
|
||||
bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
|
||||
std::string device_name;
|
||||
if (backend == "nccl") {
|
||||
device_name = "GPU";
|
||||
|
@ -128,13 +213,68 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend, uint32_t d
|
|||
return false;
|
||||
}
|
||||
|
||||
device::DeviceContextKey device_key = {device_name, device_id};
|
||||
device::DeviceContextKey device_key = {device_name, local_rank_id_};
|
||||
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_);
|
||||
if (!device_ctx_->LoadCollectiveCommLib()) {
|
||||
MS_LOG(ERROR) << "Failed to load communication library on the device side.";
|
||||
return false;
|
||||
}
|
||||
device_comm_lib_ = device_ctx_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_);
|
||||
auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_);
|
||||
device_comm_lib_instance_ = instance_func();
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
|
||||
MS_LOG(INFO) << "Start initializing communication library on device side...";
|
||||
if (!device_comm_lib_instance_->Initialize(global_rank_id_, global_rank_size_)) {
|
||||
MS_LOG(ERROR) << "Failed to initialize communication library on device side.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Communication library on device side is successfully initialized.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
|
||||
char host_name[MAX_HOSTNAME_LEN] = {0};
|
||||
#ifndef _WIN32
|
||||
if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
|
||||
MS_LOG(ERROR) << "Failed to get host name.";
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
MS_LOG(INFO) << "Host name for rank " << global_rank_id_ << " is " << host_name;
|
||||
|
||||
// Generate host name hash for every process. The host names of different physical machine should not be the same so
|
||||
// that local rank id won't repeat.
|
||||
size_t host_hash = std::hash<std::string>()(host_name);
|
||||
const uint32_t kGlobalRankSize = global_rank_size_;
|
||||
size_t all_host_hashs[kGlobalRankSize];
|
||||
if (global_rank_id_ >= global_rank_size_) {
|
||||
MS_LOG(ERROR) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
|
||||
<< global_rank_size_;
|
||||
return false;
|
||||
}
|
||||
all_host_hashs[global_rank_id_] = host_hash;
|
||||
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
// AllGather host names across the global communication group.
|
||||
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt,
|
||||
global_group_name, nullptr)) {
|
||||
MS_LOG(ERROR) << "AllGather for host names failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Accumulate rank id.
|
||||
for (uint32_t rank = 0; rank < global_rank_size_; rank++) {
|
||||
if (rank == global_rank_id_) {
|
||||
break;
|
||||
}
|
||||
if (all_host_hashs[rank] == all_host_hashs[global_rank_id_]) {
|
||||
local_rank_id_++;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_;
|
||||
return true;
|
||||
}
|
||||
} // namespace collective
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
#include <atomic>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "distributed/constants.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -30,6 +31,8 @@ namespace collective {
|
|||
using DeviceContext = device::DeviceContext;
|
||||
using DeviceContextKey = device::DeviceContextKey;
|
||||
using DeviceContextManager = device::DeviceContextManager;
|
||||
using CollectiveCommunicationLib = device::CollectiveCommunicationLib;
|
||||
using CommunicationGroupPtr = device::CommunicationGroupPtr;
|
||||
|
||||
// The collective communication API.
|
||||
// MindSpore uses OpenMPI on CPU, NCCL on GPU, HCCL on Ascend, to achieve distributed training.
|
||||
|
@ -43,6 +46,9 @@ class CollectiveManager {
|
|||
// Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL.
|
||||
bool Initialize(const std::string &backend, const std::string &global_group_name);
|
||||
|
||||
// Finalize the collective communication.
|
||||
bool Finalize();
|
||||
|
||||
// Create communication group.
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);
|
||||
|
||||
|
@ -55,8 +61,10 @@ class CollectiveManager {
|
|||
// Get the size of the specified group.
|
||||
uint32_t GetGroupSize(const std::string &group_name);
|
||||
|
||||
// Finalize the collective communication.
|
||||
bool Finalize();
|
||||
// In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
|
||||
// framework, they're generated by cluster::ClusterContext.
|
||||
uint32_t set_global_rank_id();
|
||||
uint32_t set_global_rank_size();
|
||||
|
||||
private:
|
||||
CollectiveManager();
|
||||
|
@ -64,14 +72,11 @@ class CollectiveManager {
|
|||
// Initialize communication library on host side.
|
||||
bool InitHostCommlib();
|
||||
|
||||
// Create world communication group on the host side.
|
||||
bool CreateHostGlobalCommGroup(const std::string &global_group_name);
|
||||
|
||||
// Initialize communication library on device side.
|
||||
bool InitDeviceCommLib(const std::string &backend, uint32_t device_id);
|
||||
bool InitDeviceCommLib(const std::string &backend);
|
||||
|
||||
// Create world communication group on the device side.
|
||||
bool CreateDeviceGlobalCommGroup(const std::string &global_group_name);
|
||||
// Assign the local rank id for this process.
|
||||
bool AssignLocalRank(const std::string &global_group_name);
|
||||
|
||||
std::atomic_bool inited_;
|
||||
std::atomic_bool finalized_;
|
||||
|
@ -81,9 +86,15 @@ class CollectiveManager {
|
|||
DeviceContext *host_ctx_;
|
||||
DeviceContext *device_ctx_;
|
||||
|
||||
// The dynamically loaded handle for collective communication library by 'dlopen'.
|
||||
// Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
|
||||
// framework.
|
||||
void *host_comm_lib_;
|
||||
CollectiveCommunicationLib *host_comm_lib_instance_;
|
||||
|
||||
// Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
|
||||
// When only CPU backend is used, device communication library should not be initialized.
|
||||
void *device_comm_lib_;
|
||||
CollectiveCommunicationLib *device_comm_lib_instance_;
|
||||
|
||||
// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
|
||||
uint32_t global_rank_id_;
|
||||
|
|
|
@ -34,6 +34,7 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
|||
const std::set<std::string> kValidRoleName = {kEnvRoleOfServer, kEnvRoleOfWorker, kEnvRoleOfScheduler};
|
||||
|
||||
constexpr char kLocalHost[] = "127.0.0.1";
|
||||
constexpr int MAX_HOSTNAME_LEN = 1024;
|
||||
const uint16_t kDefaultSchedPort = 6667;
|
||||
const uint16_t kMaxPort = 65535;
|
||||
} // namespace distributed
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/dlopen_macro.h"
|
||||
#include "runtime/hardware/collective/collective_communication_lib.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -51,17 +52,6 @@ using CollectiveCommLibLoaderPtr = std::shared_ptr<CollectiveCommLibLoader>;
|
|||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#ifndef _WIN32
|
||||
// The exported symbols of collective communication shared library is registered here.
|
||||
ORIGIN_METHOD(InitializeCollectiveLib, bool, uint32_t, uint32_t)
|
||||
ORIGIN_METHOD(FinalizeCollectiveLib, bool)
|
||||
ORIGIN_METHOD(CreateCommunicationGroup, bool, const std::string &, const std::vector<uint32_t> &)
|
||||
ORIGIN_METHOD(DestroyCommunicationGroup, bool, const std::string &)
|
||||
ORIGIN_METHOD(GetRankId, uint32_t, const std::string &)
|
||||
ORIGIN_METHOD(GetCommunicationGroupSize, uint32_t, const std::string &)
|
||||
ORIGIN_METHOD(AssignLocalRank, bool)
|
||||
ORIGIN_METHOD(global_rank_id, uint32_t)
|
||||
ORIGIN_METHOD(local_rank_id, uint32_t)
|
||||
ORIGIN_METHOD(global_rank_size, uint32_t)
|
||||
#endif
|
||||
ORIGIN_METHOD(communication_lib_instance, mindspore::device::CollectiveCommunicationLib *)
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_COLLECTIVE_COLLECTIVE_LIB_LOADER_H_
|
||||
|
|
|
@ -55,6 +55,11 @@ uint32_t CollectiveCommunicationLib::GetGroupSize(const std::string &group_name)
|
|||
return group->group_size();
|
||||
}
|
||||
|
||||
CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &group_name) {
|
||||
CHECK_RET(groups_.count(group_name) != 0, true, "The group " + group_name + " does not exist.");
|
||||
return groups_[group_name];
|
||||
}
|
||||
|
||||
uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }
|
||||
|
||||
uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "runtime/hardware/collective/communication_group.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -61,6 +62,21 @@ class CollectiveCommunicationLib {
|
|||
// Assign the local rank id for this process. Normally used by collective communication library on the host side.
|
||||
virtual bool AssignLocalRank() { return true; }
|
||||
|
||||
// Return communication group pointer.
|
||||
virtual CommunicationGroupPtr GetGroup(const std::string &group_name);
|
||||
|
||||
// Primitive of AllGather operation.
|
||||
virtual bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Primitive of Broadcast operation.
|
||||
virtual bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
uint32_t root_rank, const std::string &group_name, void *stream) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns global rank id of this process.
|
||||
uint32_t global_rank_id() const;
|
||||
|
||||
|
|
|
@ -44,10 +44,9 @@ class CommunicationGroup {
|
|||
// Finalize the communication group. For example, destroy the group, etc.
|
||||
virtual bool Finalize() = 0;
|
||||
|
||||
// Return the root rank's information. Only root rank of one group could call this method.Normally this is used for
|
||||
// collective libraries on the device side. For NCCL group, it returns 'ncclUniqueId'. For HCCL group, it returns
|
||||
// 'HcclRootInfo'.
|
||||
virtual void *GenerateRootInfo() { return nullptr; }
|
||||
// Return the root rank's information and its size. Normally this is used for collective libraries on the device side.
|
||||
// For NCCL group, it returns a pointer to 'ncclUniqueId'. For HCCL group, it returns a pointer to 'HcclRootInfo'.
|
||||
virtual void *GenerateRootInfo(size_t *root_info_size) { return nullptr; }
|
||||
|
||||
// Get group or global rank for the given rank.
|
||||
uint32_t GetGroupRank(uint32_t global_rank);
|
||||
|
|
|
@ -55,11 +55,12 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam
|
|||
return true;
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
// The exported APIs for 'dlsym' to load.
|
||||
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;
|
||||
|
||||
CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); }
|
||||
|
||||
bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); }
|
||||
|
||||
bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
|
||||
|
@ -80,8 +81,24 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) {
|
|||
|
||||
bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }
|
||||
|
||||
CommunicationGroupPtr GetGroup(const std::string &group_name) {
|
||||
return MPICollectiveCommLib::GetInstance().GetGroup(group_name);
|
||||
}
|
||||
|
||||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream) {
|
||||
return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream);
|
||||
}
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
|
||||
uint32_t root_rank, const std::string &group_name, void *stream) {
|
||||
return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
|
||||
group_name, stream);
|
||||
}
|
||||
|
||||
uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); }
|
||||
|
||||
uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }
|
||||
|
||||
uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); }
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,16 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
|
|||
// Override creating method. Reuse destroying method in base class CollectiveCommunicationLib.
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
|
||||
|
||||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
MPICollectiveCommLib() = default;
|
||||
~MPICollectiveCommLib() override = default;
|
||||
|
@ -45,12 +55,11 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
|
|||
MPI_Group world_group_;
|
||||
};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#ifndef EXPORT_MPI_WRAPPER
|
||||
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
|
||||
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
|
||||
uint32_t global_rank_size = UINT32_MAX);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
|
||||
|
@ -60,7 +69,15 @@ extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &
|
|||
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
|
||||
extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id();
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size();
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_
|
||||
|
|
|
@ -41,11 +41,11 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_
|
|||
return true;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
// The exported APIs for 'dlsym' to load.
|
||||
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
|
||||
CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); }
|
||||
|
||||
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
|
||||
}
|
||||
|
@ -70,4 +70,22 @@ uint32_t GetCommunicationGroupSize(const std::string &group_name) {
|
|||
|
||||
bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }
|
||||
|
||||
CommunicationGroupPtr GetGroup(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name);
|
||||
}
|
||||
|
||||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
|
||||
const std::string &group_name, void *stream) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name,
|
||||
stream);
|
||||
}
|
||||
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
|
||||
uint32_t root_rank, const std::string &group_name, void *stream) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
|
||||
group_name, stream);
|
||||
}
|
||||
|
||||
uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,17 +39,26 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
|
|||
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
|
||||
|
||||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
NvidiaCollectiveCommLib() = default;
|
||||
~NvidiaCollectiveCommLib() override = default;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#ifndef EXPORT_NCCL_WRAPPER
|
||||
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
|
||||
uint32_t global_rank_size = UINT32_MAX);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
|
||||
|
@ -59,5 +68,13 @@ extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string
|
|||
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
|
||||
extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace device {
|
|||
namespace gpu {
|
||||
NvidiaCommunicationGroup::NvidiaCommunicationGroup(const std::string name, const std::vector<uint32_t> &group_ranks,
|
||||
uint32_t global_rank)
|
||||
: CommunicationGroup(name, group_ranks, global_rank) {}
|
||||
: CommunicationGroup(name, group_ranks, global_rank), unique_id_({}), comm_(nullptr) {}
|
||||
|
||||
bool NvidiaCommunicationGroup::Initialize(void *root_info) {
|
||||
if (initialized_) {
|
||||
|
@ -50,8 +50,12 @@ bool NvidiaCommunicationGroup::Finalize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void *NvidiaCommunicationGroup::GenerateRootInfo() {
|
||||
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
|
||||
void *NvidiaCommunicationGroup::GenerateRootInfo(size_t *root_info_size) {
|
||||
*root_info_size = sizeof(unique_id_);
|
||||
uint32_t group_rank = GetGroupRank(global_rank_);
|
||||
if (group_rank == 0) {
|
||||
CHECK_RET(ncclGetUniqueId(&unique_id_), ncclSuccess, "Failed to get NCCL unique id.");
|
||||
}
|
||||
return &unique_id_;
|
||||
}
|
||||
} // namespace gpu
|
||||
|
|
|
@ -37,7 +37,7 @@ class NvidiaCommunicationGroup : public CommunicationGroup {
|
|||
bool Initialize(void *root_info) override;
|
||||
bool Finalize() override;
|
||||
|
||||
void *GenerateRootInfo() override;
|
||||
void *GenerateRootInfo(size_t *root_info_size) override;
|
||||
|
||||
private:
|
||||
// The NCCL unique id for this group. Used to initialize this group's communicator.
|
||||
|
|
Loading…
Reference in New Issue