!22056 Add scheduler Restful interface.

Merge pull request !22056 from ZPaC/add-iteration-metrics
This commit is contained in:
i-robot 2021-08-21 07:52:34 +00:00 committed by Gitee
commit b18ffcf221
19 changed files with 108 additions and 51 deletions

View File

@ -61,11 +61,6 @@ class FusedPullWeightKernel : public CPUKernel {
}
fl_iteration_++;
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
fl_iteration_ = 1;
}
MS_LOG(INFO) << "Launching pulling weight for federated learning iteration " << fl_iteration_;
if (!BuildPullWeightReq(fbb)) {
MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed.";
@ -76,6 +71,10 @@ class FusedPullWeightKernel : public CPUKernel {
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) {
if (!fl::worker::FLWorker::GetInstance().running()) {
MS_LOG(WARNING) << "Worker has finished.";
return true;
}
if (!fl::worker::FLWorker::GetInstance().SendToServer(
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. Retry later.";

View File

@ -59,11 +59,6 @@ class FusedPushWeightKernel : public CPUKernel {
}
fl_iteration_++;
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
fl_iteration_ = 1;
}
MS_LOG(INFO) << "Launching pushing weight for federated learning iteration " << fl_iteration_;
if (!BuildPushWeightReq(fbb, inputs)) {
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
@ -76,6 +71,10 @@ class FusedPushWeightKernel : public CPUKernel {
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) {
if (!fl::worker::FLWorker::GetInstance().running()) {
MS_LOG(WARNING) << "Worker has finished.";
return true;
}
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
ps::core::TcpUserCommand::kPushWeight,
&push_weight_rsp_msg)) {

View File

@ -56,6 +56,10 @@ class PushMetricsKernel : public CPUKernel {
uint32_t retry_time = 0;
std::shared_ptr<std::vector<unsigned char>> push_metrics_rsp_msg = nullptr;
do {
if (!fl::worker::FLWorker::GetInstance().running()) {
MS_LOG(WARNING) << "Worker has finished.";
return true;
}
retry_time++;
if (!fl::worker::FLWorker::GetInstance().SendToServer(fl::kLeaderServerRank, fbb_->GetBufferPointer(),
fbb_->GetSize(), ps::core::TcpUserCommand::kPushMetrics,
@ -66,7 +70,7 @@ class PushMetricsKernel : public CPUKernel {
} else {
break;
}
} while (retry_time > kMaxRetryTime);
} while (retry_time < kMaxRetryTime);
flatbuffers::Verifier verifier(push_metrics_rsp_msg->data(), push_metrics_rsp_msg->size());
if (!verifier.VerifyBuffer<schema::ResponsePushMetrics>()) {

View File

@ -66,6 +66,20 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl
return;
}
bool DistributedCountService::ReInitCounter(const std::string &name, size_t global_threshold_count) {
MS_LOG(INFO) << "Rank " << local_rank_ << " reinitialize counter for " << name << " count:" << global_threshold_count;
if (local_rank_ == counting_server_rank_) {
std::unique_lock<std::mutex> lock(mutex_[name]);
if (global_threshold_count_.count(name) == 0) {
MS_LOG(INFO) << "Counter for " << name << " is not set.";
return false;
}
global_current_count_[name] = {};
global_threshold_count_[name] = global_threshold_count;
}
return true;
}
bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
if (local_rank_ == counting_server_rank_) {

View File

@ -63,6 +63,9 @@ class DistributedCountService {
// first/last count event callbacks.
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
// Reinitialize counter due to the change of threshold count.
bool ReInitCounter(const std::string &name, size_t global_threshold_count);
// Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the
// reason why counting failed.
bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr);

View File

@ -50,7 +50,7 @@ void DistributedMetadataStore::RegisterMetadata(const std::string &name, const P
uint32_t stored_rank = router_->Find(name);
if (local_rank_ == stored_rank) {
if (metadata_.count(name) != 0) {
MS_LOG(ERROR) << "The metadata for " << name << " is already registered.";
MS_LOG(WARNING) << "The metadata for " << name << " is already registered.";
return;
}

View File

@ -340,10 +340,10 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
param_aggrs_[param_name] = param_aggr;
parameter_mutex_[param_name];
if (!param_aggr->Init(cnode, aggregation_count_)) {
MS_LOG(EXCEPTION) << "Initializing parameter aggregator failed for " << param_name;
MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed.";
return false;
}
MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success.";
MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
}
return true;
}

View File

@ -29,6 +29,7 @@ class Server;
Iteration::~Iteration() {
move_to_next_thread_running_ = false;
next_iteration_cv_.notify_all();
if (move_to_next_thread_.joinable()) {
move_to_next_thread_.join();
}
@ -93,6 +94,9 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
while (move_to_next_thread_running_.load()) {
std::unique_lock<std::mutex> lock(next_iteration_mutex_);
next_iteration_cv_.wait(lock);
if (!move_to_next_thread_running_.load()) {
break;
}
MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_);
}
});
@ -222,12 +226,12 @@ bool Iteration::EnableServerInstance(std::string *result) {
std::unique_lock<std::mutex> lock(instance_mtx_);
if (is_instance_being_updated_) {
*result = "The instance is being updated. Please retry enabling server later.";
MS_LOG(WARNING) << result;
MS_LOG(WARNING) << *result;
return false;
}
if (instance_state_.load() == InstanceState::kFinish) {
*result = "The instance is completed. Please do not enabling server now.";
MS_LOG(WARNING) << result;
MS_LOG(WARNING) << *result;
return false;
}
@ -236,6 +240,7 @@ bool Iteration::EnableServerInstance(std::string *result) {
instance_state_ = InstanceState::kRunning;
*result = "Enabling FL-Server succeeded.";
MS_LOG(INFO) << *result;
// End enabling server instance.
is_instance_being_updated_ = false;
@ -259,7 +264,7 @@ bool Iteration::DisableServerInstance(std::string *result) {
if (instance_state_.load() == InstanceState::kDisable) {
*result = "Disabling FL-Server succeeded.";
MS_LOG(INFO) << *result;
return false;
return true;
}
// Start disabling server instance.
@ -272,6 +277,8 @@ bool Iteration::DisableServerInstance(std::string *result) {
MS_LOG(ERROR) << result;
return false;
}
*result = "Disabling FL-Server succeeded.";
MS_LOG(INFO) << *result;
// End disabling server instance.
is_instance_being_updated_ = false;
@ -293,8 +300,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
// Reset current instance.
instance_state_ = InstanceState::kFinish;
MS_LOG(INFO) << "Proceed to a new instance.";
Server::GetInstance().WaitExitSafeMode();
WaitAllRoundsFinish();
MS_LOG(INFO) << "Proceed to a new instance.";
for (auto &round : rounds_) {
MS_ERROR_IF_NULL_W_RET_VAL(round, false);
round->Reset();
@ -705,7 +713,7 @@ bool Iteration::ReInitRounds() {
std::vector<RoundConfig> new_round_config = {
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
{"updateModel", true, update_model_time_window, true, update_model_threshold}};
if (!Iteration::GetInstance().ReInitForUpdatingHyperParams(new_round_config)) {
if (!ReInitForUpdatingHyperParams(new_round_config)) {
MS_LOG(ERROR) << "Reinitializing for updating hyper-parameters failed.";
return false;
}
@ -727,7 +735,6 @@ bool Iteration::ReInitRounds() {
}
return true;
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -63,8 +63,6 @@ bool IterationMetrics::Initialize() {
metrics_file_.open(metrics_file_path, std::ios::ate | std::ios::out);
}
initialized_ = true;
return true;
}

View File

@ -89,8 +89,6 @@ class IterationMetrics {
void set_iteration_time_cost(uint64_t iteration_time_cost);
private:
bool initialized_;
// This is the main config file set by ps context.
std::string config_file_path_;
std::unique_ptr<ps::core::FileConfiguration> config_;
@ -99,7 +97,7 @@ class IterationMetrics {
std::fstream metrics_file_;
// Json object of metrics data.
nlohmann::json js_;
nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> js_;
// The federated learning job name. Set by ps_context.
std::string fl_name_;

View File

@ -180,7 +180,10 @@ class FedAvgKernel : public AggregationKernel {
bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override {
done_count_ = aggr_threshold;
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) {
MS_LOG(ERROR) << "Reinitializing counter for " << name_ << " failed.";
return false;
}
return true;
}

View File

@ -99,7 +99,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first;
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
if ((current_iter == get_model_iter && latest_iter_num != current_iter)) {
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" +
"2. Worker has not push all the weights to servers.";

View File

@ -32,7 +32,7 @@ namespace server {
namespace kernel {
class PushMetricsKernel : public RoundKernel {
public:
PushMetricsKernel() = default;
PushMetricsKernel() : local_rank_(0) {}
~PushMetricsKernel() override = default;
void InitKernel(size_t threshold_count);

View File

@ -106,10 +106,10 @@ bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t
time_window_ = updated_time_window;
threshold_count_ = updated_threshold_count;
if (check_count_) {
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
{first_count_handler, last_count_handler});
if (!DistributedCountService::GetInstance().ReInitCounter(name_, threshold_count_)) {
MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
return false;
}
}
MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);

View File

@ -32,6 +32,22 @@
namespace mindspore {
namespace fl {
namespace server {
// The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S.
std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
void SignalHandler(int signal) {
MS_LOG(WARNING) << "SIGTERM captured: " << signal;
(void)std::for_each(g_communicators_with_worker.begin(), g_communicators_with_worker.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
(void)communicator->Stop();
});
MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
(void)g_communicator_with_server->Stop();
return;
}
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -48,6 +64,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
use_http_ = use_http;
http_port_ = http_port;
executor_threshold_ = executor_threshold;
signal(SIGTERM, SignalHandler);
return;
}
@ -109,6 +126,12 @@ void Server::CancelSafeMode() {
bool Server::IsSafeMode() const { return safemode_.load(); }
void Server::WaitExitSafeMode() const {
while (safemode_.load()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
}
}
void Server::InitServerContext() {
ps::PSContext::instance()->GenerateResetterRound();
scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
@ -145,6 +168,7 @@ bool Server::InitCommunicatorWithServer() {
communicator_with_server_ =
server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
MS_EXCEPTION_IF_NULL(communicator_with_server_);
g_communicator_with_server = communicator_with_server_;
return true;
}
@ -166,6 +190,7 @@ bool Server::InitCommunicatorWithWorker() {
MS_EXCEPTION_IF_NULL(http_comm);
communicators_with_worker_.push_back(http_comm);
}
g_communicator_with_worker = communicator_with_worker_;
return true;
}
@ -239,10 +264,9 @@ void Server::InitIteration() {
#endif
// 2.Initialize all the rounds.
TimeOutCb time_out_cb =
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
TimeOutCb time_out_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
FinishIterCb finish_iter_cb =
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
return;
}
@ -487,15 +511,7 @@ void Server::ProcessAfterScalingIn() {
std::unique_lock<std::mutex> lock(scaling_mtx_);
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == UINT32_MAX) {
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
(void)communicator->Stop();
});
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
(void)communicator_with_server_->Stop();
MS_LOG(WARNING) << "This server the one to be scaled in. Server need to wait SIGTERM to exit.";
return;
}
@ -588,7 +604,7 @@ void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHan
void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
nlohmann::json response;
nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> response;
response["start_fl_job_threshold"] = ps::PSContext::instance()->start_fl_job_threshold();
response["start_fl_job_time_window"] = ps::PSContext::instance()->start_fl_job_time_window();
response["update_model_ratio"] = ps::PSContext::instance()->update_model_ratio();

View File

@ -56,6 +56,7 @@ class Server {
void SwitchToSafeMode();
void CancelSafeMode();
bool IsSafeMode() const;
void WaitExitSafeMode() const;
// Whether the training job of the server is enabled.
InstanceState instance_state() const;

View File

@ -25,7 +25,7 @@ namespace mindspore {
namespace fl {
namespace worker {
void FLWorker::Run() {
if (running_) {
if (running_.load()) {
return;
}
running_ = true;
@ -48,6 +48,7 @@ void FLWorker::Run() {
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
Finalize();
running_ = false;
try {
MS_LOG(EXCEPTION)
<< "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
@ -57,6 +58,7 @@ void FLWorker::Run() {
});
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
Finalize();
running_ = false;
try {
MS_LOG(EXCEPTION)
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
@ -148,6 +150,8 @@ uint32_t FLWorker::rank_id() const { return rank_id_; }
uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
bool FLWorker::running() const { return running_.load(); }
void FLWorker::SetIterationRunning() {
MS_LOG(INFO) << "Worker iteration starts.";
worker_iteration_state_ = IterationState::kRunning;

View File

@ -72,6 +72,9 @@ class FLWorker {
uint32_t rank_id() const;
uint64_t worker_step_num_per_iteration() const;
// Check whether worker has exited.
bool running() const;
// These methods set the worker's iteration state.
void SetIterationRunning();
void SetIterationCompleted();
@ -116,7 +119,7 @@ class FLWorker {
void ProcessAfterScalingOut();
void ProcessAfterScalingIn();
bool running_;
std::atomic_bool running_;
uint32_t server_num_;
uint32_t worker_num_;
std::string scheduler_ip_;

View File

@ -78,10 +78,18 @@ FollowerScaler::~FollowerScaler() {
running_ = false;
scale_out_cv_.notify_all();
scale_in_cv_.notify_all();
process_before_scale_out_thread_.join();
process_before_scale_in_thread_.join();
process_after_scale_out_thread_.join();
process_after_scale_in_thread_.join();
if (process_before_scale_out_thread_.joinable()) {
process_before_scale_out_thread_.join();
}
if (process_before_scale_in_thread_.joinable()) {
process_before_scale_in_thread_.join();
}
if (process_after_scale_out_thread_.joinable()) {
process_after_scale_out_thread_.join();
}
if (process_after_scale_in_thread_.joinable()) {
process_after_scale_in_thread_.join();
}
}
void FollowerScaler::RegisterScaleEventCallbacks() {