forked from mindspore-Ecosystem/mindspore
The maintainability function of FL
This commit is contained in:
parent
123e1b97de
commit
51f6b77ab0
|
@ -23,6 +23,7 @@
|
|||
#include <climits>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include "proto/ps.pb.h"
|
||||
#include "proto/fl.pb.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -149,11 +150,15 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum";
|
|||
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
|
||||
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
|
||||
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
|
||||
constexpr auto kParticipationTimeLevel1 = "participationTimeLevel1";
|
||||
constexpr auto kParticipationTimeLevel2 = "participationTimeLevel2";
|
||||
constexpr auto kParticipationTimeLevel3 = "participationTimeLevel3";
|
||||
constexpr auto kMinVal = "min_val";
|
||||
constexpr auto kMaxVal = "max_val";
|
||||
constexpr auto kQuant = "QUANT";
|
||||
constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT";
|
||||
constexpr auto kNoCompress = "NO_COMPRESS";
|
||||
constexpr auto kUpdateModel = "updateModel";
|
||||
|
||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||
// launched.
|
||||
|
|
|
@ -15,27 +15,40 @@
|
|||
*/
|
||||
|
||||
#include "fl/server/iteration.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "fl/server/model_store.h"
|
||||
#include "fl/server/server.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace {
|
||||
const size_t kParticipationTimeLevelNum = 3;
|
||||
const size_t kIndexZero = 0;
|
||||
const size_t kIndexOne = 1;
|
||||
const size_t kIndexTwo = 2;
|
||||
const size_t kLastSecond = 59;
|
||||
} // namespace
|
||||
class Server;
|
||||
|
||||
Iteration::~Iteration() {
|
||||
move_to_next_thread_running_ = false;
|
||||
is_date_rate_thread_running_ = false;
|
||||
next_iteration_cv_.notify_all();
|
||||
if (move_to_next_thread_.joinable()) {
|
||||
move_to_next_thread_.join();
|
||||
}
|
||||
if (data_rate_thread_.joinable()) {
|
||||
data_rate_thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
|
@ -160,8 +173,13 @@ void Iteration::SetIterationRunning() {
|
|||
|
||||
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
|
||||
iteration_state_ = IterationState::kRunning;
|
||||
start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
|
||||
start_time_ = ps::core::CommUtil::GetNowTime();
|
||||
MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " start global timer.";
|
||||
instance_name_ = ps::PSContext::instance()->instance_name();
|
||||
if (instance_name_.empty()) {
|
||||
MS_LOG(WARNING) << "instance name is empty";
|
||||
instance_name_ = "instance_" + start_time_.time_str_mill;
|
||||
}
|
||||
global_iter_timer_->Start(std::chrono::milliseconds(global_iteration_time_window_));
|
||||
}
|
||||
|
||||
|
@ -175,7 +193,7 @@ void Iteration::SetIterationEnd() {
|
|||
|
||||
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
|
||||
iteration_state_ = IterationState::kCompleted;
|
||||
complete_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
|
||||
complete_time_ = ps::core::CommUtil::GetNowTime();
|
||||
}
|
||||
|
||||
void Iteration::ScalingBarrier() {
|
||||
|
@ -631,6 +649,13 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
round_client_num_map_[kUpdateModelAcceptClientNum] += round->kernel_accept_client_num();
|
||||
round_client_num_map_[kUpdateModelRejectClientNum] += round->kernel_reject_client_num();
|
||||
set_loss(loss_ + round->kernel_upload_loss());
|
||||
auto update_model_complete_info = round->GetUpdateModelCompleteInfo();
|
||||
if (update_model_complete_info.size() != kParticipationTimeLevelNum) {
|
||||
continue;
|
||||
}
|
||||
round_client_num_map_[kParticipationTimeLevel1] += update_model_complete_info[kIndexZero].second;
|
||||
round_client_num_map_[kParticipationTimeLevel2] += update_model_complete_info[kIndexOne].second;
|
||||
round_client_num_map_[kParticipationTimeLevel3] += update_model_complete_info[kIndexTwo].second;
|
||||
} else if (round->name() == "getModel") {
|
||||
round_client_num_map_[kGetModelTotalClientNum] += round->kernel_total_client_num();
|
||||
round_client_num_map_[kGetModelAcceptClientNum] += round->kernel_accept_client_num();
|
||||
|
@ -656,6 +681,13 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
|||
}
|
||||
|
||||
EndLastIter();
|
||||
if (iteration_fail_num_ == ps::PSContext::instance()->continuous_failure_times()) {
|
||||
std::string node_role = "SERVER";
|
||||
std::string event = "Iteration failed " + std::to_string(iteration_fail_num_) + " times continuously";
|
||||
server_node_->SendFailMessageToScheduler(node_role, event);
|
||||
// Finish sending one message, reset cout num to 0
|
||||
iteration_fail_num_ = 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -695,6 +727,13 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
|
|||
end_last_iter_rsp.set_updatemodel_accept_client_num(round->kernel_accept_client_num());
|
||||
end_last_iter_rsp.set_updatemodel_reject_client_num(round->kernel_reject_client_num());
|
||||
end_last_iter_rsp.set_upload_loss(round->kernel_upload_loss());
|
||||
auto update_model_complete_info = round->GetUpdateModelCompleteInfo();
|
||||
if (update_model_complete_info.size() != kParticipationTimeLevelNum) {
|
||||
MS_LOG(EXCEPTION) << "update_model_complete_info size is not equal 3";
|
||||
}
|
||||
end_last_iter_rsp.set_participation_time_level1_num(update_model_complete_info[kIndexZero].second);
|
||||
end_last_iter_rsp.set_participation_time_level2_num(update_model_complete_info[kIndexOne].second);
|
||||
end_last_iter_rsp.set_participation_time_level3_num(update_model_complete_info[kIndexTwo].second);
|
||||
} else if (round->name() == "getModel") {
|
||||
end_last_iter_rsp.set_getmodel_total_client_num(round->kernel_total_client_num());
|
||||
end_last_iter_rsp.set_getmodel_accept_client_num(round->kernel_accept_client_num());
|
||||
|
@ -744,6 +783,7 @@ void Iteration::EndLastIter() {
|
|||
MS_ERROR_IF_NULL_WO_RET_VAL(round);
|
||||
round->InitkernelClientVisitedNum();
|
||||
round->InitkernelClientUploadLoss();
|
||||
round->ResetParticipationTimeAndNum();
|
||||
}
|
||||
round_client_num_map_.clear();
|
||||
set_loss(0.0f);
|
||||
|
@ -754,6 +794,11 @@ void Iteration::EndLastIter() {
|
|||
} else {
|
||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
if (iteration_result_.load() == IterationResult::kFail) {
|
||||
iteration_fail_num_++;
|
||||
} else {
|
||||
iteration_fail_num_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool Iteration::ForciblyMoveToNextIteration() {
|
||||
|
@ -768,6 +813,9 @@ bool Iteration::SummarizeIteration() {
|
|||
return true;
|
||||
}
|
||||
|
||||
metrics_->set_instance_name(instance_name_);
|
||||
metrics_->set_start_time(start_time_);
|
||||
metrics_->set_end_time(complete_time_);
|
||||
metrics_->set_fl_name(ps::PSContext::instance()->fl_name());
|
||||
metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num());
|
||||
metrics_->set_cur_iteration_num(iteration_num_);
|
||||
|
@ -781,12 +829,12 @@ bool Iteration::SummarizeIteration() {
|
|||
metrics_->set_round_client_num_map(round_client_num_map_);
|
||||
metrics_->set_iteration_result(iteration_result_.load());
|
||||
|
||||
if (complete_timestamp_ < start_timestamp_) {
|
||||
MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_
|
||||
<< ". One of them is invalid.";
|
||||
if (complete_time_.time_stamp < start_time_.time_stamp) {
|
||||
MS_LOG(ERROR) << "The complete_timestamp_: " << complete_time_.time_stamp
|
||||
<< ", start_timestamp_: " << start_time_.time_stamp << ". One of them is invalid.";
|
||||
metrics_->set_iteration_time_cost(UINT64_MAX);
|
||||
} else {
|
||||
metrics_->set_iteration_time_cost(complete_timestamp_ - start_timestamp_);
|
||||
metrics_->set_iteration_time_cost(complete_time_.time_stamp - start_time_.time_stamp);
|
||||
}
|
||||
|
||||
if (!metrics_->Summarize()) {
|
||||
|
@ -910,6 +958,10 @@ void Iteration::UpdateRoundClientNumMap(const std::shared_ptr<std::vector<unsign
|
|||
round_client_num_map_[kGetModelTotalClientNum] += end_last_iter_rsp.getmodel_total_client_num();
|
||||
round_client_num_map_[kGetModelAcceptClientNum] += end_last_iter_rsp.getmodel_accept_client_num();
|
||||
round_client_num_map_[kGetModelRejectClientNum] += end_last_iter_rsp.getmodel_reject_client_num();
|
||||
|
||||
round_client_num_map_[kParticipationTimeLevel1] += end_last_iter_rsp.participation_time_level1_num();
|
||||
round_client_num_map_[kParticipationTimeLevel2] += end_last_iter_rsp.participation_time_level2_num();
|
||||
round_client_num_map_[kParticipationTimeLevel3] += end_last_iter_rsp.participation_time_level3_num();
|
||||
}
|
||||
|
||||
void Iteration::UpdateRoundClientUploadLoss(const std::shared_ptr<std::vector<unsigned char>> &client_info_rsp_msg) {
|
||||
|
@ -924,6 +976,103 @@ void Iteration::set_instance_state(InstanceState state) {
|
|||
instance_state_ = state;
|
||||
MS_LOG(INFO) << "Server instance state is " << GetInstanceStateStr(instance_state_);
|
||||
}
|
||||
|
||||
void Iteration::SetFileConfig(const std::shared_ptr<ps::core::FileConfiguration> &file_configuration) {
|
||||
file_configuration_ = file_configuration;
|
||||
}
|
||||
|
||||
string Iteration::GetDataRateFilePath() {
|
||||
ps::core::FileConfig data_rate_config;
|
||||
if (!ps::core::CommUtil::ParseAndCheckConfigJson(file_configuration_.get(), kDataRate, &data_rate_config)) {
|
||||
MS_LOG(EXCEPTION) << "Data rate parament in config is not correct";
|
||||
}
|
||||
return data_rate_config.storage_file_path;
|
||||
}
|
||||
|
||||
void Iteration::StartThreadToRecordDataRate() {
|
||||
MS_LOG(INFO) << "Start to create a thread to record data rate";
|
||||
data_rate_thread_ = std::thread([&]() {
|
||||
std::fstream file_stream;
|
||||
std::string data_rate_path = GetDataRateFilePath();
|
||||
MS_LOG(DEBUG) << "The data rate file path is " << data_rate_path;
|
||||
uint32_t rank_id = server_node_->rank_id();
|
||||
while (is_date_rate_thread_running_) {
|
||||
// record data every 60 seconds
|
||||
std::this_thread::sleep_for(std::chrono::seconds(60));
|
||||
auto time_now = std::chrono::system_clock::now();
|
||||
std::time_t tt = std::chrono::system_clock::to_time_t(time_now);
|
||||
struct tm ptm;
|
||||
(void)localtime_r(&tt, &ptm);
|
||||
std::ostringstream time_day_oss;
|
||||
time_day_oss << std::put_time(&ptm, "%Y-%m-%d");
|
||||
std::string time_day = time_day_oss.str();
|
||||
std::string data_rate_file = data_rate_path + "/" + time_day + "_flow_server" + std::to_string(rank_id) + ".json";
|
||||
file_stream.open(data_rate_file, std::ios::out | std::ios::app);
|
||||
if (!file_stream.is_open()) {
|
||||
MS_LOG(WARNING) << data_rate_file << "is not open! Please check config file!";
|
||||
return;
|
||||
}
|
||||
std::map<uint64_t, size_t> send_datas;
|
||||
std::map<uint64_t, size_t> receive_datas;
|
||||
for (const auto &round : rounds_) {
|
||||
if (round == nullptr) {
|
||||
MS_LOG(WARNING) << "round is nullptr";
|
||||
continue;
|
||||
}
|
||||
auto send_data = round->GetSendData();
|
||||
for (const auto &it : send_data) {
|
||||
if (send_datas.find(it.first) != send_datas.end()) {
|
||||
send_datas[it.first] = send_datas[it.first] + it.second;
|
||||
} else {
|
||||
send_datas.emplace(it);
|
||||
}
|
||||
}
|
||||
auto receive_data = round->GetReceiveData();
|
||||
for (const auto &it : receive_data) {
|
||||
if (receive_datas.find(it.first) != receive_datas.end()) {
|
||||
receive_datas[it.first] = receive_datas[it.first] + it.second;
|
||||
} else {
|
||||
receive_datas.emplace(it);
|
||||
}
|
||||
}
|
||||
round->ClearData();
|
||||
}
|
||||
|
||||
std::map<uint64_t, std::vector<size_t>> all_datas;
|
||||
for (auto &it : send_datas) {
|
||||
std::vector<size_t> send_and_receive_data;
|
||||
send_and_receive_data.emplace_back(it.second);
|
||||
send_and_receive_data.emplace_back(0);
|
||||
all_datas.emplace(it.first, send_and_receive_data);
|
||||
}
|
||||
for (auto &it : receive_datas) {
|
||||
if (all_datas.find(it.first) != all_datas.end()) {
|
||||
std::vector<size_t> &temp = all_datas.at(it.first);
|
||||
temp[1] = it.second;
|
||||
} else {
|
||||
std::vector<size_t> send_and_receive_data;
|
||||
send_and_receive_data.emplace_back(0);
|
||||
send_and_receive_data.emplace_back(it.second);
|
||||
all_datas.emplace(it.first, send_and_receive_data);
|
||||
}
|
||||
}
|
||||
for (auto &it : all_datas) {
|
||||
nlohmann::json js;
|
||||
auto data_time = static_cast<time_t>(it.first);
|
||||
struct tm data_tm;
|
||||
(void)localtime_r(&data_time, &data_tm);
|
||||
std::ostringstream oss_second;
|
||||
oss_second << std::put_time(&data_tm, "%Y-%m-%d %H:%M:%S");
|
||||
js["time"] = oss_second.str();
|
||||
js["send"] = it.second[0];
|
||||
js["receive"] = it.second[1];
|
||||
file_stream << js << "\n";
|
||||
}
|
||||
(void)file_stream.close();
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <map>
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "ps/core/file_configuration.h"
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/round.h"
|
||||
#include "fl/server/local_meta_store.h"
|
||||
|
@ -134,14 +135,18 @@ class Iteration {
|
|||
|
||||
void set_instance_state(InstanceState staet);
|
||||
|
||||
// Create a thread to record date rate
|
||||
void StartThreadToRecordDataRate();
|
||||
|
||||
// Set file_configuration
|
||||
void SetFileConfig(const std::shared_ptr<ps::core::FileConfiguration> &file_configuration);
|
||||
|
||||
private:
|
||||
Iteration()
|
||||
: running_round_num_(0),
|
||||
server_node_(nullptr),
|
||||
communicator_(nullptr),
|
||||
iteration_state_(IterationState::kCompleted),
|
||||
start_timestamp_(0),
|
||||
complete_timestamp_(0),
|
||||
iteration_loop_count_(0),
|
||||
iteration_num_(1),
|
||||
is_last_iteration_valid_(true),
|
||||
|
@ -164,7 +169,11 @@ class Iteration {
|
|||
{kStartFLJobRejectClientNum, 0},
|
||||
{kUpdateModelRejectClientNum, 0},
|
||||
{kGetModelRejectClientNum, 0}}),
|
||||
iteration_result_(IterationResult::kSuccess) {
|
||||
iteration_result_(IterationResult::kSuccess),
|
||||
iteration_fail_num_(0),
|
||||
is_date_rate_thread_running_(true),
|
||||
file_configuration_(nullptr),
|
||||
instance_name_("") {
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
}
|
||||
~Iteration();
|
||||
|
@ -223,6 +232,8 @@ class Iteration {
|
|||
|
||||
void StartNewInstance();
|
||||
|
||||
std::string GetDataRateFilePath();
|
||||
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||
|
||||
|
@ -236,8 +247,12 @@ class Iteration {
|
|||
std::mutex iteration_state_mtx_;
|
||||
std::condition_variable iteration_state_cv_;
|
||||
std::atomic<IterationState> iteration_state_;
|
||||
uint64_t start_timestamp_;
|
||||
uint64_t complete_timestamp_;
|
||||
|
||||
// Iteration start time
|
||||
ps::core::Time start_time_;
|
||||
|
||||
// Iteration complete time
|
||||
ps::core::Time complete_time_;
|
||||
|
||||
// The count of iteration loops which are completed.
|
||||
size_t iteration_loop_count_;
|
||||
|
@ -295,6 +310,21 @@ class Iteration {
|
|||
|
||||
// mutex for iter move to next, avoid core dump
|
||||
std::mutex iter_move_mtx_;
|
||||
|
||||
// The number of iteration continuous failure
|
||||
uint32_t iteration_fail_num_;
|
||||
|
||||
// The thread to record data rate
|
||||
std::thread data_rate_thread_;
|
||||
|
||||
// The state of data rate thread
|
||||
std::atomic_bool is_date_rate_thread_running_;
|
||||
|
||||
// The ptr of config file
|
||||
std::shared_ptr<ps::core::FileConfiguration> file_configuration_;
|
||||
|
||||
// The instance name
|
||||
std::string instance_name_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -15,11 +15,13 @@
|
|||
*/
|
||||
|
||||
#include "fl/server/iteration_metrics.h"
|
||||
#include <string>
|
||||
|
||||
#include <fstream>
|
||||
#include "utils/file_utils.h"
|
||||
#include <string>
|
||||
|
||||
#include "include/common/debug/common.h"
|
||||
#include "ps/constants.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -28,42 +30,23 @@ bool IterationMetrics::Initialize() {
|
|||
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
|
||||
MS_LOG(EXCEPTION) << "Initializing for Config file path failed!" << config_file_path_
|
||||
<< " may be invalid or not exist.";
|
||||
}
|
||||
|
||||
// Read the metrics file path. If file is not set or not exits, create one.
|
||||
if (!config_->Exists(kMetrics)) {
|
||||
MS_LOG(WARNING) << "Metrics config is not set. Don't write metrics.";
|
||||
return false;
|
||||
} else {
|
||||
std::string value = config_->Get(kMetrics, "");
|
||||
nlohmann::json value_json;
|
||||
try {
|
||||
value_json = nlohmann::json::parse(value);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parse the storage type.
|
||||
uint32_t storage_type = JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
|
||||
if (std::to_string(storage_type) != ps::kFileStorage) {
|
||||
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parse storage file path.
|
||||
metrics_file_path_ = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
|
||||
auto realpath = Common::CreatePrefixPath(metrics_file_path_.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Creating path for " << metrics_file_path_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
metrics_file_.open(realpath.value(), std::ios::app | std::ios::out);
|
||||
metrics_file_.close();
|
||||
}
|
||||
ps::core::FileConfig metrics_config;
|
||||
if (!ps::core::CommUtil::ParseAndCheckConfigJson(config_.get(), kMetrics, &metrics_config)) {
|
||||
MS_LOG(WARNING) << "Metrics parament in config is not correct";
|
||||
return false;
|
||||
}
|
||||
metrics_file_path_ = metrics_config.storage_file_path;
|
||||
auto realpath = Common::CreatePrefixPath(metrics_file_path_.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Creating path for " << metrics_file_path_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
metrics_file_.open(realpath.value(), std::ios::app | std::ios::out);
|
||||
metrics_file_.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -74,6 +57,9 @@ bool IterationMetrics::Summarize() {
|
|||
return false;
|
||||
}
|
||||
|
||||
js_[kInstanceName] = instance_name_;
|
||||
js_[kStartTime] = start_time_.time_str_mill;
|
||||
js_[kEndTime] = end_time_.time_str_mill;
|
||||
js_[kFLName] = fl_name_;
|
||||
js_[kInstanceStatus] = kInstanceStateName.at(instance_state_);
|
||||
js_[kFLIterationNum] = fl_iteration_num_;
|
||||
|
@ -120,6 +106,12 @@ void IterationMetrics::set_round_client_num_map(const std::map<std::string, size
|
|||
}
|
||||
|
||||
void IterationMetrics::set_iteration_result(IterationResult iteration_result) { iteration_result_ = iteration_result; }
|
||||
|
||||
void IterationMetrics::set_start_time(const ps::core::Time &start_time) { start_time_ = start_time; }
|
||||
|
||||
void IterationMetrics::set_end_time(const ps::core::Time &end_time) { end_time_ = end_time; }
|
||||
|
||||
void IterationMetrics::set_instance_name(const std::string &instance_name) { instance_name_ = instance_name; }
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,10 +26,12 @@
|
|||
#include "ps/core/file_configuration.h"
|
||||
#include "fl/server/local_meta_store.h"
|
||||
#include "fl/server/iteration.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
constexpr auto kInstanceName = "instanceName";
|
||||
constexpr auto kFLName = "flName";
|
||||
constexpr auto kInstanceStatus = "instanceStatus";
|
||||
constexpr auto kFLIterationNum = "flIterationNum";
|
||||
|
@ -40,6 +42,9 @@ constexpr auto kIterExecutionTime = "iterationExecutionTime";
|
|||
constexpr auto kMetrics = "metrics";
|
||||
constexpr auto kClientVisitedInfo = "clientVisitedInfo";
|
||||
constexpr auto kIterationResult = "iterationResult";
|
||||
constexpr auto kStartTime = "startTime";
|
||||
constexpr auto kEndTime = "endTime";
|
||||
constexpr auto kDataRate = "dataRate";
|
||||
|
||||
const std::map<InstanceState, std::string> kInstanceStateName = {
|
||||
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
|
||||
|
@ -80,6 +85,9 @@ class IterationMetrics {
|
|||
void set_iteration_time_cost(uint64_t iteration_time_cost);
|
||||
void set_round_client_num_map(const std::map<std::string, size_t> round_client_num_map);
|
||||
void set_iteration_result(IterationResult iteration_result);
|
||||
void set_start_time(const ps::core::Time &start_time);
|
||||
void set_end_time(const ps::core::Time &end_time);
|
||||
void set_instance_name(const std::string &instance_name);
|
||||
|
||||
private:
|
||||
// This is the main config file set by ps context.
|
||||
|
@ -122,6 +130,11 @@ class IterationMetrics {
|
|||
|
||||
// Current iteration running result.
|
||||
IterationResult iteration_result_;
|
||||
|
||||
ps::core::Time start_time_;
|
||||
ps::core::Time end_time_;
|
||||
|
||||
std::string instance_name_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -15,13 +15,15 @@
|
|||
*/
|
||||
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "fl/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -65,6 +67,7 @@ void RoundKernel::SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler
|
|||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
CalculateSendData(len);
|
||||
}
|
||||
|
||||
void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
|
||||
|
@ -77,6 +80,7 @@ void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::Messa
|
|||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
CalculateSendData(len);
|
||||
}
|
||||
|
||||
bool RoundKernel::verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
|
||||
|
@ -128,6 +132,67 @@ void RoundKernel::InitClientUploadLoss() { upload_loss_ = 0.0f; }
|
|||
void RoundKernel::UpdateClientUploadLoss(const float upload_loss) { upload_loss_ = upload_loss_ + upload_loss; }
|
||||
|
||||
float RoundKernel::upload_loss() const { return upload_loss_; }
|
||||
|
||||
void RoundKernel::CalculateSendData(size_t send_len) {
|
||||
uint64_t second_time_stamp =
|
||||
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
if (send_data_time_ == 0) {
|
||||
send_data_ = send_len;
|
||||
send_data_time_ = second_time_stamp;
|
||||
return;
|
||||
}
|
||||
if (second_time_stamp == send_data_time_) {
|
||||
send_data_ += send_len;
|
||||
} else {
|
||||
RecordSendData(send_data_time_, send_data_);
|
||||
send_data_time_ = second_time_stamp;
|
||||
send_data_ = send_len;
|
||||
}
|
||||
}
|
||||
|
||||
void RoundKernel::CalculateReceiveData(size_t receive_len) {
|
||||
uint64_t second_time_stamp =
|
||||
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
if (receive_data_time_ == 0) {
|
||||
receive_data_time_ = second_time_stamp;
|
||||
receive_data_ = receive_len;
|
||||
return;
|
||||
}
|
||||
if (second_time_stamp == receive_data_time_) {
|
||||
receive_data_ += receive_len;
|
||||
} else {
|
||||
RecordReceiveData(receive_data_time_, receive_data_);
|
||||
receive_data_time_ = second_time_stamp;
|
||||
receive_data_ = receive_len;
|
||||
}
|
||||
}
|
||||
|
||||
void RoundKernel::RecordSendData(uint64_t time_stamp_second, size_t send_data) {
|
||||
std::lock_guard<std::mutex> lock(send_data_rate_mutex_);
|
||||
send_data_and_time_.emplace(time_stamp_second, send_data);
|
||||
}
|
||||
|
||||
void RoundKernel::RecordReceiveData(uint64_t time_stamp_second, size_t receive_data) {
|
||||
std::lock_guard<std::mutex> lock(receive_data_rate_mutex_);
|
||||
receive_data_and_time_.emplace(time_stamp_second, receive_data);
|
||||
}
|
||||
|
||||
std::map<uint64_t, size_t> RoundKernel::GetSendData() {
|
||||
std::lock_guard<std::mutex> lock(send_data_rate_mutex_);
|
||||
return send_data_and_time_;
|
||||
}
|
||||
|
||||
std::map<uint64_t, size_t> RoundKernel::GetReceiveData() {
|
||||
std::lock_guard<std::mutex> lock(receive_data_rate_mutex_);
|
||||
return receive_data_and_time_;
|
||||
}
|
||||
|
||||
void RoundKernel::ClearData() {
|
||||
std::lock_guard<std::mutex> lock(send_data_rate_mutex_);
|
||||
std::lock_guard<std::mutex> lock2(receive_data_rate_mutex_);
|
||||
send_data_and_time_.clear();
|
||||
receive_data_and_time_.clear();
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -17,22 +17,23 @@
|
|||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
|
||||
#include <chrono>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/local_meta_store.h"
|
||||
#include "fl/server/distributed_count_service.h"
|
||||
#include "fl/server/distributed_metadata_store.h"
|
||||
#include "fl/server/local_meta_store.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -102,6 +103,24 @@ class RoundKernel {
|
|||
|
||||
bool verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
|
||||
|
||||
void CalculateSendData(size_t send_len);
|
||||
|
||||
void CalculateReceiveData(size_t receive_len);
|
||||
// Record the size of send data and the time stamp
|
||||
void RecordSendData(uint64_t time_stamp_second, size_t send_data);
|
||||
|
||||
// Record the size of receive data and the time stamp
|
||||
void RecordReceiveData(uint64_t time_stamp_second, size_t receive_data);
|
||||
|
||||
// Get the info of send data
|
||||
std::map<uint64_t, size_t> GetSendData();
|
||||
|
||||
// Get the info of receive data
|
||||
std::map<uint64_t, size_t> GetReceiveData();
|
||||
|
||||
// Clear the send data info
|
||||
void ClearData();
|
||||
|
||||
protected:
|
||||
// Send response to client, and the data can be released after the call.
|
||||
void SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
|
||||
|
@ -127,6 +146,26 @@ class RoundKernel {
|
|||
std::atomic<size_t> accept_client_num_;
|
||||
|
||||
std::atomic<float> upload_loss_;
|
||||
|
||||
// The mutex for send_data_and_time_
|
||||
std::mutex send_data_rate_mutex_;
|
||||
|
||||
// The size of send data ant time
|
||||
std::map<uint64_t, size_t> send_data_and_time_;
|
||||
|
||||
// The mutex for receive_data_and_time_
|
||||
std::mutex receive_data_rate_mutex_;
|
||||
|
||||
// The size of receive data and time
|
||||
std::map<uint64_t, size_t> receive_data_and_time_;
|
||||
|
||||
std::atomic_size_t send_data_ = 0;
|
||||
|
||||
std::atomic_uint64_t send_data_time_ = 0;
|
||||
|
||||
std::atomic_size_t receive_data_ = 0;
|
||||
|
||||
std::atomic_uint64_t receive_data_time_ = 0;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -115,6 +115,9 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
|
|||
}
|
||||
|
||||
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
|
||||
uint64_t start_fl_job_time =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
device_meta.set_now_time(start_fl_job_time);
|
||||
result_code = ReadyForStartFLJob(fbb, device_meta);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
|
|
@ -14,17 +14,26 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "fl/server/kernel/round/update_model_kernel.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "fl/server/kernel/round/update_model_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const size_t kLevelNum = 2;
|
||||
const uint64_t kMaxLevelNum = 2880;
|
||||
const uint64_t kMinLevelNum = 0;
|
||||
const int kBase = 10;
|
||||
const uint64_t kMinuteToSecond = 60;
|
||||
const uint64_t kSecondToMills = 1000;
|
||||
const uint64_t kDefaultLevel1 = 5;
|
||||
const uint64_t kDefaultLevel2 = 15;
|
||||
} // namespace
|
||||
const char *kCountForAggregation = "count_for_aggregation";
|
||||
|
||||
void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
||||
|
@ -49,6 +58,8 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
|||
auto last_cnt_handler = [this](std::shared_ptr<ps::core::MessageHandler>) { RunAggregation(); };
|
||||
DistributedCountService::GetInstance().RegisterCounter(kCountForAggregation, threshold_count,
|
||||
{first_cnt_handler, last_cnt_handler});
|
||||
std::string participation_time_level_str = ps::PSContext::instance()->participation_time_level();
|
||||
CheckAndTransPara(participation_time_level_str);
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::VerifyUpdateModelRequest(const schema::RequestUpdateModel *update_model_req) {
|
||||
|
@ -123,6 +134,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
|
|||
}
|
||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||
IncreaseAcceptClientNum();
|
||||
RecordCompletePeriod(device_meta);
|
||||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
||||
result_code = CountForAggregation(update_model_fl_id);
|
||||
|
@ -146,6 +158,18 @@ bool UpdateModelKernel::Reset() {
|
|||
|
||||
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {}
|
||||
|
||||
const std::vector<std::pair<uint64_t, uint32_t>> &UpdateModelKernel::GetCompletePeriodRecord() {
|
||||
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
|
||||
return participation_time_and_num_;
|
||||
}
|
||||
|
||||
void UpdateModelKernel::ResetParticipationTimeAndNum() {
|
||||
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
|
||||
for (auto &it : participation_time_and_num_) {
|
||||
it.second = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateModelKernel::RunAggregation() {
|
||||
auto is_last_iter_valid = Executor::GetInstance().RunAllWeightAggregation();
|
||||
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
|
@ -619,6 +643,66 @@ void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fb
|
|||
return;
|
||||
}
|
||||
|
||||
void UpdateModelKernel::RecordCompletePeriod(const DeviceMeta &device_meta) {
|
||||
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
|
||||
uint64_t start_fl_job_time = device_meta.now_time();
|
||||
uint64_t update_model_complete_time =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
if (start_fl_job_time >= update_model_complete_time) {
|
||||
MS_LOG(WARNING) << "start_fl_job_time " << start_fl_job_time << " is larger than update_model_complete_time "
|
||||
<< update_model_complete_time;
|
||||
return;
|
||||
}
|
||||
uint64_t cost_time = update_model_complete_time - start_fl_job_time;
|
||||
MS_LOG(DEBUG) << "start_fl_job time is " << start_fl_job_time << " update_model time is "
|
||||
<< update_model_complete_time;
|
||||
for (auto &it : participation_time_and_num_) {
|
||||
if (cost_time < it.first) {
|
||||
it.second++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateModelKernel::CheckAndTransPara(const std::string &participation_time_level) {
|
||||
std::lock_guard<std::mutex> lock(participation_time_and_num_mtx_);
|
||||
// The default time level is 5min and 15min, trans time to millisecond
|
||||
participation_time_and_num_.emplace_back(std::make_pair(kDefaultLevel1 * kMinuteToSecond * kSecondToMills, 0));
|
||||
participation_time_and_num_.emplace_back(std::make_pair(kDefaultLevel2 * kMinuteToSecond * kSecondToMills, 0));
|
||||
participation_time_and_num_.emplace_back(std::make_pair(UINT64_MAX, 0));
|
||||
std::vector<std::string> time_levels;
|
||||
std::istringstream iss(participation_time_level);
|
||||
std::string output;
|
||||
while (std::getline(iss, output, ',')) {
|
||||
if (!output.empty()) {
|
||||
time_levels.emplace_back(std::move(output));
|
||||
}
|
||||
}
|
||||
if (time_levels.size() != kLevelNum) {
|
||||
MS_LOG(WARNING) << "Parameter participation_time_level is not correct";
|
||||
return;
|
||||
}
|
||||
uint64_t level1 = std::strtoull(time_levels[0].c_str(), nullptr, kBase);
|
||||
if (level1 > kMaxLevelNum || level1 <= kMinLevelNum) {
|
||||
MS_LOG(WARNING) << "Level1 partmeter " << level1 << " is not legal";
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t level2 = std::strtoull(time_levels[1].c_str(), nullptr, kBase);
|
||||
if (level2 > kMaxLevelNum || level2 <= kMinLevelNum) {
|
||||
MS_LOG(WARNING) << "Level2 partmeter " << level2 << "is not legal";
|
||||
return;
|
||||
}
|
||||
if (level1 >= level2) {
|
||||
MS_LOG(WARNING) << "Level1 parameter " << level1 << " is larger than level2 " << level2;
|
||||
return;
|
||||
}
|
||||
// Save the the parament of user
|
||||
participation_time_and_num_.clear();
|
||||
participation_time_and_num_.emplace_back(std::make_pair(level1 * kMinuteToSecond * kSecondToMills, 0));
|
||||
participation_time_and_num_.emplace_back(std::make_pair(level2 * kMinuteToSecond * kSecondToMills, 0));
|
||||
participation_time_and_num_.emplace_back(std::make_pair(UINT64_MAX, 0));
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -18,21 +18,23 @@
|
|||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/executor.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
#include "fl/server/executor.h"
|
||||
#include "fl/server/model_store.h"
|
||||
#ifdef ENABLE_ARMOUR
|
||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||
#endif
|
||||
#include "fl/compression/decode_executor.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
#include "schema/cipher_generated.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -55,6 +57,12 @@ class UpdateModelKernel : public RoundKernel {
|
|||
// In some cases, the last updateModel message means this server iteration is finished.
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
// Get participation_time_and_num_
|
||||
const std::vector<std::pair<uint64_t, uint32_t>> &GetCompletePeriodRecord();
|
||||
|
||||
// Reset participation_time_and_num_
|
||||
void ResetParticipationTimeAndNum();
|
||||
|
||||
private:
|
||||
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req);
|
||||
|
@ -82,6 +90,12 @@ class UpdateModelKernel : public RoundKernel {
|
|||
const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta);
|
||||
bool VerifyUpdateModelRequest(const schema::RequestUpdateModel *update_model_req);
|
||||
|
||||
// Record complete update model number according to participation_time_level
|
||||
void RecordCompletePeriod(const DeviceMeta &device_meta);
|
||||
|
||||
// Check and transform participation time level parament
|
||||
void CheckAndTransPara(const std::string &participation_time_level);
|
||||
|
||||
// The executor is for updating the model for updateModel request.
|
||||
Executor *executor_{nullptr};
|
||||
|
||||
|
@ -95,6 +109,12 @@ class UpdateModelKernel : public RoundKernel {
|
|||
|
||||
// Check upload mode
|
||||
bool IsCompress(const schema::RequestUpdateModel *update_model_req);
|
||||
|
||||
// From StartFlJob to UpdateModel complete time and number
|
||||
std::vector<std::pair<uint64_t, uint32_t>> participation_time_and_num_{};
|
||||
|
||||
// The mutex for participation_time_and_num_
|
||||
std::mutex participation_time_and_num_mtx_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -15,10 +15,14 @@
|
|||
*/
|
||||
|
||||
#include "fl/server/round.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "fl/server/server.h"
|
||||
|
||||
#include "fl/server/iteration.h"
|
||||
#include "fl/server/kernel/round/update_model_kernel.h"
|
||||
#include "fl/server/server.h"
|
||||
#include "fl/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -143,6 +147,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
|||
MS_LOG(DEBUG) << "Launching round kernel of round " + name_ + " failed.";
|
||||
}
|
||||
(void)(Iteration::GetInstance().running_round_num_--);
|
||||
kernel_->CalculateReceiveData(message->len());
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -245,6 +250,32 @@ void Round::InitkernelClientVisitedNum() { kernel_->InitClientVisitedNum(); }
|
|||
void Round::InitkernelClientUploadLoss() { kernel_->InitClientUploadLoss(); }
|
||||
|
||||
float Round::kernel_upload_loss() const { return kernel_->upload_loss(); }
|
||||
|
||||
std::vector<std::pair<uint64_t, uint32_t>> Round::GetUpdateModelCompleteInfo() const {
|
||||
if (name_ == kUpdateModel) {
|
||||
auto update_model_model_ptr = std::dynamic_pointer_cast<fl::server::kernel::UpdateModelKernel>(kernel_);
|
||||
MS_EXCEPTION_IF_NULL(update_model_model_ptr);
|
||||
return update_model_model_ptr->GetCompletePeriodRecord();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The kernel is not updateModel";
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
void Round::ResetParticipationTimeAndNum() {
|
||||
if (name_ == kUpdateModel) {
|
||||
auto update_model_kernel_ptr = std::dynamic_pointer_cast<fl::server::kernel::UpdateModelKernel>(kernel_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(update_model_kernel_ptr);
|
||||
update_model_kernel_ptr->ResetParticipationTimeAndNum();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<uint64_t, size_t> Round::GetSendData() const { return kernel_->GetSendData(); }
|
||||
|
||||
std::map<uint64_t, size_t> Round::GetReceiveData() const { return kernel_->GetReceiveData(); }
|
||||
|
||||
void Round::ClearData() { return kernel_->ClearData(); }
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,11 +19,15 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/iteration_timer.h"
|
||||
#include "fl/server/distributed_count_service.h"
|
||||
#include "fl/server/iteration_timer.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -78,6 +82,16 @@ class Round {
|
|||
|
||||
float kernel_upload_loss() const;
|
||||
|
||||
std::map<uint64_t, size_t> GetSendData() const;
|
||||
|
||||
std::map<uint64_t, size_t> GetReceiveData() const;
|
||||
|
||||
std::vector<std::pair<uint64_t, uint32_t>> GetUpdateModelCompleteInfo() const;
|
||||
|
||||
void ResetParticipationTimeAndNum();
|
||||
|
||||
void ClearData();
|
||||
|
||||
private:
|
||||
// The callbacks which will be set to DistributedCounterService.
|
||||
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
|
|
@ -91,6 +91,7 @@ void Server::Run() {
|
|||
RegisterRoundKernel();
|
||||
InitMetrics();
|
||||
Recover();
|
||||
iteration_->StartThreadToRecordDataRate();
|
||||
MS_LOG(INFO) << "Server started successfully.";
|
||||
safemode_ = false;
|
||||
is_ready_ = true;
|
||||
|
@ -261,6 +262,14 @@ void Server::InitIteration() {
|
|||
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||
|
||||
iteration_->InitGlobalIterTimer(time_out_cb);
|
||||
auto file_config_ptr = std::make_shared<ps::core::FileConfiguration>(ps::PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(file_config_ptr);
|
||||
if (!file_config_ptr->Initialize()) {
|
||||
MS_LOG(WARNING) << "Initializing for Config file path failed!" << ps::PSContext::instance()->config_file_path()
|
||||
<< " may be invalid or not exist.";
|
||||
return;
|
||||
}
|
||||
iteration_->SetFileConfig(file_config_ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -551,7 +551,13 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.")
|
||||
.def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.")
|
||||
.def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.")
|
||||
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.");
|
||||
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.")
|
||||
.def("set_instance_name", &PSContext::set_instance_name, "Set instance name.")
|
||||
.def("instance_name", &PSContext::instance_name, "Get instance name.")
|
||||
.def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.")
|
||||
.def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.")
|
||||
.def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times")
|
||||
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.");
|
||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
||||
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
|
||||
|
|
|
@ -15,9 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "ps/core/abstract_node.h"
|
||||
#include "ps/core/node_recovery.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
|
||||
#include "include/common/debug/common.h"
|
||||
#include "ps/core/communicator/http_communicator.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
#include "ps/core/node_recovery.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -68,6 +70,32 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|||
}
|
||||
}
|
||||
|
||||
void AbstractNode::SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info) {
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
MS_EXCEPTION_IF_NULL(message_meta);
|
||||
message_meta->set_cmd(NodeCommand::FAILURE_EVENT_INFO);
|
||||
|
||||
std::string now_time = ps::core::CommUtil::GetNowTime().time_str_mill;
|
||||
FailureEventMessage failure_event_message;
|
||||
failure_event_message.set_node_role(node_role);
|
||||
failure_event_message.set_ip(node_info_.ip_);
|
||||
failure_event_message.set_port(node_info_.port_);
|
||||
failure_event_message.set_time(now_time);
|
||||
failure_event_message.set_event(event_info);
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << node_role << "The node id:" << node_info_.node_id_
|
||||
<< "begin to send failure message to scheduler!";
|
||||
|
||||
if (!SendMessageAsync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
|
||||
failure_event_message.SerializeAsString().data(), failure_event_message.ByteSizeLong())) {
|
||||
MS_LOG(ERROR) << "The node role:" << node_role << " the node id:" << node_info_.node_id_
|
||||
<< " send failure message timeout!";
|
||||
} else {
|
||||
MS_LOG(INFO) << "The node role:" << node_role << " the node id:" << node_info_.node_id_ << " send failure message "
|
||||
<< event_info << "success!";
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
|
|
@ -17,25 +17,25 @@
|
|||
#ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "ps/core/node.h"
|
||||
#include "ps/core/communicator/message.h"
|
||||
#include "ps/core/follower_scaler.h"
|
||||
#include "utils/ms_exception.h"
|
||||
#include "include/backend/visible.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "ps/core/communicator/message.h"
|
||||
#include "ps/core/communicator/task_executor.h"
|
||||
#include "ps/core/follower_scaler.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "ps/core/node_info.h"
|
||||
#include "ps/core/recovery_base.h"
|
||||
#include "ps/core/communicator/task_executor.h"
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "include/backend/visible.h"
|
||||
#include "utils/ms_exception.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -163,6 +163,9 @@ class BACKEND_EXPORT AbstractNode : public Node {
|
|||
// register cancel SafeMode function to node
|
||||
void SetCancelSafeModeCallBack(const CancelSafeModeFn &fn) { cancelSafeModeFn_ = fn; }
|
||||
|
||||
// server node and worker node send exception message to scheduler
|
||||
void SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info);
|
||||
|
||||
protected:
|
||||
virtual void Register(const std::shared_ptr<TcpClient> &client);
|
||||
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
|
|
|
@ -17,13 +17,15 @@
|
|||
#include "ps/core/comm_util.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <regex>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -600,6 +602,54 @@ bool CommUtil::StringToBool(const std::string &alive) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Time CommUtil::GetNowTime() {
|
||||
ps::core::Time time;
|
||||
auto time_now = std::chrono::system_clock::now();
|
||||
std::time_t tt = std::chrono::system_clock::to_time_t(time_now);
|
||||
struct tm ptm;
|
||||
(void)localtime_r(&tt, &ptm);
|
||||
std::ostringstream time_mill_oss;
|
||||
time_mill_oss << std::put_time(&ptm, "%Y-%m-%d %H:%M:%S");
|
||||
|
||||
// calculate millisecond, the format of time_str_mill is 2022-01-10 20:22:20.067
|
||||
auto second_time_stamp = std::chrono::duration_cast<std::chrono::seconds>(time_now.time_since_epoch());
|
||||
auto mill_time_stamp = std::chrono::duration_cast<std::chrono::milliseconds>(time_now.time_since_epoch());
|
||||
auto ms_stamp = mill_time_stamp - second_time_stamp;
|
||||
time_mill_oss << "." << std::setfill('0') << std::setw(kMillSecondLength) << ms_stamp.count();
|
||||
|
||||
time.time_stamp = mill_time_stamp.count();
|
||||
time.time_str_mill = time_mill_oss.str();
|
||||
return time;
|
||||
}
|
||||
|
||||
bool CommUtil::ParseAndCheckConfigJson(Configuration *file_configuration, const std::string &key,
|
||||
FileConfig *file_config) {
|
||||
MS_EXCEPTION_IF_NULL(file_configuration);
|
||||
MS_EXCEPTION_IF_NULL(file_config);
|
||||
if (!file_configuration->Exists(key)) {
|
||||
MS_LOG(WARNING) << key << " config is not set. Don't write.";
|
||||
return false;
|
||||
} else {
|
||||
std::string value = file_configuration->Get(key, "");
|
||||
nlohmann::json value_json;
|
||||
try {
|
||||
value_json = nlohmann::json::parse(value);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format.";
|
||||
}
|
||||
// Parse the storage type.
|
||||
uint32_t storage_type = ps::core::CommUtil::JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
|
||||
if (std::to_string(storage_type) != ps::kFileStorage) {
|
||||
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
|
||||
}
|
||||
// Parse storage file path.
|
||||
std::string file_path = ps::core::CommUtil::JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
|
||||
file_config->storage_type = storage_type;
|
||||
file_config->storage_file_path = file_path;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,47 +19,46 @@
|
|||
|
||||
#include <unistd.h>
|
||||
#ifdef _MSC_VER
|
||||
#include <tchar.h>
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#include <iphlpapi.h>
|
||||
#include <tchar.h>
|
||||
#include <windows.h>
|
||||
#include <winsock2.h>
|
||||
#else
|
||||
#include <net/if.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <netinet/in.h>
|
||||
#endif
|
||||
|
||||
#include <assert.h>
|
||||
#include <event2/buffer.h>
|
||||
#include <event2/event.h>
|
||||
#include <event2/http.h>
|
||||
#include <event2/keyvalq_struct.h>
|
||||
#include <event2/listener.h>
|
||||
#include <event2/util.h>
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <assert.h>
|
||||
#include <openssl/pkcs12.h>
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/x509v3.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <thread>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
|
@ -78,12 +77,14 @@ constexpr int kGroup2RandomLength = 4;
|
|||
constexpr int kGroup3RandomLength = 4;
|
||||
constexpr int kGroup4RandomLength = 4;
|
||||
constexpr int kGroup5RandomLength = 12;
|
||||
constexpr int kMillSecondLength = 3;
|
||||
|
||||
// The size of the buffer for sending and receiving data is 4096 bytes.
|
||||
constexpr int kMessageChunkLength = 4096;
|
||||
// The timeout period for the http client to connect to the http server is 120 seconds.
|
||||
constexpr int kConnectionTimeout = 120;
|
||||
constexpr char kLibeventLogPrefix[] = "[libevent log]:";
|
||||
constexpr char kFailureEvent[] = "failureEvent";
|
||||
|
||||
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
|
||||
const std::vector<std::string> kClusterState = {
|
||||
|
@ -113,6 +114,16 @@ const std::map<std::string, ClusterState> kClusterStateMap = {
|
|||
{"CLUSTER_SCHEDULER_RECOVERY", ClusterState::CLUSTER_SCHEDULER_RECOVERY},
|
||||
{"CLUSTER_SCALE_OUT_ROLLBACK", ClusterState::CLUSTER_SCALE_OUT_ROLLBACK}};
|
||||
|
||||
struct Time {
|
||||
uint64_t time_stamp;
|
||||
std::string time_str_mill;
|
||||
};
|
||||
|
||||
struct FileConfig {
|
||||
uint32_t storage_type;
|
||||
std::string storage_file_path;
|
||||
};
|
||||
|
||||
class CommUtil {
|
||||
public:
|
||||
static bool CheckIpWithRegex(const std::string &ip);
|
||||
|
@ -159,6 +170,16 @@ class CommUtil {
|
|||
static bool CreateDirectory(const std::string &directoryPath);
|
||||
static bool CheckHttpUrl(const std::string &http_url);
|
||||
static bool IsFileReadable(const std::string &file);
|
||||
template <typename T>
|
||||
static T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
|
||||
if (!json.contains(key)) {
|
||||
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
|
||||
}
|
||||
return json[key].get<T>();
|
||||
}
|
||||
static Time GetNowTime();
|
||||
static bool ParseAndCheckConfigJson(Configuration *file_configuration, const std::string &key,
|
||||
FileConfig *file_config);
|
||||
|
||||
private:
|
||||
static std::random_device rd;
|
||||
|
|
|
@ -32,7 +32,8 @@ enum class Command {
|
|||
SEND_DATA = 3,
|
||||
FETCH_METADATA = 4,
|
||||
FINISH = 5,
|
||||
COLLECTIVE_SEND_DATA = 6
|
||||
COLLECTIVE_SEND_DATA = 6,
|
||||
FAILURE_EVENT = 7
|
||||
};
|
||||
|
||||
enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 };
|
||||
|
|
|
@ -65,6 +65,8 @@ enum NodeCommand {
|
|||
QUERY_FINISH_TRANSFORM = 23;
|
||||
// This command is used to start scale out rollback
|
||||
SCALE_OUT_ROLLBACK = 24;
|
||||
// Record the failure information, such as node restart
|
||||
FAILURE_EVENT_INFO = 25;
|
||||
}
|
||||
|
||||
enum NodeRole {
|
||||
|
@ -142,6 +144,14 @@ message HeartbeatMessage {
|
|||
uint32 port = 5;
|
||||
}
|
||||
|
||||
message FailureEventMessage {
|
||||
string node_role = 1;
|
||||
string ip = 2;
|
||||
uint32 port = 3;
|
||||
string time = 4;
|
||||
string event = 5;
|
||||
}
|
||||
|
||||
enum NodeState {
|
||||
NODE_STARTING = 0;
|
||||
NODE_FINISH = 1;
|
||||
|
|
|
@ -87,6 +87,7 @@ message DeviceMeta {
|
|||
string fl_name = 1;
|
||||
string fl_id = 2;
|
||||
uint64 data_size = 3;
|
||||
uint64 now_time = 4;
|
||||
}
|
||||
|
||||
message FLIdToDeviceMeta {
|
||||
|
@ -257,6 +258,9 @@ message EndLastIterResponse {
|
|||
uint64 getModel_accept_client_num = 9;
|
||||
uint64 getModel_reject_client_num = 10;
|
||||
float upload_loss = 11;
|
||||
uint64 participation_time_level1_num = 12;
|
||||
uint64 participation_time_level2_num = 13;
|
||||
uint64 participation_time_level3_num = 14;
|
||||
}
|
||||
|
||||
message SyncAfterRecover {
|
||||
|
|
|
@ -94,6 +94,10 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
|
|||
|
||||
void RecoverFromPersistence() override;
|
||||
|
||||
void InitEventTxtFile() override {}
|
||||
|
||||
void RecordSchedulerRestartInfo() override {}
|
||||
|
||||
// Record received host hash name from workers or servers.
|
||||
std::map<NodeRole, std::vector<size_t>> host_hash_names_;
|
||||
// Record rank id of the nodes which sended host name.
|
||||
|
|
|
@ -15,6 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "ps/core/scheduler_node.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "include/common/debug/common.h"
|
||||
#include "fl/server/common.h"
|
||||
#include "ps/core/scheduler_recovery.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -32,10 +37,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
InitNodeMetaData();
|
||||
bool is_recover = false;
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(WARNING) << "The config file is empty.";
|
||||
} else {
|
||||
if (!RecoverScheduler()) {
|
||||
InitEventTxtFile();
|
||||
is_recover = RecoverScheduler();
|
||||
if (!is_recover) {
|
||||
MS_LOG(DEBUG) << "Recover the server node is failed.";
|
||||
}
|
||||
}
|
||||
|
@ -50,6 +58,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
StartUpdateClusterStateTimer();
|
||||
RunRecovery();
|
||||
|
||||
if (is_recover) {
|
||||
RecordSchedulerRestartInfo();
|
||||
}
|
||||
|
||||
if (is_worker_timeout_) {
|
||||
BroadcastTimeoutEvent();
|
||||
}
|
||||
|
@ -258,6 +270,7 @@ void SchedulerNode::InitCommandHandler() {
|
|||
handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone;
|
||||
handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone;
|
||||
handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
|
||||
handlers_[NodeCommand::FAILURE_EVENT_INFO] = &SchedulerNode::ProcessFailureEvent;
|
||||
RegisterActorRouteTableServiceHandler();
|
||||
RegisterInitCollectCommServiceHandler();
|
||||
RegisterRecoveryServiceHandler();
|
||||
|
@ -269,6 +282,26 @@ void SchedulerNode::RegisterActorRouteTableServiceHandler() {
|
|||
handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &SchedulerNode::ProcessLookupActorRoute;
|
||||
}
|
||||
|
||||
void SchedulerNode::InitEventTxtFile() {
|
||||
MS_LOG(DEBUG) << "Start init event txt";
|
||||
ps::core::FileConfig event_file_config;
|
||||
if (!ps::core::CommUtil::ParseAndCheckConfigJson(config_.get(), kFailureEvent, &event_file_config)) {
|
||||
MS_LOG(EXCEPTION) << "Parse and checkout config json failed";
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(event_txt_file_mtx_);
|
||||
event_file_path_ = event_file_config.storage_file_path;
|
||||
auto realpath = Common::CreatePrefixPath(event_file_path_.c_str());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Creating path for " << event_file_path_ << " failed.";
|
||||
return;
|
||||
}
|
||||
event_txt_file_.open(realpath.value(), std::ios::out | std::ios::app);
|
||||
event_txt_file_.close();
|
||||
MS_LOG(DEBUG) << "Init event txt success!";
|
||||
}
|
||||
|
||||
void SchedulerNode::CreateTcpServer() {
|
||||
node_manager_.InitNode();
|
||||
|
||||
|
@ -628,6 +661,34 @@ void SchedulerNode::ProcessLookupActorRoute(const std::shared_ptr<TcpServer> &se
|
|||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFailureEvent(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(server);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(data);
|
||||
FailureEventMessage failure_event_message;
|
||||
failure_event_message.ParseFromArray(data, SizeToInt(size));
|
||||
std::string node_role = failure_event_message.node_role();
|
||||
std::string ip = failure_event_message.ip();
|
||||
uint32_t port = failure_event_message.port();
|
||||
std::string time = failure_event_message.time();
|
||||
std::string event = failure_event_message.event();
|
||||
std::lock_guard<std::mutex> lock(event_txt_file_mtx_);
|
||||
event_txt_file_.open(event_file_path_, std::ios::out | std::ios::app);
|
||||
if (!event_txt_file_.is_open()) {
|
||||
MS_LOG(EXCEPTION) << "The event txt file is not open";
|
||||
return;
|
||||
}
|
||||
std::string event_info = "nodeRole:" + node_role + "," + ip + ":" + std::to_string(port) + "," +
|
||||
"currentTime:" + time + "," + "event:" + event + ";";
|
||||
event_txt_file_ << event_info << "\n";
|
||||
(void)event_txt_file_.flush();
|
||||
event_txt_file_.close();
|
||||
MS_LOG(INFO) << "Process failure event success!";
|
||||
}
|
||||
|
||||
bool SchedulerNode::SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) {
|
||||
uint64_t request_id = AddMessageTrack(node_infos.size());
|
||||
for (const auto &kvs : node_infos) {
|
||||
|
@ -1598,7 +1659,7 @@ bool SchedulerNode::RecoverScheduler() {
|
|||
MS_EXCEPTION_IF_NULL(config_);
|
||||
if (config_->Exists(kKeyRecovery)) {
|
||||
MS_LOG(INFO) << "The scheduler node is support recovery.";
|
||||
scheduler_recovery_ = std::make_unique<SchedulerRecovery>();
|
||||
scheduler_recovery_ = std::make_shared<SchedulerRecovery>();
|
||||
MS_EXCEPTION_IF_NULL(scheduler_recovery_);
|
||||
bool ret = scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
|
||||
bool ret_node = scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, ""));
|
||||
|
@ -1610,6 +1671,29 @@ bool SchedulerNode::RecoverScheduler() {
|
|||
return false;
|
||||
}
|
||||
|
||||
void SchedulerNode::RecordSchedulerRestartInfo() {
|
||||
MS_LOG(DEBUG) << "Start to record scheduler restart error message";
|
||||
std::lock_guard<std::mutex> lock(event_txt_file_mtx_);
|
||||
event_txt_file_.open(event_file_path_, std::ios::out | std::ios::app);
|
||||
if (!event_txt_file_.is_open()) {
|
||||
MS_LOG(EXCEPTION) << "The event txt file is not open";
|
||||
return;
|
||||
}
|
||||
std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_);
|
||||
auto scheduler_recovery_ptr = std::dynamic_pointer_cast<SchedulerRecovery>(scheduler_recovery_);
|
||||
MS_EXCEPTION_IF_NULL(scheduler_recovery_ptr);
|
||||
auto ip = scheduler_recovery_ptr->GetMetadata(kRecoverySchedulerIp);
|
||||
auto port = scheduler_recovery_ptr->GetMetadata(kRecoverySchedulerPort);
|
||||
std::string time = ps::core::CommUtil::GetNowTime().time_str_mill;
|
||||
std::string event = "Node restart";
|
||||
std::string event_info =
|
||||
"nodeRole:" + node_role + "," + ip + ":" + port + "," + "currentTime:" + time + "," + "event:" + event + ";";
|
||||
event_txt_file_ << event_info << "\n";
|
||||
(void)event_txt_file_.flush();
|
||||
event_txt_file_.close();
|
||||
MS_LOG(DEBUG) << "Record scheduler node restart info " << event_info << " success!";
|
||||
}
|
||||
|
||||
void SchedulerNode::PersistMetaData() {
|
||||
if (scheduler_recovery_ == nullptr) {
|
||||
MS_LOG(WARNING) << "scheduler recovery is null, do not persist meta data";
|
||||
|
|
|
@ -88,6 +88,7 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
// Register and initialize the actor route table service.
|
||||
void RegisterActorRouteTableServiceHandler();
|
||||
void InitializeActorRouteTableService();
|
||||
virtual void InitEventTxtFile();
|
||||
|
||||
// Register collective communication initialization service.
|
||||
virtual void RegisterInitCollectCommServiceHandler() {}
|
||||
|
@ -138,6 +139,10 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
void ProcessLookupActorRoute(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
// Process failure event message from other nodes.
|
||||
void ProcessFailureEvent(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
// Determine whether the registration request of the node should be rejected, the registration of the
|
||||
// alive node should be rejected.
|
||||
virtual bool NeedRejectRegister(const NodeInfo &node_info) { return false; }
|
||||
|
@ -201,6 +206,9 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
|
||||
bool RecoverScheduler();
|
||||
|
||||
// Write scheduler restart error message
|
||||
virtual void RecordSchedulerRestartInfo();
|
||||
|
||||
void PersistMetaData();
|
||||
|
||||
bool CheckIfNodeDisconnected() const;
|
||||
|
@ -245,7 +253,7 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
std::unordered_map<std::string, OnRequestReceive> callbacks_;
|
||||
|
||||
// Used to persist and obtain metadata information for scheduler.
|
||||
std::unique_ptr<RecoveryBase> scheduler_recovery_;
|
||||
std::shared_ptr<RecoveryBase> scheduler_recovery_;
|
||||
// persistent command need to be sent.
|
||||
std::atomic<PersistentCommand> persistent_cmd_;
|
||||
|
||||
|
@ -259,6 +267,15 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
std::unordered_map<int, std::string> register_connection_fd_;
|
||||
|
||||
std::unique_ptr<ActorRouteTableService> actor_route_table_service_;
|
||||
|
||||
// The event txt file path
|
||||
std::string event_file_path_;
|
||||
|
||||
// The mutex for event txt event_file_path_
|
||||
std::mutex event_txt_file_mtx_;
|
||||
|
||||
// The fstream for event_file_path_
|
||||
std::fstream event_txt_file_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -42,6 +42,7 @@ void ServerNode::Initialize() {
|
|||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
InitNodeNum();
|
||||
bool is_recover = false;
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(WARNING) << "The config file is empty.";
|
||||
} else {
|
||||
|
@ -62,6 +63,10 @@ void ServerNode::Initialize() {
|
|||
}
|
||||
InitClientToServer();
|
||||
is_already_stopped_ = false;
|
||||
if (is_recover) {
|
||||
std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_);
|
||||
SendFailMessageToScheduler(node_role, "Node restart");
|
||||
}
|
||||
MS_LOG(INFO) << "[Server start]: 3. Server node crete tcp client to scheduler successful!";
|
||||
}
|
||||
|
||||
|
|
|
@ -41,10 +41,12 @@ void WorkerNode::Initialize() {
|
|||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
InitNodeNum();
|
||||
bool is_recover = false;
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(WARNING) << "The config file is empty.";
|
||||
} else {
|
||||
if (!Recover()) {
|
||||
is_recover = Recover();
|
||||
if (!is_recover) {
|
||||
MS_LOG(DEBUG) << "Recover the worker node is failed.";
|
||||
}
|
||||
}
|
||||
|
@ -60,6 +62,10 @@ void WorkerNode::Initialize() {
|
|||
}
|
||||
InitClientToServer();
|
||||
is_already_stopped_ = false;
|
||||
if (is_recover) {
|
||||
std::string node_role = CommUtil::NodeRoleToString(node_info_.node_role_);
|
||||
SendFailMessageToScheduler(node_role, "Node restart");
|
||||
}
|
||||
MS_LOG(INFO) << "[Worker start]: 3. Worker node crete tcp client to scheduler successful!";
|
||||
}
|
||||
|
||||
|
|
|
@ -15,13 +15,14 @@
|
|||
*/
|
||||
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
#include "kernel/kernel.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "kernel/kernel.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#else
|
||||
#include "distributed/cluster/dummy_cluster_context.h"
|
||||
#endif
|
||||
|
@ -566,5 +567,21 @@ std::string PSContext::download_compress_type() const { return download_compress
|
|||
std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; }
|
||||
|
||||
void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; }
|
||||
|
||||
void PSContext::set_instance_name(const std::string &instance_name) { instance_name_ = instance_name; }
|
||||
|
||||
const std::string &PSContext::instance_name() const { return instance_name_; }
|
||||
|
||||
void PSContext::set_participation_time_level(const std::string &participation_time_level) {
|
||||
participation_time_level_ = participation_time_level;
|
||||
}
|
||||
|
||||
const std::string &PSContext::participation_time_level() { return participation_time_level_; }
|
||||
|
||||
void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times) {
|
||||
continuous_failure_times_ = continuous_failure_times;
|
||||
}
|
||||
|
||||
uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -244,6 +244,15 @@ class BACKEND_EXPORT PSContext {
|
|||
std::string checkpoint_dir() const;
|
||||
void set_checkpoint_dir(const std::string &checkpoint_dir);
|
||||
|
||||
void set_instance_name(const std::string &instance_name);
|
||||
const std::string &instance_name() const;
|
||||
|
||||
void set_participation_time_level(const std::string &participation_time_level);
|
||||
const std::string &participation_time_level();
|
||||
|
||||
void set_continuous_failure_times(uint32_t continuous_failure_times);
|
||||
uint32_t continuous_failure_times();
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -300,7 +309,10 @@ class BACKEND_EXPORT PSContext {
|
|||
upload_compress_type_(kNoCompressType),
|
||||
upload_sparse_rate_(0.4f),
|
||||
download_compress_type_(kNoCompressType),
|
||||
checkpoint_dir_("") {}
|
||||
checkpoint_dir_(""),
|
||||
instance_name_(""),
|
||||
participation_time_level_("5,15"),
|
||||
continuous_failure_times_(10) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
|
@ -442,6 +454,15 @@ class BACKEND_EXPORT PSContext {
|
|||
|
||||
// directory of server checkpoint
|
||||
std::string checkpoint_dir_;
|
||||
|
||||
// The name of instance
|
||||
std::string instance_name_;
|
||||
|
||||
// The participation time level
|
||||
std::string participation_time_level_;
|
||||
|
||||
// The times of iteration continuous failure
|
||||
uint32_t continuous_failure_times_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -83,6 +83,9 @@ _set_ps_context_func_map = {
|
|||
"upload_compress_type": ps_context().set_upload_compress_type,
|
||||
"upload_sparse_rate": ps_context().set_upload_sparse_rate,
|
||||
"download_compress_type": ps_context().set_download_compress_type,
|
||||
"instance_name": ps_context().set_instance_name,
|
||||
"participation_time_level": ps_context().set_participation_time_level,
|
||||
"continuous_failure_times": ps_context().set_continuous_failure_times,
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
@ -134,6 +137,9 @@ _get_ps_context_func_map = {
|
|||
"upload_compress_type": ps_context().upload_compress_type,
|
||||
"upload_sparse_rate": ps_context().upload_sparse_rate,
|
||||
"download_compress_type": ps_context().download_compress_type,
|
||||
"instance_name": ps_context().instance_name,
|
||||
"participation_time_level": ps_context().participation_time_level,
|
||||
"continuous_failure_times": ps_context().continuous_failure_times,
|
||||
}
|
||||
|
||||
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
|
||||
|
|
|
@ -178,10 +178,6 @@ def call_get_instance_detail():
|
|||
return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
|
||||
|
||||
ans_json_obj = {}
|
||||
metrics_auc_list = []
|
||||
metrics_loss_list = []
|
||||
iteration_execution_time_list = []
|
||||
client_visited_info_list = []
|
||||
|
||||
with open(metrics_file_path, 'r') as f:
|
||||
metrics_list = f.readlines()
|
||||
|
@ -189,28 +185,13 @@ def call_get_instance_detail():
|
|||
if not metrics_list:
|
||||
return process_self_define_json(Status.FAILED.value, "error. metrics file has no content")
|
||||
|
||||
for metrics in metrics_list:
|
||||
json_obj = json.loads(metrics)
|
||||
iteration_execution_time_list.append(json_obj['iterationExecutionTime'])
|
||||
client_visited_info_list.append(json_obj['clientVisitedInfo'])
|
||||
metrics_auc_list.append(json_obj['metricsAuc'])
|
||||
metrics_loss_list.append(json_obj['metricsLoss'])
|
||||
|
||||
last_metrics = metrics_list[len(metrics_list) - 1]
|
||||
last_metrics_obj = json.loads(last_metrics)
|
||||
|
||||
ans_json_obj["code"] = Status.SUCCESS.value
|
||||
ans_json_obj["describe"] = "get instance metrics detail successful."
|
||||
ans_json_obj["result"] = {}
|
||||
ans_json_result = ans_json_obj.get("result")
|
||||
ans_json_result['currentIteration'] = last_metrics_obj['currentIteration']
|
||||
ans_json_result['flIterationNum'] = last_metrics_obj['flIterationNum']
|
||||
ans_json_result['flName'] = last_metrics_obj['flName']
|
||||
ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus']
|
||||
ans_json_result['iterationExecutionTime'] = iteration_execution_time_list
|
||||
ans_json_result['clientVisitedInfo'] = client_visited_info_list
|
||||
ans_json_result['metricsAuc'] = metrics_auc_list
|
||||
ans_json_result['metricsLoss'] = metrics_loss_list
|
||||
ans_json_obj["result"] = last_metrics_obj
|
||||
|
||||
return json.dumps(ans_json_obj)
|
||||
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
"storage_type": 1,
|
||||
"storage_file_path": "metrics.json"
|
||||
},
|
||||
"dataRate": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": ".."
|
||||
},
|
||||
"failureEvent": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "event.txt"
|
||||
},
|
||||
"server_recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "../server_recovery.json"
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
"storage_type": 1,
|
||||
"storage_file_path": "metrics.json"
|
||||
},
|
||||
"dataRate": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": ".."
|
||||
},
|
||||
"failureEvent": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "event.txt"
|
||||
},
|
||||
"server_recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "../server_recovery.json"
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
"storage_type": 1,
|
||||
"storage_file_path": "metrics.json"
|
||||
},
|
||||
"dataRate": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": ".."
|
||||
},
|
||||
"failureEvent": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "event.txt"
|
||||
},
|
||||
"server_recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "../server_recovery.json"
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
"storage_type": 1,
|
||||
"storage_file_path": "metrics.json"
|
||||
},
|
||||
"dataRate": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": ".."
|
||||
},
|
||||
"failureEvent": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "event.txt"
|
||||
},
|
||||
"server_recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "../server_recovery.json"
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
"storage_type": 1,
|
||||
"storage_file_path": "metrics.json"
|
||||
},
|
||||
"dataRate": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": ".."
|
||||
},
|
||||
"failureEvent": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "event.txt"
|
||||
},
|
||||
"server_recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "../server_recovery.json"
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
"storage_type": 1,
|
||||
"storage_file_path": "metrics.json"
|
||||
},
|
||||
"dataRate": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": ".."
|
||||
},
|
||||
"failureEvent": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "event.txt"
|
||||
},
|
||||
"server_recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "../server_recovery.json"
|
||||
|
|
|
@ -15,6 +15,14 @@
|
|||
"storage_type": 1,
|
||||
"storage_file_path": "metrics.json"
|
||||
},
|
||||
"dataRate": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": ".."
|
||||
},
|
||||
"failureEvent": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "event.txt"
|
||||
},
|
||||
"server_recovery": {
|
||||
"storage_type": 1,
|
||||
"storage_file_path": "../server_recovery.json"
|
||||
|
|
Loading…
Reference in New Issue