forked from mindspore-Ecosystem/mindspore
!27015 Use MindSpore communication framework as OpenMPI
Merge pull request !27015 from ZPaC/dir-of-distributed
This commit is contained in:
commit
f72bce0377
|
@ -94,6 +94,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd_fp32.c:DeConvWgMerge
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64
|
||||
mindspore/mindspore/ccsrc/backend/session/gpu_session.cc:mindspore::session::gpu::GPUSession::LoadInputData
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetScalarToProto
|
||||
|
|
|
@ -101,13 +101,8 @@ using GetLocalRankId = device::gpu::GetLocalRankId;
|
|||
using InitNCCLComm = device::gpu::InitNCCLComm;
|
||||
|
||||
void GPUSession::Init(uint32_t device_id) {
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
device_id = IntToUint((*get_local_rank_funcptr)());
|
||||
if (CollectiveInitializer::instance().collective_inited()) {
|
||||
device_id = CollectiveInitializer::instance().local_rank_id();
|
||||
}
|
||||
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
|
||||
if (!ret) {
|
||||
|
@ -116,11 +111,12 @@ void GPUSession::Init(uint32_t device_id) {
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
|
||||
if (collective_inited) {
|
||||
if (collective_handle_ != nullptr) {
|
||||
if (CollectiveInitializer::instance().collective_inited()) {
|
||||
auto collective_handle = CollectiveInitializer::instance().collective_handle();
|
||||
if (collective_handle != nullptr) {
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id;
|
||||
auto init_nccl_comm_funcptr =
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle), "InitNCCLComm"));
|
||||
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
|
||||
(*init_nccl_comm_funcptr)();
|
||||
MS_LOG(INFO) << "End initializing NCCL communicator.";
|
||||
|
|
|
@ -30,7 +30,7 @@ ClusterContext::ClusterContext()
|
|||
scheduler_host_(kLocalHost),
|
||||
scheduler_port_(kDefaultSchedPort),
|
||||
node_(nullptr),
|
||||
node_role_(kEnvRoleOfWorker),
|
||||
node_role_(""),
|
||||
cluster_config_(nullptr) {}
|
||||
|
||||
ClusterContext::~ClusterContext() {
|
||||
|
@ -59,7 +59,12 @@ bool ClusterContext::Initialize() {
|
|||
|
||||
// Step 2: Build network for this cluster. Every process will block in this method until networking is done.
|
||||
if (!BuildCluster()) {
|
||||
MS_LOG(EXCEPTION) << "Building networking for " << node_role_ << " failed.";
|
||||
MS_EXCEPTION_IF_NULL(node_);
|
||||
if (!node_->Stop()) {
|
||||
MS_LOG(ERROR) << "Failed to stop node after the failure of BuildCluster";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(ERROR) << "Building networking for " << node_role_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -87,10 +92,13 @@ bool ClusterContext::Finalize() {
|
|||
|
||||
const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
|
||||
|
||||
bool ClusterContext::initialized() const { return inited_; }
|
||||
|
||||
void ClusterContext::InitClusterConfig() {
|
||||
InitNodeRole();
|
||||
InitSchedulerIp();
|
||||
InitSchedulerPort();
|
||||
ps::PSContext::instance()->set_ms_role(node_role_);
|
||||
ps::PSContext::instance()->set_worker_num(node_num_each_role_[kEnvRoleOfWorker]);
|
||||
ps::PSContext::instance()->set_server_num(node_num_each_role_[kEnvRoleOfServer]);
|
||||
ps::PSContext::instance()->set_scheduler_ip(scheduler_host_);
|
||||
|
@ -117,7 +125,7 @@ bool ClusterContext::BuildCluster() {
|
|||
|
||||
RegisterEventCallback();
|
||||
if (!node_->Start()) {
|
||||
MS_LOG(EXCEPTION) << "Building network failed.";
|
||||
MS_LOG(ERROR) << "Building network failed.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Cluster is successfully initialized.";
|
||||
|
@ -148,12 +156,7 @@ void ClusterContext::InitNodeRole() {
|
|||
}
|
||||
}
|
||||
|
||||
void ClusterContext::InitSchedulerIp() {
|
||||
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
|
||||
if (scheduler_host_ != kLocalHost) {
|
||||
MS_LOG(EXCEPTION) << "Scheduler IP should be 127.0.0.1";
|
||||
}
|
||||
}
|
||||
void ClusterContext::InitSchedulerIp() { scheduler_host_ = common::GetEnv(kEnvSchedulerHost); }
|
||||
|
||||
void ClusterContext::InitSchedulerPort() {
|
||||
TRY_AND_CATCH_WITH_EXCEPTION((scheduler_port_ = static_cast<uint16_t>(std::stoi(common::GetEnv(kEnvSchedulerPort)))),
|
||||
|
|
|
@ -52,6 +52,9 @@ class ClusterContext {
|
|||
// Return node object of this process.
|
||||
const std::shared_ptr<ps::core::Node> &node() const;
|
||||
|
||||
// Return cluster is initialized.
|
||||
bool initialized() const;
|
||||
|
||||
private:
|
||||
ClusterContext();
|
||||
|
||||
|
|
|
@ -32,6 +32,8 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
|
|||
bool ClusterContext::Initialize() const { return true; }
|
||||
|
||||
bool ClusterContext::Finalize() const { return true; }
|
||||
|
||||
bool ClusterContext::initialized() const { return false; }
|
||||
} // namespace cluster
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,6 +39,7 @@ class ClusterContext {
|
|||
|
||||
bool Initialize() const;
|
||||
bool Finalize() const;
|
||||
bool initialized() const;
|
||||
|
||||
private:
|
||||
ClusterContext() = default;
|
||||
|
|
|
@ -52,6 +52,7 @@
|
|||
#include "ps/worker.h"
|
||||
#include "fl/worker/fl_worker.h"
|
||||
#include "fl/server/server.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -834,6 +835,15 @@ bool StartFLWorkerAction(const ResourcePtr &) {
|
|||
}
|
||||
|
||||
bool StartPSServerAction(const ResourcePtr &res) {
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
MS_LOG(INFO) << "This node is server. Start wait for finalizing.";
|
||||
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
|
||||
MS_LOG(ERROR) << "Failed to finalize server.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Server is successfully finalized.";
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
auto &ps = ps::ParameterServer::GetInstance();
|
||||
|
@ -927,6 +937,15 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool StartPSSchedulerAction(const ResourcePtr &) {
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
MS_LOG(INFO) << "This node is scheduler. Start wait for finalizing.";
|
||||
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
|
||||
MS_LOG(ERROR) << "Failed to finalize server.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Scheduler is successfully finalized.";
|
||||
return true;
|
||||
}
|
||||
ps::Scheduler::GetInstance().Run();
|
||||
return true;
|
||||
}
|
||||
|
@ -1174,11 +1193,15 @@ std::vector<ActionItem> VmPipeline() {
|
|||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
std::string server_mode = ps::PSContext::instance()->server_mode();
|
||||
if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
|
||||
(void)actions.emplace_back(std::make_pair("worker", StartFLWorkerAction));
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
MS_LOG(INFO) << "This worker is initialized. No need to add worker action.";
|
||||
} else {
|
||||
(void)actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
|
||||
std::string server_mode = ps::PSContext::instance()->server_mode();
|
||||
if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
|
||||
(void)actions.emplace_back(std::make_pair("worker", StartFLWorkerAction));
|
||||
} else {
|
||||
(void)actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -1229,7 +1252,10 @@ std::vector<ActionItem> PServerPipeline() {
|
|||
}
|
||||
|
||||
std::vector<ActionItem> PSchedulerPipeline() {
|
||||
std::vector<ActionItem> actions;
|
||||
auto actions = CommonPipeline();
|
||||
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
(void)actions.emplace_back(std::make_pair("scheduler", StartPSSchedulerAction));
|
||||
return actions;
|
||||
}
|
||||
|
|
|
@ -74,6 +74,7 @@
|
|||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "fl/server/server.h"
|
||||
#include "fl/worker/fl_worker.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#endif
|
||||
|
||||
#if ((defined ENABLE_GE) || (defined ENABLE_D))
|
||||
|
@ -925,6 +926,19 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
if (ps::PSContext::instance()->is_scheduler()) {
|
||||
return PSchedulerPipeline();
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
auto node = distributed::cluster::ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(INFO) << "Cluster is initialized. This node role is " << node->role();
|
||||
switch (node->role()) {
|
||||
case ps::core::NodeRole::SERVER:
|
||||
return PServerPipeline();
|
||||
case ps::core::NodeRole::SCHEDULER:
|
||||
return PSchedulerPipeline();
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (use_vm && backend != "ge" && !is_air) {
|
||||
|
|
|
@ -21,6 +21,9 @@
|
|||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#else
|
||||
#include "distributed/cluster/dummy_cluster_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -115,6 +118,9 @@ bool PSContext::is_worker() const {
|
|||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return role_ == kEnvRoleOfWorker;
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
return role_ == kEnvRoleOfWorker;
|
||||
}
|
||||
return is_worker_;
|
||||
}
|
||||
|
||||
|
@ -122,6 +128,9 @@ bool PSContext::is_server() const {
|
|||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return role_ == kEnvRoleOfServer;
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
return role_ == kEnvRoleOfServer;
|
||||
}
|
||||
return is_pserver_;
|
||||
}
|
||||
|
||||
|
@ -129,6 +138,9 @@ bool PSContext::is_scheduler() const {
|
|||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return role_ == kEnvRoleOfScheduler;
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
return role_ == kEnvRoleOfScheduler;
|
||||
}
|
||||
return is_sched_;
|
||||
}
|
||||
|
||||
|
@ -242,10 +254,6 @@ void PSContext::set_dp_norm_clip(float dp_norm_clip) {
|
|||
float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
|
||||
|
||||
void PSContext::set_ms_role(const std::string &role) {
|
||||
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
|
||||
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
|
||||
return;
|
||||
}
|
||||
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
|
||||
MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
|
||||
return;
|
||||
|
@ -263,13 +271,7 @@ void PSContext::set_worker_num(uint32_t worker_num) {
|
|||
}
|
||||
uint32_t PSContext::worker_num() const { return worker_num_; }
|
||||
|
||||
void PSContext::set_server_num(uint32_t server_num) {
|
||||
if (server_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
|
||||
return;
|
||||
}
|
||||
server_num_ = server_num;
|
||||
}
|
||||
void PSContext::set_server_num(uint32_t server_num) { server_num_ = server_num; }
|
||||
uint32_t PSContext::server_num() const { return server_num_; }
|
||||
|
||||
void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
|
||||
|
|
|
@ -51,13 +51,13 @@ void CollectiveInitializer::InitCollective() {
|
|||
// Because this method InitCollective is static, the non-static member variables should be accessed by
|
||||
// CollectiveInitializer::instance().
|
||||
CollectiveInitializer::instance().use_mpi_ = true;
|
||||
CollectiveInitializer::instance().collective_inited_ = true;
|
||||
CollectiveInitializer::instance().collective_handle_ = handle;
|
||||
} else {
|
||||
if (!distributed::Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL.";
|
||||
}
|
||||
}
|
||||
CollectiveInitializer::instance().collective_inited_ = true;
|
||||
}
|
||||
|
||||
void CollectiveInitializer::FinalizeCollective() {
|
||||
|
@ -85,49 +85,49 @@ uint32_t CollectiveInitializer::local_rank_id() {
|
|||
bool CollectiveInitializer::CreateCommunicationGroup(const std::string &group_name,
|
||||
const std::vector<uint32_t> &group_ranks) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto create_comm_group_funcptr =
|
||||
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
|
||||
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
|
||||
return (*create_comm_group_funcptr)(group_name, group_ranks);
|
||||
} else {
|
||||
return distributed::collective::CollectiveManager::instance()->CreateCommunicationGroup(group_name, group_ranks);
|
||||
}
|
||||
}
|
||||
|
||||
bool CollectiveInitializer::DestroyCommunicationGroup(const std::string &group_name) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto destroy_group_funcptr =
|
||||
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
|
||||
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
|
||||
return (*destroy_group_funcptr)(group_name);
|
||||
} else {
|
||||
return distributed::collective::CollectiveManager::instance()->DestroyCommunicationGroup(group_name);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::GetRankIDByGroup(const std::string &group_name) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto get_rank_id_funcptr =
|
||||
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
|
||||
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
|
||||
return IntToUint((*get_rank_id_funcptr)(group_name));
|
||||
} else {
|
||||
return distributed::collective::CollectiveManager::instance()->GetRankId(group_name);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::GetGroupSize(const std::string &group_name) {
|
||||
if (common::CheckUseMPI()) {
|
||||
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
return IntToUint((*get_group_size_funcptr)(group_name));
|
||||
} else {
|
||||
return distributed::collective::CollectiveManager::instance()->GetGroupSize(group_name);
|
||||
}
|
||||
}
|
||||
} // namespace gpu
|
||||
|
|
|
@ -95,15 +95,16 @@ bool GPUKernelRuntime::Init() {
|
|||
mem_manager_ = std::make_shared<GPUMemoryManager>();
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->MallocDeviceMemory();
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id_;
|
||||
auto init_nccl_comm_funcptr =
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
|
||||
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
|
||||
(*init_nccl_comm_funcptr)();
|
||||
MS_LOG(INFO) << "End initializing NCCL communicator.";
|
||||
if (CollectiveInitializer::instance().collective_inited()) {
|
||||
auto collective_handle = CollectiveInitializer::instance().collective_handle();
|
||||
if (collective_handle != nullptr) {
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id_;
|
||||
auto init_nccl_comm_funcptr =
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle), "InitNCCLComm"));
|
||||
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
|
||||
(*init_nccl_comm_funcptr)();
|
||||
MS_LOG(INFO) << "End initializing NCCL communicator.";
|
||||
}
|
||||
}
|
||||
device_init_ = true;
|
||||
|
||||
|
|
|
@ -22,11 +22,22 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "runtime/hardware/collective/communication_group.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
// The reduce type of collective operations.
|
||||
enum CollectiveOpReduceType : int64_t {
|
||||
Reduce_Mean = 0,
|
||||
Reduce_Max = 1,
|
||||
Reduce_Min = 2,
|
||||
Reduce_Prod = 3,
|
||||
Reduce_Sum = 4,
|
||||
Reduce_Sum_Square = 5,
|
||||
Reduce_ASum = 6,
|
||||
Reduce_All = 7
|
||||
};
|
||||
|
||||
// The base class of collective communication library.
|
||||
// For collective communication on the device side like GPU, the entry is NvidiaCollectiveCommLib which calls NCCL.
|
||||
// For collective communication on the host side, the entry is MPICollectiveCommLib which call OpenMPI, or
|
||||
|
@ -72,7 +83,7 @@ class CollectiveCommunicationLib {
|
|||
return true;
|
||||
}
|
||||
virtual bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
ReduceMode reduce_op, const std::string &group_name, void *stream = nullptr) {
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) {
|
||||
return true;
|
||||
}
|
||||
virtual bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
|
@ -80,7 +91,7 @@ class CollectiveCommunicationLib {
|
|||
return true;
|
||||
}
|
||||
virtual bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
|
||||
ReduceMode reduce_op, const std::string &group_name, void *stream = nullptr) {
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -44,8 +44,10 @@ const std::map<TypeId, MPI_Datatype> kMPIDataTypeMap = {{TypeId::kNumberTypeInt8
|
|||
{TypeId::kNumberTypeFloat64, MPI_DOUBLE}};
|
||||
|
||||
// Map of reduce type to MPI reduce type.
|
||||
const std::map<ReduceMode, MPI_Op> kMPIReduceTypeMap = {
|
||||
{Reduce_Sum, MPI_SUM}, {Reduce_Prod, MPI_PROD}, {Reduce_Min, MPI_MIN}, {Reduce_Max, MPI_MAX}};
|
||||
const std::map<CollectiveOpReduceType, MPI_Op> kMPIReduceTypeMap = {{CollectiveOpReduceType::Reduce_Sum, MPI_SUM},
|
||||
{CollectiveOpReduceType::Reduce_Prod, MPI_PROD},
|
||||
{CollectiveOpReduceType::Reduce_Min, MPI_MIN},
|
||||
{CollectiveOpReduceType::Reduce_Max, MPI_MAX}};
|
||||
|
||||
constexpr char kMPIGlobalGroupName[] = "mpi_world_group";
|
||||
class EXPORT_MPI_WRAPPER MPICollectiveCommLib : public CollectiveCommunicationLib {
|
||||
|
|
|
@ -48,16 +48,16 @@ class MsCollectiveCommLib : public CollectiveCommunicationLib {
|
|||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
|
||||
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, ReduceMode reduce_op,
|
||||
const std::string &group_name, void *stream = nullptr) override {
|
||||
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
|
||||
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type, ReduceMode reduce_op,
|
||||
const std::string &group_name, void *stream = nullptr) override {
|
||||
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -63,14 +63,9 @@ void GPUDeviceContext::Initialize() {
|
|||
}
|
||||
|
||||
// Set device id
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
if (CollectiveInitializer::instance().collective_inited()) {
|
||||
DeviceContextKey old_key = device_context_key_;
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)());
|
||||
device_context_key_.device_id_ = CollectiveInitializer::instance().local_rank_id();
|
||||
|
||||
DeviceContextManager::GetInstance().UpdateDeviceContextKey(old_key, device_context_key_);
|
||||
|
||||
|
@ -91,13 +86,16 @@ void GPUDeviceContext::Initialize() {
|
|||
mem_manager_->MallocDeviceMemory();
|
||||
|
||||
// Initialize NCCL.
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_context_key_.device_id_;
|
||||
auto init_nccl_comm_funcptr =
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
|
||||
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
|
||||
(*init_nccl_comm_funcptr)();
|
||||
MS_LOG(INFO) << "End initializing NCCL communicator.";
|
||||
if (CollectiveInitializer::instance().collective_inited()) {
|
||||
auto collective_handle = CollectiveInitializer::instance().collective_handle();
|
||||
if (collective_handle != nullptr) {
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_context_key_.device_id_;
|
||||
auto init_nccl_comm_funcptr =
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle), "InitNCCLComm"));
|
||||
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
|
||||
(*init_nccl_comm_funcptr)();
|
||||
MS_LOG(INFO) << "End initializing NCCL communicator.";
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
|
@ -501,10 +499,9 @@ bool GPUDeviceContext::SyncStream(size_t stream_id) const {
|
|||
}
|
||||
|
||||
uint32_t GPUDeviceContext::GetRankID() const {
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
uint32_t rank_id = 0;
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
if (collective_inited) {
|
||||
if (!CommManager::GetInstance().GetRankID(kNcclWorldGroup, &rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get rank id.";
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ bool NvidiaCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff,
|
|||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
ReduceMode reduce_op, const std::string &group_name, void *stream) {
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream) {
|
||||
if (!CheckNCCLDataType(data_type)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -96,7 +96,8 @@ bool NvidiaCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff,
|
|||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
|
||||
ReduceMode reduce_op, const std::string &group_name, void *stream) {
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name,
|
||||
void *stream) {
|
||||
if (!CheckNCCLDataType(data_type)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -153,7 +154,7 @@ bool NvidiaCollectiveCommLib::CheckNCCLDataType(TypeId data_type) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool NvidiaCollectiveCommLib::CheckNCCLReduceType(ReduceMode reduce_op) {
|
||||
bool NvidiaCollectiveCommLib::CheckNCCLReduceType(CollectiveOpReduceType reduce_op) {
|
||||
CHECK_RET((kNCCLReduceTypeMap.count(reduce_op) != 0), true,
|
||||
"Reduce type " + std::to_string(reduce_op) + " is not supported in NCCL.");
|
||||
return true;
|
||||
|
|
|
@ -42,8 +42,11 @@ const std::map<TypeId, ncclDataType_t> kNCCLDataTypeMap = {
|
|||
{TypeId::kNumberTypeFloat64, ncclFloat64}};
|
||||
|
||||
// Map of reduce type to NCCL reduce type.
|
||||
const std::map<ReduceMode, ncclRedOp_t> kNCCLReduceTypeMap = {
|
||||
{Reduce_Sum, ncclSum}, {Reduce_Prod, ncclProd}, {Reduce_Min, ncclMin}, {Reduce_Max, ncclMax}};
|
||||
const std::map<CollectiveOpReduceType, ncclRedOp_t> kNCCLReduceTypeMap = {
|
||||
{CollectiveOpReduceType::Reduce_Sum, ncclSum},
|
||||
{CollectiveOpReduceType::Reduce_Prod, ncclProd},
|
||||
{CollectiveOpReduceType::Reduce_Min, ncclMin},
|
||||
{CollectiveOpReduceType::Reduce_Max, ncclMax}};
|
||||
|
||||
constexpr char kNCCLGlobalGroupName[] = "nccl_world_group";
|
||||
|
||||
|
@ -61,14 +64,14 @@ class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicati
|
|||
bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
|
||||
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, ReduceMode reduce_op,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override;
|
||||
|
||||
bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
|
||||
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type, ReduceMode reduce_op,
|
||||
const std::string &group_name, void *stream = nullptr) override;
|
||||
bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
|
||||
CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override;
|
||||
|
||||
bool Send(const void *send_buff, size_t count, TypeId data_type, uint32_t peer, const std::string &group_name,
|
||||
void *stream = nullptr) override;
|
||||
|
@ -84,7 +87,7 @@ class EXPORT_NCCL_WRAPPER NvidiaCollectiveCommLib : public CollectiveCommunicati
|
|||
bool CheckNCCLDataType(TypeId data_type);
|
||||
|
||||
// Check reduce type of collective operation is valid for NCCL.
|
||||
bool CheckNCCLReduceType(ReduceMode reduce_op);
|
||||
bool CheckNCCLReduceType(CollectiveOpReduceType reduce_op);
|
||||
};
|
||||
} // namespace gpu
|
||||
|
||||
|
|
|
@ -131,64 +131,36 @@ CommManager &CommManager::GetInstance() noexcept {
|
|||
}
|
||||
|
||||
bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
if (!collective_handle_) {
|
||||
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
|
||||
}
|
||||
MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list;
|
||||
auto create_comm_group_funcptr =
|
||||
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
|
||||
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
|
||||
bool ret = (*create_comm_group_funcptr)(group, rank_id_list);
|
||||
bool ret = CollectiveInitializer::instance().CreateCommunicationGroup(group, rank_id_list);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed.";
|
||||
MS_LOG(ERROR) << "Failed to create group " << group << " for rank id list " << rank_id_list;
|
||||
return ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Successfully create group " << group << " for rank id list " << rank_id_list;
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
if (!collective_handle_) {
|
||||
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
|
||||
}
|
||||
auto get_rank_id_funcptr =
|
||||
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
|
||||
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
|
||||
int rank = (*get_rank_id_funcptr)(group);
|
||||
*rank_id = static_cast<unsigned int>(rank);
|
||||
*rank_id = CollectiveInitializer::instance().GetRankIDByGroup(group);
|
||||
MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
if (!collective_handle_) {
|
||||
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
|
||||
}
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
int size = (*get_group_size_funcptr)(group);
|
||||
*rank_size = static_cast<unsigned int>(size);
|
||||
*rank_size = CollectiveInitializer::instance().GetGroupSize(group);
|
||||
MS_LOG(INFO) << "Group " << group << " size is " << *rank_size;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommManager::DestroyGroup(const string &group) const {
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
if (!collective_handle_) {
|
||||
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
|
||||
}
|
||||
auto destroy_group_funcptr =
|
||||
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
|
||||
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
|
||||
|
||||
bool ret = (*destroy_group_funcptr)(group);
|
||||
bool ret = CollectiveInitializer::instance().DestroyCommunicationGroup(group);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Destroying group " << group << " failed.";
|
||||
MS_LOG(ERROR) << "Failed to destroy group " << group;
|
||||
return ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Successfully destroy group " << group;
|
||||
return ret;
|
||||
}
|
||||
#else
|
||||
|
|
|
@ -32,7 +32,7 @@ from .tensor import Tensor as MsTensor
|
|||
from .tensor import CSRTensor as MsCSRTensor
|
||||
from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, CSRTensor, PynativeExecutor_
|
||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
|
||||
from ..parallel._ps_context import _is_role_pserver
|
||||
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
|
||||
_get_parameter_broadcast, _get_pipeline_stages
|
||||
from .._checkparam import Validator
|
||||
|
@ -737,7 +737,7 @@ class _CellGraphExecutor:
|
|||
return self._graph_executor.has_compiled(phase)
|
||||
|
||||
def __call__(self, obj, *args, phase='predict'):
|
||||
if context.get_context("precompile_only") or _is_role_pserver():
|
||||
if context.get_context("precompile_only") or _is_role_pserver() or _is_role_sched():
|
||||
return None
|
||||
return self.run(obj, *args, phase=phase)
|
||||
|
||||
|
|
|
@ -929,18 +929,23 @@ def set_ps_context(**kwargs):
|
|||
|
||||
MS_WORKER: represents the worker,
|
||||
|
||||
MS_PSERVER: represents the Server
|
||||
MS_PSERVER/MS_SERVER: represents the Server
|
||||
|
||||
Args:
|
||||
enable_ps (bool): Whether to enable parameter server training mode.
|
||||
Only after enable_ps is set True, the environment variables will be effective.
|
||||
Default: False.
|
||||
config_file_path (string): Configuration file path used by recovery. Default: ''.
|
||||
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
|
||||
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
|
||||
client_password (str): Password to decrypt the secret key stored in the client certificate.
|
||||
server_password (str): Password to decrypt the secret key stored in the server certificate.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in parameter server training mode context.
|
||||
|
||||
Examples:
|
||||
>>> context.set_ps_context(enable_ps=True)
|
||||
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
|
||||
"""
|
||||
_set_ps_context(**kwargs)
|
||||
|
||||
|
|
|
@ -27,13 +27,13 @@
|
|||
ClassType(const ClassType &) = delete; \
|
||||
ClassType &operator=(const ClassType &) = delete;
|
||||
|
||||
#define TRY_AND_CATCH_WITH_EXCEPTION(expr, error_msg) \
|
||||
do { \
|
||||
try { \
|
||||
(expr); \
|
||||
} catch (const std::exception &e) { \
|
||||
MS_LOG(EXCEPTION) << "Caught exception " << e.what() << ". " << error_msg; \
|
||||
} \
|
||||
#define TRY_AND_CATCH_WITH_EXCEPTION(expr, error_msg) \
|
||||
do { \
|
||||
try { \
|
||||
(expr); \
|
||||
} catch (const std::exception &e) { \
|
||||
MS_LOG(EXCEPTION) << "Caught exception of " << e.what() << ". " << error_msg; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -157,12 +157,17 @@ def _set_ps_context(**kwargs):
|
|||
enable_ps (bool): Whether to enable parameter server training mode.
|
||||
Only after enable_ps is set True, the environment variables will be effective.
|
||||
Default: False.
|
||||
config_file_path (string): Configuration file path used by recovery. Default: ''.
|
||||
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
|
||||
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
|
||||
client_password (str): Password to decrypt the secret key stored in the client certificate.
|
||||
server_password (str): Password to decrypt the secret key stored in the server certificate.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in parameter server training mode context.
|
||||
|
||||
Examples:
|
||||
>>> context.set_ps_context(enable_ps=True)
|
||||
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if key not in _set_ps_context_func_map:
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
import math
|
||||
import os
|
||||
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from .. import context, nn
|
||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes, _get_pipeline_stages
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
|
||||
from ..ops import operations as P
|
||||
|
||||
|
||||
|
@ -52,10 +52,9 @@ def _dynamic_sink_data(dataset, dataset_iter):
|
|||
|
||||
def _dynamic_sink_exception_scenario(dataset_iter):
|
||||
"""The exception scenario for dynamic data is not applicable."""
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
_, dataset_shapes = dataset_iter.types_shapes()
|
||||
|
||||
if _has_dynamic_shape(dataset_shapes) or ms_role == "MS_WORKER" or \
|
||||
if _has_dynamic_shape(dataset_shapes) or _is_role_worker() or \
|
||||
context.get_context("mode") != context.GRAPH_MODE:
|
||||
return True
|
||||
return False
|
||||
|
@ -171,8 +170,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
if isinstance(dataset_iter, _DatasetIterNormal):
|
||||
raise RuntimeError("The API 'connect_network_with_dataset' should be called in dataset sink mode.")
|
||||
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_sched() or _is_role_pserver():
|
||||
return network
|
||||
|
||||
queue_name = dataset.__transfer_dataset__.queue_name
|
||||
|
@ -260,10 +258,9 @@ class DatasetHelper:
|
|||
iterclass = _DatasetIterGE
|
||||
else:
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_sched() or _is_role_pserver():
|
||||
iterclass = _DatasetIterPSServer
|
||||
elif ms_role == "MS_WORKER":
|
||||
elif _is_role_worker():
|
||||
iterclass = _DatasetIterPSWork
|
||||
elif (context.get_context("device_target") == "Ascend") or \
|
||||
(context.get_context("device_target") == "GPU"):
|
||||
|
@ -365,9 +362,8 @@ class _DatasetIter:
|
|||
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
# PS mode does not support loop sink and need get the real sink size.
|
||||
if ms_role != "MS_WORKER":
|
||||
if not _is_role_worker():
|
||||
self.sink_size = dataset.__loop_size__
|
||||
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context(
|
||||
"device_target") == "Ascend")
|
||||
|
@ -413,10 +409,9 @@ class _DatasetIter:
|
|||
def get_sink_size(self):
|
||||
"""get sink_size to device"""
|
||||
sink_size = 1
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if hasattr(self.dataset, '__loop_size__'):
|
||||
sink_size = self.dataset.__loop_size__
|
||||
elif ms_role == "MS_WORKER":
|
||||
elif _is_role_worker():
|
||||
# PS mode does not support loop sink.
|
||||
sink_size = 1
|
||||
else:
|
||||
|
|
|
@ -14,12 +14,12 @@
|
|||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
import math
|
||||
import os
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore import context
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
from mindspore.nn.wrap import GetNextSingleOp
|
||||
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
|
||||
|
||||
def _send_data(dataset, epoch_num):
|
||||
|
@ -160,8 +160,7 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
loop_size = dataset.__loop_size__ + iter_first_order
|
||||
sink_count = int(sink_size / loop_size) * 2
|
||||
self.sink_count = sink_count
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
self.sink_count = 1
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
|
|
Loading…
Reference in New Issue