!27015 Use MindSpore communication framework as OpenMPI

Merge pull request !27015 from ZPaC/dir-of-distributed
This commit is contained in:
i-robot 2021-12-03 03:09:09 +00:00 committed by Gitee
commit f72bce0377
24 changed files with 192 additions and 153 deletions

View File

@ -94,6 +94,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd_fp32.c:DeConvWgMerge
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64
mindspore/mindspore/ccsrc/backend/session/gpu_session.cc:mindspore::session::gpu::GPUSession::LoadInputData
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetScalarToProto

View File

@ -101,13 +101,8 @@ using GetLocalRankId = device::gpu::GetLocalRankId;
using InitNCCLComm = device::gpu::InitNCCLComm;
void GPUSession::Init(uint32_t device_id) {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
device_id = IntToUint((*get_local_rank_funcptr)());
if (CollectiveInitializer::instance().collective_inited()) {
device_id = CollectiveInitializer::instance().local_rank_id();
}
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
if (!ret) {
@ -116,11 +111,12 @@ void GPUSession::Init(uint32_t device_id) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
if (collective_inited) {
if (collective_handle_ != nullptr) {
if (CollectiveInitializer::instance().collective_inited()) {
auto collective_handle = CollectiveInitializer::instance().collective_handle();
if (collective_handle != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";

View File

@ -30,7 +30,7 @@ ClusterContext::ClusterContext()
scheduler_host_(kLocalHost),
scheduler_port_(kDefaultSchedPort),
node_(nullptr),
node_role_(kEnvRoleOfWorker),
node_role_(""),
cluster_config_(nullptr) {}
ClusterContext::~ClusterContext() {
@ -59,7 +59,12 @@ bool ClusterContext::Initialize() {
// Step 2: Build network for this cluster. Every process will block in this method until networking is done.
if (!BuildCluster()) {
MS_LOG(EXCEPTION) << "Building networking for " << node_role_ << " failed.";
MS_EXCEPTION_IF_NULL(node_);
if (!node_->Stop()) {
MS_LOG(ERROR) << "Failed to stop node after the failure of BuildCluster";
return false;
}
MS_LOG(ERROR) << "Building networking for " << node_role_ << " failed.";
return false;
}
@ -87,10 +92,13 @@ bool ClusterContext::Finalize() {
const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
bool ClusterContext::initialized() const { return inited_; }
void ClusterContext::InitClusterConfig() {
InitNodeRole();
InitSchedulerIp();
InitSchedulerPort();
ps::PSContext::instance()->set_ms_role(node_role_);
ps::PSContext::instance()->set_worker_num(node_num_each_role_[kEnvRoleOfWorker]);
ps::PSContext::instance()->set_server_num(node_num_each_role_[kEnvRoleOfServer]);
ps::PSContext::instance()->set_scheduler_ip(scheduler_host_);
@ -117,7 +125,7 @@ bool ClusterContext::BuildCluster() {
RegisterEventCallback();
if (!node_->Start()) {
MS_LOG(EXCEPTION) << "Building network failed.";
MS_LOG(ERROR) << "Building network failed.";
return false;
}
MS_LOG(INFO) << "Cluster is successfully initialized.";
@ -148,12 +156,7 @@ void ClusterContext::InitNodeRole() {
}
}
void ClusterContext::InitSchedulerIp() {
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
if (scheduler_host_ != kLocalHost) {
MS_LOG(EXCEPTION) << "Scheduler IP should be 127.0.0.1";
}
}
void ClusterContext::InitSchedulerIp() { scheduler_host_ = common::GetEnv(kEnvSchedulerHost); }
void ClusterContext::InitSchedulerPort() {
TRY_AND_CATCH_WITH_EXCEPTION((scheduler_port_ = static_cast<uint16_t>(std::stoi(common::GetEnv(kEnvSchedulerPort)))),

View File

@ -52,6 +52,9 @@ class ClusterContext {
// Return node object of this process.
const std::shared_ptr<ps::core::Node> &node() const;
// Return cluster is initialized.
bool initialized() const;
private:
ClusterContext();

View File

@ -32,6 +32,8 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
bool ClusterContext::Initialize() const { return true; }
bool ClusterContext::Finalize() const { return true; }
bool ClusterContext::initialized() const { return false; }
} // namespace cluster
} // namespace distributed
} // namespace mindspore

View File

@ -39,6 +39,7 @@ class ClusterContext {
bool Initialize() const;
bool Finalize() const;
bool initialized() const;
private:
ClusterContext() = default;

View File

@ -52,6 +52,7 @@
#include "ps/worker.h"
#include "fl/worker/fl_worker.h"
#include "fl/server/server.h"
#include "distributed/cluster/cluster_context.h"
#endif
namespace mindspore {
@ -834,6 +835,15 @@ bool StartFLWorkerAction(const ResourcePtr &) {
}
bool StartPSServerAction(const ResourcePtr &res) {
if (distributed::cluster::ClusterContext::instance()->initialized()) {
MS_LOG(INFO) << "This node is server. Start wait for finalizing.";
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
MS_LOG(ERROR) << "Failed to finalize server.";
return false;
}
MS_LOG(INFO) << "Server is successfully finalized.";
return true;
}
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
auto &ps = ps::ParameterServer::GetInstance();
@ -927,6 +937,15 @@ bool StartServerAction(const ResourcePtr &res) {
}
bool StartPSSchedulerAction(const ResourcePtr &) {
if (distributed::cluster::ClusterContext::instance()->initialized()) {
MS_LOG(INFO) << "This node is scheduler. Start wait for finalizing.";
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
MS_LOG(ERROR) << "Failed to finalize server.";
return false;
}
MS_LOG(INFO) << "Scheduler is successfully finalized.";
return true;
}
ps::Scheduler::GetInstance().Run();
return true;
}
@ -1174,11 +1193,15 @@ std::vector<ActionItem> VmPipeline() {
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_worker()) {
std::string server_mode = ps::PSContext::instance()->server_mode();
if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
(void)actions.emplace_back(std::make_pair("worker", StartFLWorkerAction));
if (distributed::cluster::ClusterContext::instance()->initialized()) {
MS_LOG(INFO) << "This worker is initialized. No need to add worker action.";
} else {
(void)actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
std::string server_mode = ps::PSContext::instance()->server_mode();
if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
(void)actions.emplace_back(std::make_pair("worker", StartFLWorkerAction));
} else {
(void)actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
}
}
}
#endif
@ -1229,7 +1252,10 @@ std::vector<ActionItem> PServerPipeline() {
}
std::vector<ActionItem> PSchedulerPipeline() {
std::vector<ActionItem> actions;
auto actions = CommonPipeline();
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
(void)actions.emplace_back(std::make_pair("scheduler", StartPSSchedulerAction));
return actions;
}

View File

@ -74,6 +74,7 @@
#include "ps/ps_cache/ps_cache_manager.h"
#include "fl/server/server.h"
#include "fl/worker/fl_worker.h"
#include "distributed/cluster/cluster_context.h"
#endif
#if ((defined ENABLE_GE) || (defined ENABLE_D))
@ -925,6 +926,19 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
if (ps::PSContext::instance()->is_scheduler()) {
return PSchedulerPipeline();
}
if (distributed::cluster::ClusterContext::instance()->initialized()) {
auto node = distributed::cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Cluster is initialized. This node role is " << node->role();
switch (node->role()) {
case ps::core::NodeRole::SERVER:
return PServerPipeline();
case ps::core::NodeRole::SCHEDULER:
return PSchedulerPipeline();
default:
break;
}
}
#endif
if (use_vm && backend != "ge" && !is_air) {

View File

@ -21,6 +21,9 @@
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "distributed/cluster/cluster_context.h"
#else
#include "distributed/cluster/dummy_cluster_context.h"
#endif
namespace mindspore {
@ -115,6 +118,9 @@ bool PSContext::is_worker() const {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return role_ == kEnvRoleOfWorker;
}
if (distributed::cluster::ClusterContext::instance()->initialized()) {
return role_ == kEnvRoleOfWorker;
}
return is_worker_;
}
@ -122,6 +128,9 @@ bool PSContext::is_server() const {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return role_ == kEnvRoleOfServer;
}
if (distributed::cluster::ClusterContext::instance()->initialized()) {
return role_ == kEnvRoleOfServer;
}
return is_pserver_;
}
@ -129,6 +138,9 @@ bool PSContext::is_scheduler() const {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return role_ == kEnvRoleOfScheduler;
}
if (distributed::cluster::ClusterContext::instance()->initialized()) {
return role_ == kEnvRoleOfScheduler;
}
return is_sched_;
}
@ -242,10 +254,6 @@ void PSContext::set_dp_norm_clip(float dp_norm_clip) {
float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
void PSContext::set_ms_role(const std::string &role) {
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
return;
}
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
return;
@ -263,13 +271,7 @@ void PSContext::set_worker_num(uint32_t worker_num) {
}
uint32_t PSContext::worker_num() const { return worker_num_; }
void PSContext::set_server_num(uint32_t server_num) {
if (server_num == 0) {
MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
return;
}
server_num_ = server_num;
}
void PSContext::set_server_num(uint32_t server_num) { server_num_ = server_num; }
uint32_t PSContext::server_num() const { return server_num_; }
void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }

View File

@ -51,13 +51,13 @@ void CollectiveInitializer::InitCollective() {
// Because this method InitCollective is static, the non-static member variables should be accessed by
// CollectiveInitializer::instance().
CollectiveInitializer::instance().use_mpi_ = true;
CollectiveInitializer::instance().collective_inited_ = true;
CollectiveInitializer::instance().collective_handle_ = handle;
} else {
if (!distributed::Initialize()) {
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL.";
}
}
CollectiveInitializer::instance().collective_inited_ = true;
}
void CollectiveInitializer::FinalizeCollective() {
@ -85,49 +85,49 @@ uint32_t CollectiveInitializer::local_rank_id() {
bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name,
const std::vector<uint32_t> &group_ranks) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto create_comm_group_funcptr =
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
return (*create_comm_group_funcptr)(group_name, group_ranks);
} else {
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
}
}
bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto destroy_group_funcptr =
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
return (*destroy_group_funcptr)(group_name);
} else {
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
}
}
uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_rank_id_funcptr =
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
return IntToUint((*get_rank_id_funcptr)(group_name));
} else {
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
}
}
uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) {
if (common::CheckUseMPI()) {
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
} else {
MS_EXCEPTION_IF_NULL(collective_handle_);
auto get_group_size_funcptr =
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
return IntToUint((*get_group_size_funcptr)(group_name));
} else {
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
}
}
} // namespace gpu

View File

@ -95,15 +95,16 @@ bool GPUKernelRuntime::Init() {
mem_manager_ = std::make_shared<GPUMemoryManager>();
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->MallocDeviceMemory();
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id_;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";
if (CollectiveInitializer::instance().collective_inited()) {
auto collective_handle = CollectiveInitializer::instance().collective_handle();
if (collective_handle != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id_;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";
}
}
device_init_ = true;

View File

@ -22,11 +22,22 @@
#include <vector>
#include <string>
#include "ir/dtype/type_id.h"
#include "utils/check_convert_utils.h"
#include "runtime/hardware/collective/communication_group.h"
namespace mindspore {
namespace device {
// The reduce type of collective operations.
enum CollectiveOpReduceType : int64_t {
Reduce_Mean = 0,
Reduce_Max = 1,
Reduce_Min = 2,
Reduce_Prod = 3,
Reduce_Sum = 4,
Reduce_Sum_Square = 5,
Reduce_ASum = 6,
Reduce_All = 7
};
// The base class of collective communication library.
// For collective communication on the device side like GPU, the entry is NvidiaCollectiveCommLib which calls NCCL.
// For collective communication on the host side, the entry is MPICollectiveCommLib which call OpenMPI, or
@ -72,7 +83,7 @@ class CollectiveCommunicationLib {
return true;
}
virtual bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
ReduceMode reduce_op, const std::string &group_name, void *stream = nullptr) {
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) {
return true;
}
virtual bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
@ -80,7 +91,7 @@ class CollectiveCommunicationLib {
return true;
}
virtual bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
ReduceMode reduce_op, const std::string &group_name, void *stream = nullptr) {
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) {
return true;
}

View File

@ -44,8 +44,10 @@ const std::map<TypeId, MPI_Datatype> kMPIDataTypeMap = {{TypeId::kNumberTypeInt8
{TypeId::kNumberTypeFloat64, MPI_DOUBLE}};
// Map of reduce type to MPI reduce type.
const std::map<ReduceMode, MPI_Op> kMPIReduceTypeMap = {
{Reduce_Sum, MPI_SUM}, {Reduce_Prod, MPI_PROD}, {Reduce_Min, MPI_MIN}, {Reduce_Max, MPI_MAX}};
const std::map<CollectiveOpReduceType, MPI_Op> kMPIReduceTypeMap = {{CollectiveOpReduceType::Reduce_Sum, MPI_SUM},
{CollectiveOpReduceType::Reduce_Prod, MPI_PROD},
{CollectiveOpReduceType::Reduce_Min, MPI_MIN},
{CollectiveOpReduceType::Reduce_Max, MPI_MAX}};
constexpr char kMPIGlobalGroupName[] = "mpi_world_group";
class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib {

View File

@ -48,16 +48,16 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream = nullptr) override;
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, ReduceMode reduce_op,
const std::string &group_name, void *stream = nullptr) override {
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override {
return true;
}
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream = nullptr) override;
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type, ReduceMode reduce_op,
const std::string &group_name, void *stream = nullptr) override {
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override {
return true;
}

View File

@ -63,14 +63,9 @@ void GPUDeviceContext::Initialize() {
}
// Set device id
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
if (CollectiveInitializer::instance().collective_inited()) {
DeviceContextKey old_key = device_context_key_;
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)());
device_context_key_.device_id_ = CollectiveInitializer::instance().local_rank_id();
DeviceContextManager::GetInstance().UpdateDeviceContextKey(old_key, device_context_key_);
@ -91,13 +86,16 @@ void GPUDeviceContext::Initialize() {
mem_manager_->MallocDeviceMemory();
// Initialize NCCL.
if (collective_inited && collective_handle_ != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_context_key_.device_id_;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";
if (CollectiveInitializer::instance().collective_inited()) {
auto collective_handle = CollectiveInitializer::instance().collective_handle();
if (collective_handle != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_context_key_.device_id_;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";
}
}
#ifndef ENABLE_SECURITY
@ -501,10 +499,9 @@ bool GPUDeviceContext::SyncStream(size_t stream_id) const {
}
uint32_t GPUDeviceContext::GetRankID() const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
uint32_t rank_id = 0;
if (collective_inited && collective_handle_ != nullptr) {
if (collective_inited) {
if (!CommManager::GetInstance().GetRankID(kNcclWorldGroup, &rank_id)) {
MS_LOG(EXCEPTION) << "Failed to get rank id.";
}

View File

@ -60,7 +60,7 @@ bool NvidiaCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff,
}
bool NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
ReduceMode reduce_op, const std::string &group_name, void *stream) {
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream) {
if (!CheckNCCLDataType(data_type)) {
return false;
}
@ -96,7 +96,8 @@ bool NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff,
}
bool NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
ReduceMode reduce_op, const std::string &group_name, void *stream) {
CollectiveOpReduceType reduce_op, const std::string &group_name,
void *stream) {
if (!CheckNCCLDataType(data_type)) {
return false;
}
@ -153,7 +154,7 @@ bool NvidiaCollectiveCommLib::CheckNCCLDataType(TypeId data_type) {
return true;
}
bool NvidiaCollectiveCommLib::CheckNCCLReduceType(ReduceMode reduce_op) {
bool NvidiaCollectiveCommLib::CheckNCCLReduceType(CollectiveOpReduceType reduce_op) {
CHECK_RET((kNCCLReduceTypeMap.count(reduce_op) != 0), true,
"Reduce type " + std::to_string(reduce_op) + " is not supported in NCCL.");
return true;

View File

@ -42,8 +42,11 @@ const std::map<TypeId, ncclDataType_t> kNCCLDataTypeMap = {
{TypeId::kNumberTypeFloat64, ncclFloat64}};
// Map of reduce type to NCCL reduce type.
const std::map<ReduceMode, ncclRedOp_t> kNCCLReduceTypeMap = {
{Reduce_Sum, ncclSum}, {Reduce_Prod, ncclProd}, {Reduce_Min, ncclMin}, {Reduce_Max, ncclMax}};
const std::map<CollectiveOpReduceType, ncclRedOp_t> kNCCLReduceTypeMap = {
{CollectiveOpReduceType::Reduce_Sum, ncclSum},
{CollectiveOpReduceType::Reduce_Prod, ncclProd},
{CollectiveOpReduceType::Reduce_Min, ncclMin},
{CollectiveOpReduceType::Reduce_Max, ncclMax}};
constexpr char kNCCLGlobalGroupName[] = "nccl_world_group";
@ -61,14 +64,14 @@ class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicati
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
const std::string &group_name, void *stream = nullptr) override;
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, ReduceMode reduce_op,
const std::string &group_name, void *stream = nullptr) override;
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override;
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
const std::string &group_name, void *stream = nullptr) override;
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type, ReduceMode reduce_op,
const std::string &group_name, void *stream = nullptr) override;
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override;
bool Send(const void *send_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name,
void *stream = nullptr) override;
@ -84,7 +87,7 @@ class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicati
bool CheckNCCLDataType(TypeId data_type);
// Check reduce type of collective operation is valid for NCCL.
bool CheckNCCLReduceType(ReduceMode reduce_op);
bool CheckNCCLReduceType(CollectiveOpReduceType reduce_op);
};
} // namespace gpu

View File

@ -131,64 +131,36 @@ CommManager &CommManager::GetInstance() noexcept {
}
bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list;
auto create_comm_group_funcptr =
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
bool ret = (*create_comm_group_funcptr)(group, rank_id_list);
bool ret = CollectiveInitializer::instance().CreateCommunicationGroup(group, rank_id_list);
if (!ret) {
MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed.";
MS_LOG(ERROR) << "Failed to create group " << group << " for rank id list " << rank_id_list;
return ret;
}
MS_LOG(INFO) << "Successfully create group " << group << " for rank id list " << rank_id_list;
return ret;
}
bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
auto get_rank_id_funcptr =
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
int rank = (*get_rank_id_funcptr)(group);
*rank_id = static_cast<unsigned int>(rank);
*rank_id = CollectiveInitializer::instance().GetRankIDByGroup(group);
MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group;
return true;
}
bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
auto get_group_size_funcptr =
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
int size = (*get_group_size_funcptr)(group);
*rank_size = static_cast<unsigned int>(size);
*rank_size = CollectiveInitializer::instance().GetGroupSize(group);
MS_LOG(INFO) << "Group " << group << " size is " << *rank_size;
return true;
}
bool CommManager::DestroyGroup(const string &group) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
auto destroy_group_funcptr =
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
bool ret = (*destroy_group_funcptr)(group);
bool ret = CollectiveInitializer::instance().DestroyCommunicationGroup(group);
if (!ret) {
MS_LOG(ERROR) << "Destroying group " << group << " failed.";
MS_LOG(ERROR) << "Failed to destroy group " << group;
return ret;
}
MS_LOG(INFO) << "Successfully destroy group " << group;
return ret;
}
#else

View File

@ -32,7 +32,7 @@ from .tensor import Tensor as MsTensor
from .tensor import CSRTensor as MsCSRTensor
from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, CSRTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from ..parallel._ps_context import _is_role_pserver
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
_get_parameter_broadcast, _get_pipeline_stages
from .._checkparam import Validator
@ -737,7 +737,7 @@ class _CellGraphExecutor:
return self._graph_executor.has_compiled(phase)
def __call__(self, obj, *args, phase='predict'):
if context.get_context("precompile_only") or _is_role_pserver():
if context.get_context("precompile_only") or _is_role_pserver() or _is_role_sched():
return None
return self.run(obj, *args, phase=phase)

View File

@ -929,18 +929,23 @@ def set_ps_context(**kwargs):
MS_WORKER: represents the worker,
MS_PSERVER: represents the Server
MS_PSERVER/MS_SERVER: represents the Server
Args:
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.
config_file_path (string): Configuration file path used by recovery. Default: ''.
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
client_password (str): Password to decrypt the secret key stored in the client certificate.
server_password (str): Password to decrypt the secret key stored in the server certificate.
Raises:
ValueError: If input key is not the attribute in parameter server training mode context.
Examples:
>>> context.set_ps_context(enable_ps=True)
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
"""
_set_ps_context(**kwargs)

View File

@ -27,13 +27,13 @@
ClassType(const ClassType &) = delete; \
ClassType &operator=(const ClassType &) = delete;
#define TRY_AND_CATCH_WITH_EXCEPTION(expr, error_msg) \
do { \
try { \
(expr); \
} catch (const std::exception &e) { \
MS_LOG(EXCEPTION) << "Caught exception " << e.what() << ". " << error_msg; \
} \
#define TRY_AND_CATCH_WITH_EXCEPTION(expr, error_msg) \
do { \
try { \
(expr); \
} catch (const std::exception &e) { \
MS_LOG(EXCEPTION) << "Caught exception of " << e.what() << ". " << error_msg; \
} \
} while (0)
namespace mindspore {

View File

@ -157,12 +157,17 @@ def _set_ps_context(**kwargs):
enable_ps (bool): Whether to enable parameter server training mode.
Only after enable_ps is set True, the environment variables will be effective.
Default: False.
config_file_path (string): Configuration file path used by recovery. Default: ''.
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
client_password (str): Password to decrypt the secret key stored in the client certificate.
server_password (str): Password to decrypt the secret key stored in the server certificate.
Raises:
ValueError: If input key is not the attribute in parameter server training mode context.
Examples:
>>> context.set_ps_context(enable_ps=True)
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
"""
for key, value in kwargs.items():
if key not in _set_ps_context_func_map:

View File

@ -14,13 +14,13 @@
# ============================================================================
"""Dataset help for minddata dataset"""
import math
import os
from mindspore._checkparam import Validator
from mindspore.common.dtype import pytype_to_dtype
from .. import context, nn
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes, _get_pipeline_stages
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
from ..ops import operations as P
@ -52,10 +52,9 @@ def _dynamic_sink_data(dataset, dataset_iter):
def _dynamic_sink_exception_scenario(dataset_iter):
"""The exception scenario for dynamic data is not applicable."""
ms_role = os.getenv("MS_ROLE")
_, dataset_shapes = dataset_iter.types_shapes()
if _has_dynamic_shape(dataset_shapes) or ms_role == "MS_WORKER" or \
if _has_dynamic_shape(dataset_shapes) or _is_role_worker() or \
context.get_context("mode") != context.GRAPH_MODE:
return True
return False
@ -171,8 +170,7 @@ def connect_network_with_dataset(network, dataset_helper):
if isinstance(dataset_iter, _DatasetIterNormal):
raise RuntimeError("The API 'connect_network_with_dataset' should be called in dataset sink mode.")
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
if _is_role_sched() or _is_role_pserver():
return network
queue_name = dataset.__transfer_dataset__.queue_name
@ -260,10 +258,9 @@ class DatasetHelper:
iterclass = _DatasetIterGE
else:
if context.get_context("mode") == context.GRAPH_MODE:
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
if _is_role_sched() or _is_role_pserver():
iterclass = _DatasetIterPSServer
elif ms_role == "MS_WORKER":
elif _is_role_worker():
iterclass = _DatasetIterPSWork
elif (context.get_context("device_target") == "Ascend") or \
(context.get_context("device_target") == "GPU"):
@ -365,9 +362,8 @@ class _DatasetIter:
if not hasattr(dataset, '__transfer_dataset__'):
if hasattr(dataset, '__loop_size__'):
ms_role = os.getenv("MS_ROLE")
# PS mode does not support loop sink and need get the real sink size.
if ms_role != "MS_WORKER":
if not _is_role_worker():
self.sink_size = dataset.__loop_size__
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context(
"device_target") == "Ascend")
@ -413,10 +409,9 @@ class _DatasetIter:
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
ms_role = os.getenv("MS_ROLE")
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
elif ms_role == "MS_WORKER":
elif _is_role_worker():
# PS mode does not support loop sink.
sink_size = 1
else:

View File

@ -14,12 +14,12 @@
# ============================================================================
"""Dataset help for minddata dataset"""
import math
import os
from mindspore._checkparam import Validator
from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
def _send_data(dataset, epoch_num):
@ -160,8 +160,7 @@ class _DatasetIterMSLoopSink(_DatasetIter):
loop_size = dataset.__loop_size__ + iter_first_order
sink_count = int(sink_size / loop_size) * 2
self.sink_count = sink_count
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
if _is_role_pserver() or _is_role_sched():
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for