1.Purge not used API.

2.Adapt for collective_init.h
This commit is contained in:
ZPaC 2021-11-27 10:57:36 +08:00
parent 7fb6df8384
commit 2b7429c5d2
28 changed files with 281 additions and 214 deletions

View File

@ -85,7 +85,7 @@ bool ClusterContext::Finalize() {
return true; return true;
} }
std::string ClusterContext::node_role() const { return node_role_; } const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
void ClusterContext::InitClusterConfig() { void ClusterContext::InitClusterConfig() {
InitNodeRole(); InitNodeRole();

View File

@ -49,7 +49,8 @@ class ClusterContext {
// Finalize the cluster and process exits. // Finalize the cluster and process exits.
bool Finalize(); bool Finalize();
std::string node_role() const; // Return node object of this process.
const std::shared_ptr<ps::core::Node> &node() const;
private: private:
ClusterContext(); ClusterContext();

View File

@ -32,8 +32,6 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
bool ClusterContext::Initialize() const { return true; } bool ClusterContext::Initialize() const { return true; }
bool ClusterContext::Finalize() const { return true; } bool ClusterContext::Finalize() const { return true; }
std::string ClusterContext::node_role() const { return ""; }
} // namespace cluster } // namespace cluster
} // namespace distributed } // namespace distributed
} // namespace mindspore } // namespace mindspore

View File

@ -39,7 +39,6 @@ class ClusterContext {
bool Initialize() const; bool Initialize() const;
bool Finalize() const; bool Finalize() const;
std::string node_role() const;
private: private:
ClusterContext() = default; ClusterContext() = default;

View File

@ -18,6 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
namespace distributed { namespace distributed {
@ -27,9 +28,7 @@ CollectiveManager::CollectiveManager()
finalized_(true), finalized_(true),
host_ctx_(nullptr), host_ctx_(nullptr),
device_ctx_(nullptr), device_ctx_(nullptr),
host_comm_lib_(nullptr),
host_comm_lib_instance_(nullptr), host_comm_lib_instance_(nullptr),
device_comm_lib_(nullptr),
device_comm_lib_instance_(nullptr), device_comm_lib_instance_(nullptr),
global_rank_id_(0), global_rank_id_(0),
local_rank_id_(0), local_rank_id_(0),
@ -51,11 +50,13 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
return instance; return instance;
} }
bool CollectiveManager::Initialize(const std::string &backend, const std::string &global_group_name) { bool CollectiveManager::Initialize() {
if (inited_) { if (inited_) {
return true; return true;
} }
MS_LOG(INFO) << "Start initializing collective communication for backend: " << backend << "...";
device_type_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type_ << "...";
// Step 1: Initialize host side collective communication. // Step 1: Initialize host side collective communication.
if (!InitHostCommlib()) { if (!InitHostCommlib()) {
@ -66,24 +67,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
// Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not // Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not
// be necessary. // be necessary.
// Step 2: Assign local rank id(device id) for this process. // Step 2: Assign local rank id(device id) for this process.
if (!AssignLocalRank(global_group_name)) { if (!AssignLocalRank()) {
MS_LOG(ERROR) << "Failed to assign local rank id."; MS_LOG(ERROR) << "Failed to assign local rank id.";
return false; return false;
} }
// Step 3: Initialize device side collective communication. // Step 3: Initialize device side collective communication.
if (!InitDeviceCommLib(backend)) { if (!InitDeviceCommLib()) {
MS_LOG(ERROR) << "Failed to initialize device communication library."; MS_LOG(ERROR) << "Failed to initialize device communication library.";
return false; return false;
} }
// Step 4: Create global communication group. // Step 4: Create global communication group.
if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) { MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!CreateCommunicationGroup(device_comm_lib_instance_->global_group_name(), global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to initialize host communication library."; MS_LOG(ERROR) << "Failed to initialize host communication library.";
return false; return false;
} }
MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << "."; MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
return true; return true;
} }
@ -91,7 +93,7 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) { const std::vector<uint32_t> &group_ranks) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
// Step 1: Create communication group on host side. // Step 1: Create communication group on host side if.
if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) { if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side."; MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side.";
return false; return false;
@ -167,6 +169,12 @@ bool CollectiveManager::Finalize() {
return true; return true;
} }
void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; }
bool CollectiveManager::InitHostCommlib() { bool CollectiveManager::InitHostCommlib() {
device::DeviceContextKey host_key = {"CPU", 0}; device::DeviceContextKey host_key = {"CPU", 0};
host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key); host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
@ -175,10 +183,7 @@ bool CollectiveManager::InitHostCommlib() {
MS_LOG(ERROR) << "Failed to load communication library on the host side."; MS_LOG(ERROR) << "Failed to load communication library on the host side.";
return false; return false;
} }
host_comm_lib_ = host_ctx_->collective_comm_lib(); host_comm_lib_instance_ = host_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_);
host_comm_lib_instance_ = instance_func();
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using // For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
@ -197,33 +202,30 @@ bool CollectiveManager::InitHostCommlib() {
global_group_ranks_.push_back(i); global_group_ranks_.push_back(i);
} }
// Create world group on host side for AllGather operation of host name while assigning local rank.
host_global_group_name_ = host_comm_lib_instance_->global_group_name();
if (!host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to create communication group " << host_global_group_name_ << " on host side.";
return false;
}
MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_ MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
<< ", global rank size: " << global_rank_size_; << ", global rank size: " << global_rank_size_;
return true; return true;
} }
bool CollectiveManager::InitDeviceCommLib(const std::string &backend) { bool CollectiveManager::InitDeviceCommLib() {
std::string device_name; device::DeviceContextKey device_key = {device_type_, local_rank_id_};
if (backend == "nccl") {
device_name = "GPU";
} else if (backend == "hccl") {
device_name = "Ascend";
} else {
MS_LOG(ERROR) << "Backend " << backend << " is not supported.";
return false;
}
device::DeviceContextKey device_key = {device_name, local_rank_id_};
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key); device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
MS_EXCEPTION_IF_NULL(device_ctx_); MS_EXCEPTION_IF_NULL(device_ctx_);
// We can initialize device context now because device id(local_rank_id_) is already assigned.
device_ctx_->Initialize();
if (!device_ctx_->LoadCollectiveCommLib()) { if (!device_ctx_->LoadCollectiveCommLib()) {
MS_LOG(ERROR) << "Failed to load communication library on the device side."; MS_LOG(ERROR) << "Failed to load communication library on the device side.";
return false; return false;
} }
device_comm_lib_ = device_ctx_->collective_comm_lib(); device_comm_lib_instance_ = device_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(device_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_);
device_comm_lib_instance_ = instance_func();
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
MS_LOG(INFO) << "Start initializing communication library on device side..."; MS_LOG(INFO) << "Start initializing communication library on device side...";
@ -235,7 +237,7 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
return true; return true;
} }
bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { bool CollectiveManager::AssignLocalRank() {
char host_name[MAX_HOSTNAME_LEN] = {0}; char host_name[MAX_HOSTNAME_LEN] = {0};
#ifndef _WIN32 #ifndef _WIN32
if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) { if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
@ -259,8 +261,8 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// AllGather host names across the global communication group. // AllGather host names across the global communication group.
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt, if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeInt,
global_group_name, nullptr)) { host_global_group_name_, nullptr)) {
MS_LOG(ERROR) << "AllGather for host names failed."; MS_LOG(ERROR) << "AllGather for host names failed.";
return false; return false;
} }
@ -274,7 +276,10 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
local_rank_id_++; local_rank_id_++;
} }
} }
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_;
MsContext::GetInstance()->set_param<uint32_t>(MS_CTX_DEVICE_ID, local_rank_id_);
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_
<< ". device_id of ms_context is set.";
return true; return true;
} }
} // namespace collective } // namespace collective

View File

@ -43,8 +43,8 @@ class CollectiveManager {
DISABLE_COPY_AND_ASSIGN(CollectiveManager); DISABLE_COPY_AND_ASSIGN(CollectiveManager);
static std::shared_ptr<CollectiveManager> instance(); static std::shared_ptr<CollectiveManager> instance();
// Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL. // Initialize the collective communication for distributed training. The backend type is read from MindSpore context.
bool Initialize(const std::string &backend, const std::string &global_group_name); bool Initialize();
// Finalize the collective communication. // Finalize the collective communication.
bool Finalize(); bool Finalize();
@ -63,8 +63,10 @@ class CollectiveManager {
// In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication // In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
// framework, they're generated by cluster::ClusterContext. // framework, they're generated by cluster::ClusterContext.
uint32_t set_global_rank_id(); void set_global_rank_id(uint32_t global_rank_id);
uint32_t set_global_rank_size(); void set_global_rank_size(uint32_t global_rank_size);
uint32_t local_rank_id() const;
private: private:
CollectiveManager(); CollectiveManager();
@ -73,14 +75,17 @@ class CollectiveManager {
bool InitHostCommlib(); bool InitHostCommlib();
// Initialize communication library on device side. // Initialize communication library on device side.
bool InitDeviceCommLib(const std::string &backend); bool InitDeviceCommLib();
// Assign the local rank id for this process. // Assign the local rank id for this process.
bool AssignLocalRank(const std::string &global_group_name); bool AssignLocalRank();
std::atomic_bool inited_; std::atomic_bool inited_;
std::atomic_bool finalized_; std::atomic_bool finalized_;
// The device type read from MindSpore context.
std::string device_type_;
// The device context on both host and device side. They are used to access the communication library on different // The device context on both host and device side. They are used to access the communication library on different
// devices. // devices.
DeviceContext *host_ctx_; DeviceContext *host_ctx_;
@ -88,12 +93,10 @@ class CollectiveManager {
// Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication // Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
// framework. // framework.
void *host_comm_lib_;
CollectiveCommunicationLib *host_comm_lib_instance_; CollectiveCommunicationLib *host_comm_lib_instance_;
// Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL. // Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
// When only CPU backend is used, device communication library should not be initialized. // When only CPU backend is used, device communication library should not be initialized.
void *device_comm_lib_;
CollectiveCommunicationLib *device_comm_lib_instance_; CollectiveCommunicationLib *device_comm_lib_instance_;
// The global rank id of this process. Normally this range is 0 to `total process number - 1`. // The global rank id of this process. Normally this range is 0 to `total process number - 1`.
@ -107,6 +110,10 @@ class CollectiveManager {
// Global group ranks. // Global group ranks.
std::vector<uint32_t> global_group_ranks_; std::vector<uint32_t> global_group_ranks_;
// The global group name on the host side. This is used for Creating global group on host side for AllGather operation
// of host name while assigning local rank.
std::string host_global_group_name_;
}; };
} // namespace collective } // namespace collective
} // namespace distributed } // namespace distributed

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_ #define MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_
#include <set> #include <set>
#include <map>
#include <string> #include <string>
namespace mindspore { namespace mindspore {

View File

@ -20,16 +20,29 @@
namespace mindspore { namespace mindspore {
namespace distributed { namespace distributed {
bool Initialize(const std::string &backend, const std::string &global_group_name) { bool Initialize() {
if (!InitializeCluster()) { if (!InitializeCluster()) {
MS_LOG(ERROR) << "Failed to initialize cluster."; MS_LOG(ERROR) << "Failed to initialize cluster.";
return false; return false;
} }
if (!InitializeCollective(backend, global_group_name)) { #if ((defined ENABLE_CPU) && (!defined _WIN32))
MS_LOG(ERROR) << "Failed to initialize collective communication."; // Server and Scheduler don't use collective communication library.
return false; auto node = cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
if (node->role() != ps::core::NodeRole::SERVER && node->role() != ps::core::NodeRole::SCHEDULER) {
// Global rank id and size should be manually set if cluster is initialized by MindSpore communication framework.
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
if (!InitializeCollective()) {
MS_LOG(ERROR) << "Failed to initialize collective communication.";
return false;
}
} }
#endif
return true; return true;
} }
@ -51,9 +64,7 @@ bool InitializeCluster() { return cluster::ClusterContext::instance()->Initializ
bool FinalizeCluster() { return cluster::ClusterContext::instance()->Finalize(); } bool FinalizeCluster() { return cluster::ClusterContext::instance()->Finalize(); }
bool InitializeCollective(const std::string &backend, const std::string &global_group_name) { bool InitializeCollective() { return collective::CollectiveManager::instance()->Initialize(); }
return collective::CollectiveManager::instance()->Initialize(backend, global_group_name);
}
bool FinalizeCollective() { return collective::CollectiveManager::instance()->Finalize(); } bool FinalizeCollective() { return collective::CollectiveManager::instance()->Finalize(); }
} // namespace distributed } // namespace distributed

View File

@ -31,7 +31,7 @@ namespace distributed {
// The static methods of MindSpore distributed execution. They can be exported by Pybind. // The static methods of MindSpore distributed execution. They can be exported by Pybind.
// Initialize and finalize distributed execution. // Initialize and finalize distributed execution.
bool Initialize(const std::string &backend, const std::string &global_group_name); bool Initialize();
bool Finalize(); bool Finalize();
// Initialize and finalize the cluster based on MindSpore communication framework. // Initialize and finalize the cluster based on MindSpore communication framework.
@ -39,7 +39,7 @@ bool InitializeCluster();
bool FinalizeCluster(); bool FinalizeCluster();
// Initialize and finalize collective communication for distributed execution. // Initialize and finalize collective communication for distributed execution.
bool InitializeCollective(const std::string &backend, const std::string &global_group_name); bool InitializeCollective();
bool FinalizeCollective(); bool FinalizeCollective();
} // namespace distributed } // namespace distributed
} // namespace mindspore } // namespace mindspore

View File

@ -410,7 +410,8 @@ void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
core::ClusterConfig &PSContext::cluster_config() { core::ClusterConfig &PSContext::cluster_config() {
if (cluster_config_ == nullptr) { if (cluster_config_ == nullptr) {
MS_LOG(EXCEPTION) << "The cluster config is empty."; cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
MS_EXCEPTION_IF_NULL(cluster_config_);
} }
return *cluster_config_; return *cluster_config_;
} }

View File

@ -16,6 +16,8 @@
#include "runtime/device/gpu/distribution/collective_init.h" #include "runtime/device/gpu/distribution/collective_init.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "distributed/init.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
@ -30,23 +32,32 @@ bool CollectiveInitializer::collective_inited() const { return collective_inited
const void *CollectiveInitializer::collective_handle() { return collective_handle_; } const void *CollectiveInitializer::collective_handle() { return collective_handle_; }
void CollectiveInitializer::InitCollective() { void CollectiveInitializer::InitCollective() {
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); if (common::CheckUseMPI()) {
if (handle == nullptr) { void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
MS_LOG(EXCEPTION) if (handle == nullptr) {
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n" MS_LOG(EXCEPTION)
"1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built " << "Loading libgpu_collective.so failed. Many reasons could cause this:\n"
"with distributed feature.\n" "1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built "
"2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore " "with distributed feature.\n"
"requires NCCL-2.7.6.\n" "2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore "
"3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore " "requires NCCL-2.7.6.\n"
"requires OpenMPI-4.0.3.\n"; "3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore "
} "requires OpenMPI-4.0.3.\n";
auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI")); }
MS_EXCEPTION_IF_NULL(mpi_init_funcptr); auto mpi_init_funcptr = reinterpret_cast<InitMPI>(dlsym(handle, "InitMPI"));
(*mpi_init_funcptr)(); MS_EXCEPTION_IF_NULL(mpi_init_funcptr);
(*mpi_init_funcptr)();
CollectiveInitializer::instance().collective_inited_ = true; // Because this method InitCollective is static, the non-static member variables should be accessed by
CollectiveInitializer::instance().collective_handle_ = handle; // CollectiveInitializer::instance().
CollectiveInitializer::instance().use_mpi_ = true;
CollectiveInitializer::instance().collective_inited_ = true;
CollectiveInitializer::instance().collective_handle_ = handle;
} else {
if (!distributed::Initialize()) {
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL.";
}
}
} }
void CollectiveInitializer::FinalizeCollective() { void CollectiveInitializer::FinalizeCollective() {
@ -56,6 +67,69 @@ void CollectiveInitializer::FinalizeCollective() {
} }
} }
} }
uint32_t CollectiveInitializer::local_rank_id() {
uint32_t local_rank_id;
if (common::CheckUseMPI()) {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
local_rank_id = IntToUint((*get_local_rank_funcptr)());
} else {
local_rank_id = distributed::collective::CollectiveManager::instance()->local_rank_id();
}
return local_rank_id;
}
bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto create_comm_group_funcptr =
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
return (*create_comm_group_funcptr)(group_name, group_ranks);
}
}
bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto destroy_group_funcptr =
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
return (*destroy_group_funcptr)(group_name);
}
}
uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_rank_id_funcptr =
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
return IntToUint((*get_rank_id_funcptr)(group_name));
}
}
uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_group_size_funcptr =
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
return IntToUint((*get_group_size_funcptr)(group_name));
}
}
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -42,12 +42,20 @@ class CollectiveInitializer {
static void InitCollective(); static void InitCollective();
static void FinalizeCollective(); static void FinalizeCollective();
// The capsulation of the collective communication APIs for compatibility.
uint32_t local_rank_id();
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);
bool DestroyCommunicationGroup(const std::string &group_name);
uint32_t GetRankIDByGroup(const std::string &group_name);
uint32_t GetGroupSize(const std::string &group_name);
private: private:
CollectiveInitializer() : collective_inited_(false) {} CollectiveInitializer() : use_mpi_(false), collective_inited_(false), collective_handle_(nullptr) {}
~CollectiveInitializer() = default; ~CollectiveInitializer() = default;
bool use_mpi_;
bool collective_inited_; bool collective_inited_;
void *collective_handle_{nullptr}; void *collective_handle_;
}; };
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device

View File

@ -23,6 +23,9 @@ endif()
if(ENABLE_CPU) if(ENABLE_CPU)
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc") list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
if(WIN32)
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/ms_collective_comm_lib.cc")
endif()
if(ENABLE_MPI) if(ENABLE_MPI)
set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc" set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc"
"cpu/mpi_communication_group.cc" "cpu/mpi_communication_group.cc"

View File

@ -60,6 +60,8 @@ CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &gr
return groups_[group_name]; return groups_[group_name];
} }
const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; }
uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; } uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }
uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; } uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }

View File

@ -77,6 +77,10 @@ class CollectiveCommunicationLib {
return true; return true;
} }
// Returns the global group name of this collective communication library. For NCCL, it's 'nccl_world_group'. For
// HCCL, it's 'hccl_world_group'.
const std::string &global_group_name() const;
// Returns global rank id of this process. // Returns global rank id of this process.
uint32_t global_rank_id() const; uint32_t global_rank_id() const;
@ -90,6 +94,9 @@ class CollectiveCommunicationLib {
// Whether this collective communication library is initialized. // Whether this collective communication library is initialized.
bool initialized_; bool initialized_;
// The global group name.
std::string global_group_name_;
// The global rank id of this process. Normally this range is 0 to `total process number - 1`. // The global rank id of this process. Normally this range is 0 to `total process number - 1`.
uint32_t global_rank_id_; uint32_t global_rank_id_;

View File

@ -47,5 +47,11 @@ uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
} }
uint32_t CommunicationGroup::group_size() const { return size_; } uint32_t CommunicationGroup::group_size() const { return size_; }
const std::vector<uint32_t> &CommunicationGroup::group_ranks() const { return group_ranks_; }
const std::map<uint32_t, uint32_t> &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; }
const std::map<uint32_t, uint32_t> &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; }
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -55,6 +55,11 @@ class CommunicationGroup {
// Return the size of this communication group. // Return the size of this communication group.
uint32_t group_size() const; uint32_t group_size() const;
// Return group ranks info.
const std::vector<uint32_t> &group_ranks() const;
const std::map<uint32_t, uint32_t> &global_to_group_ranks() const;
const std::map<uint32_t, uint32_t> &group_to_global_ranks() const;
protected: protected:
// Whether this communication group is initialized. // Whether this communication group is initialized.
bool initialized_; bool initialized_;

View File

@ -34,6 +34,9 @@
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "profiler/device/cpu/cpu_profiling.h" #include "profiler/device/cpu/cpu_profiling.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "runtime/hardware/cpu/ms_collective_comm_lib.h"
#endif
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h" #include "debug/data_dump/dump_json_parser.h"
#endif #endif
@ -296,6 +299,32 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
return DoLaunchKernel(kernel_mod, inputs, workspace, outputs); return DoLaunchKernel(kernel_mod, inputs, workspace, outputs);
} }
bool CPUDeviceContext::LoadCollectiveCommLib() {
bool using_mpi = common::CheckUseMPI();
if (using_mpi) {
std::string mpi_comm_lib_name = "libmpi_collective.so";
auto loader = std::make_shared<CollectiveCommLibLoader>(mpi_comm_lib_name);
MS_EXCEPTION_IF_NULL(loader);
if (!loader->Initialize()) {
MS_LOG(EXCEPTION) << "Failed to load mpi collective library.";
return false;
}
void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
collective_comm_lib_ = instance_func();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
} else {
#if ((defined ENABLE_CPU) && (!defined _WIN32))
collective_comm_lib_ = &MsCollectiveCommLib::GetInstance();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
#endif
}
return true;
}
bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs, bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const { const std::vector<AddressPtr> &outputs) const {

View File

@ -57,6 +57,8 @@ class CPUDeviceContext : public DeviceContext {
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs, const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
bool is_dynamic_shape = false) const override; bool is_dynamic_shape = false) const override;
bool LoadCollectiveCommLib() override;
private: private:
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext); DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);

View File

@ -19,6 +19,8 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace cpu { namespace cpu {
MPICollectiveCommLib::MPICollectiveCommLib() { global_group_name_ = kMPIGlobalGroupName; }
bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) { bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) {
if (initialized_) { if (initialized_) {
return false; return false;
@ -56,49 +58,7 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam
} }
} // namespace cpu } // namespace cpu
// The exported APIs for 'dlsym' to load.
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib; using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); } CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); }
bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); }
bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}
bool DestroyCommunicationGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}
uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); }
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name);
}
bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }
CommunicationGroupPtr GetGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroup(group_name);
}
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}
uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); }
uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }
uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); }
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -23,10 +23,15 @@
#include "runtime/hardware/collective/collective_communication_lib.h" #include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/cpu/mpi_communication_group.h" #include "runtime/hardware/cpu/mpi_communication_group.h"
#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace cpu { namespace cpu {
class MPICollectiveCommLib : public CollectiveCommunicationLib { constexpr char kMPIGlobalGroupName[] = "mpi_world_group";
class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib {
public: public:
static MPICollectiveCommLib &GetInstance() { static MPICollectiveCommLib &GetInstance() {
static MPICollectiveCommLib instance; static MPICollectiveCommLib instance;
@ -49,35 +54,14 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
} }
private: private:
MPICollectiveCommLib() = default; MPICollectiveCommLib();
~MPICollectiveCommLib() override = default; ~MPICollectiveCommLib() override = default;
MPI_Group world_group_; MPI_Group world_group_;
}; };
} // namespace cpu } // namespace cpu
#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size();
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_ #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_

View File

@ -19,6 +19,8 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace cpu { namespace cpu {
MsCollectiveCommLib::MsCollectiveCommLib() { global_group_name_ = kMSGlobalGroupName; }
bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) { bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
if (initialized_) { if (initialized_) {
return false; return false;
@ -30,8 +32,6 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_
return true; return true;
} }
bool MsCollectiveCommLib::Finalize() { return true; }
bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) { const std::vector<uint32_t> &group_ranks) {
CHECK_RET((groups_.count(group_name) == 0), true, "The group " + group_name + " has already existed."); CHECK_RET((groups_.count(group_name) == 0), true, "The group " + group_name + " has already existed.");

View File

@ -22,10 +22,17 @@
#include <string> #include <string>
#include "runtime/hardware/collective/collective_communication_lib.h" #include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/cpu/ms_communication_group.h" #include "runtime/hardware/cpu/ms_communication_group.h"
#include "distributed/cluster/cluster_context.h"
#include "fl/server/collective_ops_impl.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace cpu { namespace cpu {
constexpr char kMSGlobalGroupName[] = "ms_world_group";
using ClusterContext = mindspore::distributed::cluster::ClusterContext;
using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl;
using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
// The collective communication library for MindSpore self developed communication framework. // The collective communication library for MindSpore self developed communication framework.
class MsCollectiveCommLib : public CollectiveCommunicationLib { class MsCollectiveCommLib : public CollectiveCommunicationLib {
public: public:
@ -35,12 +42,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {
} }
bool Initialize(uint32_t global_rank = UINT32_MAX, uint32_t global_rank_size = UINT32_MAX) override; bool Initialize(uint32_t global_rank = UINT32_MAX, uint32_t global_rank_size = UINT32_MAX) override;
bool Finalize() override;
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override; bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
private: private:
MsCollectiveCommLib() {} MsCollectiveCommLib();
~MsCollectiveCommLib() override = default; ~MsCollectiveCommLib() override = default;
}; };
} // namespace cpu } // namespace cpu

View File

@ -50,7 +50,7 @@ struct DeviceContextKey {
class DeviceContext { class DeviceContext {
public: public:
explicit DeviceContext(const DeviceContextKey &device_context_key) explicit DeviceContext(const DeviceContextKey &device_context_key)
: device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {} : device_context_key_(device_context_key), collective_comm_lib_(nullptr) {}
virtual ~DeviceContext() = default; virtual ~DeviceContext() = default;
// Initialize the device context. // Initialize the device context.
@ -150,7 +150,7 @@ class DeviceContext {
virtual bool LoadCollectiveCommLib() { return true; } virtual bool LoadCollectiveCommLib() { return true; }
// Return collective communication object for caller to access // Return collective communication object for caller to access
void *collective_comm_lib() const { return collective_comm_lib_ptr_; } CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
// TODO(jiaorui): will be delete // TODO(jiaorui): will be delete
// Dump all graphs. // Dump all graphs.
@ -159,8 +159,8 @@ class DeviceContext {
protected: protected:
DeviceContextKey device_context_key_; DeviceContextKey device_context_key_;
// The dynamically loaded handle for collective communication library by 'dlopen'. // The collective communication library.
void *collective_comm_lib_ptr_; CollectiveCommunicationLib *collective_comm_lib_;
}; };
using DeviceContextPtr = std::shared_ptr<DeviceContext>; using DeviceContextPtr = std::shared_ptr<DeviceContext>;
} // namespace device } // namespace device

View File

@ -534,8 +534,12 @@ bool GPUDeviceContext::LoadCollectiveCommLib() {
MS_LOG(EXCEPTION) << "Loading NCCL collective library failed."; MS_LOG(EXCEPTION) << "Loading NCCL collective library failed.";
return false; return false;
} }
collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr(); void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_); MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
collective_comm_lib_ = instance_func();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
return true; return true;
#else #else
return false; return false;

View File

@ -19,6 +19,8 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() { global_group_name_ = kNCCLGlobalGroupName; }
bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) { bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
if (initialized_) { if (initialized_) {
return false; return false;
@ -42,50 +44,7 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_
} }
} // namespace gpu } // namespace gpu
// The exported APIs for 'dlsym' to load.
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib; using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); } CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); }
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
}
bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); }
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}
bool DestroyCommunicationGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}
uint32_t GetRankId(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name);
}
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name);
}
bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }
CommunicationGroupPtr GetGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name);
}
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name,
stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}
uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -24,11 +24,15 @@
#include "runtime/hardware/collective/collective_communication_lib.h" #include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/gpu/nvidia_communication_group.h" #include "runtime/hardware/gpu/nvidia_communication_group.h"
#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; constexpr char kNCCLGlobalGroupName[] = "nccl_world_group";
class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
public: public:
static NvidiaCollectiveCommLib &GetInstance() { static NvidiaCollectiveCommLib &GetInstance() {
static NvidiaCollectiveCommLib instance; static NvidiaCollectiveCommLib instance;
@ -50,31 +54,12 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
} }
private: private:
NvidiaCollectiveCommLib() = default; NvidiaCollectiveCommLib();
~NvidiaCollectiveCommLib() override = default; ~NvidiaCollectiveCommLib() override = default;
}; };
} // namespace gpu } // namespace gpu
#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_ #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_

View File

@ -90,6 +90,16 @@ static inline bool IsLittleByteOrder() {
return false; return false;
} }
static inline bool CheckUseMPI() {
// If these OpenMPI environment variables are set, we consider this process is launched by OpenMPI.
std::string ompi_command_env = GetEnv("OMPI_COMMAND");
std::string pmix_rank_env = GetEnv("PMIX_RANK");
if (!ompi_command_env.empty() && !pmix_rank_env.empty()) {
return true;
}
return false;
}
template <typename T> template <typename T>
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) { bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
if (a == b) { if (a == b) {