From 5d40613d52c70c15b0acfae849123fe95cf75352 Mon Sep 17 00:00:00 2001 From: twc Date: Thu, 13 Jan 2022 10:12:35 +0800 Subject: [PATCH] fix ISSUE I4QCJM --- mindspore/ccsrc/fl/server/common.h | 16 ++++ mindspore/ccsrc/fl/server/iteration.cc | 96 +++++++++++++++++-- mindspore/ccsrc/fl/server/iteration.h | 43 +++++++-- .../ccsrc/fl/server/iteration_metrics.cc | 17 ++-- mindspore/ccsrc/fl/server/iteration_metrics.h | 26 ++--- .../server/kernel/round/get_model_kernel.cc | 4 +- .../fl/server/kernel/round/round_kernel.cc | 24 +++++ .../fl/server/kernel/round/round_kernel.h | 17 ++++ .../kernel/round/start_fl_job_kernel.cc | 4 +- .../kernel/round/update_model_kernel.cc | 3 +- mindspore/ccsrc/fl/server/round.cc | 17 +++- mindspore/ccsrc/fl/server/round.h | 13 +++ mindspore/ccsrc/fl/server/server.cc | 2 + mindspore/ccsrc/pipeline/jit/init.cc | 6 +- mindspore/ccsrc/ps/core/protos/fl.proto | 9 ++ mindspore/ccsrc/ps/ps_context.cc | 6 ++ mindspore/ccsrc/ps/ps_context.h | 13 ++- mindspore/python/mindspore/context.py | 5 +- .../python/mindspore/parallel/_ps_context.py | 6 +- scripts/fl_restful_tool.py | 9 +- 20 files changed, 276 insertions(+), 60 deletions(-) diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index 70901add4e0..8590c00d08a 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -86,6 +86,13 @@ enum class InstanceState { kFinish }; +enum class IterationResult { + // The iteration is timeout because of startfljob or updatemodel timeout. + kTimeout, + // The iteration is successful aggregation. + kSuccess +}; + using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; using mindspore::kernel::CPUKernel; @@ -133,6 +140,15 @@ constexpr auto kFtrlLinear = "linear"; constexpr auto kDataSize = "data_size"; constexpr auto kNewDataSize = "new_data_size"; constexpr auto kStat = "stat"; +constexpr auto kStartFLJobTotalClientNum = "startFLJobTotalClientNum"; +constexpr auto kStartFLJobAcceptClientNum = "startFLJobAcceptClientNum"; +constexpr auto kStartFLJobRejectClientNum = "startFLJobRejectClientNum"; +constexpr auto kUpdateModelTotalClientNum = "updateModelTotalClientNum"; +constexpr auto kUpdateModelAcceptClientNum = "updateModelAcceptClientNum"; +constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum"; +constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum"; +constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum"; +constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum"; // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is // launched. diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index 3a81d9dc4eb..5d0032de0ab 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -162,6 +162,8 @@ void Iteration::SetIterationRunning() { std::unique_lock lock(iteration_state_mtx_); iteration_state_ = IterationState::kRunning; start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()); + MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " start global timer."; + global_iter_timer_->Start(std::chrono::milliseconds(global_iteration_time_window_)); } void Iteration::SetIterationEnd() { @@ -577,6 +579,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { // Store the model which is successfully aggregated for this iteration. const auto &model = Executor::GetInstance().GetModel(); ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + iteration_result_ = IterationResult::kSuccess; MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; } else { // Store last iteration's model because this iteration is considered as invalid. @@ -584,6 +587,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { size_t latest_iter_num = iter_to_model.rbegin()->first; const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num); ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + iteration_result_ = IterationResult::kTimeout; MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; } @@ -591,6 +595,13 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { MS_ERROR_IF_NULL_WO_RET_VAL(round); round->Reset(); } + MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " stop global timer."; + global_iter_timer_->Stop(); + + for (const auto &round : rounds_) { + MS_ERROR_IF_NULL_WO_RET_VAL(round); + round->KernelSummarize(); + } } bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) { @@ -598,11 +609,14 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) { MS_LOG(INFO) << "Notify all follower servers to end last iteration."; EndLastIterRequest end_last_iter_req; end_last_iter_req.set_last_iter_num(last_iter_num); + std::shared_ptr> client_info_rsp_msg = nullptr; for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { - if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) { + if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter, + &client_info_rsp_msg)) { MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed."; continue; } + UpdateRoundClientNumMap(client_info_rsp_msg); } EndLastIter(); @@ -629,10 +643,29 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptrname() == "startFLJob") { + end_last_iter_rsp.set_startfljob_total_client_num(round->kernel_total_client_num()); + end_last_iter_rsp.set_startfljob_accept_client_num(round->kernel_accept_client_num()); + end_last_iter_rsp.set_startfljob_reject_client_num(round->kernel_reject_client_num()); + } else if (round->name() == "updateModel") { + end_last_iter_rsp.set_updatemodel_total_client_num(round->kernel_total_client_num()); + end_last_iter_rsp.set_updatemodel_accept_client_num(round->kernel_accept_client_num()); + end_last_iter_rsp.set_updatemodel_reject_client_num(round->kernel_reject_client_num()); + } else if (round->name() == "getModel") { + end_last_iter_rsp.set_getmodel_total_client_num(round->kernel_total_client_num()); + end_last_iter_rsp.set_getmodel_accept_client_num(round->kernel_accept_client_num()); + end_last_iter_rsp.set_getmodel_reject_client_num(round->kernel_reject_client_num()); + } + } + + EndLastIter(); if (!communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), end_last_iter_rsp.SerializeAsString().size(), message)) { MS_LOG(ERROR) << "Sending response failed."; @@ -668,7 +701,10 @@ void Iteration::EndLastIter() { MS_LOG(WARNING) << "Can't save current iteration number into persistent storage."; } } - + for (const auto &round : rounds_) { + round->InitkernelClientVisitedNum(); + } + round_client_num_map_.clear(); Server::GetInstance().CancelSafeMode(); iteration_state_cv_.notify_all(); MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n"; @@ -692,12 +728,8 @@ bool Iteration::SummarizeIteration() { metrics_->set_instance_state(instance_state_.load()); metrics_->set_loss(loss_); metrics_->set_accuracy(accuracy_); - // The joined client number is equal to the threshold of updateModel. - size_t update_model_threshold = static_cast( - std::ceil(ps::PSContext::instance()->start_fl_job_threshold() * ps::PSContext::instance()->update_model_ratio())); - metrics_->set_joined_client_num(update_model_threshold); - // The rejected client number is equal to threshold of startFLJob minus threshold of updateModel. - metrics_->set_rejected_client_num(ps::PSContext::instance()->start_fl_job_threshold() - update_model_threshold); + metrics_->set_round_client_num_map(round_client_num_map_); + metrics_->set_iteration_result(iteration_result_.load()); if (complete_timestamp_ < start_timestamp_) { MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_ @@ -749,7 +781,21 @@ bool Iteration::UpdateHyperParams(const nlohmann::json &json) { ps::PSContext::instance()->set_client_learning_rate(item.value().get()); continue; } + if (key == "global_iteration_time_window") { + ps::PSContext::instance()->set_global_iteration_time_window(item.value().get()); + continue; + } } + + MS_LOG(INFO) << "start_fl_job_threshold: " << ps::PSContext::instance()->start_fl_job_threshold(); + MS_LOG(INFO) << "start_fl_job_time_window: " << ps::PSContext::instance()->start_fl_job_time_window(); + MS_LOG(INFO) << "update_model_ratio: " << ps::PSContext::instance()->update_model_ratio(); + MS_LOG(INFO) << "update_model_time_window: " << ps::PSContext::instance()->update_model_time_window(); + MS_LOG(INFO) << "fl_iteration_num: " << ps::PSContext::instance()->fl_iteration_num(); + MS_LOG(INFO) << "client_epoch_num: " << ps::PSContext::instance()->client_epoch_num(); + MS_LOG(INFO) << "client_batch_size: " << ps::PSContext::instance()->client_batch_size(); + MS_LOG(INFO) << "client_learning_rate: " << ps::PSContext::instance()->client_learning_rate(); + MS_LOG(INFO) << "global_iteration_time_window: " << ps::PSContext::instance()->global_iteration_time_window(); return true; } @@ -784,6 +830,36 @@ bool Iteration::ReInitRounds() { } return true; } + +void Iteration::InitGlobalIterTimer(const TimeOutCb &timeout_cb) { + global_iteration_time_window_ = ps::PSContext::instance()->global_iteration_time_window(); + global_iter_timer_ = std::make_shared(); + + // Set the timeout callback for the timer. + global_iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void { + std::string reason = "Global Iteration " + std::to_string(iteration_num_) + + " timeout! This iteration is invalid. Proceed to next iteration."; + timeout_cb(false, reason); + }); +} + +void Iteration::UpdateRoundClientNumMap(const std::shared_ptr> &client_info_rsp_msg) { + MS_ERROR_IF_NULL_WO_RET_VAL(client_info_rsp_msg); + EndLastIterResponse end_last_iter_rsp; + (void)end_last_iter_rsp.ParseFromArray(client_info_rsp_msg->data(), SizeToInt(client_info_rsp_msg->size())); + + round_client_num_map_[kStartFLJobTotalClientNum] += end_last_iter_rsp.startfljob_total_client_num(); + round_client_num_map_[kStartFLJobAcceptClientNum] += end_last_iter_rsp.startfljob_accept_client_num(); + round_client_num_map_[kStartFLJobRejectClientNum] += end_last_iter_rsp.startfljob_reject_client_num(); + + round_client_num_map_[kUpdateModelTotalClientNum] += end_last_iter_rsp.updatemodel_total_client_num(); + round_client_num_map_[kUpdateModelAcceptClientNum] += end_last_iter_rsp.updatemodel_accept_client_num(); + round_client_num_map_[kUpdateModelRejectClientNum] += end_last_iter_rsp.updatemodel_reject_client_num(); + + round_client_num_map_[kGetModelTotalClientNum] += end_last_iter_rsp.getmodel_total_client_num(); + round_client_num_map_[kGetModelAcceptClientNum] += end_last_iter_rsp.getmodel_accept_client_num(); + round_client_num_map_[kGetModelRejectClientNum] += end_last_iter_rsp.getmodel_reject_client_num(); +} } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/iteration.h b/mindspore/ccsrc/fl/server/iteration.h index c1961669144..c66db88159b 100644 --- a/mindspore/ccsrc/fl/server/iteration.h +++ b/mindspore/ccsrc/fl/server/iteration.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "ps/core/communicator/communicator_base.h" #include "fl/server/common.h" #include "fl/server/round.h" @@ -125,9 +126,15 @@ class Iteration { // Synchronize server iteration after another server's recovery is completed. bool SyncAfterRecovery(uint64_t iteration_num); + // Initialize global iteration timer. + void InitGlobalIterTimer(const TimeOutCb &timeout_cb); + // The round kernels whose Launch method has not returned yet. std::atomic_uint32_t running_round_num_; + // Update count with client visited num in round + void UpdateRoundClientNumMap(const std::string &name, const size_t num); + private: Iteration() : running_round_num_(0), @@ -147,9 +154,18 @@ class Iteration { is_instance_being_updated_(false), loss_(0.0), accuracy_(0.0), - joined_client_num_(0), - rejected_client_num_(0), - time_cost_(0) { + time_cost_(0), + global_iteration_time_window_(0), + round_client_num_map_({{kStartFLJobTotalClientNum, 0}, + {kUpdateModelTotalClientNum, 0}, + {kGetModelTotalClientNum, 0}, + {kStartFLJobAcceptClientNum, 0}, + {kUpdateModelAcceptClientNum, 0}, + {kGetModelAcceptClientNum, 0}, + {kStartFLJobRejectClientNum, 0}, + {kUpdateModelRejectClientNum, 0}, + {kGetModelRejectClientNum, 0}}), + iteration_result_(IterationResult::kSuccess) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); } ~Iteration(); @@ -202,6 +218,8 @@ class Iteration { // Reinitialize rounds and round kernels. bool ReInitRounds(); + void UpdateRoundClientNumMap(const std::shared_ptr> &client_info_rsp_msg); + std::shared_ptr server_node_; std::shared_ptr communicator_; @@ -246,6 +264,7 @@ class Iteration { // Every instance is not reentrant. // This flag represents whether the instance is being updated. std::mutex instance_mtx_; + bool is_instance_being_updated_; // The training loss after this federated learning iteration, passed by worker. @@ -254,14 +273,20 @@ class Iteration { // The evaluation result after this federated learning iteration, passed by worker. float accuracy_; - // The number of clients which join the federated aggregation. - size_t joined_client_num_; - - // The number of clients which are not involved in federated aggregation. - size_t rejected_client_num_; - // The time cost in millisecond for this completed iteration. uint64_t time_cost_; + + // global iteration time window + uint64_t global_iteration_time_window_; + + // for example: "startFLJobTotalClientNum" -> startFLJob total client num + std::map round_client_num_map_; + + // Iteration global timer. + std::shared_ptr global_iter_timer_; + + // The result for current iteration result. + std::atomic iteration_result_; }; } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.cc b/mindspore/ccsrc/fl/server/iteration_metrics.cc index 28c165670f4..8a31729dd05 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.cc +++ b/mindspore/ccsrc/fl/server/iteration_metrics.cc @@ -79,11 +79,12 @@ bool IterationMetrics::Summarize() { js_[kInstanceStatus] = kInstanceStateName.at(instance_state_); js_[kFLIterationNum] = fl_iteration_num_; js_[kCurIteration] = cur_iteration_num_; - js_[kJoinedClientNum] = joined_client_num_; - js_[kRejectedClientNum] = rejected_client_num_; js_[kMetricsAuc] = accuracy_; js_[kMetricsLoss] = loss_; js_[kIterExecutionTime] = iteration_time_cost_; + js_[kClientVisitedInfo] = round_client_num_map_; + js_[kIterationResult] = kIterationResultName.at(iteration_result_); + metrics_file_ << js_ << "\n"; (void)metrics_file_.flush(); metrics_file_.close(); @@ -111,15 +112,15 @@ void IterationMetrics::set_loss(float loss) { loss_ = loss; } void IterationMetrics::set_accuracy(float acc) { accuracy_ = acc; } -void IterationMetrics::set_joined_client_num(size_t joined_client_num) { joined_client_num_ = joined_client_num; } - -void IterationMetrics::set_rejected_client_num(size_t rejected_client_num) { - rejected_client_num_ = rejected_client_num; -} - void IterationMetrics::set_iteration_time_cost(uint64_t iteration_time_cost) { iteration_time_cost_ = iteration_time_cost; } + +void IterationMetrics::set_round_client_num_map(const std::map round_client_num_map) { + round_client_num_map_ = round_client_num_map; +} + +void IterationMetrics::set_iteration_result(IterationResult iteration_result) { iteration_result_ = iteration_result; } } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.h b/mindspore/ccsrc/fl/server/iteration_metrics.h index 96c9d0d938b..8de1520ced8 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.h +++ b/mindspore/ccsrc/fl/server/iteration_metrics.h @@ -34,16 +34,19 @@ constexpr auto kFLName = "flName"; constexpr auto kInstanceStatus = "instanceStatus"; constexpr auto kFLIterationNum = "flIterationNum"; constexpr auto kCurIteration = "currentIteration"; -constexpr auto kJoinedClientNum = "joinedClientNum"; -constexpr auto kRejectedClientNum = "rejectedClientNum"; constexpr auto kMetricsAuc = "metricsAuc"; constexpr auto kMetricsLoss = "metricsLoss"; constexpr auto kIterExecutionTime = "iterationExecutionTime"; constexpr auto kMetrics = "metrics"; +constexpr auto kClientVisitedInfo = "clientVisitedInfo"; +constexpr auto kIterationResult = "iterationResult"; const std::map kInstanceStateName = { {InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}}; +const std::map kIterationResultName = {{IterationResult::kSuccess, "success"}, + {IterationResult::kTimeout, "timeout"}}; + class IterationMetrics { public: explicit IterationMetrics(const std::string &config_file) @@ -55,9 +58,8 @@ class IterationMetrics { instance_state_(InstanceState::kFinish), loss_(0.0), accuracy_(0.0), - joined_client_num_(0), - rejected_client_num_(0), - iteration_time_cost_(0) {} + iteration_time_cost_(0), + iteration_result_(IterationResult::kSuccess) {} ~IterationMetrics() = default; bool Initialize(); @@ -75,9 +77,9 @@ class IterationMetrics { void set_instance_state(InstanceState state); void set_loss(float loss); void set_accuracy(float acc); - void set_joined_client_num(size_t joined_client_num); - void set_rejected_client_num(size_t rejected_client_num); void set_iteration_time_cost(uint64_t iteration_time_cost); + void set_round_client_num_map(const std::map round_client_num_map); + void set_iteration_result(IterationResult iteration_result); private: // This is the main config file set by ps context. @@ -112,14 +114,14 @@ class IterationMetrics { // The evaluation result after this federated learning iteration, passed by worker. float accuracy_; - // The number of clients which join the federated aggregation. - size_t joined_client_num_; - - // The number of clients which are not involved in federated aggregation. - size_t rejected_client_num_; + // for example: "startFLJobTotalClientNum" -> startFLJob total client num + std::map round_client_num_map_; // The time cost in millisecond for this completed iteration. uint64_t iteration_time_cost_; + + // Current iteration running result. + IterationResult iteration_result_; }; } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc index 9752cd9bece..1fe2ffb0717 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc @@ -30,7 +30,7 @@ void GetModelKernel::InitKernel(size_t) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - + InitClientVisitedNum(); executor_ = &Executor::GetInstance(); MS_EXCEPTION_IF_NULL(executor_); if (!executor_->initialized()) { @@ -119,7 +119,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons } else { feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter); } - + IncreaseAcceptClientNum(); MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() << ", next request time is " << next_req_time << ", current iteration is " << current_iter; BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter), diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc index 042323ab281..13b1980f536 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc @@ -125,8 +125,32 @@ void RoundKernel::GenerateOutput(const std::vector &outputs, const v std::unique_lock lock(heap_data_mtx_); (void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data))); + IncreaseTotalClientNum(); return; } + +void RoundKernel::IncreaseTotalClientNum() { total_client_num_ += 1; } + +void RoundKernel::IncreaseAcceptClientNum() { accept_client_num_ += 1; } + +void RoundKernel::Summarize() { + if (name_ == "startFLJob" || name_ == "updateModel" || name_ == "getModel") { + MS_LOG(INFO) << "Round kernel " << name_ << " total client num is: " << total_client_num_ + << ", accept client num is: " << accept_client_num_ + << ", reject client num is: " << (total_client_num_ - accept_client_num_); + } +} + +size_t RoundKernel::total_client_num() const { return total_client_num_; } + +size_t RoundKernel::accept_client_num() const { return accept_client_num_; } + +size_t RoundKernel::reject_client_num() const { return total_client_num_ - accept_client_num_; } + +void RoundKernel::InitClientVisitedNum() { + total_client_num_ = 0; + accept_client_num_ = 0; +} } // namespace kernel } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h index c7184a89eda..94953b1e72c 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h @@ -89,6 +89,20 @@ class RoundKernel : virtual public CPUKernel { void set_stop_timer_cb(const StopTimerCb &timer_stopper); void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb); + void Summarize(); + + void IncreaseTotalClientNum(); + + void IncreaseAcceptClientNum(); + + size_t total_client_num() const; + + size_t accept_client_num() const; + + size_t reject_client_num() const; + + void InitClientVisitedNum(); + protected: // Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent // back to worker. @@ -121,6 +135,9 @@ class RoundKernel : virtual public CPUKernel { std::queue heap_data_to_release_; std::mutex heap_data_mtx_; std::unordered_map> heap_data_; + + std::atomic total_client_num_; + std::atomic accept_client_num_; }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index 9b1cf7e8935..ba071778038 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -36,14 +36,13 @@ void StartFLJobKernel::InitKernel(size_t) { } iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_; LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_); - + InitClientVisitedNum(); executor_ = &Executor::GetInstance(); MS_EXCEPTION_IF_NULL(executor_); if (!executor_->initialized()) { MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; return; } - PBMetadata devices_metas; DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas); @@ -132,6 +131,7 @@ bool StartFLJobKernel::Launch(const std::vector &inputs, const std:: GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return ConvertResultCode(result_code); } + IncreaseAcceptClientNum(); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index 252167ea0e3..3262f26aaa0 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -29,7 +29,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - + InitClientVisitedNum(); executor_ = &Executor::GetInstance(); MS_EXCEPTION_IF_NULL(executor_); if (!executor_->initialized()) { @@ -121,6 +121,7 @@ bool UpdateModelKernel::Launch(const std::vector &inputs, const std: GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return ConvertResultCode(result_code); } + IncreaseAcceptClientNum(); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } diff --git a/mindspore/ccsrc/fl/server/round.cc b/mindspore/ccsrc/fl/server/round.cc index 261c62f7999..c0933ccff5c 100644 --- a/mindspore/ccsrc/fl/server/round.cc +++ b/mindspore/ccsrc/fl/server/round.cc @@ -226,14 +226,27 @@ bool Round::IsServerAvailable(std::string *reason) { return false; } - // If the server is still in the process of scaling, reject the request. + // If the server is still in safemode, reject the request. if (Server::GetInstance().IsSafeMode()) { - MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later."; + MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later."; *reason = ps::kClusterSafeMode; return false; } return true; } + +void Round::KernelSummarize() { + MS_ERROR_IF_NULL_WO_RET_VAL(kernel_); + (void)kernel_->Summarize(); +} + +size_t Round::kernel_total_client_num() const { return kernel_->total_client_num(); } + +size_t Round::kernel_accept_client_num() const { return kernel_->accept_client_num(); } + +size_t Round::kernel_reject_client_num() const { return kernel_->reject_client_num(); } + +void Round::InitkernelClientVisitedNum() { kernel_->InitClientVisitedNum(); } } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/round.h b/mindspore/ccsrc/fl/server/round.h index cbd868b1f43..fe08638bd41 100644 --- a/mindspore/ccsrc/fl/server/round.h +++ b/mindspore/ccsrc/fl/server/round.h @@ -56,11 +56,24 @@ class Round { // Round needs to be reset after each iteration is finished or its timer expires. void Reset(); + void KernelSummarize(); + const std::string &name() const; + size_t threshold_count() const; + bool check_timeout() const; + size_t time_window() const; + size_t kernel_total_client_num() const; + + size_t kernel_accept_client_num() const; + + size_t kernel_reject_client_num() const; + + void InitkernelClientVisitedNum(); + private: // The callbacks which will be set to DistributedCounterService. void OnFirstCountEvent(const std::shared_ptr &message); diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index 1020267a5ff..deb53a930d0 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -242,6 +242,8 @@ void Server::InitIteration() { FinishIterCb finish_iter_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2); iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb); + + iteration_->InitGlobalIterTimer(time_out_cb); return; } diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 7a65ac314c5..b3faf404d1f 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -475,8 +475,10 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_encrypt_type", &PSContext::set_encrypt_type, "Set encrypt type for federated learning secure aggregation.") .def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.") - .def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication."); - + .def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.") + .def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window, + "Set global iteration time window.") + .def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window."); (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted"); diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto index 93c46e12661..98338abd71d 100644 --- a/mindspore/ccsrc/ps/core/protos/fl.proto +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -237,6 +237,15 @@ message EndLastIterRequest { message EndLastIterResponse { string result = 1; + uint64 startFLJob_total_client_num = 2; + uint64 startFLJob_accept_client_num = 3; + uint64 startFLJob_reject_client_num = 4; + uint64 updateModel_total_client_num = 5; + uint64 updateModel_accept_client_num = 6; + uint64 updateModel_reject_client_num = 7; + uint64 getModel_total_client_num = 8; + uint64 getModel_accept_client_num = 9; + uint64 getModel_reject_client_num = 10; } message SyncAfterRecover { diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 0fc2c74744b..3085a298ea7 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -476,5 +476,11 @@ void PSContext::set_server_password(const std::string &password) { server_passwo std::string PSContext::http_url_prefix() const { return http_url_prefix_; } void PSContext::set_http_url_prefix(const std::string &http_url_prefix) { http_url_prefix_ = http_url_prefix; } + +void PSContext::set_global_iteration_time_window(const uint64_t &global_iteration_time_window) { + global_iteration_time_window_ = global_iteration_time_window; +} + +uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; } } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 067634f530f..a0729054666 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -211,6 +211,9 @@ class PSContext { std::string http_url_prefix() const; void set_http_url_prefix(const std::string &http_url_prefix); + void set_global_iteration_time_window(const uint64_t &global_iteration_time_window); + uint64_t global_iteration_time_window() const; + private: PSContext() : ps_enabled_(false), @@ -229,9 +232,9 @@ class PSContext { fl_client_enable_(false), fl_name_(""), start_fl_job_threshold_(0), - start_fl_job_time_window_(3000), + start_fl_job_time_window_(300000), update_model_ratio_(1.0), - update_model_time_window_(3000), + update_model_time_window_(300000), share_secrets_ratio_(1.0), cipher_time_window_(300000), reconstruct_secrets_threshold_(2000), @@ -257,7 +260,8 @@ class PSContext { enable_ssl_(false), client_password_(""), server_password_(""), - http_url_prefix_("") {} + http_url_prefix_(""), + global_iteration_time_window_(21600000) {} bool ps_enabled_; bool is_worker_; bool is_pserver_; @@ -372,6 +376,9 @@ class PSContext { std::string server_password_; // http url prefix for http communication std::string http_url_prefix_; + + // The time window of startFLJob round in millisecond. + uint64_t global_iteration_time_window_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/python/mindspore/context.py b/mindspore/python/mindspore/context.py index 7ff3d037ea4..fe371696ffb 100644 --- a/mindspore/python/mindspore/context.py +++ b/mindspore/python/mindspore/context.py @@ -1100,7 +1100,10 @@ def set_fl_context(**kwargs): pki_verify is True. Default: "". replay_attack_time_diff (int): The maximum tolerable error of certificate timestamp verification (ms). Default: 600000. - + http_url_prefix (string): The http url prefix for http server. + Default: "". + global_iteration_time_window (unsigned long): The global iteration time window for one iteration + with rounds(ms). Default: 21600000. Raises: ValueError: If input key is not the attribute in federated learning mode context. diff --git a/mindspore/python/mindspore/parallel/_ps_context.py b/mindspore/python/mindspore/parallel/_ps_context.py index 7406042c83a..e49c1fb3d1d 100644 --- a/mindspore/python/mindspore/parallel/_ps_context.py +++ b/mindspore/python/mindspore/parallel/_ps_context.py @@ -71,7 +71,8 @@ _set_ps_context_func_map = { "dp_delta": ps_context().set_dp_delta, "dp_norm_clip": ps_context().set_dp_norm_clip, "encrypt_type": ps_context().set_encrypt_type, - "http_url_prefix": ps_context().set_http_url_prefix + "http_url_prefix": ps_context().set_http_url_prefix, + "global_iteration_time_window": ps_context().set_global_iteration_time_window } _get_ps_context_func_map = { @@ -112,7 +113,8 @@ _get_ps_context_func_map = { "server_password": ps_context().server_password, "scheduler_manage_port": ps_context().scheduler_manage_port, "config_file_path": ps_context().config_file_path, - "http_url_prefix": ps_context().http_url_prefix + "http_url_prefix": ps_context().http_url_prefix, + "global_iteration_time_window": ps_context().global_iteration_time_window } _check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", diff --git a/scripts/fl_restful_tool.py b/scripts/fl_restful_tool.py index c519d57380c..1eabaffb796 100644 --- a/scripts/fl_restful_tool.py +++ b/scripts/fl_restful_tool.py @@ -178,11 +178,10 @@ def call_get_instance_detail(): return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.") ans_json_obj = {} - joined_client_num_list = [] - rejected_client_num_list = [] metrics_auc_list = [] metrics_loss_list = [] iteration_execution_time_list = [] + client_visited_info_list = [] with open(metrics_file_path, 'r') as f: metrics_list = f.readlines() @@ -193,8 +192,7 @@ def call_get_instance_detail(): for metrics in metrics_list: json_obj = json.loads(metrics) iteration_execution_time_list.append(json_obj['iterationExecutionTime']) - joined_client_num_list.append(json_obj['joinedClientNum']) - rejected_client_num_list.append(json_obj['rejectedClientNum']) + client_visited_info_list.append(json_obj['clientVisitedInfo']) metrics_auc_list.append(json_obj['metricsAuc']) metrics_loss_list.append(json_obj['metricsLoss']) @@ -210,8 +208,7 @@ def call_get_instance_detail(): ans_json_result['flName'] = last_metrics_obj['flName'] ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus'] ans_json_result['iterationExecutionTime'] = iteration_execution_time_list - ans_json_result['joinedClientNum'] = joined_client_num_list - ans_json_result['rejectedClientNum'] = rejected_client_num_list + ans_json_result['clientVisitedInfo'] = client_visited_info_list ans_json_result['metricsAuc'] = metrics_auc_list ans_json_result['metricsLoss'] = metrics_loss_list