From 2b7429c5d2b0dbc2930b6ee046fbcc4592ddf7a6 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Sat, 27 Nov 2021 10:57:36 +0800 Subject: [PATCH] 1.Purge not used API. 2.Adapt for collective_init.h --- .../distributed/cluster/cluster_context.cc | 2 +- .../distributed/cluster/cluster_context.h | 3 +- .../cluster/dummy_cluster_context.cc | 2 - .../cluster/dummy_cluster_context.h | 1 - .../collective/collective_manager.cc | 71 ++++++------ .../collective/collective_manager.h | 23 ++-- mindspore/ccsrc/distributed/constants.h | 1 + mindspore/ccsrc/distributed/init.cc | 25 +++-- mindspore/ccsrc/distributed/init.h | 4 +- mindspore/ccsrc/ps/ps_context.cc | 3 +- .../gpu/distribution/collective_init.cc | 106 +++++++++++++++--- .../device/gpu/distribution/collective_init.h | 12 +- .../ccsrc/runtime/hardware/CMakeLists.txt | 3 + .../collective_communication_lib.cc | 2 + .../collective/collective_communication_lib.h | 7 ++ .../collective/communication_group.cc | 6 + .../hardware/collective/communication_group.h | 5 + .../hardware/cpu/cpu_device_context.cc | 29 +++++ .../runtime/hardware/cpu/cpu_device_context.h | 2 + .../hardware/cpu/mpi_collective_comm_lib.cc | 44 +------- .../hardware/cpu/mpi_collective_comm_lib.h | 30 ++--- .../hardware/cpu/ms_collective_comm_lib.cc | 4 +- .../hardware/cpu/ms_collective_comm_lib.h | 10 +- .../ccsrc/runtime/hardware/device_context.h | 8 +- .../hardware/gpu/gpu_device_context.cc | 8 +- .../gpu/nvidia_collective_comm_lib.cc | 45 +------- .../hardware/gpu/nvidia_collective_comm_lib.h | 29 ++--- mindspore/core/utils/ms_utils.h | 10 ++ 28 files changed, 281 insertions(+), 214 deletions(-) diff --git a/mindspore/ccsrc/distributed/cluster/cluster_context.cc b/mindspore/ccsrc/distributed/cluster/cluster_context.cc index 0dad24fe23f..349206fd3fa 100644 --- a/mindspore/ccsrc/distributed/cluster/cluster_context.cc +++ b/mindspore/ccsrc/distributed/cluster/cluster_context.cc @@ -85,7 +85,7 @@ bool ClusterContext::Finalize() { return true; } -std::string ClusterContext::node_role() const { return node_role_; } +const std::shared_ptr &ClusterContext::node() const { return node_; } void ClusterContext::InitClusterConfig() { InitNodeRole(); diff --git a/mindspore/ccsrc/distributed/cluster/cluster_context.h b/mindspore/ccsrc/distributed/cluster/cluster_context.h index f510617d304..1968c257571 100644 --- a/mindspore/ccsrc/distributed/cluster/cluster_context.h +++ b/mindspore/ccsrc/distributed/cluster/cluster_context.h @@ -49,7 +49,8 @@ class ClusterContext { // Finalize the cluster and process exits. bool Finalize(); - std::string node_role() const; + // Return node object of this process. + const std::shared_ptr &node() const; private: ClusterContext(); diff --git a/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.cc b/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.cc index f76324616bd..0fff74f946f 100644 --- a/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.cc +++ b/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.cc @@ -32,8 +32,6 @@ std::shared_ptr ClusterContext::instance() { bool ClusterContext::Initialize() const { return true; } bool ClusterContext::Finalize() const { return true; } - -std::string ClusterContext::node_role() const { return ""; } } // namespace cluster } // namespace distributed } // namespace mindspore diff --git a/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.h b/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.h index 8e41ea9a607..51c8eb8f547 100644 --- a/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.h +++ b/mindspore/ccsrc/distributed/cluster/dummy_cluster_context.h @@ -39,7 +39,6 @@ class ClusterContext { bool Initialize() const; bool Finalize() const; - std::string node_role() const; private: ClusterContext() = default; diff --git a/mindspore/ccsrc/distributed/collective/collective_manager.cc b/mindspore/ccsrc/distributed/collective/collective_manager.cc index fab23e69346..1f99b41371b 100644 --- a/mindspore/ccsrc/distributed/collective/collective_manager.cc +++ b/mindspore/ccsrc/distributed/collective/collective_manager.cc @@ -18,6 +18,7 @@ #include #include #include +#include "utils/ms_context.h" namespace mindspore { namespace distributed { @@ -27,9 +28,7 @@ CollectiveManager::CollectiveManager() finalized_(true), host_ctx_(nullptr), device_ctx_(nullptr), - host_comm_lib_(nullptr), host_comm_lib_instance_(nullptr), - device_comm_lib_(nullptr), device_comm_lib_instance_(nullptr), global_rank_id_(0), local_rank_id_(0), @@ -51,11 +50,13 @@ std::shared_ptr CollectiveManager::instance() { return instance; } -bool CollectiveManager::Initialize(const std::string &backend, const std::string &global_group_name) { +bool CollectiveManager::Initialize() { if (inited_) { return true; } - MS_LOG(INFO) << "Start initializing collective communication for backend: " << backend << "..."; + + device_type_ = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET); + MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type_ << "..."; // Step 1: Initialize host side collective communication. if (!InitHostCommlib()) { @@ -66,24 +67,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string // Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not // be necessary. // Step 2: Assign local rank id(device id) for this process. - if (!AssignLocalRank(global_group_name)) { + if (!AssignLocalRank()) { MS_LOG(ERROR) << "Failed to assign local rank id."; return false; } // Step 3: Initialize device side collective communication. - if (!InitDeviceCommLib(backend)) { + if (!InitDeviceCommLib()) { MS_LOG(ERROR) << "Failed to initialize device communication library."; return false; } // Step 4: Create global communication group. - if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) { + MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); + if (!CreateCommunicationGroup(device_comm_lib_instance_->global_group_name(), global_group_ranks_)) { MS_LOG(ERROR) << "Failed to initialize host communication library."; return false; } - MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << "."; + MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_; return true; } @@ -91,7 +93,7 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) { MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); - // Step 1: Create communication group on host side. + // Step 1: Create communication group on host side if. if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) { MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side."; return false; @@ -167,6 +169,12 @@ bool CollectiveManager::Finalize() { return true; } +void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; } + +void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; } + +uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; } + bool CollectiveManager::InitHostCommlib() { device::DeviceContextKey host_key = {"CPU", 0}; host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key); @@ -175,10 +183,7 @@ bool CollectiveManager::InitHostCommlib() { MS_LOG(ERROR) << "Failed to load communication library on the host side."; return false; } - host_comm_lib_ = host_ctx_->collective_comm_lib(); - MS_EXCEPTION_IF_NULL(host_comm_lib_); - auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_); - host_comm_lib_instance_ = instance_func(); + host_comm_lib_instance_ = host_ctx_->collective_comm_lib(); MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); // For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using @@ -197,33 +202,30 @@ bool CollectiveManager::InitHostCommlib() { global_group_ranks_.push_back(i); } + // Create world group on host side for AllGather operation of host name while assigning local rank. + host_global_group_name_ = host_comm_lib_instance_->global_group_name(); + if (!host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_)) { + MS_LOG(ERROR) << "Failed to create communication group " << host_global_group_name_ << " on host side."; + return false; + } + MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_ << ", global rank size: " << global_rank_size_; return true; } -bool CollectiveManager::InitDeviceCommLib(const std::string &backend) { - std::string device_name; - if (backend == "nccl") { - device_name = "GPU"; - } else if (backend == "hccl") { - device_name = "Ascend"; - } else { - MS_LOG(ERROR) << "Backend " << backend << " is not supported."; - return false; - } - - device::DeviceContextKey device_key = {device_name, local_rank_id_}; +bool CollectiveManager::InitDeviceCommLib() { + device::DeviceContextKey device_key = {device_type_, local_rank_id_}; device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key); MS_EXCEPTION_IF_NULL(device_ctx_); + // We can initialize device context now because device id(local_rank_id_) is already assigned. + device_ctx_->Initialize(); + if (!device_ctx_->LoadCollectiveCommLib()) { MS_LOG(ERROR) << "Failed to load communication library on the device side."; return false; } - device_comm_lib_ = device_ctx_->collective_comm_lib(); - MS_EXCEPTION_IF_NULL(device_comm_lib_); - auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_); - device_comm_lib_instance_ = instance_func(); + device_comm_lib_instance_ = device_ctx_->collective_comm_lib(); MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); MS_LOG(INFO) << "Start initializing communication library on device side..."; @@ -235,7 +237,7 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend) { return true; } -bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { +bool CollectiveManager::AssignLocalRank() { char host_name[MAX_HOSTNAME_LEN] = {0}; #ifndef _WIN32 if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) { @@ -259,8 +261,8 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); // AllGather host names across the global communication group. - if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt, - global_group_name, nullptr)) { + if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeInt, + host_global_group_name_, nullptr)) { MS_LOG(ERROR) << "AllGather for host names failed."; return false; } @@ -274,7 +276,10 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) { local_rank_id_++; } } - MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_; + + MsContext::GetInstance()->set_param(MS_CTX_DEVICE_ID, local_rank_id_); + MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_ + << ". device_id of ms_context is set."; return true; } } // namespace collective diff --git a/mindspore/ccsrc/distributed/collective/collective_manager.h b/mindspore/ccsrc/distributed/collective/collective_manager.h index 9fa9de2f956..5688745a1b3 100644 --- a/mindspore/ccsrc/distributed/collective/collective_manager.h +++ b/mindspore/ccsrc/distributed/collective/collective_manager.h @@ -43,8 +43,8 @@ class CollectiveManager { DISABLE_COPY_AND_ASSIGN(CollectiveManager); static std::shared_ptr instance(); - // Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL. - bool Initialize(const std::string &backend, const std::string &global_group_name); + // Initialize the collective communication for distributed training. The backend type is read from MindSpore context. + bool Initialize(); // Finalize the collective communication. bool Finalize(); @@ -63,8 +63,10 @@ class CollectiveManager { // In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication // framework, they're generated by cluster::ClusterContext. - uint32_t set_global_rank_id(); - uint32_t set_global_rank_size(); + void set_global_rank_id(uint32_t global_rank_id); + void set_global_rank_size(uint32_t global_rank_size); + + uint32_t local_rank_id() const; private: CollectiveManager(); @@ -73,14 +75,17 @@ class CollectiveManager { bool InitHostCommlib(); // Initialize communication library on device side. - bool InitDeviceCommLib(const std::string &backend); + bool InitDeviceCommLib(); // Assign the local rank id for this process. - bool AssignLocalRank(const std::string &global_group_name); + bool AssignLocalRank(); std::atomic_bool inited_; std::atomic_bool finalized_; + // The device type read from MindSpore context. + std::string device_type_; + // The device context on both host and device side. They are used to access the communication library on different // devices. DeviceContext *host_ctx_; @@ -88,12 +93,10 @@ class CollectiveManager { // Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication // framework. - void *host_comm_lib_; CollectiveCommunicationLib *host_comm_lib_instance_; // Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL. // When only CPU backend is used, device communication library should not be initialized. - void *device_comm_lib_; CollectiveCommunicationLib *device_comm_lib_instance_; // The global rank id of this process. Normally this range is 0 to `total process number - 1`. @@ -107,6 +110,10 @@ class CollectiveManager { // Global group ranks. std::vector global_group_ranks_; + + // The global group name on the host side. This is used for Creating global group on host side for AllGather operation + // of host name while assigning local rank. + std::string host_global_group_name_; }; } // namespace collective } // namespace distributed diff --git a/mindspore/ccsrc/distributed/constants.h b/mindspore/ccsrc/distributed/constants.h index e23c7b338be..bf92fba4fc0 100644 --- a/mindspore/ccsrc/distributed/constants.h +++ b/mindspore/ccsrc/distributed/constants.h @@ -18,6 +18,7 @@ #define MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_ #include +#include #include namespace mindspore { diff --git a/mindspore/ccsrc/distributed/init.cc b/mindspore/ccsrc/distributed/init.cc index aea5c0cb479..63a89109a8d 100644 --- a/mindspore/ccsrc/distributed/init.cc +++ b/mindspore/ccsrc/distributed/init.cc @@ -20,16 +20,29 @@ namespace mindspore { namespace distributed { -bool Initialize(const std::string &backend, const std::string &global_group_name) { +bool Initialize() { if (!InitializeCluster()) { MS_LOG(ERROR) << "Failed to initialize cluster."; return false; } - if (!InitializeCollective(backend, global_group_name)) { - MS_LOG(ERROR) << "Failed to initialize collective communication."; - return false; +#if ((defined ENABLE_CPU) && (!defined _WIN32)) + // Server and Scheduler don't use collective communication library. + auto node = cluster::ClusterContext::instance()->node(); + MS_EXCEPTION_IF_NULL(node); + if (node->role() != ps::core::NodeRole::SERVER && node->role() != ps::core::NodeRole::SCHEDULER) { + // Global rank id and size should be manually set if cluster is initialized by MindSpore communication framework. + auto abstract_node = std::dynamic_pointer_cast(cluster::ClusterContext::instance()->node()); + MS_EXCEPTION_IF_NULL(abstract_node); + collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id()); + collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num()); + + if (!InitializeCollective()) { + MS_LOG(ERROR) << "Failed to initialize collective communication."; + return false; + } } +#endif return true; } @@ -51,9 +64,7 @@ bool InitializeCluster() { return cluster::ClusterContext::instance()->Initializ bool FinalizeCluster() { return cluster::ClusterContext::instance()->Finalize(); } -bool InitializeCollective(const std::string &backend, const std::string &global_group_name) { - return collective::CollectiveManager::instance()->Initialize(backend, global_group_name); -} +bool InitializeCollective() { return collective::CollectiveManager::instance()->Initialize(); } bool FinalizeCollective() { return collective::CollectiveManager::instance()->Finalize(); } } // namespace distributed diff --git a/mindspore/ccsrc/distributed/init.h b/mindspore/ccsrc/distributed/init.h index 36a1d922ca0..60d60c5a16e 100644 --- a/mindspore/ccsrc/distributed/init.h +++ b/mindspore/ccsrc/distributed/init.h @@ -31,7 +31,7 @@ namespace distributed { // The static methods of MindSpore distributed execution. They can be exported by Pybind. // Initialize and finalize distributed execution. -bool Initialize(const std::string &backend, const std::string &global_group_name); +bool Initialize(); bool Finalize(); // Initialize and finalize the cluster based on MindSpore communication framework. @@ -39,7 +39,7 @@ bool InitializeCluster(); bool FinalizeCluster(); // Initialize and finalize collective communication for distributed execution. -bool InitializeCollective(const std::string &backend, const std::string &global_group_name); +bool InitializeCollective(); bool FinalizeCollective(); } // namespace distributed } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 1302f4df1ee..8f9df67e29f 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -410,7 +410,8 @@ void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; } core::ClusterConfig &PSContext::cluster_config() { if (cluster_config_ == nullptr) { - MS_LOG(EXCEPTION) << "The cluster config is empty."; + cluster_config_ = std::make_unique(worker_num_, server_num_, scheduler_host_, scheduler_port_); + MS_EXCEPTION_IF_NULL(cluster_config_); } return *cluster_config_; } diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc index 8886488c3dc..eb12c9ad57f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc @@ -16,6 +16,8 @@ #include "runtime/device/gpu/distribution/collective_init.h" #include "utils/log_adapter.h" +#include "utils/ms_utils.h" +#include "distributed/init.h" namespace mindspore { namespace device { @@ -30,23 +32,32 @@ bool CollectiveInitializer::collective_inited() const { return collective_inited const void *CollectiveInitializer::collective_handle() { return collective_handle_; } void CollectiveInitializer::InitCollective() { - void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); - if (handle == nullptr) { - MS_LOG(EXCEPTION) - << "Loading libgpu_collective.so failed. Many reasons could cause this:\n" - "1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built " - "with distributed feature.\n" - "2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore " - "requires NCCL-2.7.6.\n" - "3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore " - "requires OpenMPI-4.0.3.\n"; - } - auto mpi_init_funcptr = reinterpret_cast(dlsym(handle, "InitMPI")); - MS_EXCEPTION_IF_NULL(mpi_init_funcptr); - (*mpi_init_funcptr)(); + if (common::CheckUseMPI()) { + void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); + if (handle == nullptr) { + MS_LOG(EXCEPTION) + << "Loading libgpu_collective.so failed. Many reasons could cause this:\n" + "1.libgpu_collective.so is not found, please check this MindSpore package is GPU version and built " + "with distributed feature.\n" + "2.NCCL is not found or the user-installed NCCL version installed is incompatible: MindSpore " + "requires NCCL-2.7.6.\n" + "3.OpenMPI is not found or the user-installed OpenMPI version is incompatible: MindSpore " + "requires OpenMPI-4.0.3.\n"; + } + auto mpi_init_funcptr = reinterpret_cast(dlsym(handle, "InitMPI")); + MS_EXCEPTION_IF_NULL(mpi_init_funcptr); + (*mpi_init_funcptr)(); - CollectiveInitializer::instance().collective_inited_ = true; - CollectiveInitializer::instance().collective_handle_ = handle; + // Because this method InitCollective is static, the non-static member variables should be accessed by + // CollectiveInitializer::instance(). + CollectiveInitializer::instance().use_mpi_ = true; + CollectiveInitializer::instance().collective_inited_ = true; + CollectiveInitializer::instance().collective_handle_ = handle; + } else { + if (!distributed::Initialize()) { + MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL."; + } + } } void CollectiveInitializer::FinalizeCollective() { @@ -56,6 +67,69 @@ void CollectiveInitializer::FinalizeCollective() { } } } + +uint32_t CollectiveInitializer::local_rank_id() { + uint32_t local_rank_id; + if (common::CheckUseMPI()) { + MS_EXCEPTION_IF_NULL(collective_handle_); + auto get_local_rank_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); + MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); + local_rank_id = IntToUint((*get_local_rank_funcptr)()); + } else { + local_rank_id = distributed::collective::CollectiveManager::instance()->local_rank_id(); + } + return local_rank_id; +} + +bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name, + const std::vector &group_ranks) { + if (common::CheckUseMPI()) { + return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks); + } else { + MS_EXCEPTION_IF_NULL(collective_handle_); + auto create_comm_group_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "CreateCommGroup")); + MS_EXCEPTION_IF_NULL(create_comm_group_funcptr); + return (*create_comm_group_funcptr)(group_name, group_ranks); + } +} + +bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) { + if (common::CheckUseMPI()) { + return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name); + } else { + MS_EXCEPTION_IF_NULL(collective_handle_); + auto destroy_group_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "DestroyGroup")); + MS_EXCEPTION_IF_NULL(destroy_group_funcptr); + return (*destroy_group_funcptr)(group_name); + } +} + +uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) { + if (common::CheckUseMPI()) { + return distributed::collective::CollectiveManager::instance()->GetRankId(group_name); + } else { + MS_EXCEPTION_IF_NULL(collective_handle_); + auto get_rank_id_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GetRankIDByGroup")); + MS_EXCEPTION_IF_NULL(get_rank_id_funcptr); + return IntToUint((*get_rank_id_funcptr)(group_name)); + } +} + +uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) { + if (common::CheckUseMPI()) { + return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name); + } else { + MS_EXCEPTION_IF_NULL(collective_handle_); + auto get_group_size_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GetGroupSize")); + MS_EXCEPTION_IF_NULL(get_group_size_funcptr); + return IntToUint((*get_group_size_funcptr)(group_name)); + } +} } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h index 5851b31dcaa..4ccad0235ba 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h @@ -42,12 +42,20 @@ class CollectiveInitializer { static void InitCollective(); static void FinalizeCollective(); + // The capsulation of the collective communication APIs for compatibility. + uint32_t local_rank_id(); + bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks); + bool DestroyCommunicationGroup(const std::string &group_name); + uint32_t GetRankIDByGroup(const std::string &group_name); + uint32_t GetGroupSize(const std::string &group_name); + private: - CollectiveInitializer() : collective_inited_(false) {} + CollectiveInitializer() : use_mpi_(false), collective_inited_(false), collective_handle_(nullptr) {} ~CollectiveInitializer() = default; + bool use_mpi_; bool collective_inited_; - void *collective_handle_{nullptr}; + void *collective_handle_; }; } // namespace gpu } // namespace device diff --git a/mindspore/ccsrc/runtime/hardware/CMakeLists.txt b/mindspore/ccsrc/runtime/hardware/CMakeLists.txt index c58bb11ed68..696361ebe80 100644 --- a/mindspore/ccsrc/runtime/hardware/CMakeLists.txt +++ b/mindspore/ccsrc/runtime/hardware/CMakeLists.txt @@ -23,6 +23,9 @@ endif() if(ENABLE_CPU) file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc") + if(WIN32) + list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/ms_collective_comm_lib.cc") + endif() if(ENABLE_MPI) set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc" diff --git a/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc b/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc index c0b7e3e879e..8d5ecdc3f90 100644 --- a/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc +++ b/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.cc @@ -60,6 +60,8 @@ CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &gr return groups_[group_name]; } +const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; } + uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; } uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; } diff --git a/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h b/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h index 31b1206610b..61222719b26 100644 --- a/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h +++ b/mindspore/ccsrc/runtime/hardware/collective/collective_communication_lib.h @@ -77,6 +77,10 @@ class CollectiveCommunicationLib { return true; } + // Returns the global group name of this collective communication library. For NCCL, it's 'nccl_world_group'. For + // HCCL, it's 'hccl_world_group'. + const std::string &global_group_name() const; + // Returns global rank id of this process. uint32_t global_rank_id() const; @@ -90,6 +94,9 @@ class CollectiveCommunicationLib { // Whether this collective communication library is initialized. bool initialized_; + // The global group name. + std::string global_group_name_; + // The global rank id of this process. Normally this range is 0 to `total process number - 1`. uint32_t global_rank_id_; diff --git a/mindspore/ccsrc/runtime/hardware/collective/communication_group.cc b/mindspore/ccsrc/runtime/hardware/collective/communication_group.cc index 2ec3bfdb411..1fdb783c92f 100644 --- a/mindspore/ccsrc/runtime/hardware/collective/communication_group.cc +++ b/mindspore/ccsrc/runtime/hardware/collective/communication_group.cc @@ -47,5 +47,11 @@ uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) { } uint32_t CommunicationGroup::group_size() const { return size_; } + +const std::vector &CommunicationGroup::group_ranks() const { return group_ranks_; } + +const std::map &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; } + +const std::map &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; } } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/hardware/collective/communication_group.h b/mindspore/ccsrc/runtime/hardware/collective/communication_group.h index c2b75ac134e..b771ae9d711 100644 --- a/mindspore/ccsrc/runtime/hardware/collective/communication_group.h +++ b/mindspore/ccsrc/runtime/hardware/collective/communication_group.h @@ -55,6 +55,11 @@ class CommunicationGroup { // Return the size of this communication group. uint32_t group_size() const; + // Return group ranks info. + const std::vector &group_ranks() const; + const std::map &global_to_group_ranks() const; + const std::map &group_to_global_ranks() const; + protected: // Whether this communication group is initialized. bool initialized_; diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc index ffd2f97d642..73c17978260 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc @@ -34,6 +34,9 @@ #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" #include "backend/session/anf_runtime_algorithm.h" #include "profiler/device/cpu/cpu_profiling.h" +#if ((defined ENABLE_CPU) && (!defined _WIN32)) +#include "runtime/hardware/cpu/ms_collective_comm_lib.h" +#endif #ifndef ENABLE_SECURITY #include "debug/data_dump/dump_json_parser.h" #endif @@ -296,6 +299,32 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector(mpi_comm_lib_name); + MS_EXCEPTION_IF_NULL(loader); + if (!loader->Initialize()) { + MS_LOG(EXCEPTION) << "Failed to load mpi collective library."; + return false; + } + + void *collective_comm_lib_handle = loader->collective_comm_lib_ptr(); + MS_EXCEPTION_IF_NULL(collective_comm_lib_handle); + + auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle); + collective_comm_lib_ = instance_func(); + MS_EXCEPTION_IF_NULL(collective_comm_lib_); + } else { +#if ((defined ENABLE_CPU) && (!defined _WIN32)) + collective_comm_lib_ = &MsCollectiveCommLib::GetInstance(); + MS_EXCEPTION_IF_NULL(collective_comm_lib_); +#endif + } + return true; +} + bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) const { diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h index 730d4c36f99..970bc2e9a21 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h @@ -57,6 +57,8 @@ class CPUDeviceContext : public DeviceContext { const std::vector &workspace, const std::vector &outputs, bool is_dynamic_shape = false) const override; + bool LoadCollectiveCommLib() override; + private: DISABLE_COPY_AND_ASSIGN(CPUDeviceContext); diff --git a/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc b/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc index b232003cf2f..2489df40042 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.cc @@ -19,6 +19,8 @@ namespace mindspore { namespace device { namespace cpu { +MPICollectiveCommLib::MPICollectiveCommLib() { global_group_name_ = kMPIGlobalGroupName; } + bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) { if (initialized_) { return false; @@ -56,49 +58,7 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam } } // namespace cpu -// The exported APIs for 'dlsym' to load. using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib; - CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); } - -bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); } - -bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); } - -bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) { - return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks); -} - -bool DestroyCommunicationGroup(const std::string &group_name) { - return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name); -} - -uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); } - -uint32_t GetCommunicationGroupSize(const std::string &group_name) { - return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name); -} - -bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); } - -CommunicationGroupPtr GetGroup(const std::string &group_name) { - return MPICollectiveCommLib::GetInstance().GetGroup(group_name); -} - -bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, - const std::string &group_name, void *stream) { - return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream); -} -bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, - uint32_t root_rank, const std::string &group_name, void *stream) { - return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank, - group_name, stream); -} - -uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); } - -uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); } - -uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); } } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h b/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h index 1329bd0c6e3..47ebccd011b 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h +++ b/mindspore/ccsrc/runtime/hardware/cpu/mpi_collective_comm_lib.h @@ -23,10 +23,15 @@ #include "runtime/hardware/collective/collective_communication_lib.h" #include "runtime/hardware/cpu/mpi_communication_group.h" +#ifndef EXPORT_MPI_WRAPPER +#define EXPORT_MPI_WRAPPER __attribute__((visibility("default"))) +#endif + namespace mindspore { namespace device { namespace cpu { -class MPICollectiveCommLib : public CollectiveCommunicationLib { +constexpr char kMPIGlobalGroupName[] = "mpi_world_group"; +class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib { public: static MPICollectiveCommLib &GetInstance() { static MPICollectiveCommLib instance; @@ -49,35 +54,14 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib { } private: - MPICollectiveCommLib() = default; + MPICollectiveCommLib(); ~MPICollectiveCommLib() override = default; MPI_Group world_group_; }; } // namespace cpu -#ifndef EXPORT_MPI_WRAPPER -#define EXPORT_MPI_WRAPPER __attribute__((visibility("default"))) -#endif extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); -extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX, - uint32_t global_rank_size = UINT32_MAX); -extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib(); -extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name, - const std::vector &group_ranks); -extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name); -extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name); -extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name); -extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank(); -extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name); -extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, - mindspore::TypeId data_type, const std::string &group_name, void *stream); -extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, - mindspore::TypeId data_type, uint32_t root_rank, - const std::string &group_name, void *stream); -extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id(); -extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id(); -extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size(); } // namespace device } // namespace mindspore #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_ diff --git a/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.cc b/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.cc index a32701e965f..bd57a14c57e 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.cc @@ -19,6 +19,8 @@ namespace mindspore { namespace device { namespace cpu { +MsCollectiveCommLib::MsCollectiveCommLib() { global_group_name_ = kMSGlobalGroupName; } + bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) { if (initialized_) { return false; @@ -30,8 +32,6 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_ return true; } -bool MsCollectiveCommLib::Finalize() { return true; } - bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) { CHECK_RET((groups_.count(group_name) == 0), true, "The group " + group_name + " has already existed."); diff --git a/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h b/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h index b7685e470ef..bcff4d23a71 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h +++ b/mindspore/ccsrc/runtime/hardware/cpu/ms_collective_comm_lib.h @@ -22,10 +22,17 @@ #include #include "runtime/hardware/collective/collective_communication_lib.h" #include "runtime/hardware/cpu/ms_communication_group.h" +#include "distributed/cluster/cluster_context.h" +#include "fl/server/collective_ops_impl.h" namespace mindspore { namespace device { namespace cpu { +constexpr char kMSGlobalGroupName[] = "ms_world_group"; +using ClusterContext = mindspore::distributed::cluster::ClusterContext; +using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl; +using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo; + // The collective communication library for MindSpore self developed communication framework. class MsCollectiveCommLib : public CollectiveCommunicationLib { public: @@ -35,12 +42,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib { } bool Initialize(uint32_t global_rank = UINT32_MAX, uint32_t global_rank_size = UINT32_MAX) override; - bool Finalize() override; bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) override; private: - MsCollectiveCommLib() {} + MsCollectiveCommLib(); ~MsCollectiveCommLib() override = default; }; } // namespace cpu diff --git a/mindspore/ccsrc/runtime/hardware/device_context.h b/mindspore/ccsrc/runtime/hardware/device_context.h index 8fcfe815719..4f3682f6e72 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context.h +++ b/mindspore/ccsrc/runtime/hardware/device_context.h @@ -50,7 +50,7 @@ struct DeviceContextKey { class DeviceContext { public: explicit DeviceContext(const DeviceContextKey &device_context_key) - : device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {} + : device_context_key_(device_context_key), collective_comm_lib_(nullptr) {} virtual ~DeviceContext() = default; // Initialize the device context. @@ -150,7 +150,7 @@ class DeviceContext { virtual bool LoadCollectiveCommLib() { return true; } // Return collective communication object for caller to access - void *collective_comm_lib() const { return collective_comm_lib_ptr_; } + CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; } // TODO(jiaorui): will be delete // Dump all graphs. @@ -159,8 +159,8 @@ class DeviceContext { protected: DeviceContextKey device_context_key_; - // The dynamically loaded handle for collective communication library by 'dlopen'. - void *collective_comm_lib_ptr_; + // The collective communication library. + CollectiveCommunicationLib *collective_comm_lib_; }; using DeviceContextPtr = std::shared_ptr; } // namespace device diff --git a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc index f8c7abe6e1a..8a0a69869de 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc @@ -534,8 +534,12 @@ bool GPUDeviceContext::LoadCollectiveCommLib() { MS_LOG(EXCEPTION) << "Loading NCCL collective library failed."; return false; } - collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr(); - MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_); + void *collective_comm_lib_handle = loader->collective_comm_lib_ptr(); + MS_EXCEPTION_IF_NULL(collective_comm_lib_handle); + + auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle); + collective_comm_lib_ = instance_func(); + MS_EXCEPTION_IF_NULL(collective_comm_lib_); return true; #else return false; diff --git a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc index 334841a9edb..47ca8ea38cb 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.cc @@ -19,6 +19,8 @@ namespace mindspore { namespace device { namespace gpu { +NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() { global_group_name_ = kNCCLGlobalGroupName; } + bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) { if (initialized_) { return false; @@ -42,50 +44,7 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_ } } // namespace gpu -// The exported APIs for 'dlsym' to load. using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib; CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); } - -bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) { - return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size); -} - -bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); } - -bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) { - return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks); -} - -bool DestroyCommunicationGroup(const std::string &group_name) { - return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name); -} - -uint32_t GetRankId(const std::string &group_name) { - return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name); -} - -uint32_t GetCommunicationGroupSize(const std::string &group_name) { - return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name); -} - -bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); } - -CommunicationGroupPtr GetGroup(const std::string &group_name) { - return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name); -} - -bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, - const std::string &group_name, void *stream) { - return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, - stream); -} - -bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type, - uint32_t root_rank, const std::string &group_name, void *stream) { - return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank, - group_name, stream); -} - -uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); } } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h index 16565a52d90..a0637deadd9 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h +++ b/mindspore/ccsrc/runtime/hardware/gpu/nvidia_collective_comm_lib.h @@ -24,11 +24,15 @@ #include "runtime/hardware/collective/collective_communication_lib.h" #include "runtime/hardware/gpu/nvidia_communication_group.h" +#ifndef EXPORT_NCCL_WRAPPER +#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default"))) +#endif + namespace mindspore { namespace device { namespace gpu { -constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; -class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { +constexpr char kNCCLGlobalGroupName[] = "nccl_world_group"; +class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicationLib { public: static NvidiaCollectiveCommLib &GetInstance() { static NvidiaCollectiveCommLib instance; @@ -50,31 +54,12 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib { } private: - NvidiaCollectiveCommLib() = default; + NvidiaCollectiveCommLib(); ~NvidiaCollectiveCommLib() override = default; }; } // namespace gpu -#ifndef EXPORT_NCCL_WRAPPER -#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default"))) -#endif extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); -extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX, - uint32_t global_rank_size = UINT32_MAX); -extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib(); -extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name, - const std::vector &group_ranks); -extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name); -extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name); -extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name); -extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank(); -extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name); -extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, - mindspore::TypeId data_type, const std::string &group_name, void *stream); -extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, - mindspore::TypeId data_type, uint32_t root_rank, - const std::string &group_name, void *stream); -extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id(); } // namespace device } // namespace mindspore #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_ diff --git a/mindspore/core/utils/ms_utils.h b/mindspore/core/utils/ms_utils.h index b9e4630a166..212b72597ed 100644 --- a/mindspore/core/utils/ms_utils.h +++ b/mindspore/core/utils/ms_utils.h @@ -90,6 +90,16 @@ static inline bool IsLittleByteOrder() { return false; } +static inline bool CheckUseMPI() { + // If these OpenMPI environment variables are set, we consider this process is launched by OpenMPI. + std::string ompi_command_env = GetEnv("OMPI_COMMAND"); + std::string pmix_rank_env = GetEnv("PMIX_RANK"); + if (!ompi_command_env.empty() && !pmix_rank_env.empty()) { + return true; + } + return false; +} + template bool IsEqual(const std::shared_ptr &a, const std::shared_ptr &b) { if (a == b) {