forked from mindspore-Ecosystem/mindspore
parent
7fb6df8384
commit
2b7429c5d2
|
@ -85,7 +85,7 @@ bool ClusterContext::Finalize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::string ClusterContext::node_role() const { return node_role_; }
|
||||
const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
|
||||
|
||||
void ClusterContext::InitClusterConfig() {
|
||||
InitNodeRole();
|
||||
|
|
|
@ -49,7 +49,8 @@ class ClusterContext {
|
|||
// Finalize the cluster and process exits.
|
||||
bool Finalize();
|
||||
|
||||
std::string node_role() const;
|
||||
// Return node object of this process.
|
||||
const std::shared_ptr<ps::core::Node> &node() const;
|
||||
|
||||
private:
|
||||
ClusterContext();
|
||||
|
|
|
@ -32,8 +32,6 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
|
|||
bool ClusterContext::Initialize() const { return true; }
|
||||
|
||||
bool ClusterContext::Finalize() const { return true; }
|
||||
|
||||
std::string ClusterContext::node_role() const { return ""; }
|
||||
} // namespace cluster
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,7 +39,6 @@ class ClusterContext {
|
|||
|
||||
bool Initialize() const;
|
||||
bool Finalize() const;
|
||||
std::string node_role() const;
|
||||
|
||||
private:
|
||||
ClusterContext() = default;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
|
@ -27,9 +28,7 @@ CollectiveManager::CollectiveManager()
|
|||
finalized_(true),
|
||||
host_ctx_(nullptr),
|
||||
device_ctx_(nullptr),
|
||||
host_comm_lib_(nullptr),
|
||||
host_comm_lib_instance_(nullptr),
|
||||
device_comm_lib_(nullptr),
|
||||
device_comm_lib_instance_(nullptr),
|
||||
global_rank_id_(0),
|
||||
local_rank_id_(0),
|
||||
|
@ -51,11 +50,13 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
|
|||
return instance;
|
||||
}
|
||||
|
||||
bool CollectiveManager::Initialize(const std::string &backend, const std::string &global_group_name) {
|
||||
bool CollectiveManager::Initialize() {
|
||||
if (inited_) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "Start initializing collective communication for backend: " << backend << "...";
|
||||
|
||||
device_type_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type_ << "...";
|
||||
|
||||
// Step 1: Initialize host side collective communication.
|
||||
if (!InitHostCommlib()) {
|
||||
|
@ -66,24 +67,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
|
|||
// Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not
|
||||
// be necessary.
|
||||
// Step 2: Assign local rank id(device id) for this process.
|
||||
if (!AssignLocalRank(global_group_name)) {
|
||||
if (!AssignLocalRank()) {
|
||||
MS_LOG(ERROR) << "Failed to assign local rank id.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 3: Initialize device side collective communication.
|
||||
if (!InitDeviceCommLib(backend)) {
|
||||
if (!InitDeviceCommLib()) {
|
||||
MS_LOG(ERROR) << "Failed to initialize device communication library.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 4: Create global communication group.
|
||||
if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) {
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
if (!CreateCommunicationGroup(device_comm_lib_instance_->global_group_name(), global_group_ranks_)) {
|
||||
MS_LOG(ERROR) << "Failed to initialize host communication library.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << ".";
|
||||
MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -91,7 +93,7 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
|
|||
const std::vector<uint32_t> &group_ranks) {
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
// Step 1: Create communication group on host side.
|
||||
// Step 1: Create communication group on host side if.
|
||||
if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
|
||||
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side.";
|
||||
return false;
|
||||
|
@ -167,6 +169,12 @@ bool CollectiveManager::Finalize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
|
||||
|
||||
void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
|
||||
|
||||
uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; }
|
||||
|
||||
bool CollectiveManager::InitHostCommlib() {
|
||||
device::DeviceContextKey host_key = {"CPU", 0};
|
||||
host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
|
||||
|
@ -175,10 +183,7 @@ bool CollectiveManager::InitHostCommlib() {
|
|||
MS_LOG(ERROR) << "Failed to load communication library on the host side.";
|
||||
return false;
|
||||
}
|
||||
host_comm_lib_ = host_ctx_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_);
|
||||
auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_);
|
||||
host_comm_lib_instance_ = instance_func();
|
||||
host_comm_lib_instance_ = host_ctx_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
|
||||
// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
|
||||
|
@ -197,33 +202,30 @@ bool CollectiveManager::InitHostCommlib() {
|
|||
global_group_ranks_.push_back(i);
|
||||
}
|
||||
|
||||
// Create world group on host side for AllGather operation of host name while assigning local rank.
|
||||
host_global_group_name_ = host_comm_lib_instance_->global_group_name();
|
||||
if (!host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_)) {
|
||||
MS_LOG(ERROR) << "Failed to create communication group " << host_global_group_name_ << " on host side.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
|
||||
<< ", global rank size: " << global_rank_size_;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
|
||||
std::string device_name;
|
||||
if (backend == "nccl") {
|
||||
device_name = "GPU";
|
||||
} else if (backend == "hccl") {
|
||||
device_name = "Ascend";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Backend " << backend << " is not supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
device::DeviceContextKey device_key = {device_name, local_rank_id_};
|
||||
bool CollectiveManager::InitDeviceCommLib() {
|
||||
device::DeviceContextKey device_key = {device_type_, local_rank_id_};
|
||||
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_);
|
||||
// We can initialize device context now because device id(local_rank_id_) is already assigned.
|
||||
device_ctx_->Initialize();
|
||||
|
||||
if (!device_ctx_->LoadCollectiveCommLib()) {
|
||||
MS_LOG(ERROR) << "Failed to load communication library on the device side.";
|
||||
return false;
|
||||
}
|
||||
device_comm_lib_ = device_ctx_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_);
|
||||
auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_);
|
||||
device_comm_lib_instance_ = instance_func();
|
||||
device_comm_lib_instance_ = device_ctx_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
|
||||
MS_LOG(INFO) << "Start initializing communication library on device side...";
|
||||
|
@ -235,7 +237,7 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
|
||||
bool CollectiveManager::AssignLocalRank() {
|
||||
char host_name[MAX_HOSTNAME_LEN] = {0};
|
||||
#ifndef _WIN32
|
||||
if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
|
||||
|
@ -259,8 +261,8 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
|
|||
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
// AllGather host names across the global communication group.
|
||||
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt,
|
||||
global_group_name, nullptr)) {
|
||||
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeInt,
|
||||
host_global_group_name_, nullptr)) {
|
||||
MS_LOG(ERROR) << "AllGather for host names failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -274,7 +276,10 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
|
|||
local_rank_id_++;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_;
|
||||
|
||||
MsContext::GetInstance()->set_param<uint32_t>(MS_CTX_DEVICE_ID, local_rank_id_);
|
||||
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_
|
||||
<< ". device_id of ms_context is set.";
|
||||
return true;
|
||||
}
|
||||
} // namespace collective
|
||||
|
|
|
@ -43,8 +43,8 @@ class CollectiveManager {
|
|||
DISABLE_COPY_AND_ASSIGN(CollectiveManager);
|
||||
static std::shared_ptr<CollectiveManager> instance();
|
||||
|
||||
// Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL.
|
||||
bool Initialize(const std::string &backend, const std::string &global_group_name);
|
||||
// Initialize the collective communication for distributed training. The backend type is read from MindSpore context.
|
||||
bool Initialize();
|
||||
|
||||
// Finalize the collective communication.
|
||||
bool Finalize();
|
||||
|
@ -63,8 +63,10 @@ class CollectiveManager {
|
|||
|
||||
// In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
|
||||
// framework, they're generated by cluster::ClusterContext.
|
||||
uint32_t set_global_rank_id();
|
||||
uint32_t set_global_rank_size();
|
||||
void set_global_rank_id(uint32_t global_rank_id);
|
||||
void set_global_rank_size(uint32_t global_rank_size);
|
||||
|
||||
uint32_t local_rank_id() const;
|
||||
|
||||
private:
|
||||
CollectiveManager();
|
||||
|
@ -73,14 +75,17 @@ class CollectiveManager {
|
|||
bool InitHostCommlib();
|
||||
|
||||
// Initialize communication library on device side.
|
||||
bool InitDeviceCommLib(const std::string &backend);
|
||||
bool InitDeviceCommLib();
|
||||
|
||||
// Assign the local rank id for this process.
|
||||
bool AssignLocalRank(const std::string &global_group_name);
|
||||
bool AssignLocalRank();
|
||||
|
||||
std::atomic_bool inited_;
|
||||
std::atomic_bool finalized_;
|
||||
|
||||
// The device type read from MindSpore context.
|
||||
std::string device_type_;
|
||||
|
||||
// The device context on both host and device side. They are used to access the communication library on different
|
||||
// devices.
|
||||
DeviceContext *host_ctx_;
|
||||
|
@ -88,12 +93,10 @@ class CollectiveManager {
|
|||
|
||||
// Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
|
||||
// framework.
|
||||
void *host_comm_lib_;
|
||||
CollectiveCommunicationLib *host_comm_lib_instance_;
|
||||
|
||||
// Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
|
||||
// When only CPU backend is used, device communication library should not be initialized.
|
||||
void *device_comm_lib_;
|
||||
CollectiveCommunicationLib *device_comm_lib_instance_;
|
||||
|
||||
// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
|
||||
|
@ -107,6 +110,10 @@ class CollectiveManager {
|
|||
|
||||
// Global group ranks.
|
||||
std::vector<uint32_t> global_group_ranks_;
|
||||
|
||||
// The global group name on the host side. This is used for Creating global group on host side for AllGather operation
|
||||
// of host name while assigning local rank.
|
||||
std::string host_global_group_name_;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace distributed
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -20,16 +20,29 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
bool Initialize(const std::string &backend, const std::string &global_group_name) {
|
||||
bool Initialize() {
|
||||
if (!InitializeCluster()) {
|
||||
MS_LOG(ERROR) << "Failed to initialize cluster.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!InitializeCollective(backend, global_group_name)) {
|
||||
MS_LOG(ERROR) << "Failed to initialize collective communication.";
|
||||
return false;
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
// Server and Scheduler don't use collective communication library.
|
||||
auto node = cluster::ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->role() != ps::core::NodeRole::SERVER && node->role() != ps::core::NodeRole::SCHEDULER) {
|
||||
// Global rank id and size should be manually set if cluster is initialized by MindSpore communication framework.
|
||||
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node());
|
||||
MS_EXCEPTION_IF_NULL(abstract_node);
|
||||
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
|
||||
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
|
||||
|
||||
if (!InitializeCollective()) {
|
||||
MS_LOG(ERROR) << "Failed to initialize collective communication.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -51,9 +64,7 @@ bool InitializeCluster() { return cluster::ClusterContext::instance()->Initializ
|
|||
|
||||
bool FinalizeCluster() { return cluster::ClusterContext::instance()->Finalize(); }
|
||||
|
||||
bool InitializeCollective(const std::string &backend, const std::string &global_group_name) {
|
||||
return collective::CollectiveManager::instance()->Initialize(backend, global_group_name);
|
||||
}
|
||||
bool InitializeCollective() { return collective::CollectiveManager::instance()->Initialize(); }
|
||||
|
||||
bool FinalizeCollective() { return collective::CollectiveManager::instance()->Finalize(); }
|
||||
} // namespace distributed
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace distributed {
|
|||
// The static methods of MindSpore distributed execution. They can be exported by Pybind.
|
||||
|
||||
// Initialize and finalize distributed execution.
|
||||
bool Initialize(const std::string &backend, const std::string &global_group_name);
|
||||
bool Initialize();
|
||||
bool Finalize();
|
||||
|
||||
// Initialize and finalize the cluster based on MindSpore communication framework.
|
||||
|
@ -39,7 +39,7 @@ bool InitializeCluster();
|
|||
bool FinalizeCluster();
|
||||
|
||||
// Initialize and finalize collective communication for distributed execution.
|
||||
bool InitializeCollective(const std::string &backend, const std::string &global_group_name);
|
||||
bool InitializeCollective();
|
||||
bool FinalizeCollective();
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -410,7 +410,8 @@ void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
|
|||
|
||||
core::ClusterConfig &PSContext::cluster_config() {
|
||||
if (cluster_config_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The cluster config is empty.";
|
||||
cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
|
||||
MS_EXCEPTION_IF_NULL(cluster_config_);
|
||||
}
|
||||
return *cluster_config_;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "distributed/init.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -30,23 +32,32 @@ bool CollectiveInitializer::collective_inited() const { return collective_inited
|
|||
const void *CollectiveInitializer::collective_handle() { return collective_handle_; }
|
||||
|
||||
void CollectiveInitializer::InitCollective() {
|
||||
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n"
|
||||
"1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built "
|
||||
"with distributed feature.\n"
|
||||
"2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore "
|
||||
"requires NCCL-2.7.6.\n"
|
||||
"3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore "
|
||||
"requires OpenMPI-4.0.3.\n";
|
||||
}
|
||||
auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI"));
|
||||
MS_EXCEPTION_IF_NULL(mpi_init_funcptr);
|
||||
(*mpi_init_funcptr)();
|
||||
if (common::CheckUseMPI()) {
|
||||
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n"
|
||||
"1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built "
|
||||
"with distributed feature.\n"
|
||||
"2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore "
|
||||
"requires NCCL-2.7.6.\n"
|
||||
"3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore "
|
||||
"requires OpenMPI-4.0.3.\n";
|
||||
}
|
||||
auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI"));
|
||||
MS_EXCEPTION_IF_NULL(mpi_init_funcptr);
|
||||
(*mpi_init_funcptr)();
|
||||
|
||||
CollectiveInitializer::instance().collective_inited_ = true;
|
||||
CollectiveInitializer::instance().collective_handle_ = handle;
|
||||
// Because this method InitCollective is static, the non-static member variables should be accessed by
|
||||
// CollectiveInitializer::instance().
|
||||
CollectiveInitializer::instance().use_mpi_ = true;
|
||||
CollectiveInitializer::instance().collective_inited_ = true;
|
||||
CollectiveInitializer::instance().collective_handle_ = handle;
|
||||
} else {
|
||||
if (!distributed::Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CollectiveInitializer::FinalizeCollective() {
|
||||
|
@ -56,6 +67,69 @@ void CollectiveInitializer::FinalizeCollective() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::local_rank_id() {
|
||||
uint32_t local_rank_id;
|
||||
if (common::CheckUseMPI()) {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
local_rank_id = IntToUint((*get_local_rank_funcptr)());
|
||||
} else {
|
||||
local_rank_id = distributed::collective::CollectiveManager::instance()->local_rank_id();
|
||||
}
|
||||
return local_rank_id;
|
||||
}
|
||||
|
||||
bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto create_comm_group_funcptr =
|
||||
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
|
||||
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
|
||||
return (*create_comm_group_funcptr)(group_name, group_ranks);
|
||||
}
|
||||
}
|
||||
|
||||
bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto destroy_group_funcptr =
|
||||
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
|
||||
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
|
||||
return (*destroy_group_funcptr)(group_name);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto get_rank_id_funcptr =
|
||||
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
|
||||
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
|
||||
return IntToUint((*get_rank_id_funcptr)(group_name));
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
return IntToUint((*get_group_size_funcptr)(group_name));
|
||||
}
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,12 +42,20 @@ class CollectiveInitializer {
|
|||
static void InitCollective();
|
||||
static void FinalizeCollective();
|
||||
|
||||
// The capsulation of the collective communication APIs for compatibility.
|
||||
uint32_t local_rank_id();
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);
|
||||
bool DestroyCommunicationGroup(const std::string &group_name);
|
||||
uint32_t GetRankIDByGroup(const std::string &group_name);
|
||||
uint32_t GetGroupSize(const std::string &group_name);
|
||||
|
||||
private:
|
||||
CollectiveInitializer() : collective_inited_(false) {}
|
||||
CollectiveInitializer() : use_mpi_(false), collective_inited_(false), collective_handle_(nullptr) {}
|
||||
~CollectiveInitializer() = default;
|
||||
|
||||
bool use_mpi_;
|
||||
bool collective_inited_;
|
||||
void *collective_handle_{nullptr};
|
||||
void *collective_handle_;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
|
|
|
@ -23,6 +23,9 @@ endif()
|
|||
if(ENABLE_CPU)
|
||||
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
|
||||
if(WIN32)
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/ms_collective_comm_lib.cc")
|
||||
endif()
|
||||
if(ENABLE_MPI)
|
||||
set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc"
|
||||
"cpu/mpi_communication_group.cc"
|
||||
|
|
|
@ -60,6 +60,8 @@ CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &gr
|
|||
return groups_[group_name];
|
||||
}
|
||||
|
||||
const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; }
|
||||
|
||||
uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }
|
||||
|
||||
uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }
|
||||
|
|
|
@ -77,6 +77,10 @@ class CollectiveCommunicationLib {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Returns the global group name of this collective communication library. For NCCL, it's 'nccl_world_group'. For
|
||||
// HCCL, it's 'hccl_world_group'.
|
||||
const std::string &global_group_name() const;
|
||||
|
||||
// Returns global rank id of this process.
|
||||
uint32_t global_rank_id() const;
|
||||
|
||||
|
@ -90,6 +94,9 @@ class CollectiveCommunicationLib {
|
|||
// Whether this collective communication library is initialized.
|
||||
bool initialized_;
|
||||
|
||||
// The global group name.
|
||||
std::string global_group_name_;
|
||||
|
||||
// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
|
||||
uint32_t global_rank_id_;
|
||||
|
||||
|
|
|
@ -47,5 +47,11 @@ uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
|
|||
}
|
||||
|
||||
uint32_t CommunicationGroup::group_size() const { return size_; }
|
||||
|
||||
const std::vector<uint32_t> &CommunicationGroup::group_ranks() const { return group_ranks_; }
|
||||
|
||||
const std::map<uint32_t, uint32_t> &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; }
|
||||
|
||||
const std::map<uint32_t, uint32_t> &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; }
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,11 @@ class CommunicationGroup {
|
|||
// Return the size of this communication group.
|
||||
uint32_t group_size() const;
|
||||
|
||||
// Return group ranks info.
|
||||
const std::vector<uint32_t> &group_ranks() const;
|
||||
const std::map<uint32_t, uint32_t> &global_to_group_ranks() const;
|
||||
const std::map<uint32_t, uint32_t> &group_to_global_ranks() const;
|
||||
|
||||
protected:
|
||||
// Whether this communication group is initialized.
|
||||
bool initialized_;
|
||||
|
|
|
@ -34,6 +34,9 @@
|
|||
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "profiler/device/cpu/cpu_profiling.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "runtime/hardware/cpu/ms_collective_comm_lib.h"
|
||||
#endif
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#endif
|
||||
|
@ -296,6 +299,32 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
return DoLaunchKernel(kernel_mod, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool CPUDeviceContext::LoadCollectiveCommLib() {
|
||||
bool using_mpi = common::CheckUseMPI();
|
||||
if (using_mpi) {
|
||||
std::string mpi_comm_lib_name = "libmpi_collective.so";
|
||||
auto loader = std::make_shared<CollectiveCommLibLoader>(mpi_comm_lib_name);
|
||||
MS_EXCEPTION_IF_NULL(loader);
|
||||
if (!loader->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to load mpi collective library.";
|
||||
return false;
|
||||
}
|
||||
|
||||
void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
|
||||
|
||||
auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
|
||||
collective_comm_lib_ = instance_func();
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
|
||||
} else {
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
collective_comm_lib_ = &MsCollectiveCommLib::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
|
||||
#endif
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
|
|
|
@ -57,6 +57,8 @@ class CPUDeviceContext : public DeviceContext {
|
|||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
bool is_dynamic_shape = false) const override;
|
||||
|
||||
bool LoadCollectiveCommLib() override;
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
MPICollectiveCommLib::MPICollectiveCommLib() { global_group_name_ = kMPIGlobalGroupName; }
|
||||
|
||||
bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) {
|
||||
if (initialized_) {
|
||||
return false;
|
||||
|
@ -56,49 +58,7 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam
|
|||
}
|
||||
} // namespace cpu
|
||||
|
||||
// The exported APIs for 'dlsym' to load.
|
||||
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;
|
||||
|
||||
CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); }
|
||||
|
||||
bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); }
|
||||
|
||||
bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
|
||||
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
|
||||
return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
|
||||
}
|
||||
|
||||
bool DestroyCommunicationGroup(const std::string &group_name) {
|
||||
return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
|
||||
}
|
||||
|
||||
uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); }
|
||||
|
||||
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
|
||||
return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name);
|
||||
}
|
||||
|
||||
bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }
|
||||
|
||||
CommunicationGroupPtr GetGroup(const std::string &group_name) {
|
||||
return MPICollectiveCommLib::GetInstance().GetGroup(group_name);
|
||||
}
|
||||
|
||||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream) {
|
||||
return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream);
|
||||
}
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
|
||||
uint32_t root_rank, const std::string &group_name, void *stream) {
|
||||
return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
|
||||
group_name, stream);
|
||||
}
|
||||
|
||||
uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); }
|
||||
|
||||
uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }
|
||||
|
||||
uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); }
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,10 +23,15 @@
|
|||
#include "runtime/hardware/collective/collective_communication_lib.h"
|
||||
#include "runtime/hardware/cpu/mpi_communication_group.h"
|
||||
|
||||
#ifndef EXPORT_MPI_WRAPPER
|
||||
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
class MPICollectiveCommLib : public CollectiveCommunicationLib {
|
||||
constexpr char kMPIGlobalGroupName[] = "mpi_world_group";
|
||||
class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib {
|
||||
public:
|
||||
static MPICollectiveCommLib &GetInstance() {
|
||||
static MPICollectiveCommLib instance;
|
||||
|
@ -49,35 +54,14 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
|
|||
}
|
||||
|
||||
private:
|
||||
MPICollectiveCommLib() = default;
|
||||
MPICollectiveCommLib();
|
||||
~MPICollectiveCommLib() override = default;
|
||||
|
||||
MPI_Group world_group_;
|
||||
};
|
||||
} // namespace cpu
|
||||
|
||||
#ifndef EXPORT_MPI_WRAPPER
|
||||
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
|
||||
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
|
||||
uint32_t global_rank_size = UINT32_MAX);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
|
||||
extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
|
||||
extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id();
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
|
||||
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size();
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
MsCollectiveCommLib::MsCollectiveCommLib() { global_group_name_ = kMSGlobalGroupName; }
|
||||
|
||||
bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
|
||||
if (initialized_) {
|
||||
return false;
|
||||
|
@ -30,8 +32,6 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MsCollectiveCommLib::Finalize() { return true; }
|
||||
|
||||
bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks) {
|
||||
CHECK_RET((groups_.count(group_name) == 0), true, "The group " + group_name + " has already existed.");
|
||||
|
|
|
@ -22,10 +22,17 @@
|
|||
#include <string>
|
||||
#include "runtime/hardware/collective/collective_communication_lib.h"
|
||||
#include "runtime/hardware/cpu/ms_communication_group.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "fl/server/collective_ops_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
constexpr char kMSGlobalGroupName[] = "ms_world_group";
|
||||
using ClusterContext = mindspore::distributed::cluster::ClusterContext;
|
||||
using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl;
|
||||
using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
|
||||
|
||||
// The collective communication library for MindSpore self developed communication framework.
|
||||
class MsCollectiveCommLib : public CollectiveCommunicationLib {
|
||||
public:
|
||||
|
@ -35,12 +42,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {
|
|||
}
|
||||
|
||||
bool Initialize(uint32_t global_rank = UINT32_MAX, uint32_t global_rank_size = UINT32_MAX) override;
|
||||
bool Finalize() override;
|
||||
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
|
||||
|
||||
private:
|
||||
MsCollectiveCommLib() {}
|
||||
MsCollectiveCommLib();
|
||||
~MsCollectiveCommLib() override = default;
|
||||
};
|
||||
} // namespace cpu
|
||||
|
|
|
@ -50,7 +50,7 @@ struct DeviceContextKey {
|
|||
class DeviceContext {
|
||||
public:
|
||||
explicit DeviceContext(const DeviceContextKey &device_context_key)
|
||||
: device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {}
|
||||
: device_context_key_(device_context_key), collective_comm_lib_(nullptr) {}
|
||||
virtual ~DeviceContext() = default;
|
||||
|
||||
// Initialize the device context.
|
||||
|
@ -150,7 +150,7 @@ class DeviceContext {
|
|||
virtual bool LoadCollectiveCommLib() { return true; }
|
||||
|
||||
// Return collective communication object for caller to access
|
||||
void *collective_comm_lib() const { return collective_comm_lib_ptr_; }
|
||||
CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
|
||||
|
||||
// TODO(jiaorui): will be delete
|
||||
// Dump all graphs.
|
||||
|
@ -159,8 +159,8 @@ class DeviceContext {
|
|||
protected:
|
||||
DeviceContextKey device_context_key_;
|
||||
|
||||
// The dynamically loaded handle for collective communication library by 'dlopen'.
|
||||
void *collective_comm_lib_ptr_;
|
||||
// The collective communication library.
|
||||
CollectiveCommunicationLib *collective_comm_lib_;
|
||||
};
|
||||
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
|
||||
} // namespace device
|
||||
|
|
|
@ -534,8 +534,12 @@ bool GPUDeviceContext::LoadCollectiveCommLib() {
|
|||
MS_LOG(EXCEPTION) << "Loading NCCL collective library failed.";
|
||||
return false;
|
||||
}
|
||||
collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr();
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
|
||||
void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
|
||||
|
||||
auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
|
||||
collective_comm_lib_ = instance_func();
|
||||
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() { global_group_name_ = kNCCLGlobalGroupName; }
|
||||
|
||||
bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
|
||||
if (initialized_) {
|
||||
return false;
|
||||
|
@ -42,50 +44,7 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_
|
|||
}
|
||||
} // namespace gpu
|
||||
|
||||
// The exported APIs for 'dlsym' to load.
|
||||
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
|
||||
CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); }
|
||||
|
||||
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
|
||||
}
|
||||
|
||||
bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); }
|
||||
|
||||
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
|
||||
}
|
||||
|
||||
bool DestroyCommunicationGroup(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
|
||||
}
|
||||
|
||||
uint32_t GetRankId(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name);
|
||||
}
|
||||
|
||||
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name);
|
||||
}
|
||||
|
||||
bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }
|
||||
|
||||
CommunicationGroupPtr GetGroup(const std::string &group_name) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name);
|
||||
}
|
||||
|
||||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
|
||||
const std::string &group_name, void *stream) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name,
|
||||
stream);
|
||||
}
|
||||
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
|
||||
uint32_t root_rank, const std::string &group_name, void *stream) {
|
||||
return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
|
||||
group_name, stream);
|
||||
}
|
||||
|
||||
uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,11 +24,15 @@
|
|||
#include "runtime/hardware/collective/collective_communication_lib.h"
|
||||
#include "runtime/hardware/gpu/nvidia_communication_group.h"
|
||||
|
||||
#ifndef EXPORT_NCCL_WRAPPER
|
||||
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
|
||||
class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
|
||||
constexpr char kNCCLGlobalGroupName[] = "nccl_world_group";
|
||||
class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
|
||||
public:
|
||||
static NvidiaCollectiveCommLib &GetInstance() {
|
||||
static NvidiaCollectiveCommLib instance;
|
||||
|
@ -50,31 +54,12 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
|
|||
}
|
||||
|
||||
private:
|
||||
NvidiaCollectiveCommLib() = default;
|
||||
NvidiaCollectiveCommLib();
|
||||
~NvidiaCollectiveCommLib() override = default;
|
||||
};
|
||||
} // namespace gpu
|
||||
|
||||
#ifndef EXPORT_NCCL_WRAPPER
|
||||
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
|
||||
#endif
|
||||
extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
|
||||
uint32_t global_rank_size = UINT32_MAX);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
|
||||
extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
|
||||
mindspore::TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream);
|
||||
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_
|
||||
|
|
|
@ -90,6 +90,16 @@ static inline bool IsLittleByteOrder() {
|
|||
return false;
|
||||
}
|
||||
|
||||
static inline bool CheckUseMPI() {
|
||||
// If these OpenMPI environment variables are set, we consider this process is launched by OpenMPI.
|
||||
std::string ompi_command_env = GetEnv("OMPI_COMMAND");
|
||||
std::string pmix_rank_env = GetEnv("PMIX_RANK");
|
||||
if (!ompi_command_env.empty() && !pmix_rank_env.empty()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
|
||||
if (a == b) {
|
||||
|
|
Loading…
Reference in New Issue