From 51f6b77ab04779d4cf1d3fae078763be9122aa50 Mon Sep 17 00:00:00 2001 From: zhou_chao1993 Date: Fri, 22 Apr 2022 14:29:24 +0800 Subject: [PATCH] The maintainability function of FL --- mindspore/ccsrc/fl/server/common.h | 5 + mindspore/ccsrc/fl/server/iteration.cc | 167 +++++++++++++++++- mindspore/ccsrc/fl/server/iteration.h | 40 ++++- .../ccsrc/fl/server/iteration_metrics.cc | 62 +++---- mindspore/ccsrc/fl/server/iteration_metrics.h | 13 ++ .../fl/server/kernel/round/round_kernel.cc | 69 +++++++- .../fl/server/kernel/round/round_kernel.h | 53 +++++- .../kernel/round/start_fl_job_kernel.cc | 3 + .../kernel/round/update_model_kernel.cc | 90 +++++++++- .../server/kernel/round/update_model_kernel.h | 26 ++- mindspore/ccsrc/fl/server/round.cc | 33 +++- mindspore/ccsrc/fl/server/round.h | 18 +- mindspore/ccsrc/fl/server/server.cc | 9 + mindspore/ccsrc/pipeline/jit/init.cc | 8 +- mindspore/ccsrc/ps/core/abstract_node.cc | 32 +++- mindspore/ccsrc/ps/core/abstract_node.h | 27 +-- mindspore/ccsrc/ps/core/comm_util.cc | 54 +++++- mindspore/ccsrc/ps/core/comm_util.h | 49 +++-- .../ccsrc/ps/core/communicator/message.h | 3 +- mindspore/ccsrc/ps/core/protos/comm.proto | 10 ++ mindspore/ccsrc/ps/core/protos/fl.proto | 4 + mindspore/ccsrc/ps/core/ps_scheduler_node.h | 4 + mindspore/ccsrc/ps/core/scheduler_node.cc | 88 ++++++++- mindspore/ccsrc/ps/core/scheduler_node.h | 19 +- mindspore/ccsrc/ps/core/server_node.cc | 5 + mindspore/ccsrc/ps/core/worker_node.cc | 8 +- mindspore/ccsrc/ps/ps_context.cc | 21 ++- mindspore/ccsrc/ps/ps_context.h | 23 ++- .../python/mindspore/parallel/_ps_context.py | 6 + scripts/fl_restful_tool.py | 21 +-- tests/st/fl/albert/config.json | 8 + .../fl/cross_device_lenet/cloud/config.json | 8 + .../st/fl/cross_silo_faster_rcnn/config.json | 8 + tests/st/fl/cross_silo_femnist/config.json | 8 + tests/st/fl/cross_silo_lenet/config.json | 8 + tests/st/fl/hybrid_lenet/config.json | 8 + tests/st/fl/mobile/config.json | 8 + 37 files changed, 900 insertions(+), 126 deletions(-) diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index b22d273fe85..a81a7b5f8d1 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "proto/ps.pb.h" #include "proto/fl.pb.h" #include "ir/anf.h" @@ -149,11 +150,15 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum"; constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum"; constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum"; constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum"; +constexpr auto kParticipationTimeLevel1 = "participationTimeLevel1"; +constexpr auto kParticipationTimeLevel2 = "participationTimeLevel2"; +constexpr auto kParticipationTimeLevel3 = "participationTimeLevel3"; constexpr auto kMinVal = "min_val"; constexpr auto kMaxVal = "max_val"; constexpr auto kQuant = "QUANT"; constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT"; constexpr auto kNoCompress = "NO_COMPRESS"; +constexpr auto kUpdateModel = "updateModel"; // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is // launched. diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index ffcec0a629e..bc9e13b4e53 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -15,27 +15,40 @@ */ #include "fl/server/iteration.h" + #include -#include -#include #include +#include #include +#include + #include "fl/server/model_store.h" #include "fl/server/server.h" +#include "ps/core/comm_util.h" namespace mindspore { namespace fl { namespace server { +namespace { +const size_t kParticipationTimeLevelNum = 3; +const size_t kIndexZero = 0; +const size_t kIndexOne = 1; +const size_t kIndexTwo = 2; +const size_t kLastSecond = 59; +} // namespace class Server; Iteration::~Iteration() { move_to_next_thread_running_ = false; + is_date_rate_thread_running_ = false; next_iteration_cv_.notify_all(); if (move_to_next_thread_.joinable()) { move_to_next_thread_.join(); } + if (data_rate_thread_.joinable()) { + data_rate_thread_.join(); + } } - void Iteration::RegisterMessageCallback(const std::shared_ptr &communicator) { MS_EXCEPTION_IF_NULL(communicator); communicator_ = communicator; @@ -160,8 +173,13 @@ void Iteration::SetIterationRunning() { std::unique_lock lock(iteration_state_mtx_); iteration_state_ = IterationState::kRunning; - start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()); + start_time_ = ps::core::CommUtil::GetNowTime(); MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " start global timer."; + instance_name_ = ps::PSContext::instance()->instance_name(); + if (instance_name_.empty()) { + MS_LOG(WARNING) << "instance name is empty"; + instance_name_ = "instance_" + start_time_.time_str_mill; + } global_iter_timer_->Start(std::chrono::milliseconds(global_iteration_time_window_)); } @@ -175,7 +193,7 @@ void Iteration::SetIterationEnd() { std::unique_lock lock(iteration_state_mtx_); iteration_state_ = IterationState::kCompleted; - complete_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()); + complete_time_ = ps::core::CommUtil::GetNowTime(); } void Iteration::ScalingBarrier() { @@ -631,6 +649,13 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { round_client_num_map_[kUpdateModelAcceptClientNum] += round->kernel_accept_client_num(); round_client_num_map_[kUpdateModelRejectClientNum] += round->kernel_reject_client_num(); set_loss(loss_ + round->kernel_upload_loss()); + auto update_model_complete_info = round->GetUpdateModelCompleteInfo(); + if (update_model_complete_info.size() != kParticipationTimeLevelNum) { + continue; + } + round_client_num_map_[kParticipationTimeLevel1] += update_model_complete_info[kIndexZero].second; + round_client_num_map_[kParticipationTimeLevel2] += update_model_complete_info[kIndexOne].second; + round_client_num_map_[kParticipationTimeLevel3] += update_model_complete_info[kIndexTwo].second; } else if (round->name() == "getModel") { round_client_num_map_[kGetModelTotalClientNum] += round->kernel_total_client_num(); round_client_num_map_[kGetModelAcceptClientNum] += round->kernel_accept_client_num(); @@ -656,6 +681,13 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) { } EndLastIter(); + if (iteration_fail_num_ == ps::PSContext::instance()->continuous_failure_times()) { + std::string node_role = "SERVER"; + std::string event = "Iteration failed " + std::to_string(iteration_fail_num_) + " times continuously"; + server_node_->SendFailMessageToScheduler(node_role, event); + // Finish sending one message, reset cout num to 0 + iteration_fail_num_ = 0; + } return true; } @@ -695,6 +727,13 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptrkernel_accept_client_num()); end_last_iter_rsp.set_updatemodel_reject_client_num(round->kernel_reject_client_num()); end_last_iter_rsp.set_upload_loss(round->kernel_upload_loss()); + auto update_model_complete_info = round->GetUpdateModelCompleteInfo(); + if (update_model_complete_info.size() != kParticipationTimeLevelNum) { + MS_LOG(EXCEPTION) << "update_model_complete_info size is not equal 3"; + } + end_last_iter_rsp.set_participation_time_level1_num(update_model_complete_info[kIndexZero].second); + end_last_iter_rsp.set_participation_time_level2_num(update_model_complete_info[kIndexOne].second); + end_last_iter_rsp.set_participation_time_level3_num(update_model_complete_info[kIndexTwo].second); } else if (round->name() == "getModel") { end_last_iter_rsp.set_getmodel_total_client_num(round->kernel_total_client_num()); end_last_iter_rsp.set_getmodel_accept_client_num(round->kernel_accept_client_num()); @@ -744,6 +783,7 @@ void Iteration::EndLastIter() { MS_ERROR_IF_NULL_WO_RET_VAL(round); round->InitkernelClientVisitedNum(); round->InitkernelClientUploadLoss(); + round->ResetParticipationTimeAndNum(); } round_client_num_map_.clear(); set_loss(0.0f); @@ -754,6 +794,11 @@ void Iteration::EndLastIter() { } else { MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n"; } + if (iteration_result_.load() == IterationResult::kFail) { + iteration_fail_num_++; + } else { + iteration_fail_num_ = 0; + } } bool Iteration::ForciblyMoveToNextIteration() { @@ -768,6 +813,9 @@ bool Iteration::SummarizeIteration() { return true; } + metrics_->set_instance_name(instance_name_); + metrics_->set_start_time(start_time_); + metrics_->set_end_time(complete_time_); metrics_->set_fl_name(ps::PSContext::instance()->fl_name()); metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num()); metrics_->set_cur_iteration_num(iteration_num_); @@ -781,12 +829,12 @@ bool Iteration::SummarizeIteration() { metrics_->set_round_client_num_map(round_client_num_map_); metrics_->set_iteration_result(iteration_result_.load()); - if (complete_timestamp_ < start_timestamp_) { - MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_ - << ". One of them is invalid."; + if (complete_time_.time_stamp < start_time_.time_stamp) { + MS_LOG(ERROR) << "The complete_timestamp_: " << complete_time_.time_stamp + << ", start_timestamp_: " << start_time_.time_stamp << ". One of them is invalid."; metrics_->set_iteration_time_cost(UINT64_MAX); } else { - metrics_->set_iteration_time_cost(complete_timestamp_ - start_timestamp_); + metrics_->set_iteration_time_cost(complete_time_.time_stamp - start_time_.time_stamp); } if (!metrics_->Summarize()) { @@ -910,6 +958,10 @@ void Iteration::UpdateRoundClientNumMap(const std::shared_ptr> &client_info_rsp_msg) { @@ -924,6 +976,103 @@ void Iteration::set_instance_state(InstanceState state) { instance_state_ = state; MS_LOG(INFO) << "Server instance state is " << GetInstanceStateStr(instance_state_); } + +void Iteration::SetFileConfig(const std::shared_ptr &file_configuration) { + file_configuration_ = file_configuration; +} + +string Iteration::GetDataRateFilePath() { + ps::core::FileConfig data_rate_config; + if (!ps::core::CommUtil::ParseAndCheckConfigJson(file_configuration_.get(), kDataRate, &data_rate_config)) { + MS_LOG(EXCEPTION) << "Data rate parament in config is not correct"; + } + return data_rate_config.storage_file_path; +} + +void Iteration::StartThreadToRecordDataRate() { + MS_LOG(INFO) << "Start to create a thread to record data rate"; + data_rate_thread_ = std::thread([&]() { + std::fstream file_stream; + std::string data_rate_path = GetDataRateFilePath(); + MS_LOG(DEBUG) << "The data rate file path is " << data_rate_path; + uint32_t rank_id = server_node_->rank_id(); + while (is_date_rate_thread_running_) { + // record data every 60 seconds + std::this_thread::sleep_for(std::chrono::seconds(60)); + auto time_now = std::chrono::system_clock::now(); + std::time_t tt = std::chrono::system_clock::to_time_t(time_now); + struct tm ptm; + (void)localtime_r(&tt, &ptm); + std::ostringstream time_day_oss; + time_day_oss << std::put_time(&ptm, "%Y-%m-%d"); + std::string time_day = time_day_oss.str(); + std::string data_rate_file = data_rate_path + "/" + time_day + "_flow_server" + std::to_string(rank_id) + ".json"; + file_stream.open(data_rate_file, std::ios::out | std::ios::app); + if (!file_stream.is_open()) { + MS_LOG(WARNING) << data_rate_file << "is not open! Please check config file!"; + return; + } + std::map send_datas; + std::map receive_datas; + for (const auto &round : rounds_) { + if (round == nullptr) { + MS_LOG(WARNING) << "round is nullptr"; + continue; + } + auto send_data = round->GetSendData(); + for (const auto &it : send_data) { + if (send_datas.find(it.first) != send_datas.end()) { + send_datas[it.first] = send_datas[it.first] + it.second; + } else { + send_datas.emplace(it); + } + } + auto receive_data = round->GetReceiveData(); + for (const auto &it : receive_data) { + if (receive_datas.find(it.first) != receive_datas.end()) { + receive_datas[it.first] = receive_datas[it.first] + it.second; + } else { + receive_datas.emplace(it); + } + } + round->ClearData(); + } + + std::map> all_datas; + for (auto &it : send_datas) { + std::vector send_and_receive_data; + send_and_receive_data.emplace_back(it.second); + send_and_receive_data.emplace_back(0); + all_datas.emplace(it.first, send_and_receive_data); + } + for (auto &it : receive_datas) { + if (all_datas.find(it.first) != all_datas.end()) { + std::vector &temp = all_datas.at(it.first); + temp[1] = it.second; + } else { + std::vector send_and_receive_data; + send_and_receive_data.emplace_back(0); + send_and_receive_data.emplace_back(it.second); + all_datas.emplace(it.first, send_and_receive_data); + } + } + for (auto &it : all_datas) { + nlohmann::json js; + auto data_time = static_cast(it.first); + struct tm data_tm; + (void)localtime_r(&data_time, &data_tm); + std::ostringstream oss_second; + oss_second << std::put_time(&data_tm, "%Y-%m-%d %H:%M:%S"); + js["time"] = oss_second.str(); + js["send"] = it.second[0]; + js["receive"] = it.second[1]; + file_stream << js << "\n"; + } + (void)file_stream.close(); + } + }); + return; +} } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/iteration.h b/mindspore/ccsrc/fl/server/iteration.h index 1693dacd94a..4061b4b5821 100644 --- a/mindspore/ccsrc/fl/server/iteration.h +++ b/mindspore/ccsrc/fl/server/iteration.h @@ -22,6 +22,7 @@ #include #include #include "ps/core/communicator/communicator_base.h" +#include "ps/core/file_configuration.h" #include "fl/server/common.h" #include "fl/server/round.h" #include "fl/server/local_meta_store.h" @@ -134,14 +135,18 @@ class Iteration { void set_instance_state(InstanceState staet); + // Create a thread to record date rate + void StartThreadToRecordDataRate(); + + // Set file_configuration + void SetFileConfig(const std::shared_ptr &file_configuration); + private: Iteration() : running_round_num_(0), server_node_(nullptr), communicator_(nullptr), iteration_state_(IterationState::kCompleted), - start_timestamp_(0), - complete_timestamp_(0), iteration_loop_count_(0), iteration_num_(1), is_last_iteration_valid_(true), @@ -164,7 +169,11 @@ class Iteration { {kStartFLJobRejectClientNum, 0}, {kUpdateModelRejectClientNum, 0}, {kGetModelRejectClientNum, 0}}), - iteration_result_(IterationResult::kSuccess) { + iteration_result_(IterationResult::kSuccess), + iteration_fail_num_(0), + is_date_rate_thread_running_(true), + file_configuration_(nullptr), + instance_name_("") { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); } ~Iteration(); @@ -223,6 +232,8 @@ class Iteration { void StartNewInstance(); + std::string GetDataRateFilePath(); + std::shared_ptr server_node_; std::shared_ptr communicator_; @@ -236,8 +247,12 @@ class Iteration { std::mutex iteration_state_mtx_; std::condition_variable iteration_state_cv_; std::atomic iteration_state_; - uint64_t start_timestamp_; - uint64_t complete_timestamp_; + + // Iteration start time + ps::core::Time start_time_; + + // Iteration complete time + ps::core::Time complete_time_; // The count of iteration loops which are completed. size_t iteration_loop_count_; @@ -295,6 +310,21 @@ class Iteration { // mutex for iter move to next, avoid core dump std::mutex iter_move_mtx_; + + // The number of iteration continuous failure + uint32_t iteration_fail_num_; + + // The thread to record data rate + std::thread data_rate_thread_; + + // The state of data rate thread + std::atomic_bool is_date_rate_thread_running_; + + // The ptr of config file + std::shared_ptr file_configuration_; + + // The instance name + std::string instance_name_; }; } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.cc b/mindspore/ccsrc/fl/server/iteration_metrics.cc index 58047eb1dbb..3e1075bb32f 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.cc +++ b/mindspore/ccsrc/fl/server/iteration_metrics.cc @@ -15,11 +15,13 @@ */ #include "fl/server/iteration_metrics.h" -#include + #include -#include "utils/file_utils.h" +#include + #include "include/common/debug/common.h" #include "ps/constants.h" +#include "utils/file_utils.h" namespace mindspore { namespace fl { @@ -28,42 +30,23 @@ bool IterationMetrics::Initialize() { config_ = std::make_unique(config_file_path_); MS_EXCEPTION_IF_NULL(config_); if (!config_->Initialize()) { - MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_ + MS_LOG(EXCEPTION) << "Initializing for Config file path failed!" << config_file_path_ << " may be invalid or not exist."; - } - - // Read the metrics file path. If file is not set or not exits, create one. - if (!config_->Exists(kMetrics)) { - MS_LOG(WARNING) << "Metrics config is not set. Don't write metrics."; return false; - } else { - std::string value = config_->Get(kMetrics, ""); - nlohmann::json value_json; - try { - value_json = nlohmann::json::parse(value); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format."; - return false; - } - - // Parse the storage type. - uint32_t storage_type = JsonGetKeyWithException(value_json, ps::kStoreType); - if (std::to_string(storage_type) != ps::kFileStorage) { - MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported."; - return false; - } - - // Parse storage file path. - metrics_file_path_ = JsonGetKeyWithException(value_json, ps::kStoreFilePath); - auto realpath = Common::CreatePrefixPath(metrics_file_path_.c_str()); - if (!realpath.has_value()) { - MS_LOG(EXCEPTION) << "Creating path for " << metrics_file_path_ << " failed."; - return false; - } - - metrics_file_.open(realpath.value(), std::ios::app | std::ios::out); - metrics_file_.close(); } + ps::core::FileConfig metrics_config; + if (!ps::core::CommUtil::ParseAndCheckConfigJson(config_.get(), kMetrics, &metrics_config)) { + MS_LOG(WARNING) << "Metrics parament in config is not correct"; + return false; + } + metrics_file_path_ = metrics_config.storage_file_path; + auto realpath = Common::CreatePrefixPath(metrics_file_path_.c_str()); + if (!realpath.has_value()) { + MS_LOG(EXCEPTION) << "Creating path for " << metrics_file_path_ << " failed."; + return false; + } + metrics_file_.open(realpath.value(), std::ios::app | std::ios::out); + metrics_file_.close(); return true; } @@ -74,6 +57,9 @@ bool IterationMetrics::Summarize() { return false; } + js_[kInstanceName] = instance_name_; + js_[kStartTime] = start_time_.time_str_mill; + js_[kEndTime] = end_time_.time_str_mill; js_[kFLName] = fl_name_; js_[kInstanceStatus] = kInstanceStateName.at(instance_state_); js_[kFLIterationNum] = fl_iteration_num_; @@ -120,6 +106,12 @@ void IterationMetrics::set_round_client_num_map(const std::map kInstanceStateName = { {InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}}; @@ -80,6 +85,9 @@ class IterationMetrics { void set_iteration_time_cost(uint64_t iteration_time_cost); void set_round_client_num_map(const std::map round_client_num_map); void set_iteration_result(IterationResult iteration_result); + void set_start_time(const ps::core::Time &start_time); + void set_end_time(const ps::core::Time &end_time); + void set_instance_name(const std::string &instance_name); private: // This is the main config file set by ps context. @@ -122,6 +130,11 @@ class IterationMetrics { // Current iteration running result. IterationResult iteration_result_; + + ps::core::Time start_time_; + ps::core::Time end_time_; + + std::string instance_name_; }; } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc index 09c84c7da87..694d30c39e4 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc @@ -15,13 +15,15 @@ */ #include "fl/server/kernel/round/round_kernel.h" + +#include #include #include -#include +#include #include #include -#include #include + #include "fl/server/iteration.h" namespace mindspore { @@ -65,6 +67,7 @@ void RoundKernel::SendResponseMsg(const std::shared_ptr &message, const void *data, @@ -77,6 +80,7 @@ void RoundKernel::SendResponseMsgInference(const std::shared_ptr &message, const void *data, @@ -128,6 +132,67 @@ void RoundKernel::InitClientUploadLoss() { upload_loss_ = 0.0f; } void RoundKernel::UpdateClientUploadLoss(const float upload_loss) { upload_loss_ = upload_loss_ + upload_loss; } float RoundKernel::upload_loss() const { return upload_loss_; } + +void RoundKernel::CalculateSendData(size_t send_len) { + uint64_t second_time_stamp = + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + if (send_data_time_ == 0) { + send_data_ = send_len; + send_data_time_ = second_time_stamp; + return; + } + if (second_time_stamp == send_data_time_) { + send_data_ += send_len; + } else { + RecordSendData(send_data_time_, send_data_); + send_data_time_ = second_time_stamp; + send_data_ = send_len; + } +} + +void RoundKernel::CalculateReceiveData(size_t receive_len) { + uint64_t second_time_stamp = + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + if (receive_data_time_ == 0) { + receive_data_time_ = second_time_stamp; + receive_data_ = receive_len; + return; + } + if (second_time_stamp == receive_data_time_) { + receive_data_ += receive_len; + } else { + RecordReceiveData(receive_data_time_, receive_data_); + receive_data_time_ = second_time_stamp; + receive_data_ = receive_len; + } +} + +void RoundKernel::RecordSendData(uint64_t time_stamp_second, size_t send_data) { + std::lock_guard lock(send_data_rate_mutex_); + send_data_and_time_.emplace(time_stamp_second, send_data); +} + +void RoundKernel::RecordReceiveData(uint64_t time_stamp_second, size_t receive_data) { + std::lock_guard lock(receive_data_rate_mutex_); + receive_data_and_time_.emplace(time_stamp_second, receive_data); +} + +std::map RoundKernel::GetSendData() { + std::lock_guard lock(send_data_rate_mutex_); + return send_data_and_time_; +} + +std::map RoundKernel::GetReceiveData() { + std::lock_guard lock(receive_data_rate_mutex_); + return receive_data_and_time_; +} + +void RoundKernel::ClearData() { + std::lock_guard lock(send_data_rate_mutex_); + std::lock_guard lock2(receive_data_rate_mutex_); + send_data_and_time_.clear(); + receive_data_and_time_.clear(); +} } // namespace kernel } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h index 0dbd02bd0ac..efbe0ab64ef 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h @@ -17,22 +17,23 @@ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ +#include #include #include -#include -#include #include #include -#include -#include +#include #include #include -#include "kernel/common_utils.h" -#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include +#include + #include "fl/server/common.h" -#include "fl/server/local_meta_store.h" #include "fl/server/distributed_count_service.h" #include "fl/server/distributed_metadata_store.h" +#include "fl/server/local_meta_store.h" +#include "kernel/common_utils.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" namespace mindspore { namespace fl { @@ -102,6 +103,24 @@ class RoundKernel { bool verifyResponse(const std::shared_ptr &message, const void *data, size_t len); + void CalculateSendData(size_t send_len); + + void CalculateReceiveData(size_t receive_len); + // Record the size of send data and the time stamp + void RecordSendData(uint64_t time_stamp_second, size_t send_data); + + // Record the size of receive data and the time stamp + void RecordReceiveData(uint64_t time_stamp_second, size_t receive_data); + + // Get the info of send data + std::map GetSendData(); + + // Get the info of receive data + std::map GetReceiveData(); + + // Clear the send data info + void ClearData(); + protected: // Send response to client, and the data can be released after the call. void SendResponseMsg(const std::shared_ptr &message, const void *data, size_t len); @@ -127,6 +146,26 @@ class RoundKernel { std::atomic accept_client_num_; std::atomic upload_loss_; + + // The mutex for send_data_and_time_ + std::mutex send_data_rate_mutex_; + + // The size of send data ant time + std::map send_data_and_time_; + + // The mutex for receive_data_and_time_ + std::mutex receive_data_rate_mutex_; + + // The size of receive data and time + std::map receive_data_and_time_; + + std::atomic_size_t send_data_ = 0; + + std::atomic_uint64_t send_data_time_ = 0; + + std::atomic_size_t receive_data_ = 0; + + std::atomic_uint64_t receive_data_time_ = 0; }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index 0fc97bd5eb2..f58a4aa8b3c 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -115,6 +115,9 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, } DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); + uint64_t start_fl_job_time = + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + device_meta.set_now_time(start_fl_job_time); result_code = ReadyForStartFLJob(fbb, device_meta); if (result_code != ResultCode::kSuccess) { SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index 1008d65dac3..4388bf9e84c 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -14,17 +14,26 @@ * limitations under the License. */ +#include "fl/server/kernel/round/update_model_kernel.h" + #include #include #include -#include -#include -#include "fl/server/kernel/round/update_model_kernel.h" namespace mindspore { namespace fl { namespace server { namespace kernel { +namespace { +const size_t kLevelNum = 2; +const uint64_t kMaxLevelNum = 2880; +const uint64_t kMinLevelNum = 0; +const int kBase = 10; +const uint64_t kMinuteToSecond = 60; +const uint64_t kSecondToMills = 1000; +const uint64_t kDefaultLevel1 = 5; +const uint64_t kDefaultLevel2 = 15; +} // namespace const char *kCountForAggregation = "count_for_aggregation"; void UpdateModelKernel::InitKernel(size_t threshold_count) { @@ -49,6 +58,8 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) { auto last_cnt_handler = [this](std::shared_ptr) { RunAggregation(); }; DistributedCountService::GetInstance().RegisterCounter(kCountForAggregation, threshold_count, {first_cnt_handler, last_cnt_handler}); + std::string participation_time_level_str = ps::PSContext::instance()->participation_time_level(); + CheckAndTransPara(participation_time_level_str); } bool UpdateModelKernel::VerifyUpdateModelRequest(const schema::RequestUpdateModel *update_model_req) { @@ -123,6 +134,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, } std::string update_model_fl_id = update_model_req->fl_id()->str(); IncreaseAcceptClientNum(); + RecordCompletePeriod(device_meta); SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); result_code = CountForAggregation(update_model_fl_id); @@ -146,6 +158,18 @@ bool UpdateModelKernel::Reset() { void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr &) {} +const std::vector> &UpdateModelKernel::GetCompletePeriodRecord() { + std::lock_guard lock(participation_time_and_num_mtx_); + return participation_time_and_num_; +} + +void UpdateModelKernel::ResetParticipationTimeAndNum() { + std::lock_guard lock(participation_time_and_num_mtx_); + for (auto &it : participation_time_and_num_) { + it.second = 0; + } +} + void UpdateModelKernel::RunAggregation() { auto is_last_iter_valid = Executor::GetInstance().RunAllWeightAggregation(); auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num(); @@ -619,6 +643,66 @@ void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr &fb return; } +void UpdateModelKernel::RecordCompletePeriod(const DeviceMeta &device_meta) { + std::lock_guard lock(participation_time_and_num_mtx_); + uint64_t start_fl_job_time = device_meta.now_time(); + uint64_t update_model_complete_time = + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + if (start_fl_job_time >= update_model_complete_time) { + MS_LOG(WARNING) << "start_fl_job_time " << start_fl_job_time << " is larger than update_model_complete_time " + << update_model_complete_time; + return; + } + uint64_t cost_time = update_model_complete_time - start_fl_job_time; + MS_LOG(DEBUG) << "start_fl_job time is " << start_fl_job_time << " update_model time is " + << update_model_complete_time; + for (auto &it : participation_time_and_num_) { + if (cost_time < it.first) { + it.second++; + } + } +} + +void UpdateModelKernel::CheckAndTransPara(const std::string &participation_time_level) { + std::lock_guard lock(participation_time_and_num_mtx_); + // The default time level is 5min and 15min, trans time to millisecond + participation_time_and_num_.emplace_back(std::make_pair(kDefaultLevel1 * kMinuteToSecond * kSecondToMills, 0)); + participation_time_and_num_.emplace_back(std::make_pair(kDefaultLevel2 * kMinuteToSecond * kSecondToMills, 0)); + participation_time_and_num_.emplace_back(std::make_pair(UINT64_MAX, 0)); + std::vector time_levels; + std::istringstream iss(participation_time_level); + std::string output; + while (std::getline(iss, output, ',')) { + if (!output.empty()) { + time_levels.emplace_back(std::move(output)); + } + } + if (time_levels.size() != kLevelNum) { + MS_LOG(WARNING) << "Parameter participation_time_level is not correct"; + return; + } + uint64_t level1 = std::strtoull(time_levels[0].c_str(), nullptr, kBase); + if (level1 > kMaxLevelNum || level1 <= kMinLevelNum) { + MS_LOG(WARNING) << "Level1 partmeter " << level1 << " is not legal"; + return; + } + + uint64_t level2 = std::strtoull(time_levels[1].c_str(), nullptr, kBase); + if (level2 > kMaxLevelNum || level2 <= kMinLevelNum) { + MS_LOG(WARNING) << "Level2 partmeter " << level2 << "is not legal"; + return; + } + if (level1 >= level2) { + MS_LOG(WARNING) << "Level1 parameter " << level1 << " is larger than level2 " << level2; + return; + } + // Save the the parament of user + participation_time_and_num_.clear(); + participation_time_and_num_.emplace_back(std::make_pair(level1 * kMinuteToSecond * kSecondToMills, 0)); + participation_time_and_num_.emplace_back(std::make_pair(level2 * kMinuteToSecond * kSecondToMills, 0)); + participation_time_and_num_.emplace_back(std::make_pair(UINT64_MAX, 0)); +} + REG_ROUND_KERNEL(updateModel, UpdateModelKernel) } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h index 7a5e8a2b942..502ac5ec441 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h @@ -18,21 +18,23 @@ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ #include -#include #include #include +#include #include +#include + #include "fl/server/common.h" +#include "fl/server/executor.h" #include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel_factory.h" -#include "fl/server/executor.h" #include "fl/server/model_store.h" #ifdef ENABLE_ARMOUR #include "fl/armour/cipher/cipher_meta_storage.h" #endif #include "fl/compression/decode_executor.h" -#include "schema/fl_job_generated.h" #include "schema/cipher_generated.h" +#include "schema/fl_job_generated.h" namespace mindspore { namespace fl { @@ -55,6 +57,12 @@ class UpdateModelKernel : public RoundKernel { // In some cases, the last updateModel message means this server iteration is finished. void OnLastCountEvent(const std::shared_ptr &message) override; + // Get participation_time_and_num_ + const std::vector> &GetCompletePeriodRecord(); + + // Reset participation_time_and_num_ + void ResetParticipationTimeAndNum(); + private: ResultCode ReachThresholdForUpdateModel(const std::shared_ptr &fbb, const schema::RequestUpdateModel *update_model_req); @@ -82,6 +90,12 @@ class UpdateModelKernel : public RoundKernel { const std::shared_ptr &fbb, DeviceMeta *device_meta); bool VerifyUpdateModelRequest(const schema::RequestUpdateModel *update_model_req); + // Record complete update model number according to participation_time_level + void RecordCompletePeriod(const DeviceMeta &device_meta); + + // Check and transform participation time level parament + void CheckAndTransPara(const std::string &participation_time_level); + // The executor is for updating the model for updateModel request. Executor *executor_{nullptr}; @@ -95,6 +109,12 @@ class UpdateModelKernel : public RoundKernel { // Check upload mode bool IsCompress(const schema::RequestUpdateModel *update_model_req); + + // From StartFlJob to UpdateModel complete time and number + std::vector> participation_time_and_num_{}; + + // The mutex for participation_time_and_num_ + std::mutex participation_time_and_num_mtx_; }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/round.cc b/mindspore/ccsrc/fl/server/round.cc index 856ef2906d7..55853cbd9f2 100644 --- a/mindspore/ccsrc/fl/server/round.cc +++ b/mindspore/ccsrc/fl/server/round.cc @@ -15,10 +15,14 @@ */ #include "fl/server/round.h" + #include #include -#include "fl/server/server.h" + #include "fl/server/iteration.h" +#include "fl/server/kernel/round/update_model_kernel.h" +#include "fl/server/server.h" +#include "fl/server/common.h" namespace mindspore { namespace fl { @@ -143,6 +147,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr &m MS_LOG(DEBUG) << "Launching round kernel of round " + name_ + " failed."; } (void)(Iteration::GetInstance().running_round_num_--); + kernel_->CalculateReceiveData(message->len()); return; } @@ -245,6 +250,32 @@ void Round::InitkernelClientVisitedNum() { kernel_->InitClientVisitedNum(); } void Round::InitkernelClientUploadLoss() { kernel_->InitClientUploadLoss(); } float Round::kernel_upload_loss() const { return kernel_->upload_loss(); } + +std::vector> Round::GetUpdateModelCompleteInfo() const { + if (name_ == kUpdateModel) { + auto update_model_model_ptr = std::dynamic_pointer_cast(kernel_); + MS_EXCEPTION_IF_NULL(update_model_model_ptr); + return update_model_model_ptr->GetCompletePeriodRecord(); + } else { + MS_LOG(EXCEPTION) << "The kernel is not updateModel"; + return {}; + } +} + +void Round::ResetParticipationTimeAndNum() { + if (name_ == kUpdateModel) { + auto update_model_kernel_ptr = std::dynamic_pointer_cast(kernel_); + MS_ERROR_IF_NULL_WO_RET_VAL(update_model_kernel_ptr); + update_model_kernel_ptr->ResetParticipationTimeAndNum(); + } + return; +} + +std::map Round::GetSendData() const { return kernel_->GetSendData(); } + +std::map Round::GetReceiveData() const { return kernel_->GetReceiveData(); } + +void Round::ClearData() { return kernel_->ClearData(); } } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/round.h b/mindspore/ccsrc/fl/server/round.h index dac0d2316f7..4260b40e969 100644 --- a/mindspore/ccsrc/fl/server/round.h +++ b/mindspore/ccsrc/fl/server/round.h @@ -19,11 +19,15 @@ #include #include -#include "ps/core/communicator/communicator_base.h" +#include +#include +#include + #include "fl/server/common.h" -#include "fl/server/iteration_timer.h" #include "fl/server/distributed_count_service.h" +#include "fl/server/iteration_timer.h" #include "fl/server/kernel/round/round_kernel.h" +#include "ps/core/communicator/communicator_base.h" namespace mindspore { namespace fl { @@ -78,6 +82,16 @@ class Round { float kernel_upload_loss() const; + std::map GetSendData() const; + + std::map GetReceiveData() const; + + std::vector> GetUpdateModelCompleteInfo() const; + + void ResetParticipationTimeAndNum(); + + void ClearData(); + private: // The callbacks which will be set to DistributedCounterService. void OnFirstCountEvent(const std::shared_ptr &message); diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index 5ad9f40ba93..baf0569ef96 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -91,6 +91,7 @@ void Server::Run() { RegisterRoundKernel(); InitMetrics(); Recover(); + iteration_->StartThreadToRecordDataRate(); MS_LOG(INFO) << "Server started successfully."; safemode_ = false; is_ready_ = true; @@ -261,6 +262,14 @@ void Server::InitIteration() { iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb); iteration_->InitGlobalIterTimer(time_out_cb); + auto file_config_ptr = std::make_shared(ps::PSContext::instance()->config_file_path()); + MS_EXCEPTION_IF_NULL(file_config_ptr); + if (!file_config_ptr->Initialize()) { + MS_LOG(WARNING) << "Initializing for Config file path failed!" << ps::PSContext::instance()->config_file_path() + << " may be invalid or not exist."; + return; + } + iteration_->SetFileConfig(file_config_ptr); return; } diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index f3dd6f96125..332c145eb86 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -551,7 +551,13 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.") .def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.") .def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.") - .def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory."); + .def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.") + .def("set_instance_name", &PSContext::set_instance_name, "Set instance name.") + .def("instance_name", &PSContext::instance_name, "Get instance name.") + .def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.") + .def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.") + .def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times") + .def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times."); (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted"); diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 941aafa568a..9242733835f 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -15,9 +15,11 @@ */ #include "ps/core/abstract_node.h" -#include "ps/core/node_recovery.h" -#include "ps/core/communicator/tcp_communicator.h" + +#include "include/common/debug/common.h" #include "ps/core/communicator/http_communicator.h" +#include "ps/core/communicator/tcp_communicator.h" +#include "ps/core/node_recovery.h" namespace mindspore { namespace ps { @@ -68,6 +70,32 @@ void AbstractNode::Register(const std::shared_ptr &client) { } } +void AbstractNode::SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info) { + auto message_meta = std::make_shared(); + MS_EXCEPTION_IF_NULL(message_meta); + message_meta->set_cmd(NodeCommand::FAILURE_EVENT_INFO); + + std::string now_time = ps::core::CommUtil::GetNowTime().time_str_mill; + FailureEventMessage failure_event_message; + failure_event_message.set_node_role(node_role); + failure_event_message.set_ip(node_info_.ip_); + failure_event_message.set_port(node_info_.port_); + failure_event_message.set_time(now_time); + failure_event_message.set_event(event_info); + + MS_LOG(INFO) << "The node role:" << node_role << "The node id:" << node_info_.node_id_ + << "begin to send failure message to scheduler!"; + + if (!SendMessageAsync(client_to_scheduler_, message_meta, Protos::PROTOBUF, + failure_event_message.SerializeAsString().data(), failure_event_message.ByteSizeLong())) { + MS_LOG(ERROR) << "The node role:" << node_role << " the node id:" << node_info_.node_id_ + << " send failure message timeout!"; + } else { + MS_LOG(INFO) << "The node role:" << node_role << " the node id:" << node_info_.node_id_ << " send failure message " + << event_info << "success!"; + } +} + void AbstractNode::ProcessRegisterResp(const std::shared_ptr &meta, const void *data, size_t size) { MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(data); diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index ce2b128bd34..55466ab6dc7 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -17,25 +17,25 @@ #ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ #define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ -#include -#include -#include +#include #include -#include #include #include -#include +#include +#include +#include +#include -#include "ps/core/node.h" -#include "ps/core/communicator/message.h" -#include "ps/core/follower_scaler.h" -#include "utils/ms_exception.h" +#include "include/backend/visible.h" #include "ps/constants.h" +#include "ps/core/communicator/communicator_base.h" +#include "ps/core/communicator/message.h" +#include "ps/core/communicator/task_executor.h" +#include "ps/core/follower_scaler.h" +#include "ps/core/node.h" #include "ps/core/node_info.h" #include "ps/core/recovery_base.h" -#include "ps/core/communicator/task_executor.h" -#include "ps/core/communicator/communicator_base.h" -#include "include/backend/visible.h" +#include "utils/ms_exception.h" namespace mindspore { namespace ps { @@ -163,6 +163,9 @@ class BACKEND_EXPORT AbstractNode : public Node { // register cancel SafeMode function to node void SetCancelSafeModeCallBack(const CancelSafeModeFn &fn) { cancelSafeModeFn_ = fn; } + // server node and worker node send exception message to scheduler + void SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info); + protected: virtual void Register(const std::shared_ptr &client); bool Heartbeat(const std::shared_ptr &client); diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 6ce0d7c025a..f1e00a12c86 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -17,13 +17,15 @@ #include "ps/core/comm_util.h" #include -#include #include +#include + +#include #include #include #include #include -#include +#include #include namespace mindspore { @@ -600,6 +602,54 @@ bool CommUtil::StringToBool(const std::string &alive) { } return false; } + +Time CommUtil::GetNowTime() { + ps::core::Time time; + auto time_now = std::chrono::system_clock::now(); + std::time_t tt = std::chrono::system_clock::to_time_t(time_now); + struct tm ptm; + (void)localtime_r(&tt, &ptm); + std::ostringstream time_mill_oss; + time_mill_oss << std::put_time(&ptm, "%Y-%m-%d %H:%M:%S"); + + // calculate millisecond, the format of time_str_mill is 2022-01-10 20:22:20.067 + auto second_time_stamp = std::chrono::duration_cast(time_now.time_since_epoch()); + auto mill_time_stamp = std::chrono::duration_cast(time_now.time_since_epoch()); + auto ms_stamp = mill_time_stamp - second_time_stamp; + time_mill_oss << "." << std::setfill('0') << std::setw(kMillSecondLength) << ms_stamp.count(); + + time.time_stamp = mill_time_stamp.count(); + time.time_str_mill = time_mill_oss.str(); + return time; +} + +bool CommUtil::ParseAndCheckConfigJson(Configuration *file_configuration, const std::string &key, + FileConfig *file_config) { + MS_EXCEPTION_IF_NULL(file_configuration); + MS_EXCEPTION_IF_NULL(file_config); + if (!file_configuration->Exists(key)) { + MS_LOG(WARNING) << key << " config is not set. Don't write."; + return false; + } else { + std::string value = file_configuration->Get(key, ""); + nlohmann::json value_json; + try { + value_json = nlohmann::json::parse(value); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format."; + } + // Parse the storage type. + uint32_t storage_type = ps::core::CommUtil::JsonGetKeyWithException(value_json, ps::kStoreType); + if (std::to_string(storage_type) != ps::kFileStorage) { + MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported."; + } + // Parse storage file path. + std::string file_path = ps::core::CommUtil::JsonGetKeyWithException(value_json, ps::kStoreFilePath); + file_config->storage_type = storage_type; + file_config->storage_file_path = file_path; + } + return true; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index e30dc19b9e1..09d54fc9d2a 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -19,47 +19,46 @@ #include #ifdef _MSC_VER -#include -#include -#include #include +#include +#include +#include #else -#include #include #include +#include #include #endif +#include #include #include #include #include #include #include - -#include -#include +#include #include #include -#include #include -#include +#include +#include #include +#include #include #include #include +#include #include +#include +#include #include #include #include -#include #include -#include -#include +#include #include -#include -#include #include "proto/comm.pb.h" #include "proto/ps.pb.h" @@ -78,12 +77,14 @@ constexpr int kGroup2RandomLength = 4; constexpr int kGroup3RandomLength = 4; constexpr int kGroup4RandomLength = 4; constexpr int kGroup5RandomLength = 12; +constexpr int kMillSecondLength = 3; // The size of the buffer for sending and receiving data is 4096 bytes. constexpr int kMessageChunkLength = 4096; // The timeout period for the http client to connect to the http server is 120 seconds. constexpr int kConnectionTimeout = 120; constexpr char kLibeventLogPrefix[] = "[libevent log]:"; +constexpr char kFailureEvent[] = "failureEvent"; // Find the corresponding string style of cluster state through the subscript of the enum:ClusterState const std::vector kClusterState = { @@ -113,6 +114,16 @@ const std::map kClusterStateMap = { {"CLUSTER_SCHEDULER_RECOVERY", ClusterState::CLUSTER_SCHEDULER_RECOVERY}, {"CLUSTER_SCALE_OUT_ROLLBACK", ClusterState::CLUSTER_SCALE_OUT_ROLLBACK}}; +struct Time { + uint64_t time_stamp; + std::string time_str_mill; +}; + +struct FileConfig { + uint32_t storage_type; + std::string storage_file_path; +}; + class CommUtil { public: static bool CheckIpWithRegex(const std::string &ip); @@ -159,6 +170,16 @@ class CommUtil { static bool CreateDirectory(const std::string &directoryPath); static bool CheckHttpUrl(const std::string &http_url); static bool IsFileReadable(const std::string &file); + template + static T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) { + if (!json.contains(key)) { + MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump(); + } + return json[key].get(); + } + static Time GetNowTime(); + static bool ParseAndCheckConfigJson(Configuration *file_configuration, const std::string &key, + FileConfig *file_config); private: static std::random_device rd; diff --git a/mindspore/ccsrc/ps/core/communicator/message.h b/mindspore/ccsrc/ps/core/communicator/message.h index 2f907d56777..7747b97cbc1 100644 --- a/mindspore/ccsrc/ps/core/communicator/message.h +++ b/mindspore/ccsrc/ps/core/communicator/message.h @@ -32,7 +32,8 @@ enum class Command { SEND_DATA = 3, FETCH_METADATA = 4, FINISH = 5, - COLLECTIVE_SEND_DATA = 6 + COLLECTIVE_SEND_DATA = 6, + FAILURE_EVENT = 7 }; enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 }; diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 99855a842fe..0e18cf2b52b 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -65,6 +65,8 @@ enum NodeCommand { QUERY_FINISH_TRANSFORM = 23; // This command is used to start scale out rollback SCALE_OUT_ROLLBACK = 24; + // Record the failure information, such as node restart + FAILURE_EVENT_INFO = 25; } enum NodeRole { @@ -142,6 +144,14 @@ message HeartbeatMessage { uint32 port = 5; } +message FailureEventMessage { + string node_role = 1; + string ip = 2; + uint32 port = 3; + string time = 4; + string event = 5; +} + enum NodeState { NODE_STARTING = 0; NODE_FINISH = 1; diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto index 7960d295f75..043c0b79ec2 100644 --- a/mindspore/ccsrc/ps/core/protos/fl.proto +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -87,6 +87,7 @@ message DeviceMeta { string fl_name = 1; string fl_id = 2; uint64 data_size = 3; + uint64 now_time = 4; } message FLIdToDeviceMeta { @@ -257,6 +258,9 @@ message EndLastIterResponse { uint64 getModel_accept_client_num = 9; uint64 getModel_reject_client_num = 10; float upload_loss = 11; + uint64 participation_time_level1_num = 12; + uint64 participation_time_level2_num = 13; + uint64 participation_time_level3_num = 14; } message SyncAfterRecover { diff --git a/mindspore/ccsrc/ps/core/ps_scheduler_node.h b/mindspore/ccsrc/ps/core/ps_scheduler_node.h index 7f1c0082597..b129a5eb7bc 100644 --- a/mindspore/ccsrc/ps/core/ps_scheduler_node.h +++ b/mindspore/ccsrc/ps/core/ps_scheduler_node.h @@ -94,6 +94,10 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode { void RecoverFromPersistence() override; + void InitEventTxtFile() override {} + + void RecordSchedulerRestartInfo() override {} + // Record received host hash name from workers or servers. std::map> host_hash_names_; // Record rank id of the nodes which sended host name. diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index ca68679946f..f5e41e40746 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -15,6 +15,11 @@ */ #include "ps/core/scheduler_node.h" + +#include + +#include "include/common/debug/common.h" +#include "fl/server/common.h" #include "ps/core/scheduler_recovery.h" namespace mindspore { @@ -32,10 +37,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) { config_ = std::make_unique(PSContext::instance()->config_file_path()); MS_EXCEPTION_IF_NULL(config_); InitNodeMetaData(); + bool is_recover = false; if (!config_->Initialize()) { MS_LOG(WARNING) << "The config file is empty."; } else { - if (!RecoverScheduler()) { + InitEventTxtFile(); + is_recover = RecoverScheduler(); + if (!is_recover) { MS_LOG(DEBUG) << "Recover the server node is failed."; } } @@ -50,6 +58,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) { StartUpdateClusterStateTimer(); RunRecovery(); + if (is_recover) { + RecordSchedulerRestartInfo(); + } + if (is_worker_timeout_) { BroadcastTimeoutEvent(); } @@ -258,6 +270,7 @@ void SchedulerNode::InitCommandHandler() { handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone; handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone; handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent; + handlers_[NodeCommand::FAILURE_EVENT_INFO] = &SchedulerNode::ProcessFailureEvent; RegisterActorRouteTableServiceHandler(); RegisterInitCollectCommServiceHandler(); RegisterRecoveryServiceHandler(); @@ -269,6 +282,26 @@ void SchedulerNode::RegisterActorRouteTableServiceHandler() { handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &SchedulerNode::ProcessLookupActorRoute; } +void SchedulerNode::InitEventTxtFile() { + MS_LOG(DEBUG) << "Start init event txt"; + ps::core::FileConfig event_file_config; + if (!ps::core::CommUtil::ParseAndCheckConfigJson(config_.get(), kFailureEvent, &event_file_config)) { + MS_LOG(EXCEPTION) << "Parse and checkout config json failed"; + return; + } + + std::lock_guard lock(event_txt_file_mtx_); + event_file_path_ = event_file_config.storage_file_path; + auto realpath = Common::CreatePrefixPath(event_file_path_.c_str()); + if (!realpath.has_value()) { + MS_LOG(EXCEPTION) << "Creating path for " << event_file_path_ << " failed."; + return; + } + event_txt_file_.open(realpath.value(), std::ios::out | std::ios::app); + event_txt_file_.close(); + MS_LOG(DEBUG) << "Init event txt success!"; +} + void SchedulerNode::CreateTcpServer() { node_manager_.InitNode(); @@ -628,6 +661,34 @@ void SchedulerNode::ProcessLookupActorRoute(const std::shared_ptr &se } } +void SchedulerNode::ProcessFailureEvent(const std::shared_ptr &server, + const std::shared_ptr &conn, + const std::shared_ptr &meta, const void *data, size_t size) { + MS_ERROR_IF_NULL_WO_RET_VAL(server); + MS_ERROR_IF_NULL_WO_RET_VAL(conn); + MS_ERROR_IF_NULL_WO_RET_VAL(meta); + MS_ERROR_IF_NULL_WO_RET_VAL(data); + FailureEventMessage failure_event_message; + failure_event_message.ParseFromArray(data, SizeToInt(size)); + std::string node_role = failure_event_message.node_role(); + std::string ip = failure_event_message.ip(); + uint32_t port = failure_event_message.port(); + std::string time = failure_event_message.time(); + std::string event = failure_event_message.event(); + std::lock_guard lock(event_txt_file_mtx_); + event_txt_file_.open(event_file_path_, std::ios::out | std::ios::app); + if (!event_txt_file_.is_open()) { + MS_LOG(EXCEPTION) << "The event txt file is not open"; + return; + } + std::string event_info = "nodeRole:" + node_role + "," + ip + ":" + std::to_string(port) + "," + + "currentTime:" + time + "," + "event:" + event + ";"; + event_txt_file_ << event_info << "\n"; + (void)event_txt_file_.flush(); + event_txt_file_.close(); + MS_LOG(INFO) << "Process failure event success!"; +} + bool SchedulerNode::SendPrepareBuildingNetwork(const std::unordered_map &node_infos) { uint64_t request_id = AddMessageTrack(node_infos.size()); for (const auto &kvs : node_infos) { @@ -1598,7 +1659,7 @@ bool SchedulerNode::RecoverScheduler() { MS_EXCEPTION_IF_NULL(config_); if (config_->Exists(kKeyRecovery)) { MS_LOG(INFO) << "The scheduler node is support recovery."; - scheduler_recovery_ = std::make_unique(); + scheduler_recovery_ = std::make_shared(); MS_EXCEPTION_IF_NULL(scheduler_recovery_); bool ret = scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, "")); bool ret_node = scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, "")); @@ -1610,6 +1671,29 @@ bool SchedulerNode::RecoverScheduler() { return false; } +void SchedulerNode::RecordSchedulerRestartInfo() { + MS_LOG(DEBUG) << "Start to record scheduler restart error message"; + std::lock_guard lock(event_txt_file_mtx_); + event_txt_file_.open(event_file_path_, std::ios::out | std::ios::app); + if (!event_txt_file_.is_open()) { + MS_LOG(EXCEPTION) << "The event txt file is not open"; + return; + } + std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_); + auto scheduler_recovery_ptr = std::dynamic_pointer_cast(scheduler_recovery_); + MS_EXCEPTION_IF_NULL(scheduler_recovery_ptr); + auto ip = scheduler_recovery_ptr->GetMetadata(kRecoverySchedulerIp); + auto port = scheduler_recovery_ptr->GetMetadata(kRecoverySchedulerPort); + std::string time = ps::core::CommUtil::GetNowTime().time_str_mill; + std::string event = "Node restart"; + std::string event_info = + "nodeRole:" + node_role + "," + ip + ":" + port + "," + "currentTime:" + time + "," + "event:" + event + ";"; + event_txt_file_ << event_info << "\n"; + (void)event_txt_file_.flush(); + event_txt_file_.close(); + MS_LOG(DEBUG) << "Record scheduler node restart info " << event_info << " success!"; +} + void SchedulerNode::PersistMetaData() { if (scheduler_recovery_ == nullptr) { MS_LOG(WARNING) << "scheduler recovery is null, do not persist meta data"; diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index 757c6dc9107..ab7a110e1c0 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -88,6 +88,7 @@ class BACKEND_EXPORT SchedulerNode : public Node { // Register and initialize the actor route table service. void RegisterActorRouteTableServiceHandler(); void InitializeActorRouteTableService(); + virtual void InitEventTxtFile(); // Register collective communication initialization service. virtual void RegisterInitCollectCommServiceHandler() {} @@ -138,6 +139,10 @@ class BACKEND_EXPORT SchedulerNode : public Node { void ProcessLookupActorRoute(const std::shared_ptr &server, const std::shared_ptr &conn, const std::shared_ptr &meta, const void *data, size_t size); + // Process failure event message from other nodes. + void ProcessFailureEvent(const std::shared_ptr &server, const std::shared_ptr &conn, + const std::shared_ptr &meta, const void *data, size_t size); + // Determine whether the registration request of the node should be rejected, the registration of the // alive node should be rejected. virtual bool NeedRejectRegister(const NodeInfo &node_info) { return false; } @@ -201,6 +206,9 @@ class BACKEND_EXPORT SchedulerNode : public Node { bool RecoverScheduler(); + // Write scheduler restart error message + virtual void RecordSchedulerRestartInfo(); + void PersistMetaData(); bool CheckIfNodeDisconnected() const; @@ -245,7 +253,7 @@ class BACKEND_EXPORT SchedulerNode : public Node { std::unordered_map callbacks_; // Used to persist and obtain metadata information for scheduler. - std::unique_ptr scheduler_recovery_; + std::shared_ptr scheduler_recovery_; // persistent command need to be sent. std::atomic persistent_cmd_; @@ -259,6 +267,15 @@ class BACKEND_EXPORT SchedulerNode : public Node { std::unordered_map register_connection_fd_; std::unique_ptr actor_route_table_service_; + + // The event txt file path + std::string event_file_path_; + + // The mutex for event txt event_file_path_ + std::mutex event_txt_file_mtx_; + + // The fstream for event_file_path_ + std::fstream event_txt_file_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 9931725e48f..7fe3a47a984 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -42,6 +42,7 @@ void ServerNode::Initialize() { config_ = std::make_unique(PSContext::instance()->config_file_path()); MS_EXCEPTION_IF_NULL(config_); InitNodeNum(); + bool is_recover = false; if (!config_->Initialize()) { MS_LOG(WARNING) << "The config file is empty."; } else { @@ -62,6 +63,10 @@ void ServerNode::Initialize() { } InitClientToServer(); is_already_stopped_ = false; + if (is_recover) { + std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_); + SendFailMessageToScheduler(node_role, "Node restart"); + } MS_LOG(INFO) << "[Server start]: 3. Server node crete tcp client to scheduler successful!"; } diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index 442d1a13ca2..9f397bb39f3 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -41,10 +41,12 @@ void WorkerNode::Initialize() { config_ = std::make_unique(PSContext::instance()->config_file_path()); MS_EXCEPTION_IF_NULL(config_); InitNodeNum(); + bool is_recover = false; if (!config_->Initialize()) { MS_LOG(WARNING) << "The config file is empty."; } else { - if (!Recover()) { + is_recover = Recover(); + if (!is_recover) { MS_LOG(DEBUG) << "Recover the worker node is failed."; } } @@ -60,6 +62,10 @@ void WorkerNode::Initialize() { } InitClientToServer(); is_already_stopped_ = false; + if (is_recover) { + std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_); + SendFailMessageToScheduler(node_role, "Node restart"); + } MS_LOG(INFO) << "[Worker start]: 3. Worker node crete tcp client to scheduler successful!"; } diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 8f8fc1d66f6..f32f8880653 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -15,13 +15,14 @@ */ #include "ps/ps_context.h" + +#include "kernel/kernel.h" #include "utils/log_adapter.h" #include "utils/ms_utils.h" -#include "kernel/kernel.h" #if ((defined ENABLE_CPU) && (!defined _WIN32)) +#include "distributed/cluster/cluster_context.h" #include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" -#include "distributed/cluster/cluster_context.h" #else #include "distributed/cluster/dummy_cluster_context.h" #endif @@ -566,5 +567,21 @@ std::string PSContext::download_compress_type() const { return download_compress std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; } void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; } + +void PSContext::set_instance_name(const std::string &instance_name) { instance_name_ = instance_name; } + +const std::string &PSContext::instance_name() const { return instance_name_; } + +void PSContext::set_participation_time_level(const std::string &participation_time_level) { + participation_time_level_ = participation_time_level; +} + +const std::string &PSContext::participation_time_level() { return participation_time_level_; } + +void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times) { + continuous_failure_times_ = continuous_failure_times; +} + +uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; } } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 0245ef87492..9042461925b 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -244,6 +244,15 @@ class BACKEND_EXPORT PSContext { std::string checkpoint_dir() const; void set_checkpoint_dir(const std::string &checkpoint_dir); + void set_instance_name(const std::string &instance_name); + const std::string &instance_name() const; + + void set_participation_time_level(const std::string &participation_time_level); + const std::string &participation_time_level(); + + void set_continuous_failure_times(uint32_t continuous_failure_times); + uint32_t continuous_failure_times(); + private: PSContext() : ps_enabled_(false), @@ -300,7 +309,10 @@ class BACKEND_EXPORT PSContext { upload_compress_type_(kNoCompressType), upload_sparse_rate_(0.4f), download_compress_type_(kNoCompressType), - checkpoint_dir_("") {} + checkpoint_dir_(""), + instance_name_(""), + participation_time_level_("5,15"), + continuous_failure_times_(10) {} bool ps_enabled_; bool is_worker_; bool is_pserver_; @@ -442,6 +454,15 @@ class BACKEND_EXPORT PSContext { // directory of server checkpoint std::string checkpoint_dir_; + + // The name of instance + std::string instance_name_; + + // The participation time level + std::string participation_time_level_; + + // The times of iteration continuous failure + uint32_t continuous_failure_times_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/python/mindspore/parallel/_ps_context.py b/mindspore/python/mindspore/parallel/_ps_context.py index b3fad327aa7..117b62e6179 100644 --- a/mindspore/python/mindspore/parallel/_ps_context.py +++ b/mindspore/python/mindspore/parallel/_ps_context.py @@ -83,6 +83,9 @@ _set_ps_context_func_map = { "upload_compress_type": ps_context().set_upload_compress_type, "upload_sparse_rate": ps_context().set_upload_sparse_rate, "download_compress_type": ps_context().set_download_compress_type, + "instance_name": ps_context().set_instance_name, + "participation_time_level": ps_context().set_participation_time_level, + "continuous_failure_times": ps_context().set_continuous_failure_times, } _get_ps_context_func_map = { @@ -134,6 +137,9 @@ _get_ps_context_func_map = { "upload_compress_type": ps_context().upload_compress_type, "upload_sparse_rate": ps_context().upload_sparse_rate, "download_compress_type": ps_context().download_compress_type, + "instance_name": ps_context().instance_name, + "participation_time_level": ps_context().participation_time_level, + "continuous_failure_times": ps_context().continuous_failure_times, } _check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", diff --git a/scripts/fl_restful_tool.py b/scripts/fl_restful_tool.py index c0dbd840f37..e3d15e937fd 100644 --- a/scripts/fl_restful_tool.py +++ b/scripts/fl_restful_tool.py @@ -178,10 +178,6 @@ def call_get_instance_detail(): return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.") ans_json_obj = {} - metrics_auc_list = [] - metrics_loss_list = [] - iteration_execution_time_list = [] - client_visited_info_list = [] with open(metrics_file_path, 'r') as f: metrics_list = f.readlines() @@ -189,28 +185,13 @@ def call_get_instance_detail(): if not metrics_list: return process_self_define_json(Status.FAILED.value, "error. metrics file has no content") - for metrics in metrics_list: - json_obj = json.loads(metrics) - iteration_execution_time_list.append(json_obj['iterationExecutionTime']) - client_visited_info_list.append(json_obj['clientVisitedInfo']) - metrics_auc_list.append(json_obj['metricsAuc']) - metrics_loss_list.append(json_obj['metricsLoss']) last_metrics = metrics_list[len(metrics_list) - 1] last_metrics_obj = json.loads(last_metrics) ans_json_obj["code"] = Status.SUCCESS.value ans_json_obj["describe"] = "get instance metrics detail successful." - ans_json_obj["result"] = {} - ans_json_result = ans_json_obj.get("result") - ans_json_result['currentIteration'] = last_metrics_obj['currentIteration'] - ans_json_result['flIterationNum'] = last_metrics_obj['flIterationNum'] - ans_json_result['flName'] = last_metrics_obj['flName'] - ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus'] - ans_json_result['iterationExecutionTime'] = iteration_execution_time_list - ans_json_result['clientVisitedInfo'] = client_visited_info_list - ans_json_result['metricsAuc'] = metrics_auc_list - ans_json_result['metricsLoss'] = metrics_loss_list + ans_json_obj["result"] = last_metrics_obj return json.dumps(ans_json_obj) diff --git a/tests/st/fl/albert/config.json b/tests/st/fl/albert/config.json index db2d3571180..75e7e4ee2f1 100644 --- a/tests/st/fl/albert/config.json +++ b/tests/st/fl/albert/config.json @@ -15,6 +15,14 @@ "storage_type": 1, "storage_file_path": "metrics.json" }, + "dataRate": { + "storage_type": 1, + "storage_file_path": ".." + }, + "failureEvent": { + "storage_type": 1, + "storage_file_path": "event.txt" + }, "server_recovery": { "storage_type": 1, "storage_file_path": "../server_recovery.json" diff --git a/tests/st/fl/cross_device_lenet/cloud/config.json b/tests/st/fl/cross_device_lenet/cloud/config.json index 674338a2fdc..7473e792498 100644 --- a/tests/st/fl/cross_device_lenet/cloud/config.json +++ b/tests/st/fl/cross_device_lenet/cloud/config.json @@ -15,6 +15,14 @@ "storage_type": 1, "storage_file_path": "metrics.json" }, + "dataRate": { + "storage_type": 1, + "storage_file_path": ".." + }, + "failureEvent": { + "storage_type": 1, + "storage_file_path": "event.txt" + }, "server_recovery": { "storage_type": 1, "storage_file_path": "../server_recovery.json" diff --git a/tests/st/fl/cross_silo_faster_rcnn/config.json b/tests/st/fl/cross_silo_faster_rcnn/config.json index db2d3571180..75e7e4ee2f1 100644 --- a/tests/st/fl/cross_silo_faster_rcnn/config.json +++ b/tests/st/fl/cross_silo_faster_rcnn/config.json @@ -15,6 +15,14 @@ "storage_type": 1, "storage_file_path": "metrics.json" }, + "dataRate": { + "storage_type": 1, + "storage_file_path": ".." + }, + "failureEvent": { + "storage_type": 1, + "storage_file_path": "event.txt" + }, "server_recovery": { "storage_type": 1, "storage_file_path": "../server_recovery.json" diff --git a/tests/st/fl/cross_silo_femnist/config.json b/tests/st/fl/cross_silo_femnist/config.json index db2d3571180..75e7e4ee2f1 100644 --- a/tests/st/fl/cross_silo_femnist/config.json +++ b/tests/st/fl/cross_silo_femnist/config.json @@ -15,6 +15,14 @@ "storage_type": 1, "storage_file_path": "metrics.json" }, + "dataRate": { + "storage_type": 1, + "storage_file_path": ".." + }, + "failureEvent": { + "storage_type": 1, + "storage_file_path": "event.txt" + }, "server_recovery": { "storage_type": 1, "storage_file_path": "../server_recovery.json" diff --git a/tests/st/fl/cross_silo_lenet/config.json b/tests/st/fl/cross_silo_lenet/config.json index db2d3571180..75e7e4ee2f1 100644 --- a/tests/st/fl/cross_silo_lenet/config.json +++ b/tests/st/fl/cross_silo_lenet/config.json @@ -15,6 +15,14 @@ "storage_type": 1, "storage_file_path": "metrics.json" }, + "dataRate": { + "storage_type": 1, + "storage_file_path": ".." + }, + "failureEvent": { + "storage_type": 1, + "storage_file_path": "event.txt" + }, "server_recovery": { "storage_type": 1, "storage_file_path": "../server_recovery.json" diff --git a/tests/st/fl/hybrid_lenet/config.json b/tests/st/fl/hybrid_lenet/config.json index db2d3571180..75e7e4ee2f1 100644 --- a/tests/st/fl/hybrid_lenet/config.json +++ b/tests/st/fl/hybrid_lenet/config.json @@ -15,6 +15,14 @@ "storage_type": 1, "storage_file_path": "metrics.json" }, + "dataRate": { + "storage_type": 1, + "storage_file_path": ".." + }, + "failureEvent": { + "storage_type": 1, + "storage_file_path": "event.txt" + }, "server_recovery": { "storage_type": 1, "storage_file_path": "../server_recovery.json" diff --git a/tests/st/fl/mobile/config.json b/tests/st/fl/mobile/config.json index db2d3571180..75e7e4ee2f1 100644 --- a/tests/st/fl/mobile/config.json +++ b/tests/st/fl/mobile/config.json @@ -15,6 +15,14 @@ "storage_type": 1, "storage_file_path": "metrics.json" }, + "dataRate": { + "storage_type": 1, + "storage_file_path": ".." + }, + "failureEvent": { + "storage_type": 1, + "storage_file_path": "event.txt" + }, "server_recovery": { "storage_type": 1, "storage_file_path": "../server_recovery.json"