1.Purge not used API.

2.Adapt for collective_init.h
This commit is contained in:
ZPaC 2021-11-27 10:57:36 +08:00
parent 7fb6df8384
commit 2b7429c5d2
28 changed files with 281 additions and 214 deletions

View File

@ -85,7 +85,7 @@ bool ClusterContext::Finalize() {
return true;
}
std::string ClusterContext::node_role() const { return node_role_; }
const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
void ClusterContext::InitClusterConfig() {
InitNodeRole();

View File

@ -49,7 +49,8 @@ class ClusterContext {
// Finalize the cluster and process exits.
bool Finalize();
std::string node_role() const;
// Return node object of this process.
const std::shared_ptr<ps::core::Node> &node() const;
private:
ClusterContext();

View File

@ -32,8 +32,6 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
bool ClusterContext::Initialize() const { return true; }
bool ClusterContext::Finalize() const { return true; }
std::string ClusterContext::node_role() const { return ""; }
} // namespace cluster
} // namespace distributed
} // namespace mindspore

View File

@ -39,7 +39,6 @@ class ClusterContext {
bool Initialize() const;
bool Finalize() const;
std::string node_role() const;
private:
ClusterContext() = default;

View File

@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include <memory>
#include "utils/ms_context.h"
namespace mindspore {
namespace distributed {
@ -27,9 +28,7 @@ CollectiveManager::CollectiveManager()
finalized_(true),
host_ctx_(nullptr),
device_ctx_(nullptr),
host_comm_lib_(nullptr),
host_comm_lib_instance_(nullptr),
device_comm_lib_(nullptr),
device_comm_lib_instance_(nullptr),
global_rank_id_(0),
local_rank_id_(0),
@ -51,11 +50,13 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
return instance;
}
bool CollectiveManager::Initialize(const std::string &backend, const std::string &global_group_name) {
bool CollectiveManager::Initialize() {
if (inited_) {
return true;
}
MS_LOG(INFO) << "Start initializing collective communication for backend: " << backend << "...";
device_type_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
MS_LOG(INFO) << "Start initializing collective communication for backend: " << device_type_ << "...";
// Step 1: Initialize host side collective communication.
if (!InitHostCommlib()) {
@ -66,24 +67,25 @@ bool CollectiveManager::Initialize(const std::string &backend, const std::string
// Step 2, 3 and 4 are for device communication library. So if the training job is only launched on CPU, they will not
// be necessary.
// Step 2: Assign local rank id(device id) for this process.
if (!AssignLocalRank(global_group_name)) {
if (!AssignLocalRank()) {
MS_LOG(ERROR) << "Failed to assign local rank id.";
return false;
}
// Step 3: Initialize device side collective communication.
if (!InitDeviceCommLib(backend)) {
if (!InitDeviceCommLib()) {
MS_LOG(ERROR) << "Failed to initialize device communication library.";
return false;
}
// Step 4: Create global communication group.
if (!CreateCommunicationGroup(global_group_name, global_group_ranks_)) {
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!CreateCommunicationGroup(device_comm_lib_instance_->global_group_name(), global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to initialize host communication library.";
return false;
}
MS_LOG(INFO) << "End initializing collective communication for backend: " << backend << ".";
MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
return true;
}
@ -91,7 +93,7 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
// Step 1: Create communication group on host side.
// Step 1: Create communication group on host side if.
if (!host_comm_lib_instance_->CreateCommunicationGroup(group_name, group_ranks)) {
MS_LOG(ERROR) << "Failed to create communication group " << group_name << " on host side.";
return false;
@ -167,6 +169,12 @@ bool CollectiveManager::Finalize() {
return true;
}
void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
void CollectiveManager::set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
uint32_t CollectiveManager::local_rank_id() const { return local_rank_id_; }
bool CollectiveManager::InitHostCommlib() {
device::DeviceContextKey host_key = {"CPU", 0};
host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
@ -175,10 +183,7 @@ bool CollectiveManager::InitHostCommlib() {
MS_LOG(ERROR) << "Failed to load communication library on the host side.";
return false;
}
host_comm_lib_ = host_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, host_comm_lib_);
host_comm_lib_instance_ = instance_func();
host_comm_lib_instance_ = host_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
@ -197,33 +202,30 @@ bool CollectiveManager::InitHostCommlib() {
global_group_ranks_.push_back(i);
}
// Create world group on host side for AllGather operation of host name while assigning local rank.
host_global_group_name_ = host_comm_lib_instance_->global_group_name();
if (!host_comm_lib_instance_->CreateCommunicationGroup(host_global_group_name_, global_group_ranks_)) {
MS_LOG(ERROR) << "Failed to create communication group " << host_global_group_name_ << " on host side.";
return false;
}
MS_LOG(INFO) << "Communication library on host side is successfully initialized. Global rank id: " << global_rank_id_
<< ", global rank size: " << global_rank_size_;
return true;
}
bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
std::string device_name;
if (backend == "nccl") {
device_name = "GPU";
} else if (backend == "hccl") {
device_name = "Ascend";
} else {
MS_LOG(ERROR) << "Backend " << backend << " is not supported.";
return false;
}
device::DeviceContextKey device_key = {device_name, local_rank_id_};
bool CollectiveManager::InitDeviceCommLib() {
device::DeviceContextKey device_key = {device_type_, local_rank_id_};
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(device_key);
MS_EXCEPTION_IF_NULL(device_ctx_);
// We can initialize device context now because device id(local_rank_id_) is already assigned.
device_ctx_->Initialize();
if (!device_ctx_->LoadCollectiveCommLib()) {
MS_LOG(ERROR) << "Failed to load communication library on the device side.";
return false;
}
device_comm_lib_ = device_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(device_comm_lib_);
auto instance_func = DlsymFuncObj(communication_lib_instance, device_comm_lib_);
device_comm_lib_instance_ = instance_func();
device_comm_lib_instance_ = device_ctx_->collective_comm_lib();
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
MS_LOG(INFO) << "Start initializing communication library on device side...";
@ -235,7 +237,7 @@ bool CollectiveManager::InitDeviceCommLib(const std::string &backend) {
return true;
}
bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
bool CollectiveManager::AssignLocalRank() {
char host_name[MAX_HOSTNAME_LEN] = {0};
#ifndef _WIN32
if (gethostname(host_name, MAX_HOSTNAME_LEN) != 0) {
@ -259,8 +261,8 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
// AllGather host names across the global communication group.
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, sizeof(size_t), TypeId::kNumberTypeInt,
global_group_name, nullptr)) {
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeInt,
host_global_group_name_, nullptr)) {
MS_LOG(ERROR) << "AllGather for host names failed.";
return false;
}
@ -274,7 +276,10 @@ bool CollectiveManager::AssignLocalRank(const std::string &global_group_name) {
local_rank_id_++;
}
}
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_;
MsContext::GetInstance()->set_param<uint32_t>(MS_CTX_DEVICE_ID, local_rank_id_);
MS_LOG(INFO) << "The local rank id assigned for this process is " << local_rank_id_
<< ". device_id of ms_context is set.";
return true;
}
} // namespace collective

View File

@ -43,8 +43,8 @@ class CollectiveManager {
DISABLE_COPY_AND_ASSIGN(CollectiveManager);
static std::shared_ptr<CollectiveManager> instance();
// Initialize the collective communication for distributed training with the backend name, e.g., NCCL or HCCL.
bool Initialize(const std::string &backend, const std::string &global_group_name);
// Initialize the collective communication for distributed training. The backend type is read from MindSpore context.
bool Initialize();
// Finalize the collective communication.
bool Finalize();
@ -63,8 +63,10 @@ class CollectiveManager {
// In some cases global rank id and rank size should be set by caller, e.g., when using MindSpore communication
// framework, they're generated by cluster::ClusterContext.
uint32_t set_global_rank_id();
uint32_t set_global_rank_size();
void set_global_rank_id(uint32_t global_rank_id);
void set_global_rank_size(uint32_t global_rank_size);
uint32_t local_rank_id() const;
private:
CollectiveManager();
@ -73,14 +75,17 @@ class CollectiveManager {
bool InitHostCommlib();
// Initialize communication library on device side.
bool InitDeviceCommLib(const std::string &backend);
bool InitDeviceCommLib();
// Assign the local rank id for this process.
bool AssignLocalRank(const std::string &global_group_name);
bool AssignLocalRank();
std::atomic_bool inited_;
std::atomic_bool finalized_;
// The device type read from MindSpore context.
std::string device_type_;
// The device context on both host and device side. They are used to access the communication library on different
// devices.
DeviceContext *host_ctx_;
@ -88,12 +93,10 @@ class CollectiveManager {
// Host communication library refers to the communication libaray for CPU, e.g., OpenMPI and MindSpore communication
// framework.
void *host_comm_lib_;
CollectiveCommunicationLib *host_comm_lib_instance_;
// Device communication library refers to the communication libaray for NPU or GPU, e.g., NCCL and HCCL.
// When only CPU backend is used, device communication library should not be initialized.
void *device_comm_lib_;
CollectiveCommunicationLib *device_comm_lib_instance_;
// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
@ -107,6 +110,10 @@ class CollectiveManager {
// Global group ranks.
std::vector<uint32_t> global_group_ranks_;
// The global group name on the host side. This is used for Creating global group on host side for AllGather operation
// of host name while assigning local rank.
std::string host_global_group_name_;
};
} // namespace collective
} // namespace distributed

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_
#include <set>
#include <map>
#include <string>
namespace mindspore {

View File

@ -20,16 +20,29 @@
namespace mindspore {
namespace distributed {
bool Initialize(const std::string &backend, const std::string &global_group_name) {
bool Initialize() {
if (!InitializeCluster()) {
MS_LOG(ERROR) << "Failed to initialize cluster.";
return false;
}
if (!InitializeCollective(backend, global_group_name)) {
#if ((defined ENABLE_CPU) && (!defined _WIN32))
// Server and Scheduler don't use collective communication library.
auto node = cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
if (node->role() != ps::core::NodeRole::SERVER && node->role() != ps::core::NodeRole::SCHEDULER) {
// Global rank id and size should be manually set if cluster is initialized by MindSpore communication framework.
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
if (!InitializeCollective()) {
MS_LOG(ERROR) << "Failed to initialize collective communication.";
return false;
}
}
#endif
return true;
}
@ -51,9 +64,7 @@ bool InitializeCluster() { return cluster::ClusterContext::instance()->Initializ
bool FinalizeCluster() { return cluster::ClusterContext::instance()->Finalize(); }
bool InitializeCollective(const std::string &backend, const std::string &global_group_name) {
return collective::CollectiveManager::instance()->Initialize(backend, global_group_name);
}
bool InitializeCollective() { return collective::CollectiveManager::instance()->Initialize(); }
bool FinalizeCollective() { return collective::CollectiveManager::instance()->Finalize(); }
} // namespace distributed

View File

@ -31,7 +31,7 @@ namespace distributed {
// The static methods of MindSpore distributed execution. They can be exported by Pybind.
// Initialize and finalize distributed execution.
bool Initialize(const std::string &backend, const std::string &global_group_name);
bool Initialize();
bool Finalize();
// Initialize and finalize the cluster based on MindSpore communication framework.
@ -39,7 +39,7 @@ bool InitializeCluster();
bool FinalizeCluster();
// Initialize and finalize collective communication for distributed execution.
bool InitializeCollective(const std::string &backend, const std::string &global_group_name);
bool InitializeCollective();
bool FinalizeCollective();
} // namespace distributed
} // namespace mindspore

View File

@ -410,7 +410,8 @@ void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
core::ClusterConfig &PSContext::cluster_config() {
if (cluster_config_ == nullptr) {
MS_LOG(EXCEPTION) << "The cluster config is empty.";
cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
MS_EXCEPTION_IF_NULL(cluster_config_);
}
return *cluster_config_;
}

View File

@ -16,6 +16,8 @@
#include "runtime/device/gpu/distribution/collective_init.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "distributed/init.h"
namespace mindspore {
namespace device {
@ -30,6 +32,7 @@ bool CollectiveInitializer::collective_inited() const { return collective_inited
const void *CollectiveInitializer::collective_handle() { return collective_handle_; }
void CollectiveInitializer::InitCollective() {
if (common::CheckUseMPI()) {
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
if (handle == nullptr) {
MS_LOG(EXCEPTION)
@ -45,8 +48,16 @@ void CollectiveInitializer::InitCollective() {
MS_EXCEPTION_IF_NULL(mpi_init_funcptr);
(*mpi_init_funcptr)();
// Because this method InitCollective is static, the non-static member variables should be accessed by
// CollectiveInitializer::instance().
CollectiveInitializer::instance().use_mpi_ = true;
CollectiveInitializer::instance().collective_inited_ = true;
CollectiveInitializer::instance().collective_handle_ = handle;
} else {
if (!distributed::Initialize()) {
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL.";
}
}
}
void CollectiveInitializer::FinalizeCollective() {
@ -56,6 +67,69 @@ void CollectiveInitializer::FinalizeCollective() {
}
}
}
uint32_t CollectiveInitializer::local_rank_id() {
uint32_t local_rank_id;
if (common::CheckUseMPI()) {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
local_rank_id = IntToUint((*get_local_rank_funcptr)());
} else {
local_rank_id = distributed::collective::CollectiveManager::instance()->local_rank_id();
}
return local_rank_id;
}
bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto create_comm_group_funcptr =
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
return (*create_comm_group_funcptr)(group_name, group_ranks);
}
}
bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto destroy_group_funcptr =
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
return (*destroy_group_funcptr)(group_name);
}
}
uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_rank_id_funcptr =
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
return IntToUint((*get_rank_id_funcptr)(group_name));
}
}
uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_group_size_funcptr =
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
return IntToUint((*get_group_size_funcptr)(group_name));
}
}
} // namespace gpu
} // namespace device
} // namespace mindspore

View File

@ -42,12 +42,20 @@ class CollectiveInitializer {
static void InitCollective();
static void FinalizeCollective();
// The capsulation of the collective communication APIs for compatibility.
uint32_t local_rank_id();
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks);
bool DestroyCommunicationGroup(const std::string &group_name);
uint32_t GetRankIDByGroup(const std::string &group_name);
uint32_t GetGroupSize(const std::string &group_name);
private:
CollectiveInitializer() : collective_inited_(false) {}
CollectiveInitializer() : use_mpi_(false), collective_inited_(false), collective_handle_(nullptr) {}
~CollectiveInitializer() = default;
bool use_mpi_;
bool collective_inited_;
void *collective_handle_{nullptr};
void *collective_handle_;
};
} // namespace gpu
} // namespace device

View File

@ -23,6 +23,9 @@ endif()
if(ENABLE_CPU)
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/mpi_collective_comm_lib.cc" "cpu/mpi_communication_group.cc")
if(WIN32)
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "cpu/ms_collective_comm_lib.cc")
endif()
if(ENABLE_MPI)
set(MPI_COLLECTIVE_SRCS "cpu/mpi_collective_comm_lib.cc"
"cpu/mpi_communication_group.cc"

View File

@ -60,6 +60,8 @@ CommunicationGroupPtr CollectiveCommunicationLib::GetGroup(const std::string &gr
return groups_[group_name];
}
const std::string &CollectiveCommunicationLib::global_group_name() const { return global_group_name_; }
uint32_t CollectiveCommunicationLib::global_rank_id() const { return global_rank_id_; }
uint32_t CollectiveCommunicationLib::local_rank_id() const { return local_rank_id_; }

View File

@ -77,6 +77,10 @@ class CollectiveCommunicationLib {
return true;
}
// Returns the global group name of this collective communication library. For NCCL, it's 'nccl_world_group'. For
// HCCL, it's 'hccl_world_group'.
const std::string &global_group_name() const;
// Returns global rank id of this process.
uint32_t global_rank_id() const;
@ -90,6 +94,9 @@ class CollectiveCommunicationLib {
// Whether this collective communication library is initialized.
bool initialized_;
// The global group name.
std::string global_group_name_;
// The global rank id of this process. Normally this range is 0 to `total process number - 1`.
uint32_t global_rank_id_;

View File

@ -47,5 +47,11 @@ uint32_t CommunicationGroup::GetGlobalRank(uint32_t group_rank) {
}
uint32_t CommunicationGroup::group_size() const { return size_; }
const std::vector<uint32_t> &CommunicationGroup::group_ranks() const { return group_ranks_; }
const std::map<uint32_t, uint32_t> &CommunicationGroup::global_to_group_ranks() const { return global_to_group_ranks_; }
const std::map<uint32_t, uint32_t> &CommunicationGroup::group_to_global_ranks() const { return group_to_global_ranks_; }
} // namespace device
} // namespace mindspore

View File

@ -55,6 +55,11 @@ class CommunicationGroup {
// Return the size of this communication group.
uint32_t group_size() const;
// Return group ranks info.
const std::vector<uint32_t> &group_ranks() const;
const std::map<uint32_t, uint32_t> &global_to_group_ranks() const;
const std::map<uint32_t, uint32_t> &group_to_global_ranks() const;
protected:
// Whether this communication group is initialized.
bool initialized_;

View File

@ -34,6 +34,9 @@
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "profiler/device/cpu/cpu_profiling.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "runtime/hardware/cpu/ms_collective_comm_lib.h"
#endif
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
#endif
@ -296,6 +299,32 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
return DoLaunchKernel(kernel_mod, inputs, workspace, outputs);
}
bool CPUDeviceContext::LoadCollectiveCommLib() {
bool using_mpi = common::CheckUseMPI();
if (using_mpi) {
std::string mpi_comm_lib_name = "libmpi_collective.so";
auto loader = std::make_shared<CollectiveCommLibLoader>(mpi_comm_lib_name);
MS_EXCEPTION_IF_NULL(loader);
if (!loader->Initialize()) {
MS_LOG(EXCEPTION) << "Failed to load mpi collective library.";
return false;
}
void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
collective_comm_lib_ = instance_func();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
} else {
#if ((defined ENABLE_CPU) && (!defined _WIN32))
collective_comm_lib_ = &MsCollectiveCommLib::GetInstance();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
#endif
}
return true;
}
bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const {

View File

@ -57,6 +57,8 @@ class CPUDeviceContext : public DeviceContext {
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
bool is_dynamic_shape = false) const override;
bool LoadCollectiveCommLib() override;
private:
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);

View File

@ -19,6 +19,8 @@
namespace mindspore {
namespace device {
namespace cpu {
MPICollectiveCommLib::MPICollectiveCommLib() { global_group_name_ = kMPIGlobalGroupName; }
bool MPICollectiveCommLib::Initialize(uint32_t, uint32_t) {
if (initialized_) {
return false;
@ -56,49 +58,7 @@ bool MPICollectiveCommLib::CreateCommunicationGroup(const std::string &group_nam
}
} // namespace cpu
// The exported APIs for 'dlsym' to load.
using MPICollectiveCommLib = mindspore::device::cpu::MPICollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &MPICollectiveCommLib::GetInstance(); }
bool InitializeCollectiveLib(uint32_t, uint32_t) { return MPICollectiveCommLib::GetInstance().Initialize(); }
bool FinalizeCollectiveLib() { return MPICollectiveCommLib::GetInstance().Finalize(); }
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return MPICollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}
bool DestroyCommunicationGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}
uint32_t GetRankId(const std::string &group_name) { return MPICollectiveCommLib::GetInstance().GetRankId(group_name); }
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroupSize(group_name);
}
bool AssignLocalRank() { return MPICollectiveCommLib::GetInstance().AssignLocalRank(); }
CommunicationGroupPtr GetGroup(const std::string &group_name) {
return MPICollectiveCommLib::GetInstance().GetGroup(group_name);
}
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name, stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return MPICollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}
uint32_t global_rank_id() { return MPICollectiveCommLib::GetInstance().global_rank_id(); }
uint32_t local_rank_id() { return MPICollectiveCommLib::GetInstance().local_rank_id(); }
uint32_t global_rank_size() { return MPICollectiveCommLib::GetInstance().global_rank_size(); }
} // namespace device
} // namespace mindspore

View File

@ -23,10 +23,15 @@
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/cpu/mpi_communication_group.h"
#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
namespace mindspore {
namespace device {
namespace cpu {
class MPICollectiveCommLib : public CollectiveCommunicationLib {
constexpr char kMPIGlobalGroupName[] = "mpi_world_group";
class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib {
public:
static MPICollectiveCommLib &GetInstance() {
static MPICollectiveCommLib instance;
@ -49,35 +54,14 @@ class MPICollectiveCommLib : public CollectiveCommunicationLib {
}
private:
MPICollectiveCommLib() = default;
MPICollectiveCommLib();
~MPICollectiveCommLib() override = default;
MPI_Group world_group_;
};
} // namespace cpu
#ifndef EXPORT_MPI_WRAPPER
#define EXPORT_MPI_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_MPI_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_MPI_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_MPI_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_MPI_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_MPI_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER uint32_t GetGroupSize(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_MPI_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_MPI_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t local_rank_id();
extern "C" EXPORT_MPI_WRAPPER uint32_t global_rank_size();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MPI_COLLECTIVE_COMM_LIB_H_

View File

@ -19,6 +19,8 @@
namespace mindspore {
namespace device {
namespace cpu {
MsCollectiveCommLib::MsCollectiveCommLib() { global_group_name_ = kMSGlobalGroupName; }
bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
if (initialized_) {
return false;
@ -30,8 +32,6 @@ bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_
return true;
}
bool MsCollectiveCommLib::Finalize() { return true; }
bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
CHECK_RET((groups_.count(group_name) == 0), true, "The group " + group_name + " has already existed.");

View File

@ -22,10 +22,17 @@
#include <string>
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/cpu/ms_communication_group.h"
#include "distributed/cluster/cluster_context.h"
#include "fl/server/collective_ops_impl.h"
namespace mindspore {
namespace device {
namespace cpu {
constexpr char kMSGlobalGroupName[] = "ms_world_group";
using ClusterContext = mindspore::distributed::cluster::ClusterContext;
using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl;
using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
// The collective communication library for MindSpore self developed communication framework.
class MsCollectiveCommLib : public CollectiveCommunicationLib {
public:
@ -35,12 +42,11 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {
}
bool Initialize(uint32_t global_rank = UINT32_MAX, uint32_t global_rank_size = UINT32_MAX) override;
bool Finalize() override;
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) override;
private:
MsCollectiveCommLib() {}
MsCollectiveCommLib();
~MsCollectiveCommLib() override = default;
};
} // namespace cpu

View File

@ -50,7 +50,7 @@ struct DeviceContextKey {
class DeviceContext {
public:
explicit DeviceContext(const DeviceContextKey &device_context_key)
: device_context_key_(device_context_key), collective_comm_lib_ptr_(nullptr) {}
: device_context_key_(device_context_key), collective_comm_lib_(nullptr) {}
virtual ~DeviceContext() = default;
// Initialize the device context.
@ -150,7 +150,7 @@ class DeviceContext {
virtual bool LoadCollectiveCommLib() { return true; }
// Return collective communication object for caller to access
void *collective_comm_lib() const { return collective_comm_lib_ptr_; }
CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
// TODO(jiaorui): will be delete
// Dump all graphs.
@ -159,8 +159,8 @@ class DeviceContext {
protected:
DeviceContextKey device_context_key_;
// The dynamically loaded handle for collective communication library by 'dlopen'.
void *collective_comm_lib_ptr_;
// The collective communication library.
CollectiveCommunicationLib *collective_comm_lib_;
};
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
} // namespace device

View File

@ -534,8 +534,12 @@ bool GPUDeviceContext::LoadCollectiveCommLib() {
MS_LOG(EXCEPTION) << "Loading NCCL collective library failed.";
return false;
}
collective_comm_lib_ptr_ = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_ptr_);
void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
collective_comm_lib_ = instance_func();
MS_EXCEPTION_IF_NULL(collective_comm_lib_);
return true;
#else
return false;

View File

@ -19,6 +19,8 @@
namespace mindspore {
namespace device {
namespace gpu {
NvidiaCollectiveCommLib::NvidiaCollectiveCommLib() { global_group_name_ = kNCCLGlobalGroupName; }
bool NvidiaCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size) {
if (initialized_) {
return false;
@ -42,50 +44,7 @@ bool NvidiaCollectiveCommLib::CreateCommunicationGroup(const std::string &group_
}
} // namespace gpu
// The exported APIs for 'dlsym' to load.
using NvidiaCollectiveCommLib = mindspore::device::gpu::NvidiaCollectiveCommLib;
CollectiveCommunicationLib *communication_lib_instance() { return &NvidiaCollectiveCommLib::GetInstance(); }
bool InitializeCollectiveLib(uint32_t global_rank, uint32_t global_rank_size) {
return NvidiaCollectiveCommLib::GetInstance().Initialize(global_rank, global_rank_size);
}
bool FinalizeCollectiveLib() { return NvidiaCollectiveCommLib::GetInstance().Finalize(); }
bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks) {
return NvidiaCollectiveCommLib::GetInstance().CreateCommunicationGroup(group_name, group_ranks);
}
bool DestroyCommunicationGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().DestroyCommunicationGroup(group_name);
}
uint32_t GetRankId(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetRankId(group_name);
}
uint32_t GetCommunicationGroupSize(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroupSize(group_name);
}
bool AssignLocalRank() { return NvidiaCollectiveCommLib::GetInstance().AssignLocalRank(); }
CommunicationGroupPtr GetGroup(const std::string &group_name) {
return NvidiaCollectiveCommLib::GetInstance().GetGroup(group_name);
}
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().AllGather(send_buff, recv_buff, send_count, data_type, group_name,
stream);
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, mindspore::TypeId data_type,
uint32_t root_rank, const std::string &group_name, void *stream) {
return NvidiaCollectiveCommLib::GetInstance().Broadcast(send_buff, recv_buff, send_count, data_type, root_rank,
group_name, stream);
}
uint32_t local_rank_id() { return NvidiaCollectiveCommLib::GetInstance().local_rank_id(); }
} // namespace device
} // namespace mindspore

View File

@ -24,11 +24,15 @@
#include "runtime/hardware/collective/collective_communication_lib.h"
#include "runtime/hardware/gpu/nvidia_communication_group.h"
#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
namespace mindspore {
namespace device {
namespace gpu {
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
constexpr char kNCCLGlobalGroupName[] = "nccl_world_group";
class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
public:
static NvidiaCollectiveCommLib &GetInstance() {
static NvidiaCollectiveCommLib instance;
@ -50,31 +54,12 @@ class NvidiaCollectiveCommLib : public CollectiveCommunicationLib {
}
private:
NvidiaCollectiveCommLib() = default;
NvidiaCollectiveCommLib();
~NvidiaCollectiveCommLib() override = default;
};
} // namespace gpu
#ifndef EXPORT_NCCL_WRAPPER
#define EXPORT_NCCL_WRAPPER __attribute__((visibility("default")))
#endif
extern "C" EXPORT_NCCL_WRAPPER CollectiveCommunicationLib *communication_lib_instance();
extern "C" EXPORT_NCCL_WRAPPER bool InitializeCollectiveLib(uint32_t global_rank = UINT32_MAX,
uint32_t global_rank_size = UINT32_MAX);
extern "C" EXPORT_NCCL_WRAPPER bool FinalizeCollectiveLib();
extern "C" EXPORT_NCCL_WRAPPER bool CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks);
extern "C" EXPORT_NCCL_WRAPPER bool DestroyCommunicationGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetRankId(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER uint32_t GetCommunicationGroupSize(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AssignLocalRank();
extern "C" EXPORT_NCCL_WRAPPER CommunicationGroupPtr GetGroup(const std::string &group_name);
extern "C" EXPORT_NCCL_WRAPPER bool AllGather(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count,
mindspore::TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream);
extern "C" EXPORT_NCCL_WRAPPER uint32_t local_rank_id();
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_NVIDIA_COLLECTIVE_COMM_LIB_H_

View File

@ -90,6 +90,16 @@ static inline bool IsLittleByteOrder() {
return false;
}
static inline bool CheckUseMPI() {
// If these OpenMPI environment variables are set, we consider this process is launched by OpenMPI.
std::string ompi_command_env = GetEnv("OMPI_COMMAND");
std::string pmix_rank_env = GetEnv("PMIX_RANK");
if (!ompi_command_env.empty() && !pmix_rank_env.empty()) {
return true;
}
return false;
}
template <typename T>
bool IsEqual(const std::shared_ptr<T> &a, const std::shared_ptr<T> &b) {
if (a == b) {