The maintainability function of FL

This commit is contained in:
zhou_chao1993 2022-04-22 14:29:24 +08:00
parent 123e1b97de
commit 51f6b77ab0
37 changed files with 900 additions and 126 deletions

View File

@ -23,6 +23,7 @@
#include <climits>
#include <memory>
#include <functional>
#include <iomanip>
#include "proto/ps.pb.h"
#include "proto/fl.pb.h"
#include "ir/anf.h"
@ -149,11 +150,15 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum";
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
constexpr auto kParticipationTimeLevel1 = "participationTimeLevel1";
constexpr auto kParticipationTimeLevel2 = "participationTimeLevel2";
constexpr auto kParticipationTimeLevel3 = "participationTimeLevel3";
constexpr auto kMinVal = "min_val";
constexpr auto kMaxVal = "max_val";
constexpr auto kQuant = "QUANT";
constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT";
constexpr auto kNoCompress = "NO_COMPRESS";
constexpr auto kUpdateModel = "updateModel";
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
// launched.

View File

@ -15,27 +15,40 @@
*/
#include "fl/server/iteration.h"
#include <memory>
#include <vector>
#include <string>
#include <numeric>
#include <string>
#include <unordered_map>
#include <vector>
#include "fl/server/model_store.h"
#include "fl/server/server.h"
#include "ps/core/comm_util.h"
namespace mindspore {
namespace fl {
namespace server {
namespace {
const size_t kParticipationTimeLevelNum = 3;
const size_t kIndexZero = 0;
const size_t kIndexOne = 1;
const size_t kIndexTwo = 2;
const size_t kLastSecond = 59;
} // namespace
class Server;
Iteration::~Iteration() {
move_to_next_thread_running_ = false;
is_date_rate_thread_running_ = false;
next_iteration_cv_.notify_all();
if (move_to_next_thread_.joinable()) {
move_to_next_thread_.join();
}
if (data_rate_thread_.joinable()) {
data_rate_thread_.join();
}
}
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
@ -160,8 +173,13 @@ void Iteration::SetIterationRunning() {
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
iteration_state_ = IterationState::kRunning;
start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
start_time_ = ps::core::CommUtil::GetNowTime();
MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " start global timer.";
instance_name_ = ps::PSContext::instance()->instance_name();
if (instance_name_.empty()) {
MS_LOG(WARNING) << "instance name is empty";
instance_name_ = "instance_" + start_time_.time_str_mill;
}
global_iter_timer_->Start(std::chrono::milliseconds(global_iteration_time_window_));
}
@ -175,7 +193,7 @@ void Iteration::SetIterationEnd() {
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
iteration_state_ = IterationState::kCompleted;
complete_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
complete_time_ = ps::core::CommUtil::GetNowTime();
}
void Iteration::ScalingBarrier() {
@ -631,6 +649,13 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
round_client_num_map_[kUpdateModelAcceptClientNum] += round->kernel_accept_client_num();
round_client_num_map_[kUpdateModelRejectClientNum] += round->kernel_reject_client_num();
set_loss(loss_ + round->kernel_upload_loss());
auto update_model_complete_info = round->GetUpdateModelCompleteInfo();
if (update_model_complete_info.size() != kParticipationTimeLevelNum) {
continue;
}
round_client_num_map_[kParticipationTimeLevel1] += update_model_complete_info[kIndexZero].second;
round_client_num_map_[kParticipationTimeLevel2] += update_model_complete_info[kIndexOne].second;
round_client_num_map_[kParticipationTimeLevel3] += update_model_complete_info[kIndexTwo].second;
} else if (round->name() == "getModel") {
round_client_num_map_[kGetModelTotalClientNum] += round->kernel_total_client_num();
round_client_num_map_[kGetModelAcceptClientNum] += round->kernel_accept_client_num();
@ -656,6 +681,13 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
}
EndLastIter();
if (iteration_fail_num_ == ps::PSContext::instance()->continuous_failure_times()) {
std::string node_role = "SERVER";
std::string event = "Iteration failed " + std::to_string(iteration_fail_num_) + " times continuously";
server_node_->SendFailMessageToScheduler(node_role, event);
// Finish sending one message, reset cout num to 0
iteration_fail_num_ = 0;
}
return true;
}
@ -695,6 +727,13 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
end_last_iter_rsp.set_updatemodel_accept_client_num(round->kernel_accept_client_num());
end_last_iter_rsp.set_updatemodel_reject_client_num(round->kernel_reject_client_num());
end_last_iter_rsp.set_upload_loss(round->kernel_upload_loss());
auto update_model_complete_info = round->GetUpdateModelCompleteInfo();
if (update_model_complete_info.size() != kParticipationTimeLevelNum) {
MS_LOG(EXCEPTION) << "update_model_complete_info size is not equal 3";
}
end_last_iter_rsp.set_participation_time_level1_num(update_model_complete_info[kIndexZero].second);
end_last_iter_rsp.set_participation_time_level2_num(update_model_complete_info[kIndexOne].second);
end_last_iter_rsp.set_participation_time_level3_num(update_model_complete_info[kIndexTwo].second);
} else if (round->name() == "getModel") {
end_last_iter_rsp.set_getmodel_total_client_num(round->kernel_total_client_num());
end_last_iter_rsp.set_getmodel_accept_client_num(round->kernel_accept_client_num());
@ -744,6 +783,7 @@ void Iteration::EndLastIter() {
MS_ERROR_IF_NULL_WO_RET_VAL(round);
round->InitkernelClientVisitedNum();
round->InitkernelClientUploadLoss();
round->ResetParticipationTimeAndNum();
}
round_client_num_map_.clear();
set_loss(0.0f);
@ -754,6 +794,11 @@ void Iteration::EndLastIter() {
} else {
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
}
if (iteration_result_.load() == IterationResult::kFail) {
iteration_fail_num_++;
} else {
iteration_fail_num_ = 0;
}
}
bool Iteration::ForciblyMoveToNextIteration() {
@ -768,6 +813,9 @@ bool Iteration::SummarizeIteration() {
return true;
}
metrics_->set_instance_name(instance_name_);
metrics_->set_start_time(start_time_);
metrics_->set_end_time(complete_time_);
metrics_->set_fl_name(ps::PSContext::instance()->fl_name());
metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num());
metrics_->set_cur_iteration_num(iteration_num_);
@ -781,12 +829,12 @@ bool Iteration::SummarizeIteration() {
metrics_->set_round_client_num_map(round_client_num_map_);
metrics_->set_iteration_result(iteration_result_.load());
if (complete_timestamp_ < start_timestamp_) {
MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_
<< ". One of them is invalid.";
if (complete_time_.time_stamp < start_time_.time_stamp) {
MS_LOG(ERROR) << "The complete_timestamp_: " << complete_time_.time_stamp
<< ", start_timestamp_: " << start_time_.time_stamp << ". One of them is invalid.";
metrics_->set_iteration_time_cost(UINT64_MAX);
} else {
metrics_->set_iteration_time_cost(complete_timestamp_ - start_timestamp_);
metrics_->set_iteration_time_cost(complete_time_.time_stamp - start_time_.time_stamp);
}
if (!metrics_->Summarize()) {
@ -910,6 +958,10 @@ void Iteration::UpdateRoundClientNumMap(const std::shared_ptr<std::vector<unsign
round_client_num_map_[kGetModelTotalClientNum] += end_last_iter_rsp.getmodel_total_client_num();
round_client_num_map_[kGetModelAcceptClientNum] += end_last_iter_rsp.getmodel_accept_client_num();
round_client_num_map_[kGetModelRejectClientNum] += end_last_iter_rsp.getmodel_reject_client_num();
round_client_num_map_[kParticipationTimeLevel1] += end_last_iter_rsp.participation_time_level1_num();
round_client_num_map_[kParticipationTimeLevel2] += end_last_iter_rsp.participation_time_level2_num();
round_client_num_map_[kParticipationTimeLevel3] += end_last_iter_rsp.participation_time_level3_num();
}
void Iteration::UpdateRoundClientUploadLoss(const std::shared_ptr<std::vector<unsigned char>> &client_info_rsp_msg) {
@ -924,6 +976,103 @@ void Iteration::set_instance_state(InstanceState state) {
instance_state_ = state;
MS_LOG(INFO) << "Server instance state is " << GetInstanceStateStr(instance_state_);
}
void Iteration::SetFileConfig(const std::shared_ptr<ps::core::FileConfiguration> &file_configuration) {
file_configuration_ = file_configuration;
}
string Iteration::GetDataRateFilePath() {
ps::core::FileConfig data_rate_config;
if (!ps::core::CommUtil::ParseAndCheckConfigJson(file_configuration_.get(), kDataRate, &data_rate_config)) {
MS_LOG(EXCEPTION) << "Data rate parament in config is not correct";
}
return data_rate_config.storage_file_path;
}
void Iteration::StartThreadToRecordDataRate() {
MS_LOG(INFO) << "Start to create a thread to record data rate";
data_rate_thread_ = std::thread([&]() {
std::fstream file_stream;
std::string data_rate_path = GetDataRateFilePath();
MS_LOG(DEBUG) << "The data rate file path is " << data_rate_path;
uint32_t rank_id = server_node_->rank_id();
while (is_date_rate_thread_running_) {
// record data every 60 seconds
std::this_thread::sleep_for(std::chrono::seconds(60));
auto time_now = std::chrono::system_clock::now();
std::time_t tt = std::chrono::system_clock::to_time_t(time_now);
struct tm ptm;
(void)localtime_r(&tt, &ptm);
std::ostringstream time_day_oss;
time_day_oss << std::put_time(&ptm, "%Y-%m-%d");
std::string time_day = time_day_oss.str();
std::string data_rate_file = data_rate_path + "/" + time_day + "_flow_server" + std::to_string(rank_id) + ".json";
file_stream.open(data_rate_file, std::ios::out | std::ios::app);
if (!file_stream.is_open()) {
MS_LOG(WARNING) << data_rate_file << "is not open! Please check config file!";
return;
}
std::map<uint64_t, size_t> send_datas;
std::map<uint64_t, size_t> receive_datas;
for (const auto &round : rounds_) {
if (round == nullptr) {
MS_LOG(WARNING) << "round is nullptr";
continue;
}
auto send_data = round->GetSendData();
for (const auto &it : send_data) {
if (send_datas.find(it.first) != send_datas.end()) {
send_datas[it.first] = send_datas[it.first] + it.second;
} else {
send_datas.emplace(it);
}
}
auto receive_data = round->GetReceiveData();
for (const auto &it : receive_data) {
if (receive_datas.find(it.first) != receive_datas.end()) {
receive_datas[it.first] = receive_datas[it.first] + it.second;
} else {
receive_datas.emplace(it);
}
}
round->ClearData();
}
std::map<uint64_t, std::vector<size_t>> all_datas;
for (auto &it : send_datas) {
std::vector<size_t> send_and_receive_data;
send_and_receive_data.emplace_back(it.second);
send_and_receive_data.emplace_back(0);
all_datas.emplace(it.first, send_and_receive_data);
}
for (auto &it : receive_datas) {
if (all_datas.find(it.first) != all_datas.end()) {
std::vector<size_t> &temp = all_datas.at(it.first);
temp[1] = it.second;
} else {
std::vector<size_t> send_and_receive_data;
send_and_receive_data.emplace_back(0);
send_and_receive_data.emplace_back(it.second);
all_datas.emplace(it.first, send_and_receive_data);
}
}
for (auto &it : all_datas) {
nlohmann::json js;
auto data_time = static_cast<time_t>(it.first);
struct tm data_tm;
(void)localtime_r(&data_time, &data_tm);
std::ostringstream oss_second;
oss_second << std::put_time(&data_tm, "%Y-%m-%d %H:%M:%S");
js["time"] = oss_second.str();
js["send"] = it.second[0];
js["receive"] = it.second[1];
file_stream << js << "\n";
}
(void)file_stream.close();
}
});
return;
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include <string>
#include <map>
#include "ps/core/communicator/communicator_base.h"
#include "ps/core/file_configuration.h"
#include "fl/server/common.h"
#include "fl/server/round.h"
#include "fl/server/local_meta_store.h"
@ -134,14 +135,18 @@ class Iteration {
void set_instance_state(InstanceState staet);
// Create a thread to record date rate
void StartThreadToRecordDataRate();
// Set file_configuration
void SetFileConfig(const std::shared_ptr<ps::core::FileConfiguration> &file_configuration);
private:
Iteration()
: running_round_num_(0),
server_node_(nullptr),
communicator_(nullptr),
iteration_state_(IterationState::kCompleted),
start_timestamp_(0),
complete_timestamp_(0),
iteration_loop_count_(0),
iteration_num_(1),
is_last_iteration_valid_(true),
@ -164,7 +169,11 @@ class Iteration {
{kStartFLJobRejectClientNum, 0},
{kUpdateModelRejectClientNum, 0},
{kGetModelRejectClientNum, 0}}),
iteration_result_(IterationResult::kSuccess) {
iteration_result_(IterationResult::kSuccess),
iteration_fail_num_(0),
is_date_rate_thread_running_(true),
file_configuration_(nullptr),
instance_name_("") {
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
}
~Iteration();
@ -223,6 +232,8 @@ class Iteration {
void StartNewInstance();
std::string GetDataRateFilePath();
std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
@ -236,8 +247,12 @@ class Iteration {
std::mutex iteration_state_mtx_;
std::condition_variable iteration_state_cv_;
std::atomic<IterationState> iteration_state_;
uint64_t start_timestamp_;
uint64_t complete_timestamp_;
// Iteration start time
ps::core::Time start_time_;
// Iteration complete time
ps::core::Time complete_time_;
// The count of iteration loops which are completed.
size_t iteration_loop_count_;
@ -295,6 +310,21 @@ class Iteration {
// mutex for iter move to next, avoid core dump
std::mutex iter_move_mtx_;
// The number of iteration continuous failure
uint32_t iteration_fail_num_;
// The thread to record data rate
std::thread data_rate_thread_;
// The state of data rate thread
std::atomic_bool is_date_rate_thread_running_;
// The ptr of config file
std::shared_ptr<ps::core::FileConfiguration> file_configuration_;
// The instance name
std::string instance_name_;
};
} // namespace server
} // namespace fl

View File

@ -15,11 +15,13 @@
*/
#include "fl/server/iteration_metrics.h"
#include <string>
#include <fstream>
#include "utils/file_utils.h"
#include <string>
#include "include/common/debug/common.h"
#include "ps/constants.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace fl {
@ -28,42 +30,23 @@ bool IterationMetrics::Initialize() {
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
MS_LOG(EXCEPTION) << "Initializing for Config file path failed!" << config_file_path_
<< " may be invalid or not exist.";
}
// Read the metrics file path. If file is not set or not exits, create one.
if (!config_->Exists(kMetrics)) {
MS_LOG(WARNING) << "Metrics config is not set. Don't write metrics.";
return false;
} else {
std::string value = config_->Get(kMetrics, "");
nlohmann::json value_json;
try {
value_json = nlohmann::json::parse(value);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format.";
return false;
}
// Parse the storage type.
uint32_t storage_type = JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
if (std::to_string(storage_type) != ps::kFileStorage) {
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
return false;
}
// Parse storage file path.
metrics_file_path_ = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
auto realpath = Common::CreatePrefixPath(metrics_file_path_.c_str());
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Creating path for " << metrics_file_path_ << " failed.";
return false;
}
metrics_file_.open(realpath.value(), std::ios::app | std::ios::out);
metrics_file_.close();
}
ps::core::FileConfig metrics_config;
if (!ps::core::CommUtil::ParseAndCheckConfigJson(config_.get(), kMetrics, &metrics_config)) {
MS_LOG(WARNING) << "Metrics parament in config is not correct";
return false;
}
metrics_file_path_ = metrics_config.storage_file_path;
auto realpath = Common::CreatePrefixPath(metrics_file_path_.c_str());
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Creating path for " << metrics_file_path_ << " failed.";
return false;
}
metrics_file_.open(realpath.value(), std::ios::app | std::ios::out);
metrics_file_.close();
return true;
}
@ -74,6 +57,9 @@ bool IterationMetrics::Summarize() {
return false;
}
js_[kInstanceName] = instance_name_;
js_[kStartTime] = start_time_.time_str_mill;
js_[kEndTime] = end_time_.time_str_mill;
js_[kFLName] = fl_name_;
js_[kInstanceStatus] = kInstanceStateName.at(instance_state_);
js_[kFLIterationNum] = fl_iteration_num_;
@ -120,6 +106,12 @@ void IterationMetrics::set_round_client_num_map(const std::map<std::string, size
}
void IterationMetrics::set_iteration_result(IterationResult iteration_result) { iteration_result_ = iteration_result; }
void IterationMetrics::set_start_time(const ps::core::Time &start_time) { start_time_ = start_time; }
void IterationMetrics::set_end_time(const ps::core::Time &end_time) { end_time_ = end_time; }
void IterationMetrics::set_instance_name(const std::string &instance_name) { instance_name_ = instance_name; }
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -26,10 +26,12 @@
#include "ps/core/file_configuration.h"
#include "fl/server/local_meta_store.h"
#include "fl/server/iteration.h"
#include "ps/core/comm_util.h"
namespace mindspore {
namespace fl {
namespace server {
constexpr auto kInstanceName = "instanceName";
constexpr auto kFLName = "flName";
constexpr auto kInstanceStatus = "instanceStatus";
constexpr auto kFLIterationNum = "flIterationNum";
@ -40,6 +42,9 @@ constexpr auto kIterExecutionTime = "iterationExecutionTime";
constexpr auto kMetrics = "metrics";
constexpr auto kClientVisitedInfo = "clientVisitedInfo";
constexpr auto kIterationResult = "iterationResult";
constexpr auto kStartTime = "startTime";
constexpr auto kEndTime = "endTime";
constexpr auto kDataRate = "dataRate";
const std::map<InstanceState, std::string> kInstanceStateName = {
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
@ -80,6 +85,9 @@ class IterationMetrics {
void set_iteration_time_cost(uint64_t iteration_time_cost);
void set_round_client_num_map(const std::map<std::string, size_t> round_client_num_map);
void set_iteration_result(IterationResult iteration_result);
void set_start_time(const ps::core::Time &start_time);
void set_end_time(const ps::core::Time &end_time);
void set_instance_name(const std::string &instance_name);
private:
// This is the main config file set by ps context.
@ -122,6 +130,11 @@ class IterationMetrics {
// Current iteration running result.
IterationResult iteration_result_;
ps::core::Time start_time_;
ps::core::Time end_time_;
std::string instance_name_;
};
} // namespace server
} // namespace fl

View File

@ -15,13 +15,15 @@
*/
#include "fl/server/kernel/round/round_kernel.h"
#include <chrono>
#include <mutex>
#include <queue>
#include <chrono>
#include <string>
#include <thread>
#include <utility>
#include <string>
#include <vector>
#include "fl/server/iteration.h"
namespace mindspore {
@ -65,6 +67,7 @@ void RoundKernel::SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler
MS_LOG(WARNING) << "Sending response failed.";
return;
}
CalculateSendData(len);
}
void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
@ -77,6 +80,7 @@ void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::Messa
MS_LOG(WARNING) << "Sending response failed.";
return;
}
CalculateSendData(len);
}
bool RoundKernel::verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
@ -128,6 +132,67 @@ void RoundKernel::InitClientUploadLoss() { upload_loss_ = 0.0f; }
void RoundKernel::UpdateClientUploadLoss(const float upload_loss) { upload_loss_ = upload_loss_ + upload_loss; }
float RoundKernel::upload_loss() const { return upload_loss_; }
void RoundKernel::CalculateSendData(size_t send_len) {
uint64_t second_time_stamp =
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
if (send_data_time_ == 0) {
send_data_ = send_len;
send_data_time_ = second_time_stamp;
return;
}
if (second_time_stamp == send_data_time_) {
send_data_ += send_len;
} else {
RecordSendData(send_data_time_, send_data_);
send_data_time_ = second_time_stamp;
send_data_ = send_len;
}
}
void RoundKernel::CalculateReceiveData(size_t receive_len) {
uint64_t second_time_stamp =
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
if (receive_data_time_ == 0) {
receive_data_time_ = second_time_stamp;
receive_data_ = receive_len;
return;
}
if (second_time_stamp == receive_data_time_) {
receive_data_ += receive_len;
} else {
RecordReceiveData(receive_data_time_, receive_data_);
receive_data_time_ = second_time_stamp;
receive_data_ = receive_len;
}
}
void RoundKernel::RecordSendData(uint64_t time_stamp_second, size_t send_data) {
std::lock_guard<std::mutex> lock(send_data_rate_mutex_);
send_data_and_time_.emplace(time_stamp_second, send_data);
}
void RoundKernel::RecordReceiveData(uint64_t time_stamp_second, size_t receive_data) {
std::lock_guard<std::mutex> lock(receive_data_rate_mutex_);
receive_data_and_time_.emplace(time_stamp_second, receive_data);
}
std::map<uint64_t, size_t> RoundKernel::GetSendData() {
std::lock_guard<std::mutex> lock(send_data_rate_mutex_);
return send_data_and_time_;
}
std::map<uint64_t, size_t> RoundKernel::GetReceiveData() {
std::lock_guard<std::mutex> lock(receive_data_rate_mutex_);
return receive_data_and_time_;
}
void RoundKernel::ClearData() {
std::lock_guard<std::mutex> lock(send_data_rate_mutex_);
std::lock_guard<std::mutex> lock2(receive_data_rate_mutex_);
send_data_and_time_.clear();
receive_data_and_time_.clear();
}
} // namespace kernel
} // namespace server
} // namespace fl

View File

@ -17,22 +17,23 @@
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <mutex>
#include <queue>
#include <utility>
#include <chrono>
#include <string>
#include <thread>
#include <unordered_map>
#include "kernel/common_utils.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include <utility>
#include <vector>
#include "fl/server/common.h"
#include "fl/server/local_meta_store.h"
#include "fl/server/distributed_count_service.h"
#include "fl/server/distributed_metadata_store.h"
#include "fl/server/local_meta_store.h"
#include "kernel/common_utils.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
namespace mindspore {
namespace fl {
@ -102,6 +103,24 @@ class RoundKernel {
bool verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
void CalculateSendData(size_t send_len);
void CalculateReceiveData(size_t receive_len);
// Record the size of send data and the time stamp
void RecordSendData(uint64_t time_stamp_second, size_t send_data);
// Record the size of receive data and the time stamp
void RecordReceiveData(uint64_t time_stamp_second, size_t receive_data);
// Get the info of send data
std::map<uint64_t, size_t> GetSendData();
// Get the info of receive data
std::map<uint64_t, size_t> GetReceiveData();
// Clear the send data info
void ClearData();
protected:
// Send response to client, and the data can be released after the call.
void SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
@ -127,6 +146,26 @@ class RoundKernel {
std::atomic<size_t> accept_client_num_;
std::atomic<float> upload_loss_;
// The mutex for send_data_and_time_
std::mutex send_data_rate_mutex_;
// The size of send data ant time
std::map<uint64_t, size_t> send_data_and_time_;
// The mutex for receive_data_and_time_
std::mutex receive_data_rate_mutex_;
// The size of receive data and time
std::map<uint64_t, size_t> receive_data_and_time_;
std::atomic_size_t send_data_ = 0;
std::atomic_uint64_t send_data_time_ = 0;
std::atomic_size_t receive_data_ = 0;
std::atomic_uint64_t receive_data_time_ = 0;
};
} // namespace kernel
} // namespace server

View File

@ -115,6 +115,9 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
}
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
uint64_t start_fl_job_time =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
device_meta.set_now_time(start_fl_job_time);
result_code = ReadyForStartFLJob(fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());

View File

@ -14,17 +14,26 @@
* limitations under the License.
*/
#include "fl/server/kernel/round/update_model_kernel.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "fl/server/kernel/round/update_model_kernel.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
namespace {
const size_t kLevelNum = 2;
const uint64_t kMaxLevelNum = 2880;
const uint64_t kMinLevelNum = 0;
const int kBase = 10;
const uint64_t kMinuteToSecond = 60;
const uint64_t kSecondToMills = 1000;
const uint64_t kDefaultLevel1 = 5;
const uint64_t kDefaultLevel2 = 15;
} // namespace
const char *kCountForAggregation = "count_for_aggregation";
void UpdateModelKernel::InitKernel(size_t threshold_count) {
@ -49,6 +58,8 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
auto last_cnt_handler = [this](std::shared_ptr<ps::core::MessageHandler>) { RunAggregation(); };
DistributedCountService::GetInstance().RegisterCounter(kCountForAggregation, threshold_count,
{first_cnt_handler, last_cnt_handler});
std::string participation_time_level_str = ps::PSContext::instance()->participation_time_level();
CheckAndTransPara(participation_time_level_str);
}
bool UpdateModelKernel::VerifyUpdateModelRequest(const schema::RequestUpdateModel *update_model_req) {
@ -123,6 +134,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
}
std::string update_model_fl_id = update_model_req->fl_id()->str();
IncreaseAcceptClientNum();
RecordCompletePeriod(device_meta);
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
result_code = CountForAggregation(update_model_fl_id);
@ -146,6 +158,18 @@ bool UpdateModelKernel::Reset() {
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {}
const std::vector<std::pair<uint64_t, uint32_t>> &UpdateModelKernel::GetCompletePeriodRecord() {
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
return participation_time_and_num_;
}
void UpdateModelKernel::ResetParticipationTimeAndNum() {
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
for (auto &it : participation_time_and_num_) {
it.second = 0;
}
}
void UpdateModelKernel::RunAggregation() {
auto is_last_iter_valid = Executor::GetInstance().RunAllWeightAggregation();
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
@ -619,6 +643,66 @@ void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fb
return;
}
void UpdateModelKernel::RecordCompletePeriod(const DeviceMeta &device_meta) {
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
uint64_t start_fl_job_time = device_meta.now_time();
uint64_t update_model_complete_time =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
if (start_fl_job_time >= update_model_complete_time) {
MS_LOG(WARNING) << "start_fl_job_time " << start_fl_job_time << " is larger than update_model_complete_time "
<< update_model_complete_time;
return;
}
uint64_t cost_time = update_model_complete_time - start_fl_job_time;
MS_LOG(DEBUG) << "start_fl_job time is " << start_fl_job_time << " update_model time is "
<< update_model_complete_time;
for (auto &it : participation_time_and_num_) {
if (cost_time < it.first) {
it.second++;
}
}
}
void UpdateModelKernel::CheckAndTransPara(const std::string &participation_time_level) {
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
// The default time level is 5min and 15min, trans time to millisecond
participation_time_and_num_.emplace_back(std::make_pair(kDefaultLevel1 * kMinuteToSecond * kSecondToMills, 0));
participation_time_and_num_.emplace_back(std::make_pair(kDefaultLevel2 * kMinuteToSecond * kSecondToMills, 0));
participation_time_and_num_.emplace_back(std::make_pair(UINT64_MAX, 0));
std::vector<std::string> time_levels;
std::istringstream iss(participation_time_level);
std::string output;
while (std::getline(iss, output, ',')) {
if (!output.empty()) {
time_levels.emplace_back(std::move(output));
}
}
if (time_levels.size() != kLevelNum) {
MS_LOG(WARNING) << "Parameter participation_time_level is not correct";
return;
}
uint64_t level1 = std::strtoull(time_levels[0].c_str(), nullptr, kBase);
if (level1 > kMaxLevelNum || level1 <= kMinLevelNum) {
MS_LOG(WARNING) << "Level1 partmeter " << level1 << " is not legal";
return;
}
uint64_t level2 = std::strtoull(time_levels[1].c_str(), nullptr, kBase);
if (level2 > kMaxLevelNum || level2 <= kMinLevelNum) {
MS_LOG(WARNING) << "Level2 partmeter " << level2 << "is not legal";
return;
}
if (level1 >= level2) {
MS_LOG(WARNING) << "Level1 parameter " << level1 << " is larger than level2 " << level2;
return;
}
// Save the the parament of user
participation_time_and_num_.clear();
participation_time_and_num_.emplace_back(std::make_pair(level1 * kMinuteToSecond * kSecondToMills, 0));
participation_time_and_num_.emplace_back(std::make_pair(level2 * kMinuteToSecond * kSecondToMills, 0));
participation_time_and_num_.emplace_back(std::make_pair(UINT64_MAX, 0));
}
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
} // namespace kernel
} // namespace server

View File

@ -18,21 +18,23 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#include <map>
#include <unordered_map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include "fl/server/common.h"
#include "fl/server/executor.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/server/executor.h"
#include "fl/server/model_store.h"
#ifdef ENABLE_ARMOUR
#include "fl/armour/cipher/cipher_meta_storage.h"
#endif
#include "fl/compression/decode_executor.h"
#include "schema/fl_job_generated.h"
#include "schema/cipher_generated.h"
#include "schema/fl_job_generated.h"
namespace mindspore {
namespace fl {
@ -55,6 +57,12 @@ class UpdateModelKernel : public RoundKernel {
// In some cases, the last updateModel message means this server iteration is finished.
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
// Get participation_time_and_num_
const std::vector<std::pair<uint64_t, uint32_t>> &GetCompletePeriodRecord();
// Reset participation_time_and_num_
void ResetParticipationTimeAndNum();
private:
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req);
@ -82,6 +90,12 @@ class UpdateModelKernel : public RoundKernel {
const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta);
bool VerifyUpdateModelRequest(const schema::RequestUpdateModel *update_model_req);
// Record complete update model number according to participation_time_level
void RecordCompletePeriod(const DeviceMeta &device_meta);
// Check and transform participation time level parament
void CheckAndTransPara(const std::string &participation_time_level);
// The executor is for updating the model for updateModel request.
Executor *executor_{nullptr};
@ -95,6 +109,12 @@ class UpdateModelKernel : public RoundKernel {
// Check upload mode
bool IsCompress(const schema::RequestUpdateModel *update_model_req);
// From StartFlJob to UpdateModel complete time and number
std::vector<std::pair<uint64_t, uint32_t>> participation_time_and_num_{};
// The mutex for participation_time_and_num_
std::mutex participation_time_and_num_mtx_;
};
} // namespace kernel
} // namespace server

View File

@ -15,10 +15,14 @@
*/
#include "fl/server/round.h"
#include <memory>
#include <string>
#include "fl/server/server.h"
#include "fl/server/iteration.h"
#include "fl/server/kernel/round/update_model_kernel.h"
#include "fl/server/server.h"
#include "fl/server/common.h"
namespace mindspore {
namespace fl {
@ -143,6 +147,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
MS_LOG(DEBUG) << "Launching round kernel of round " + name_ + " failed.";
}
(void)(Iteration::GetInstance().running_round_num_--);
kernel_->CalculateReceiveData(message->len());
return;
}
@ -245,6 +250,32 @@ void Round::InitkernelClientVisitedNum() { kernel_->InitClientVisitedNum(); }
void Round::InitkernelClientUploadLoss() { kernel_->InitClientUploadLoss(); }
float Round::kernel_upload_loss() const { return kernel_->upload_loss(); }
std::vector<std::pair<uint64_t, uint32_t>> Round::GetUpdateModelCompleteInfo() const {
if (name_ == kUpdateModel) {
auto update_model_model_ptr = std::dynamic_pointer_cast<fl::server::kernel::UpdateModelKernel>(kernel_);
MS_EXCEPTION_IF_NULL(update_model_model_ptr);
return update_model_model_ptr->GetCompletePeriodRecord();
} else {
MS_LOG(EXCEPTION) << "The kernel is not updateModel";
return {};
}
}
void Round::ResetParticipationTimeAndNum() {
if (name_ == kUpdateModel) {
auto update_model_kernel_ptr = std::dynamic_pointer_cast<fl::server::kernel::UpdateModelKernel>(kernel_);
MS_ERROR_IF_NULL_WO_RET_VAL(update_model_kernel_ptr);
update_model_kernel_ptr->ResetParticipationTimeAndNum();
}
return;
}
std::map<uint64_t, size_t> Round::GetSendData() const { return kernel_->GetSendData(); }
std::map<uint64_t, size_t> Round::GetReceiveData() const { return kernel_->GetReceiveData(); }
void Round::ClearData() { return kernel_->ClearData(); }
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -19,11 +19,15 @@
#include <memory>
#include <string>
#include "ps/core/communicator/communicator_base.h"
#include <utility>
#include <vector>
#include <map>
#include "fl/server/common.h"
#include "fl/server/iteration_timer.h"
#include "fl/server/distributed_count_service.h"
#include "fl/server/iteration_timer.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "ps/core/communicator/communicator_base.h"
namespace mindspore {
namespace fl {
@ -78,6 +82,16 @@ class Round {
float kernel_upload_loss() const;
std::map<uint64_t, size_t> GetSendData() const;
std::map<uint64_t, size_t> GetReceiveData() const;
std::vector<std::pair<uint64_t, uint32_t>> GetUpdateModelCompleteInfo() const;
void ResetParticipationTimeAndNum();
void ClearData();
private:
// The callbacks which will be set to DistributedCounterService.
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);

View File

@ -91,6 +91,7 @@ void Server::Run() {
RegisterRoundKernel();
InitMetrics();
Recover();
iteration_->StartThreadToRecordDataRate();
MS_LOG(INFO) << "Server started successfully.";
safemode_ = false;
is_ready_ = true;
@ -261,6 +262,14 @@ void Server::InitIteration() {
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
iteration_->InitGlobalIterTimer(time_out_cb);
auto file_config_ptr = std::make_shared<ps::core::FileConfiguration>(ps::PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(file_config_ptr);
if (!file_config_ptr->Initialize()) {
MS_LOG(WARNING) << "Initializing for Config file path failed!" << ps::PSContext::instance()->config_file_path()
<< " may be invalid or not exist.";
return;
}
iteration_->SetFileConfig(file_config_ptr);
return;
}

View File

@ -551,7 +551,13 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.")
.def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.")
.def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.")
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.");
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.")
.def("set_instance_name", &PSContext::set_instance_name, "Set instance name.")
.def("instance_name", &PSContext::instance_name, "Get instance name.")
.def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.")
.def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.")
.def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times")
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");

View File

@ -15,9 +15,11 @@
*/
#include "ps/core/abstract_node.h"
#include "ps/core/node_recovery.h"
#include "ps/core/communicator/tcp_communicator.h"
#include "include/common/debug/common.h"
#include "ps/core/communicator/http_communicator.h"
#include "ps/core/communicator/tcp_communicator.h"
#include "ps/core/node_recovery.h"
namespace mindspore {
namespace ps {
@ -68,6 +70,32 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
}
}
void AbstractNode::SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info) {
auto message_meta = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_cmd(NodeCommand::FAILURE_EVENT_INFO);
std::string now_time = ps::core::CommUtil::GetNowTime().time_str_mill;
FailureEventMessage failure_event_message;
failure_event_message.set_node_role(node_role);
failure_event_message.set_ip(node_info_.ip_);
failure_event_message.set_port(node_info_.port_);
failure_event_message.set_time(now_time);
failure_event_message.set_event(event_info);
MS_LOG(INFO) << "The node role:" << node_role << "The node id:" << node_info_.node_id_
<< "begin to send failure message to scheduler!";
if (!SendMessageAsync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
failure_event_message.SerializeAsString().data(), failure_event_message.ByteSizeLong())) {
MS_LOG(ERROR) << "The node role:" << node_role << " the node id:" << node_info_.node_id_
<< " send failure message timeout!";
} else {
MS_LOG(INFO) << "The node role:" << node_role << " the node id:" << node_info_.node_id_ << " send failure message "
<< event_info << "success!";
}
}
void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);

View File

@ -17,25 +17,25 @@
#ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
#include <utility>
#include <string>
#include <memory>
#include <functional>
#include <map>
#include <vector>
#include <queue>
#include <unordered_map>
#include <functional>
#include <utility>
#include <vector>
#include <memory>
#include <string>
#include "ps/core/node.h"
#include "ps/core/communicator/message.h"
#include "ps/core/follower_scaler.h"
#include "utils/ms_exception.h"
#include "include/backend/visible.h"
#include "ps/constants.h"
#include "ps/core/communicator/communicator_base.h"
#include "ps/core/communicator/message.h"
#include "ps/core/communicator/task_executor.h"
#include "ps/core/follower_scaler.h"
#include "ps/core/node.h"
#include "ps/core/node_info.h"
#include "ps/core/recovery_base.h"
#include "ps/core/communicator/task_executor.h"
#include "ps/core/communicator/communicator_base.h"
#include "include/backend/visible.h"
#include "utils/ms_exception.h"
namespace mindspore {
namespace ps {
@ -163,6 +163,9 @@ class BACKEND_EXPORT AbstractNode : public Node {
// register cancel SafeMode function to node
void SetCancelSafeModeCallBack(const CancelSafeModeFn &fn) { cancelSafeModeFn_ = fn; }
// server node and worker node send exception message to scheduler
void SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info);
protected:
virtual void Register(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client);

View File

@ -17,13 +17,15 @@
#include "ps/core/comm_util.h"
#include <arpa/inet.h>
#include <unistd.h>
#include <sys/stat.h>
#include <unistd.h>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <algorithm>
#include <iomanip>
#include <regex>
namespace mindspore {
@ -600,6 +602,54 @@ bool CommUtil::StringToBool(const std::string &alive) {
}
return false;
}
Time CommUtil::GetNowTime() {
ps::core::Time time;
auto time_now = std::chrono::system_clock::now();
std::time_t tt = std::chrono::system_clock::to_time_t(time_now);
struct tm ptm;
(void)localtime_r(&tt, &ptm);
std::ostringstream time_mill_oss;
time_mill_oss << std::put_time(&ptm, "%Y-%m-%d %H:%M:%S");
// calculate millisecond, the format of time_str_mill is 2022-01-10 20:22:20.067
auto second_time_stamp = std::chrono::duration_cast<std::chrono::seconds>(time_now.time_since_epoch());
auto mill_time_stamp = std::chrono::duration_cast<std::chrono::milliseconds>(time_now.time_since_epoch());
auto ms_stamp = mill_time_stamp - second_time_stamp;
time_mill_oss << "." << std::setfill('0') << std::setw(kMillSecondLength) << ms_stamp.count();
time.time_stamp = mill_time_stamp.count();
time.time_str_mill = time_mill_oss.str();
return time;
}
bool CommUtil::ParseAndCheckConfigJson(Configuration *file_configuration, const std::string &key,
FileConfig *file_config) {
MS_EXCEPTION_IF_NULL(file_configuration);
MS_EXCEPTION_IF_NULL(file_config);
if (!file_configuration->Exists(key)) {
MS_LOG(WARNING) << key << " config is not set. Don't write.";
return false;
} else {
std::string value = file_configuration->Get(key, "");
nlohmann::json value_json;
try {
value_json = nlohmann::json::parse(value);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format.";
}
// Parse the storage type.
uint32_t storage_type = ps::core::CommUtil::JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
if (std::to_string(storage_type) != ps::kFileStorage) {
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
}
// Parse storage file path.
std::string file_path = ps::core::CommUtil::JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
file_config->storage_type = storage_type;
file_config->storage_file_path = file_path;
}
return true;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -19,47 +19,46 @@
#include <unistd.h>
#ifdef _MSC_VER
#include <tchar.h>
#include <winsock2.h>
#include <windows.h>
#include <iphlpapi.h>
#include <tchar.h>
#include <windows.h>
#include <winsock2.h>
#else
#include <net/if.h>
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netinet/in.h>
#endif
#include <assert.h>
#include <event2/buffer.h>
#include <event2/event.h>
#include <event2/http.h>
#include <event2/keyvalq_struct.h>
#include <event2/listener.h>
#include <event2/util.h>
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <assert.h>
#include <openssl/pkcs12.h>
#include <openssl/bio.h>
#include <openssl/rand.h>
#include <openssl/ssl.h>
#include <openssl/x509v3.h>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <functional>
#include <iostream>
#include <map>
#include <random>
#include <sstream>
#include <string>
#include <utility>
#include <thread>
#include <fstream>
#include <iostream>
#include <utility>
#include <vector>
#include <map>
#include <algorithm>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@ -78,12 +77,14 @@ constexpr int kGroup2RandomLength = 4;
constexpr int kGroup3RandomLength = 4;
constexpr int kGroup4RandomLength = 4;
constexpr int kGroup5RandomLength = 12;
constexpr int kMillSecondLength = 3;
// The size of the buffer for sending and receiving data is 4096 bytes.
constexpr int kMessageChunkLength = 4096;
// The timeout period for the http client to connect to the http server is 120 seconds.
constexpr int kConnectionTimeout = 120;
constexpr char kLibeventLogPrefix[] = "[libevent log]:";
constexpr char kFailureEvent[] = "failureEvent";
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
const std::vector<std::string> kClusterState = {
@ -113,6 +114,16 @@ const std::map<std::string, ClusterState> kClusterStateMap = {
{"CLUSTER_SCHEDULER_RECOVERY", ClusterState::CLUSTER_SCHEDULER_RECOVERY},
{"CLUSTER_SCALE_OUT_ROLLBACK", ClusterState::CLUSTER_SCALE_OUT_ROLLBACK}};
struct Time {
uint64_t time_stamp;
std::string time_str_mill;
};
struct FileConfig {
uint32_t storage_type;
std::string storage_file_path;
};
class CommUtil {
public:
static bool CheckIpWithRegex(const std::string &ip);
@ -159,6 +170,16 @@ class CommUtil {
static bool CreateDirectory(const std::string &directoryPath);
static bool CheckHttpUrl(const std::string &http_url);
static bool IsFileReadable(const std::string &file);
template <typename T>
static T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
if (!json.contains(key)) {
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
}
return json[key].get<T>();
}
static Time GetNowTime();
static bool ParseAndCheckConfigJson(Configuration *file_configuration, const std::string &key,
FileConfig *file_config);
private:
static std::random_device rd;

View File

@ -32,7 +32,8 @@ enum class Command {
SEND_DATA = 3,
FETCH_METADATA = 4,
FINISH = 5,
COLLECTIVE_SEND_DATA = 6
COLLECTIVE_SEND_DATA = 6,
FAILURE_EVENT = 7
};
enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 };

View File

@ -65,6 +65,8 @@ enum NodeCommand {
QUERY_FINISH_TRANSFORM = 23;
// This command is used to start scale out rollback
SCALE_OUT_ROLLBACK = 24;
// Record the failure information, such as node restart
FAILURE_EVENT_INFO = 25;
}
enum NodeRole {
@ -142,6 +144,14 @@ message HeartbeatMessage {
uint32 port = 5;
}
message FailureEventMessage {
string node_role = 1;
string ip = 2;
uint32 port = 3;
string time = 4;
string event = 5;
}
enum NodeState {
NODE_STARTING = 0;
NODE_FINISH = 1;

View File

@ -87,6 +87,7 @@ message DeviceMeta {
string fl_name = 1;
string fl_id = 2;
uint64 data_size = 3;
uint64 now_time = 4;
}
message FLIdToDeviceMeta {
@ -257,6 +258,9 @@ message EndLastIterResponse {
uint64 getModel_accept_client_num = 9;
uint64 getModel_reject_client_num = 10;
float upload_loss = 11;
uint64 participation_time_level1_num = 12;
uint64 participation_time_level2_num = 13;
uint64 participation_time_level3_num = 14;
}
message SyncAfterRecover {

View File

@ -94,6 +94,10 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
void RecoverFromPersistence() override;
void InitEventTxtFile() override {}
void RecordSchedulerRestartInfo() override {}
// Record received host hash name from workers or servers.
std::map<NodeRole, std::vector<size_t>> host_hash_names_;
// Record rank id of the nodes which sended host name.

View File

@ -15,6 +15,11 @@
*/
#include "ps/core/scheduler_node.h"
#include <string>
#include "include/common/debug/common.h"
#include "fl/server/common.h"
#include "ps/core/scheduler_recovery.h"
namespace mindspore {
@ -32,10 +37,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
InitNodeMetaData();
bool is_recover = false;
if (!config_->Initialize()) {
MS_LOG(WARNING) << "The config file is empty.";
} else {
if (!RecoverScheduler()) {
InitEventTxtFile();
is_recover = RecoverScheduler();
if (!is_recover) {
MS_LOG(DEBUG) << "Recover the server node is failed.";
}
}
@ -50,6 +58,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
StartUpdateClusterStateTimer();
RunRecovery();
if (is_recover) {
RecordSchedulerRestartInfo();
}
if (is_worker_timeout_) {
BroadcastTimeoutEvent();
}
@ -258,6 +270,7 @@ void SchedulerNode::InitCommandHandler() {
handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone;
handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone;
handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
handlers_[NodeCommand::FAILURE_EVENT_INFO] = &SchedulerNode::ProcessFailureEvent;
RegisterActorRouteTableServiceHandler();
RegisterInitCollectCommServiceHandler();
RegisterRecoveryServiceHandler();
@ -269,6 +282,26 @@ void SchedulerNode::RegisterActorRouteTableServiceHandler() {
handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &SchedulerNode::ProcessLookupActorRoute;
}
void SchedulerNode::InitEventTxtFile() {
MS_LOG(DEBUG) << "Start init event txt";
ps::core::FileConfig event_file_config;
if (!ps::core::CommUtil::ParseAndCheckConfigJson(config_.get(), kFailureEvent, &event_file_config)) {
MS_LOG(EXCEPTION) << "Parse and checkout config json failed";
return;
}
std::lock_guard<std::mutex> lock(event_txt_file_mtx_);
event_file_path_ = event_file_config.storage_file_path;
auto realpath = Common::CreatePrefixPath(event_file_path_.c_str());
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Creating path for " << event_file_path_ << " failed.";
return;
}
event_txt_file_.open(realpath.value(), std::ios::out | std::ios::app);
event_txt_file_.close();
MS_LOG(DEBUG) << "Init event txt success!";
}
void SchedulerNode::CreateTcpServer() {
node_manager_.InitNode();
@ -628,6 +661,34 @@ void SchedulerNode::ProcessLookupActorRoute(const std::shared_ptr<TcpServer> &se
}
}
void SchedulerNode::ProcessFailureEvent(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);
FailureEventMessage failure_event_message;
failure_event_message.ParseFromArray(data, SizeToInt(size));
std::string node_role = failure_event_message.node_role();
std::string ip = failure_event_message.ip();
uint32_t port = failure_event_message.port();
std::string time = failure_event_message.time();
std::string event = failure_event_message.event();
std::lock_guard<std::mutex> lock(event_txt_file_mtx_);
event_txt_file_.open(event_file_path_, std::ios::out | std::ios::app);
if (!event_txt_file_.is_open()) {
MS_LOG(EXCEPTION) << "The event txt file is not open";
return;
}
std::string event_info = "nodeRole:" + node_role + "," + ip + ":" + std::to_string(port) + "," +
"currentTime:" + time + "," + "event:" + event + ";";
event_txt_file_ << event_info << "\n";
(void)event_txt_file_.flush();
event_txt_file_.close();
MS_LOG(INFO) << "Process failure event success!";
}
bool SchedulerNode::SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) {
uint64_t request_id = AddMessageTrack(node_infos.size());
for (const auto &kvs : node_infos) {
@ -1598,7 +1659,7 @@ bool SchedulerNode::RecoverScheduler() {
MS_EXCEPTION_IF_NULL(config_);
if (config_->Exists(kKeyRecovery)) {
MS_LOG(INFO) << "The scheduler node is support recovery.";
scheduler_recovery_ = std::make_unique<SchedulerRecovery>();
scheduler_recovery_ = std::make_shared<SchedulerRecovery>();
MS_EXCEPTION_IF_NULL(scheduler_recovery_);
bool ret = scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
bool ret_node = scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, ""));
@ -1610,6 +1671,29 @@ bool SchedulerNode::RecoverScheduler() {
return false;
}
void SchedulerNode::RecordSchedulerRestartInfo() {
MS_LOG(DEBUG) << "Start to record scheduler restart error message";
std::lock_guard<std::mutex> lock(event_txt_file_mtx_);
event_txt_file_.open(event_file_path_, std::ios::out | std::ios::app);
if (!event_txt_file_.is_open()) {
MS_LOG(EXCEPTION) << "The event txt file is not open";
return;
}
std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_);
auto scheduler_recovery_ptr = std::dynamic_pointer_cast<SchedulerRecovery>(scheduler_recovery_);
MS_EXCEPTION_IF_NULL(scheduler_recovery_ptr);
auto ip = scheduler_recovery_ptr->GetMetadata(kRecoverySchedulerIp);
auto port = scheduler_recovery_ptr->GetMetadata(kRecoverySchedulerPort);
std::string time = ps::core::CommUtil::GetNowTime().time_str_mill;
std::string event = "Node restart";
std::string event_info =
"nodeRole:" + node_role + "," + ip + ":" + port + "," + "currentTime:" + time + "," + "event:" + event + ";";
event_txt_file_ << event_info << "\n";
(void)event_txt_file_.flush();
event_txt_file_.close();
MS_LOG(DEBUG) << "Record scheduler node restart info " << event_info << " success!";
}
void SchedulerNode::PersistMetaData() {
if (scheduler_recovery_ == nullptr) {
MS_LOG(WARNING) << "scheduler recovery is null, do not persist meta data";

View File

@ -88,6 +88,7 @@ class BACKEND_EXPORT SchedulerNode : public Node {
// Register and initialize the actor route table service.
void RegisterActorRouteTableServiceHandler();
void InitializeActorRouteTableService();
virtual void InitEventTxtFile();
// Register collective communication initialization service.
virtual void RegisterInitCollectCommServiceHandler() {}
@ -138,6 +139,10 @@ class BACKEND_EXPORT SchedulerNode : public Node {
void ProcessLookupActorRoute(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Process failure event message from other nodes.
void ProcessFailureEvent(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Determine whether the registration request of the node should be rejected, the registration of the
// alive node should be rejected.
virtual bool NeedRejectRegister(const NodeInfo &node_info) { return false; }
@ -201,6 +206,9 @@ class BACKEND_EXPORT SchedulerNode : public Node {
bool RecoverScheduler();
// Write scheduler restart error message
virtual void RecordSchedulerRestartInfo();
void PersistMetaData();
bool CheckIfNodeDisconnected() const;
@ -245,7 +253,7 @@ class BACKEND_EXPORT SchedulerNode : public Node {
std::unordered_map<std::string, OnRequestReceive> callbacks_;
// Used to persist and obtain metadata information for scheduler.
std::unique_ptr<RecoveryBase> scheduler_recovery_;
std::shared_ptr<RecoveryBase> scheduler_recovery_;
// persistent command need to be sent.
std::atomic<PersistentCommand> persistent_cmd_;
@ -259,6 +267,15 @@ class BACKEND_EXPORT SchedulerNode : public Node {
std::unordered_map<int, std::string> register_connection_fd_;
std::unique_ptr<ActorRouteTableService> actor_route_table_service_;
// The event txt file path
std::string event_file_path_;
// The mutex for event txt event_file_path_
std::mutex event_txt_file_mtx_;
// The fstream for event_file_path_
std::fstream event_txt_file_;
};
} // namespace core
} // namespace ps

View File

@ -42,6 +42,7 @@ void ServerNode::Initialize() {
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
InitNodeNum();
bool is_recover = false;
if (!config_->Initialize()) {
MS_LOG(WARNING) << "The config file is empty.";
} else {
@ -62,6 +63,10 @@ void ServerNode::Initialize() {
}
InitClientToServer();
is_already_stopped_ = false;
if (is_recover) {
std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_);
SendFailMessageToScheduler(node_role, "Node restart");
}
MS_LOG(INFO) << "[Server start]: 3. Server node crete tcp client to scheduler successful!";
}

View File

@ -41,10 +41,12 @@ void WorkerNode::Initialize() {
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
InitNodeNum();
bool is_recover = false;
if (!config_->Initialize()) {
MS_LOG(WARNING) << "The config file is empty.";
} else {
if (!Recover()) {
is_recover = Recover();
if (!is_recover) {
MS_LOG(DEBUG) << "Recover the worker node is failed.";
}
}
@ -60,6 +62,10 @@ void WorkerNode::Initialize() {
}
InitClientToServer();
is_already_stopped_ = false;
if (is_recover) {
std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_);
SendFailMessageToScheduler(node_role, "Node restart");
}
MS_LOG(INFO) << "[Worker start]: 3. Worker node crete tcp client to scheduler successful!";
}

View File

@ -15,13 +15,14 @@
*/
#include "ps/ps_context.h"
#include "kernel/kernel.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "kernel/kernel.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "distributed/cluster/cluster_context.h"
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "distributed/cluster/cluster_context.h"
#else
#include "distributed/cluster/dummy_cluster_context.h"
#endif
@ -566,5 +567,21 @@ std::string PSContext::download_compress_type() const { return download_compress
std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; }
void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; }
void PSContext::set_instance_name(const std::string &instance_name) { instance_name_ = instance_name; }
const std::string &PSContext::instance_name() const { return instance_name_; }
void PSContext::set_participation_time_level(const std::string &participation_time_level) {
participation_time_level_ = participation_time_level;
}
const std::string &PSContext::participation_time_level() { return participation_time_level_; }
void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times) {
continuous_failure_times_ = continuous_failure_times;
}
uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; }
} // namespace ps
} // namespace mindspore

View File

@ -244,6 +244,15 @@ class BACKEND_EXPORT PSContext {
std::string checkpoint_dir() const;
void set_checkpoint_dir(const std::string &checkpoint_dir);
void set_instance_name(const std::string &instance_name);
const std::string &instance_name() const;
void set_participation_time_level(const std::string &participation_time_level);
const std::string &participation_time_level();
void set_continuous_failure_times(uint32_t continuous_failure_times);
uint32_t continuous_failure_times();
private:
PSContext()
: ps_enabled_(false),
@ -300,7 +309,10 @@ class BACKEND_EXPORT PSContext {
upload_compress_type_(kNoCompressType),
upload_sparse_rate_(0.4f),
download_compress_type_(kNoCompressType),
checkpoint_dir_("") {}
checkpoint_dir_(""),
instance_name_(""),
participation_time_level_("5,15"),
continuous_failure_times_(10) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@ -442,6 +454,15 @@ class BACKEND_EXPORT PSContext {
// directory of server checkpoint
std::string checkpoint_dir_;
// The name of instance
std::string instance_name_;
// The participation time level
std::string participation_time_level_;
// The times of iteration continuous failure
uint32_t continuous_failure_times_;
};
} // namespace ps
} // namespace mindspore

View File

@ -83,6 +83,9 @@ _set_ps_context_func_map = {
"upload_compress_type": ps_context().set_upload_compress_type,
"upload_sparse_rate": ps_context().set_upload_sparse_rate,
"download_compress_type": ps_context().set_download_compress_type,
"instance_name": ps_context().set_instance_name,
"participation_time_level": ps_context().set_participation_time_level,
"continuous_failure_times": ps_context().set_continuous_failure_times,
}
_get_ps_context_func_map = {
@ -134,6 +137,9 @@ _get_ps_context_func_map = {
"upload_compress_type": ps_context().upload_compress_type,
"upload_sparse_rate": ps_context().upload_sparse_rate,
"download_compress_type": ps_context().download_compress_type,
"instance_name": ps_context().instance_name,
"participation_time_level": ps_context().participation_time_level,
"continuous_failure_times": ps_context().continuous_failure_times,
}
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",

View File

@ -178,10 +178,6 @@ def call_get_instance_detail():
return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
ans_json_obj = {}
metrics_auc_list = []
metrics_loss_list = []
iteration_execution_time_list = []
client_visited_info_list = []
with open(metrics_file_path, 'r') as f:
metrics_list = f.readlines()
@ -189,28 +185,13 @@ def call_get_instance_detail():
if not metrics_list:
return process_self_define_json(Status.FAILED.value, "error. metrics file has no content")
for metrics in metrics_list:
json_obj = json.loads(metrics)
iteration_execution_time_list.append(json_obj['iterationExecutionTime'])
client_visited_info_list.append(json_obj['clientVisitedInfo'])
metrics_auc_list.append(json_obj['metricsAuc'])
metrics_loss_list.append(json_obj['metricsLoss'])
last_metrics = metrics_list[len(metrics_list) - 1]
last_metrics_obj = json.loads(last_metrics)
ans_json_obj["code"] = Status.SUCCESS.value
ans_json_obj["describe"] = "get instance metrics detail successful."
ans_json_obj["result"] = {}
ans_json_result = ans_json_obj.get("result")
ans_json_result['currentIteration'] = last_metrics_obj['currentIteration']
ans_json_result['flIterationNum'] = last_metrics_obj['flIterationNum']
ans_json_result['flName'] = last_metrics_obj['flName']
ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus']
ans_json_result['iterationExecutionTime'] = iteration_execution_time_list
ans_json_result['clientVisitedInfo'] = client_visited_info_list
ans_json_result['metricsAuc'] = metrics_auc_list
ans_json_result['metricsLoss'] = metrics_loss_list
ans_json_obj["result"] = last_metrics_obj
return json.dumps(ans_json_obj)

View File

@ -15,6 +15,14 @@
"storage_type": 1,
"storage_file_path": "metrics.json"
},
"dataRate": {
"storage_type": 1,
"storage_file_path": ".."
},
"failureEvent": {
"storage_type": 1,
"storage_file_path": "event.txt"
},
"server_recovery": {
"storage_type": 1,
"storage_file_path": "../server_recovery.json"

View File

@ -15,6 +15,14 @@
"storage_type": 1,
"storage_file_path": "metrics.json"
},
"dataRate": {
"storage_type": 1,
"storage_file_path": ".."
},
"failureEvent": {
"storage_type": 1,
"storage_file_path": "event.txt"
},
"server_recovery": {
"storage_type": 1,
"storage_file_path": "../server_recovery.json"

View File

@ -15,6 +15,14 @@
"storage_type": 1,
"storage_file_path": "metrics.json"
},
"dataRate": {
"storage_type": 1,
"storage_file_path": ".."
},
"failureEvent": {
"storage_type": 1,
"storage_file_path": "event.txt"
},
"server_recovery": {
"storage_type": 1,
"storage_file_path": "../server_recovery.json"

View File

@ -15,6 +15,14 @@
"storage_type": 1,
"storage_file_path": "metrics.json"
},
"dataRate": {
"storage_type": 1,
"storage_file_path": ".."
},
"failureEvent": {
"storage_type": 1,
"storage_file_path": "event.txt"
},
"server_recovery": {
"storage_type": 1,
"storage_file_path": "../server_recovery.json"

View File

@ -15,6 +15,14 @@
"storage_type": 1,
"storage_file_path": "metrics.json"
},
"dataRate": {
"storage_type": 1,
"storage_file_path": ".."
},
"failureEvent": {
"storage_type": 1,
"storage_file_path": "event.txt"
},
"server_recovery": {
"storage_type": 1,
"storage_file_path": "../server_recovery.json"

View File

@ -15,6 +15,14 @@
"storage_type": 1,
"storage_file_path": "metrics.json"
},
"dataRate": {
"storage_type": 1,
"storage_file_path": ".."
},
"failureEvent": {
"storage_type": 1,
"storage_file_path": "event.txt"
},
"server_recovery": {
"storage_type": 1,
"storage_file_path": "../server_recovery.json"

View File

@ -15,6 +15,14 @@
"storage_type": 1,
"storage_file_path": "metrics.json"
},
"dataRate": {
"storage_type": 1,
"storage_file_path": ".."
},
"failureEvent": {
"storage_type": 1,
"storage_file_path": "event.txt"
},
"server_recovery": {
"storage_type": 1,
"storage_file_path": "../server_recovery.json"