From a9a0f590e66c1881bac9c8c9a91a0026d22feaa9 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Thu, 15 Jul 2021 16:24:21 +0800 Subject: [PATCH] Fix master static check --- .../kernel_compiler/cpu/ps/pserver_kernel.cc | 4 +- .../ccsrc/backend/session/ascend_session.cc | 2 +- .../ccsrc/backend/session/cpu_session.cc | 6 +-- mindspore/ccsrc/backend/session/executor.cc | 2 +- .../ccsrc/backend/session/session_basic.cc | 4 +- .../fl/armour/cipher/cipher_reconstruct.cc | 10 ++-- .../fl/armour/cipher/cipher_reconstruct.h | 4 +- .../ccsrc/fl/armour/cipher/cipher_shares.cc | 8 +-- .../ccsrc/fl/armour/cipher/cipher_shares.h | 10 ++-- .../ccsrc/fl/armour/cipher/cipher_unmask.cc | 2 +- .../ccsrc/fl/server/collective_ops_impl.cc | 1 - .../ccsrc/fl/server/consistent_hash_ring.h | 1 + .../fl/server/distributed_count_service.cc | 12 ++--- .../fl/server/distributed_metadata_store.cc | 12 ++--- mindspore/ccsrc/fl/server/executor.cc | 9 ++-- mindspore/ccsrc/fl/server/executor.h | 2 + mindspore/ccsrc/fl/server/iteration.cc | 50 +++++++++---------- mindspore/ccsrc/fl/server/iteration.h | 2 +- .../server/kernel/round/get_model_kernel.cc | 3 +- .../server/kernel/round/pull_weight_kernel.cc | 9 ++-- .../server/kernel/round/pull_weight_kernel.h | 7 +-- .../server/kernel/round/push_weight_kernel.cc | 8 +-- .../server/kernel/round/push_weight_kernel.h | 6 +-- .../fl/server/kernel/round/round_kernel.cc | 6 +-- .../fl/server/kernel/round/round_kernel.h | 1 + .../kernel/round/start_fl_job_kernel.cc | 8 +-- .../kernel/round/update_model_kernel.cc | 4 +- .../ccsrc/fl/server/parameter_aggregator.cc | 24 ++++----- mindspore/ccsrc/fl/server/round.cc | 14 +++--- mindspore/ccsrc/fl/server/round.h | 4 +- mindspore/ccsrc/fl/server/server.cc | 33 +++++++----- mindspore/ccsrc/fl/server/server.h | 10 ++-- mindspore/ccsrc/fl/worker/fl_worker.cc | 18 ++++--- .../parallel/ops_info/gather_v2_p_info.cc | 6 +-- .../frontend/parallel/ops_info/unique_info.cc | 6 +-- .../ccsrc/frontend/parallel/step_parallel.cc | 2 +- mindspore/ccsrc/pipeline/jit/action.cc | 14 +++--- mindspore/ccsrc/pipeline/jit/init.cc | 2 +- mindspore/ccsrc/pipeline/jit/pass.cc | 6 +-- mindspore/ccsrc/ps/core/server_node.cc | 2 +- mindspore/ccsrc/ps/core/server_node.h | 2 +- mindspore/ccsrc/ps/ps_context.cc | 30 +++++------ mindspore/ccsrc/ps/util.cc | 3 +- mindspore/ccsrc/ps/util.h | 3 +- .../ccsrc/runtime/device/kernel_runtime.cc | 8 +-- mindspore/context.py | 1 + mindspore/ops/operations/other_ops.py | 3 ++ mindspore/parallel/_ps_context.py | 2 + 48 files changed, 207 insertions(+), 179 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc index d4eef639de2..e2074c70a4a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc @@ -20,8 +20,8 @@ namespace mindspore { namespace kernel { namespace ps { void PServerKernel::Shard(std::vector *shape, int axis) { - (*shape)[axis] = - LongToSize(Util::LocalShard(SizeToLong((*shape)[axis]), SizeToLong(rank_id_), SizeToLong(pserver_num_))); + (*shape)[IntToSize(axis)] = + LongToSize(Util::LocalShard(SizeToLong((*shape)[IntToSize(axis)]), SizeToLong(rank_id_), SizeToLong(pserver_num_))); } } // namespace ps } // namespace kernel diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index fe8f7790585..36f3ed6e44f 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -350,7 +350,7 @@ void AscendSession::LoadInputData(const std::shared_ptr &kernel_gra size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); } if (AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) const std::string ¶m_name = input_node->fullname_with_scope(); if (ps::ps_cache_instance.IsHashTable(param_name)) { continue; diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 364bfde07e9..3ab9d41e929 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -34,7 +34,7 @@ #include "debug/anf_ir_dump.h" #include "debug/dump_proto.h" #include "debug/data_dump/dump_json_parser.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/util.h" #include "ps/ps_context.h" #endif @@ -77,7 +77,7 @@ void CPUSession::Reorder(std::vector *node_list) { AnfAlgo::ReorderPos void CPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { @@ -193,7 +193,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr &kernel_grap MS_LOG(INFO) << "Bind input output address"; runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) InitPSParamAndOptim(kernel_graph, inputs); #endif } diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 01b9f230d14..bf19115f100 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -22,7 +22,7 @@ #include "utils/comm_manager.h" #include "utils/scoped_long_running.h" #include "pybind_api/ir/tensor_py.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/ps_cache/ps_cache_manager.h" #endif diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 7d7df0418fb..118d5383391 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -44,7 +44,7 @@ #include "debug/common.h" #include "utils/trace_base.h" #include "frontend/parallel/context.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/ps_cache/ps_cache_manager.h" #include "ps/constants.h" #include "ps/util.h" @@ -2483,7 +2483,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr &kernel_graph) { void SessionBasic::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIROptimization(graph); } -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { if (!ps::PSContext::instance()->is_worker()) { return; diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc index 51a9406f2fe..5e6138019df 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc @@ -157,10 +157,10 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector &cl } // reconstruct secrets -bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, - const schema::SendReconstructSecret *reconstruct_secret_req, - std::shared_ptr reconstruct_secret_resp_builder, - const std::vector &client_list) { +bool CipherReconStruct::ReconstructSecrets( + const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req, + const std::shared_ptr &reconstruct_secret_resp_builder, + const std::vector &client_list) { MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START"; clock_t start_time = clock(); if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) { @@ -285,7 +285,7 @@ void CipherReconStruct::ClearReconstructSecrets() { MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success"; } -void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr fbb, +void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const int iteration, const std::string &next_req_time) { auto fbs_reason = fbb->CreateString(reason); diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.h b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.h index d43fad277e9..83696da8126 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.h +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.h @@ -44,11 +44,11 @@ class CipherReconStruct { // reconstruct secret mask bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req, - std::shared_ptr reconstruct_secret_resp_builder, + const std::shared_ptr &reconstruct_secret_resp_builder, const std::vector &client_list); // build response code of reconstruct secret. - void BuildReconstructSecretsRsp(std::shared_ptr fbb, const schema::ResponseCode retcode, + void BuildReconstructSecretsRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const int iteration, const std::string &next_req_time); // clear the shared memory. diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_shares.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_shares.cc index 6de5d973f97..8a80e1eca1c 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_shares.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_shares.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace armour { bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req, - std::shared_ptr share_secrets_resp_builder, + const std::shared_ptr &share_secrets_resp_builder, const string next_req_time) { MS_LOG(INFO) << "CipherShares::ShareSecrets START"; if (share_secrets_req == nullptr) { @@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha } bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, - std::shared_ptr get_secrets_resp_builder, + const std::shared_ptr &get_secrets_resp_builder, const std::string &next_req_time) { MS_LOG(INFO) << "CipherShares::GetSecrets START"; clock_t start_time = clock(); @@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, } void CipherShares::BuildGetSecretsRsp( - std::shared_ptr get_secrets_resp_builder, schema::ResponseCode retcode, int iteration, + const std::shared_ptr &get_secrets_resp_builder, schema::ResponseCode retcode, int iteration, std::string next_req_time, std::vector> *encrypted_shares) { int rsp_retcode = retcode; int rsp_iteration = iteration; @@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp( return; } -void CipherShares::BuildShareSecretsRsp(std::shared_ptr share_secrets_resp_builder, +void CipherShares::BuildShareSecretsRsp(const std::shared_ptr &share_secrets_resp_builder, const schema::ResponseCode retcode, const string &reason, const string &next_req_time, const int iteration) { auto rsp_reason = share_secrets_resp_builder->CreateString(reason); diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_shares.h b/mindspore/ccsrc/fl/armour/cipher/cipher_shares.h index 1915db0dfc9..498643b8f4f 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_shares.h +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_shares.h @@ -43,17 +43,19 @@ class CipherShares { // handle the client's request of share secrets. bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req, - std::shared_ptr share_secrets_resp_builder, const string next_req_time); + const std::shared_ptr &share_secrets_resp_builder, + const string next_req_time); // handle the client's request of get secrets. bool GetSecrets(const schema::GetShareSecrets *get_secrets_req, - std::shared_ptr get_secrets_resp_builder, const std::string &next_req_time); + const std::shared_ptr &get_secrets_resp_builder, + const std::string &next_req_time); // build response code of share secrets. - void BuildShareSecretsRsp(std::shared_ptr share_secrets_resp_builder, + void BuildShareSecretsRsp(const std::shared_ptr &share_secrets_resp_builder, const schema::ResponseCode retcode, const string &reason, const string &next_req_time, const int iteration); // build response code of get secrets. - void BuildGetSecretsRsp(std::shared_ptr get_secrets_resp_builder, + void BuildGetSecretsRsp(const std::shared_ptr &get_secrets_resp_builder, const schema::ResponseCode retcode, const int iteration, std::string next_req_time, std::vector> *encrypted_shares); // clear the shared memory. diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_unmask.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_unmask.cc index b19279973e7..4b5a29d1cad 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_unmask.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_unmask.cc @@ -26,7 +26,7 @@ bool CipherUnmask::UnMask(const std::map &data) { clock_t start_time = clock(); std::vector noise; - cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise); + (void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise); if (noise.size() != cipher_init_->featuremap_) { MS_LOG(ERROR) << " CipherMgr UnMask ERROR"; return false; diff --git a/mindspore/ccsrc/fl/server/collective_ops_impl.cc b/mindspore/ccsrc/fl/server/collective_ops_impl.cc index cbd9e507a1e..1f33a0bb8f9 100644 --- a/mindspore/ccsrc/fl/server/collective_ops_impl.cc +++ b/mindspore/ccsrc/fl/server/collective_ops_impl.cc @@ -114,7 +114,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size std::shared_ptr> recv_str; auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str); - if (!server_node_->CollectiveWait(recv_req_id)) { MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; return false; diff --git a/mindspore/ccsrc/fl/server/consistent_hash_ring.h b/mindspore/ccsrc/fl/server/consistent_hash_ring.h index ddf19e66620..9eb14875d91 100644 --- a/mindspore/ccsrc/fl/server/consistent_hash_ring.h +++ b/mindspore/ccsrc/fl/server/consistent_hash_ring.h @@ -24,6 +24,7 @@ namespace mindspore { namespace fl { namespace server { +constexpr uint32_t kDefaultVirtualNodeNum = 32; // To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in // server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node. diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.cc b/mindspore/ccsrc/fl/server/distributed_count_service.cc index 02213123bf9..4d8a74507d7 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.cc +++ b/mindspore/ccsrc/fl/server/distributed_count_service.cc @@ -104,7 +104,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & } CountResponse count_rsp; - count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); + (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); if (!count_rsp.result()) { MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { @@ -138,8 +138,8 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { } CountReachThresholdResponse count_reach_threshold_rsp; - count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), - SizeToInt(query_cnt_enough_rsp_msg->size())); + (void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), + SizeToInt(query_cnt_enough_rsp_msg->size())); return count_reach_threshold_rsp.is_enough(); } } @@ -178,7 +178,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptrdata(), SizeToInt(message->len())); + (void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len())); const std::string &name = report_count_req.name(); const std::string &id = report_count_req.id(); @@ -228,7 +228,7 @@ void DistributedCountService::HandleCountReachThresholdRequest( } CountReachThresholdRequest count_reach_threshold_req; - count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len())); + (void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len())); const std::string &name = count_reach_threshold_req.name(); std::unique_lock lock(mutex_[name]); @@ -256,7 +256,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptrSendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message); CounterEvent counter_event; - counter_event.ParseFromArray(message->data(), SizeToInt(message->len())); + (void)counter_event.ParseFromArray(message->data(), SizeToInt(message->len())); const auto &type = counter_event.type(); const auto &name = counter_event.name(); diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc index 0e5f166feb8..a7c32fec499 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc @@ -141,7 +141,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; return get_metadata_rsp; } - get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size())); + (void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size())); return get_metadata_rsp; } } @@ -165,7 +165,7 @@ bool DistributedMetadataStore::ReInitForScaling() { } void DistributedMetadataStore::InitHashRing() { - router_ = std::make_shared(32); + router_ = std::make_shared(kDefaultVirtualNodeNum); MS_EXCEPTION_IF_NULL(router_); for (uint32_t i = 0; i < server_num_; i++) { bool ret = router_->Insert(i); @@ -184,7 +184,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr } PBMetadataWithName meta_with_name; - meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len())); + (void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len())); const std::string &name = meta_with_name.name(); MS_LOG(INFO) << "Update metadata for " << name; @@ -195,7 +195,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr } else { update_meta_rsp_msg = "Success"; } - communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message); + (void)communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message); return; } @@ -206,14 +206,14 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptrdata(), message->len()); + (void)get_metadata_req.ParseFromArray(message->data(), message->len()); const std::string &name = get_metadata_req.name(); MS_LOG(INFO) << "Getting metadata for " << name; std::unique_lock lock(mutex_[name]); PBMetadata stored_meta = metadata_[name]; std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); - communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message); + (void)communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message); return; } diff --git a/mindspore/ccsrc/fl/server/executor.cc b/mindspore/ccsrc/fl/server/executor.cc index 151a48014d6..671d0252d17 100644 --- a/mindspore/ccsrc/fl/server/executor.cc +++ b/mindspore/ccsrc/fl/server/executor.cc @@ -67,7 +67,7 @@ bool Executor::HandlePush(const std::string ¶m_name, const UploadData &uploa // Push operation needs to wait until the pulling process is done. while (!param_aggr->IsPullingDone()) { lock.unlock(); - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime)); lock.lock(); } @@ -192,7 +192,7 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) { // Pulling must wait until the optimizing process is done. while (!param_aggr->IsOptimizingDone()) { lock.unlock(); - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime)); lock.lock(); } AddressPtr addr = param_aggr->Pull(); @@ -314,7 +314,10 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { param_names_.push_back(param_name); param_aggrs_[param_name] = param_aggr; parameter_mutex_[param_name]; - param_aggr->Init(cnode, aggregation_count_); + if (!param_aggr->Init(cnode, aggregation_count_)) { + MS_LOG(EXCEPTION) << "Initializing parameter aggregator failed for " << param_name; + return false; + } MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success."; } return true; diff --git a/mindspore/ccsrc/fl/server/executor.h b/mindspore/ccsrc/fl/server/executor.h index 462f967aab4..1ba82d9a852 100644 --- a/mindspore/ccsrc/fl/server/executor.h +++ b/mindspore/ccsrc/fl/server/executor.h @@ -33,6 +33,8 @@ namespace mindspore { namespace fl { namespace server { +constexpr int kThreadSleepTime = 5; + // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles // logics relevant to kernel launching. class Executor { diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index 4097c22e14d..eaca0c2e8f2 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -74,10 +74,10 @@ void Iteration::InitRounds(const std::vector &round) { - return round->check_timeout() ? total + round->time_window() : total; - }); + size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0), + [](size_t total, const std::shared_ptr &round) { + return round->check_timeout() ? total + round->time_window() : total; + }); LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window; return; @@ -162,7 +162,7 @@ bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) { return true; } -const std::vector> &Iteration::rounds() { return rounds_; } +const std::vector> &Iteration::rounds() const { return rounds_; } bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; } @@ -182,7 +182,7 @@ bool Iteration::SyncIteration(uint32_t rank) { } SyncIterationResponse sync_iter_rsp; - sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size()); + (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size())); iteration_num_ = sync_iter_rsp.iteration(); MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is " << sync_iter_rsp.iteration(); @@ -196,14 +196,14 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptrdata(), message->len()); + (void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); uint32_t rank = sync_iter_req.rank(); MS_LOG(INFO) << "Synchronizing iteration request from rank " << rank; SyncIterationResponse sync_iter_rsp; sync_iter_rsp.set_iteration(iteration_num_); std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString(); - communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message); + (void)communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message); } bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) { @@ -238,11 +238,11 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptrSendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(), - notify_leader_to_next_iter_rsp.SerializeAsString().size(), message); + (void)communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(), + notify_leader_to_next_iter_rsp.SerializeAsString().size(), message); NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req; - notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); + (void)notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); const auto &rank = notify_leader_to_next_iter_req.rank(); const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid(); const auto &iter_num = notify_leader_to_next_iter_req.iter_num(); @@ -296,7 +296,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons } MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success."; }); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); return true; } @@ -306,15 +306,15 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptrdata(), message->len()); + (void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); const auto &reason = prepare_next_iter_req.reason(); MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason; PrepareForNextIter(); PrepareForNextIterResponse prepare_next_iter_rsp; prepare_next_iter_rsp.set_result("success"); - communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(), - prepare_next_iter_rsp.SerializeAsString().size(), message); + (void)communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(), + prepare_next_iter_rsp.SerializeAsString().size(), message); } void Iteration::PrepareForNextIter() { @@ -347,11 +347,11 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptrSendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), - proceed_to_next_iter_rsp.SerializeAsString().size(), message); + (void)communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), + proceed_to_next_iter_rsp.SerializeAsString().size(), message); MoveToNextIterRequest proceed_to_next_iter_req; - proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); + (void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid(); const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num(); const auto &reason = proceed_to_next_iter_req.reason(); @@ -370,12 +370,12 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { if (is_iteration_valid) { // Store the model which is successfully aggregated for this iteration. const auto &model = Executor::GetInstance().GetModel(); - ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + (void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; } else { // Store last iteration's model because this iteration is considered as invalid. const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1); - ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + (void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; } @@ -405,7 +405,7 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptrdata(), SizeToInt(message->len())); + (void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); const auto &last_iter_num = end_last_iter_req.last_iter_num(); // If the iteration number is not matched, return error. if (last_iter_num != iteration_num_) { @@ -413,8 +413,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptrSendResponse(end_last_iter_rsp.SerializeAsString().data(), - end_last_iter_rsp.SerializeAsString().size(), message); + (void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), + end_last_iter_rsp.SerializeAsString().size(), message); return; } @@ -422,8 +422,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptrSendResponse(end_last_iter_rsp.SerializeAsString().data(), - end_last_iter_rsp.SerializeAsString().size(), message); + (void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), + end_last_iter_rsp.SerializeAsString().size(), message); } void Iteration::EndLastIter() { diff --git a/mindspore/ccsrc/fl/server/iteration.h b/mindspore/ccsrc/fl/server/iteration.h index 6e8f89de713..28f0da69bf5 100644 --- a/mindspore/ccsrc/fl/server/iteration.h +++ b/mindspore/ccsrc/fl/server/iteration.h @@ -79,7 +79,7 @@ class Iteration { // The server number after scaling is required in some rounds. bool ReInitForScaling(uint32_t server_num, uint32_t server_rank); - const std::vector> &rounds(); + const std::vector> &rounds() const; bool is_last_iteration_valid() const; diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc index 046390f40ae..4876780d451 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc @@ -70,10 +70,9 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons auto next_req_time = LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp); std::map feature_maps; size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); - size_t get_model_iter = static_cast(get_model_req->iteration()); + size_t get_model_iter = IntToSize(get_model_req->iteration()); const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); size_t latest_iter_num = iter_to_model.rbegin()->first; - // If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later. if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) { std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) + diff --git a/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc index 53bf7041bf1..17aa2cfb640 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc @@ -63,10 +63,11 @@ bool PullWeightKernel::Reset() { return true; } -void PullWeightKernel::PullWeight(std::shared_ptr fbb, const schema::RequestPullWeight *pull_weight_req) { +void PullWeightKernel::PullWeight(const std::shared_ptr &fbb, + const schema::RequestPullWeight *pull_weight_req) { std::map feature_maps = {}; size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); - size_t pull_weight_iter = static_cast(pull_weight_req->iteration()); + size_t pull_weight_iter = IntToSize(pull_weight_req->iteration()); // The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry. if (pull_weight_iter != current_iter) { std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) + @@ -110,7 +111,7 @@ void PullWeightKernel::PullWeight(std::shared_ptr fbb, const schema:: return; } -void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr fbb, const schema::ResponseCode retcode, +void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, size_t iteration, const std::map &feature_maps) { auto fbs_reason = fbb->CreateString(reason); @@ -127,7 +128,7 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr fbb, const schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get())); rsp_pull_weight_builder.add_retcode(retcode); rsp_pull_weight_builder.add_reason(fbs_reason); - rsp_pull_weight_builder.add_iteration(iteration); + rsp_pull_weight_builder.add_iteration(SizeToInt(iteration)); rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector); auto rsp_pull_weight = rsp_pull_weight_builder.Finish(); fbb->Finish(rsp_pull_weight); diff --git a/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.h index 9f26a747fe6..ef4735ecb55 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.h @@ -42,9 +42,10 @@ class PullWeightKernel : public RoundKernel { bool Reset() override; private: - void PullWeight(std::shared_ptr fbb, const schema::RequestPullWeight *pull_weight_req); - void BuildPullWeightRsp(std::shared_ptr fbb, const schema::ResponseCode retcode, const std::string &reason, - size_t iteration, const std::map &feature_maps); + void PullWeight(const std::shared_ptr &fbb, const schema::RequestPullWeight *pull_weight_req); + void BuildPullWeightRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, size_t iteration, + const std::map &feature_maps); Executor *executor_; diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc index a93335f6862..98071a84dba 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc @@ -67,12 +67,12 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr fbb, +ResultCode PushWeightKernel::PushWeight(const std::shared_ptr &fbb, const schema::RequestPushWeight *push_weight_req) { if (fbb == nullptr || push_weight_req == nullptr) { return ResultCode::kSuccessAndReturn; } - size_t iteration = static_cast(push_weight_req->iteration()); + size_t iteration = IntToSize(push_weight_req->iteration()); size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); if (iteration != current_iter) { std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) + @@ -123,13 +123,13 @@ std::map PushWeightKernel::ParseFeatureMap(const schema::R return upload_feature_map; } -void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr fbb, const schema::ResponseCode retcode, +void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, size_t iteration) { auto fbs_reason = fbb->CreateString(reason); schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get())); rsp_push_weight_builder.add_retcode(retcode); rsp_push_weight_builder.add_reason(fbs_reason); - rsp_push_weight_builder.add_iteration(iteration); + rsp_push_weight_builder.add_iteration(SizeToInt(iteration)); auto rsp_push_weight = rsp_push_weight_builder.Finish(); fbb->Finish(rsp_push_weight); return; diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h index 7b09d3d8601..93088176a35 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h @@ -42,10 +42,10 @@ class PushWeightKernel : public RoundKernel { void OnLastCountEvent(const std::shared_ptr &message) override; private: - ResultCode PushWeight(std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req); + ResultCode PushWeight(const std::shared_ptr &fbb, const schema::RequestPushWeight *push_weight_req); std::map ParseFeatureMap(const schema::RequestPushWeight *push_weight_req); - void BuildPushWeightRsp(std::shared_ptr fbb, const schema::ResponseCode retcode, const std::string &reason, - size_t iteration); + void BuildPushWeightRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, size_t iteration); Executor *executor_; uint32_t local_rank_; diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc index 1b7dd685467..2b0a317f2e8 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc @@ -34,7 +34,7 @@ RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), e // Detect whether there's any data needs to be released every 100 milliseconds. if (heap_data_to_release_.empty()) { release_lock.unlock(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(std::chrono::milliseconds(kReleaseDuration)); continue; } @@ -61,9 +61,9 @@ RoundKernel::~RoundKernel() { } } -void RoundKernel::OnFirstCountEvent(const std::shared_ptr &message) { return; } +void RoundKernel::OnFirstCountEvent(const std::shared_ptr &) { return; } -void RoundKernel::OnLastCountEvent(const std::shared_ptr &message) { return; } +void RoundKernel::OnLastCountEvent(const std::shared_ptr &) { return; } void RoundKernel::StopTimer() const { if (stop_timer_cb_) { diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h index 4f5cab58666..c7184a89eda 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h @@ -38,6 +38,7 @@ namespace mindspore { namespace fl { namespace server { namespace kernel { +constexpr uint64_t kReleaseDuration = 100; // RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round // kernels to represent the process. They receive and parse messages from the server communication module. After // handling these messages, round kernels allocate response data and send it back. diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index 324de765abc..d873a38271f 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -118,7 +118,7 @@ bool StartFLJobKernel::Reset() { } void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr &) { - iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_; + iter_next_req_timestamp_ = LongToSize(CURRENT_TIME_MILLI.count()) + iteration_time_window_; LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_); // The first startFLJob request means a new iteration starts running. Iteration::GetInstance().SetIterationRunning(); @@ -220,9 +220,9 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); fl_plan_builder.add_fl_name(fbs_fl_name); fl_plan_builder.add_server_mode(fbs_server_mode); - fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num()); - fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num()); - fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size()); + fl_plan_builder.add_iterations(SizeToInt(ps::PSContext::instance()->fl_iteration_num())); + fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num())); + fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size())); fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate()); #ifdef ENABLE_ARMOUR fl_plan_builder.add_cipher(cipher_public_params); diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index f5698d17471..58b58de8598 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -90,7 +90,7 @@ bool UpdateModelKernel::Reset() { return true; } -void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr &message) { +void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr &) { if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) { while (!executor_->IsAllWeightAggregationDone()) { std::this_thread::sleep_for(std::chrono::milliseconds(5)); @@ -120,7 +120,7 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr &fbb) { RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn); - size_t iteration = static_cast(update_model_req->iteration()); + size_t iteration = IntToSize(update_model_req->iteration()); if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) + diff --git a/mindspore/ccsrc/fl/server/parameter_aggregator.cc b/mindspore/ccsrc/fl/server/parameter_aggregator.cc index 72be46aaeee..d975ed538b1 100644 --- a/mindspore/ccsrc/fl/server/parameter_aggregator.cc +++ b/mindspore/ccsrc/fl/server/parameter_aggregator.cc @@ -281,16 +281,16 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< KernelParams aggr_params = {}; const std::vector &input_names = aggr_kernel->input_names(); - std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs), - [&](const std::string &name) { return memory_register->addresses()[name]; }); + (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); const std::vector &workspace_names = aggr_kernel->workspace_names(); - std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace), - [&](const std::string &name) { return memory_register->addresses()[name]; }); + (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace), + [&](const std::string &name) { return memory_register->addresses()[name]; }); const std::vector &output_names = aggr_kernel->output_names(); - std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), - [&](const std::string &name) { return memory_register->addresses()[name]; }); + (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs); aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); @@ -304,16 +304,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr &input_names = optimizer_kernel->input_names(); - std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs), - [&](const std::string &name) { return memory_register->addresses()[name]; }); + (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); const std::vector &workspace_names = optimizer_kernel->workspace_names(); - std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace), - [&](const std::string &name) { return memory_register->addresses()[name]; }); + (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace), + [&](const std::string &name) { return memory_register->addresses()[name]; }); const std::vector &output_names = optimizer_kernel->output_names(); - std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs), - [&](const std::string &name) { return memory_register->addresses()[name]; }); + (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params)); return true; diff --git a/mindspore/ccsrc/fl/server/round.cc b/mindspore/ccsrc/fl/server/round.cc index 4d25693fd25..df5ae32dc52 100644 --- a/mindspore/ccsrc/fl/server/round.cc +++ b/mindspore/ccsrc/fl/server/round.cc @@ -34,8 +34,8 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo threshold_count_(threshold_count), server_num_as_threshold_(server_num_as_threshold) {} -void Round::Initialize(const std::shared_ptr &communicator, TimeOutCb timeout_cb, - FinishIterCb finish_iteration_cb) { +void Round::Initialize(const std::shared_ptr &communicator, const TimeOutCb &timeout_cb, + const FinishIterCb &finish_iteration_cb) { MS_EXCEPTION_IF_NULL(communicator); communicator_ = communicator; @@ -50,7 +50,7 @@ void Round::Initialize(const std::shared_ptr &commun }; // Callback for finalizing the server. This can only be called once. - finalize_cb_ = [&](void) -> void { communicator_->Stop(); }; + finalize_cb_ = [&](void) -> void { (void)communicator_->Stop(); }; if (check_timeout_) { iter_timer_ = std::make_shared(); @@ -116,7 +116,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr &m if (Server::GetInstance().IsSafeMode()) { MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later."; std::string reason = "The cluster is in safemode."; - communicator_->SendResponse(reason.c_str(), reason.size(), message); + (void)communicator_->SendResponse(reason.c_str(), reason.size(), message); return; } @@ -128,10 +128,10 @@ void Round::LaunchRoundKernel(const std::shared_ptr &m if (output->size == 0) { std::string reason = "The output of the round " + name_ + " is empty."; MS_LOG(WARNING) << reason; - communicator_->SendResponse(reason.c_str(), reason.size(), message); + (void)communicator_->SendResponse(reason.c_str(), reason.size(), message); return; } - communicator_->SendResponse(output->addr, output->size, message); + (void)communicator_->SendResponse(output->addr, output->size, message); kernel_->Release(output); // Must send response back no matter what value Launch method returns. @@ -142,7 +142,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr &m return; } -void Round::Reset() { kernel_->Reset(); } +void Round::Reset() { (void)kernel_->Reset(); } const std::string &Round::name() const { return name_; } diff --git a/mindspore/ccsrc/fl/server/round.h b/mindspore/ccsrc/fl/server/round.h index e2ba6586904..1aae7b560d7 100644 --- a/mindspore/ccsrc/fl/server/round.h +++ b/mindspore/ccsrc/fl/server/round.h @@ -37,8 +37,8 @@ class Round { bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false); ~Round() = default; - void Initialize(const std::shared_ptr &communicator, TimeOutCb timeout_cb, - FinishIterCb finish_iteration_cb); + void Initialize(const std::shared_ptr &communicator, const TimeOutCb &timeout_cb, + const FinishIterCb &finish_iteration_cb); // Reinitialize count service and round kernel of this round after scaling operations are done. bool ReInitForScaling(uint32_t server_num); diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index b7009501373..09758a3509a 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -102,7 +102,7 @@ void Server::CancelSafeMode() { safemode_ = false; } -bool Server::IsSafeMode() { return safemode_.load(); } +bool Server::IsSafeMode() const { return safemode_.load(); } void Server::InitServerContext() { ps::PSContext::instance()->GenerateResetterRound(); @@ -121,7 +121,7 @@ void Server::InitServerContext() { void Server::InitCluster() { server_node_ = std::make_shared(); MS_EXCEPTION_IF_NULL(server_node_); - task_executor_ = std::make_shared(32); + task_executor_ = std::make_shared(kExecutorThreadPoolSize); MS_EXCEPTION_IF_NULL(task_executor_); if (!InitCommunicatorWithServer()) { MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed."; @@ -235,9 +235,9 @@ void Server::InitCipher() { #ifdef ENABLE_ARMOUR cipher_init_ = &armour::CipherInit::GetInstance(); - int cipher_t = cipher_reconstruct_secrets_down_cnt_; + int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_); unsigned char cipher_p[SECRET_MAX_LEN] = {0}; - int cipher_g = 1; + const int cipher_g = 1; unsigned char cipher_prime[PRIME_MAX_LEN] = {0}; float dp_eps = ps::PSContext::instance()->dp_eps(); float dp_delta = ps::PSContext::instance()->dp_delta(); @@ -304,8 +304,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr &communicator) { communicator->Stop(); }); - communicator_with_server_->Stop(); + [](const std::shared_ptr &communicator) { (void)communicator->Stop(); }); + (void)communicator_with_server_->Stop(); }); communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() { @@ -314,8 +314,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr &communicator) { communicator->Stop(); }); - communicator_with_server_->Stop(); + [](const std::shared_ptr &communicator) { (void)communicator->Stop(); }); + (void)communicator_with_server_->Stop(); }); } @@ -363,7 +363,10 @@ void Server::StartCommunicator() { } MS_LOG(INFO) << "Start communicator with server."; - communicator_with_server_->Start(); + if (!communicator_with_server_->Start()) { + MS_LOG(EXCEPTION) << "Starting communicator with server failed."; + return; + } DistributedMetadataStore::GetInstance().Initialize(server_node_); CollectiveOpsImpl::GetInstance().Initialize(server_node_); DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank); @@ -371,7 +374,11 @@ void Server::StartCommunicator() { MS_LOG(INFO) << "Start communicator with worker."; std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), - [](const std::shared_ptr &communicator) { communicator->Start(); }); + [](const std::shared_ptr &communicator) { + if (!communicator->Start()) { + MS_LOG(EXCEPTION) << "Starting communicator with worker failed."; + } + }); } void Server::ProcessBeforeScalingOut() { @@ -405,7 +412,7 @@ void Server::ProcessAfterScalingOut() { if (!Executor::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "Executor reinitializing failed."; } - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); safemode_ = false; } @@ -418,7 +425,7 @@ void Server::ProcessAfterScalingIn() { if (server_node_->rank_id() == UINT32_MAX) { MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting."; std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), - [](const std::shared_ptr &communicator) { communicator->Stop(); }); + [](const std::shared_ptr &communicator) { (void)communicator->Stop(); }); communicator_with_server_->Stop(); return; } @@ -439,7 +446,7 @@ void Server::ProcessAfterScalingIn() { if (!Executor::GetInstance().ReInitForScaling()) { MS_LOG(WARNING) << "Executor reinitializing failed."; } - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); safemode_ = false; } } // namespace server diff --git a/mindspore/ccsrc/fl/server/server.h b/mindspore/ccsrc/fl/server/server.h index 30c2491f5e9..8566d4f6f2d 100644 --- a/mindspore/ccsrc/fl/server/server.h +++ b/mindspore/ccsrc/fl/server/server.h @@ -33,6 +33,9 @@ namespace mindspore { namespace fl { namespace server { +// The sleeping time of the server thread before the networking is completed. +constexpr uint32_t kServerSleepTimeForNetworking = 1000; + // Class Server is the entrance of MindSpore's parameter server training mode and federated learning. class Server { public: @@ -51,7 +54,7 @@ class Server { void SwitchToSafeMode(); void CancelSafeMode(); - bool IsSafeMode(); + bool IsSafeMode() const; private: Server() @@ -162,8 +165,6 @@ class Server { uint32_t server_num_; uint32_t worker_num_; uint16_t fl_server_port_; - size_t start_fl_job_cnt_; - size_t update_model_cnt_; size_t cipher_initial_client_cnt_; size_t cipher_exchange_secrets_cnt_; size_t cipher_share_secrets_cnt_; @@ -171,9 +172,6 @@ class Server { size_t cipher_reconstruct_secrets_up_cnt_; size_t cipher_reconstruct_secrets_down_cnt_; uint64_t cipher_time_window_; - - float percent_for_update_model_; - float percent_for_get_model_; }; } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/worker/fl_worker.cc b/mindspore/ccsrc/fl/worker/fl_worker.cc index a004f53facf..2d2d0001ea7 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.cc +++ b/mindspore/ccsrc/fl/worker/fl_worker.cc @@ -70,8 +70,14 @@ void FLWorker::Run() { void FLWorker::Finalize() { MS_EXCEPTION_IF_NULL(worker_node_); - worker_node_->Finish(); - worker_node_->Stop(); + if (!worker_node_->Finish()) { + MS_LOG(ERROR) << "Worker node finishing failed."; + return; + } + if (!worker_node_->Stop()) { + MS_LOG(ERROR) << "Worker node stopping failed."; + return; + } } bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command, @@ -201,8 +207,8 @@ void FLWorker::ProcessAfterScalingOut() { } MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker."; - server_num_ = worker_node_->server_num(); - worker_num_ = worker_node_->worker_num(); + server_num_ = IntToUint(worker_node_->server_num()); + worker_num_ = IntToUint(worker_node_->worker_num()); MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is " << server_num_ << ". Exit safemode."; std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking)); @@ -215,8 +221,8 @@ void FLWorker::ProcessAfterScalingIn() { } MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker."; - server_num_ = worker_node_->server_num(); - worker_num_ = worker_node_->worker_num(); + server_num_ = IntToUint(worker_node_->server_num()); + worker_num_ = IntToUint(worker_node_->worker_num()); MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_ << ". Exit safemode."; std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking)); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index 66f1f86808d..d3b59db92b7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -25,7 +25,7 @@ #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/context.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/ps_cache/ps_cache_manager.h" #include "utils/ms_context.h" #endif @@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() { if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { dynamic_shape_indices_ = true; } -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); bool enable_sparse = MsContext::GetInstance()->get_param(MS_CTX_ENABLE_SPARSE); if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) { @@ -637,7 +637,7 @@ Status GatherPInfo::InferBias() { rank = rank % (params_strategy[0] * params_strategy[1]); } } -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PsDataPrefetch::GetInstance().cache_enable()) { bias_ = static_cast(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); return SUCCESS; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc index 4e7099d943d..5ffaa46ca93 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc @@ -28,7 +28,7 @@ #include "frontend/parallel/strategy.h" #include "frontend/parallel/context.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/ps_cache/ps_cache_manager.h" #endif @@ -119,7 +119,7 @@ std::vector UniqueInfo::GenerateOpStrategies(int64_t stage_id) { return sp_vector; } -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { GenerateGraph gen_g = GenerateGraph(attrs_); if (gen_g.Init(cnode) != SUCCESS) { @@ -156,7 +156,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { #endif ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PsDataPrefetch::GetInstance().cache_enable()) { auto inputs = cnode->inputs(); if (inputs.empty()) { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 4e285585da5..a59e71cd1cc 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -47,7 +47,7 @@ #include "utils/ms_context.h" #include "utils/symbolic.h" #include "mindspore/core/utils/parallel_node_check.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/util.h" #include "ps/ps_context.h" #endif diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 7405a720ad1..1035e187ef7 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -44,7 +44,7 @@ #include "vm/transform.h" #include "parse/python_adapter.h" #include "frontend/optimizer/py_pass_manager.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/parameter_server.h" #include "ps/scheduler.h" #include "ps/worker.h" @@ -478,7 +478,7 @@ bool OptInlineAction(const ResourcePtr &res) { bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } bool VmOptimizeAction(const ResourcePtr &res) { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PSContext::instance()->is_ps_mode()) { kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps}); } @@ -633,8 +633,8 @@ bool ExecuteAction(const ResourcePtr &res) { return true; } -#if (ENABLE_CPU && !_WIN32) -bool StartPSWorkerAction(const ResourcePtr &res) { +#if ((defined ENABLE_CPU) && (!defined _WIN32)) +bool StartPSWorkerAction(const ResourcePtr &) { ps::Worker::GetInstance().Run(); return true; } @@ -695,7 +695,7 @@ bool StartServerAction(const ResourcePtr &res) { return true; } -bool StartPSSchedulerAction(const ResourcePtr &res) { +bool StartPSSchedulerAction(const ResourcePtr &) { ps::Scheduler::GetInstance().Run(); return true; } @@ -861,7 +861,7 @@ std::vector VmPipeline() { actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction)); actions.emplace_back(std::make_pair("validate", ValidateAction)); -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PSContext::instance()->is_worker()) { std::string server_mode = ps::PSContext::instance()->server_mode(); if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) { @@ -889,7 +889,7 @@ std::vector BackendPipeline() { return actions; } -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) std::vector ServerPipeline() { auto actions = CommonPipeline(); actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index ea7e936b681..3f940664876 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -34,7 +34,7 @@ #else #include "runtime/device/gpu/distribution/collective_fake_init.h" #endif -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/util.h" #endif #include "ps/ps_context.h" diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 22484bd5ccb..f4e325b91b6 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -47,7 +47,7 @@ #include "frontend/optimizer/irpass/gradient_eliminate.h" #include "frontend/optimizer/irpass/parameter_eliminate.h" #include "frontend/optimizer/irpass/updatestate_eliminate.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/util.h" #include "ps/ps_context.h" #endif @@ -211,7 +211,7 @@ namespace { bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); } bool parallel_mode() { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { return false; } @@ -556,7 +556,7 @@ bool AddRecomputationPass(const ResourcePtr &res) { } bool AddCacheEmbeddingPass(const ResourcePtr &res) { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PSContext::instance()->is_ps_mode()) { return true; } diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 9649b0f33bd..290aaa82dcc 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -148,7 +148,7 @@ std::shared_ptr ServerNode::GetOrCreateHttpComm(const std::str } std::shared_ptr ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip, - std::int16_t scheduler_port, uint32_t worker_num, + uint16_t scheduler_port, uint32_t worker_num, uint32_t server_num, const std::shared_ptr &task_executor) { std::lock_guard lock(communicator_mutex_); diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 6abd5a2dcd0..812bf5d377b 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -61,7 +61,7 @@ class ServerNode : public AbstractNode { std::shared_ptr GetOrCreateHttpComm(const std::string &ip, uint16_t port, const std::shared_ptr &task_executor); - std::shared_ptr GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port, + std::shared_ptr GetOrCreateTcpComm(const std::string &scheduler_ip, uint16_t scheduler_port, uint32_t worker_num, uint32_t server_num, const std::shared_ptr &task_executor); diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index b6dfd69789d..b1dce5f3bcd 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -18,7 +18,7 @@ #include "utils/log_adapter.h" #include "utils/ms_utils.h" #include "backend/kernel_compiler/kernel.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" #endif @@ -63,7 +63,7 @@ void PSContext::SetPSEnable(bool enabled) { } bool PSContext::is_ps_mode() const { - if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { return true; } return ps_enabled_; @@ -74,7 +74,7 @@ void PSContext::Reset() { is_worker_ = false; is_pserver_ = false; is_sched_ = false; -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) if (ps::PsDataPrefetch::GetInstance().cache_enable()) { ps_cache_instance.Finalize(); set_cache_enable(false); @@ -83,7 +83,7 @@ void PSContext::Reset() { } std::string PSContext::ms_role() const { - if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { return role_; } if (is_worker_) { @@ -98,21 +98,21 @@ std::string PSContext::ms_role() const { } bool PSContext::is_worker() const { - if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { return role_ == kEnvRoleOfWorker; } return is_worker_; } bool PSContext::is_server() const { - if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { return role_ == kEnvRoleOfServer; } return is_pserver_; } bool PSContext::is_scheduler() const { - if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { return role_ == kEnvRoleOfScheduler; } return is_sched_; @@ -130,44 +130,44 @@ uint32_t PSContext::ps_rank_id() const { return rank_id_; } void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, size_t vocab_size) const { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); #endif } void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, size_t cache_vocab_size, size_t embedding_size) const { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); #endif } void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); #endif } void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); #endif } void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); #endif } void PSContext::set_cache_enable(bool cache_enable) const { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); #endif } void PSContext::set_rank_id(uint32_t rank_id) const { -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) ps_cache_instance.set_rank_id(rank_id); #endif } @@ -358,7 +358,7 @@ void PSContext::set_cipher_time_window(uint64_t cipher_time_window) { uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; } void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) { - if (reconstruct_secrets_threshold <= 0) { + if (reconstruct_secrets_threshold == 0) { MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive."; return; } diff --git a/mindspore/ccsrc/ps/util.cc b/mindspore/ccsrc/ps/util.cc index 1fd67c4a298..d6716f230be 100644 --- a/mindspore/ccsrc/ps/util.cc +++ b/mindspore/ccsrc/ps/util.cc @@ -136,7 +136,8 @@ bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) { return true; } -void Util::DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name) { +void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name, + const std::string &fused_cnode_name) { MS_EXCEPTION_IF_NULL(func_graph); std::vector node_list = TopoSort(func_graph->get_return()); diff --git a/mindspore/ccsrc/ps/util.h b/mindspore/ccsrc/ps/util.h index 11d11cf233f..7875c075656 100644 --- a/mindspore/ccsrc/ps/util.h +++ b/mindspore/ccsrc/ps/util.h @@ -56,7 +56,8 @@ class Util { static bool FuseServerCommOps(const pipeline::ResourcePtr &res); private: - static void DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name); + static void DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name, + const std::string &fused_cnode_name); static kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector &node_list); static std::unordered_map optimizer_to_ids; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 4a60590277b..53d58ecf080 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -32,7 +32,7 @@ #include "utils/utils.h" #include "frontend/parallel/context.h" #include "debug/env_config_parser.h" -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/ps_cache/ps_cache_manager.h" #endif @@ -333,7 +333,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { } add_need_alloc_nodes(input_node); } -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) bool ps_cache_check = false; #endif for (auto &item : need_alloc_nodes) { @@ -346,7 +346,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { continue; } DeviceAddressPtr device_address = nullptr; -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) const std::string ¶m_name = item->fullname_with_scope(); if (ps::ps_cache_instance.IsHashTable(param_name)) { MS_LOG(INFO) << "Parameter(" << param_name << ")" @@ -1087,7 +1087,7 @@ void KernelRuntime::ClearOutputAddress(const std::vector &inputs, } } -#if (ENABLE_CPU && !_WIN32) +#if ((defined ENABLE_CPU) && (!defined _WIN32)) void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index, size_t *const first_cache_size) { diff --git a/mindspore/context.py b/mindspore/context.py index 3af9b5f90a7..24b3ee5b475 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -817,6 +817,7 @@ def reset_ps_context(): """ _reset_ps_context() + def set_fl_context(**kwargs): """ Set federated learning training mode context. diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 2e4fffcd486..c5ee2d600e0 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -726,6 +726,7 @@ class Pull(PrimitiveWithInfer): def infer_dtype(self, key_dtype, weight_dtype): return mstype.float32 + class PullWeight(PrimitiveWithInfer): """ Pull weight by its names from server. @@ -751,6 +752,7 @@ class PullWeight(PrimitiveWithInfer): def infer_dtype(self, weight, name, index): return mstype.float32 + class PushWeight(PrimitiveWithInfer): """ Upload weight by its names to server. @@ -776,6 +778,7 @@ class PushWeight(PrimitiveWithInfer): def infer_dtype(self, weight, ps_key, index): return mstype.float32 + class identity(Primitive): """ Makes a identify primitive, used for pynative mode. diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index c0ba1690b28..c4b5c3cc943 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -31,6 +31,7 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] _check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"] + def ps_context(): """ Get the global _ps_context, if it is not created, create a new one. @@ -226,6 +227,7 @@ def _set_cache_enable(cache_enable): def _set_rank_id(rank_id): ps_context().set_rank_id(rank_id) + def _check_value(key, value): """ Validate the value for parameter server context keys.