diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h index 1ade8b4649e..15a8b6bb33c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h @@ -107,6 +107,7 @@ class FusedPullWeightKernel : public CPUKernel { } } MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_; + ps::worker::FLWorker::GetInstance().SetIterationRunning(); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h index 4d53568020c..a7431ec7566 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h @@ -68,8 +68,8 @@ class FusedPushWeightKernel : public CPUKernel { std::shared_ptr> push_weight_rsp_msg = nullptr; if (!ps::worker::FLWorker::GetInstance().SendToServer( i, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPushWeight, &push_weight_rsp_msg)) { - MS_LOG(EXCEPTION) << "Sending request for FusedPushWeight to server " << i << " failed."; - return false; + MS_LOG(ERROR) << "Sending request for FusedPushWeight to server " << i << " failed."; + continue; } MS_EXCEPTION_IF_NULL(push_weight_rsp_msg); @@ -83,6 +83,7 @@ class FusedPushWeightKernel : public CPUKernel { } } MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_; + ps::worker::FLWorker::GetInstance().SetIterationCompleted(); return true; } diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h index c524b02f52a..206a489a16c 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -47,7 +47,11 @@ enum class TcpUserCommand { kCounterEvent, kPullWeight, kPushWeight, - kSyncIteration + kSyncIteration, + kNotifyLeaderToNextIter, + kPrepareForNextIter, + kProceedToNextIter, + kEndLastIter }; const std::unordered_map kUserCommandToMsgType = { @@ -61,7 +65,11 @@ const std::unordered_map kUserCommandToMsgType = { {TcpUserCommand::kCounterEvent, "counterEvent"}, {TcpUserCommand::kPullWeight, "pullWeight"}, {TcpUserCommand::kPushWeight, "pushWeight"}, - {TcpUserCommand::kSyncIteration, "syncIteration"}}; + {TcpUserCommand::kSyncIteration, "syncIteration"}, + {TcpUserCommand::kNotifyLeaderToNextIter, "notifyLeaderToNextIter"}, + {TcpUserCommand::kPrepareForNextIter, "prepareForNextIter"}, + {TcpUserCommand::kProceedToNextIter, "proceedToNextIter"}, + {TcpUserCommand::kEndLastIter, "endLastIter"}}; class TcpCommunicator : public CommunicatorBase { public: diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto index 9d46df9a24c..5dd0c7ffd80 100644 --- a/mindspore/ccsrc/ps/core/protos/fl.proto +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -163,3 +163,41 @@ message SyncIterationResponse { // The current iteration number. uint64 iteration = 1; } + +message PrepareForNextIterRequest { + bool is_last_iter_valid = 1; + string reason = 2; +} + +message PrepareForNextIterResponse { + string result = 1; +} + +message NotifyLeaderMoveToNextIterRequest { + uint32 rank = 1; + bool is_last_iter_valid = 2; + uint64 iter_num = 3; + string reason = 4; +} + +message NotifyLeaderMoveToNextIterResponse { + string result = 1; +} + +message MoveToNextIterRequest { + bool is_last_iter_valid = 1; + uint64 last_iter_num = 2; + string reason = 3; +} + +message MoveToNextIterResponse { + string result = 1; +} + +message EndLastIterRequest { + uint64 last_iter_num = 1; +} + +message EndLastIterResponse { + string result = 1; +} diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 2eafebd1787..8cca46067e8 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -45,7 +45,7 @@ void PSContext::SetPSEnable(bool enabled) { } else if (ms_role == kEnvRoleOfScheduler) { is_sched_ = true; } else { - MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid."; + MS_LOG(INFO) << "MS_ROLE is " << ms_role; } worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); @@ -273,7 +273,13 @@ void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; } -void PSContext::set_update_model_ratio(float update_model_ratio) { update_model_ratio_ = update_model_ratio; } +void PSContext::set_update_model_ratio(float update_model_ratio) { + if (update_model_ratio > 1.0) { + MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1."; + return; + } + update_model_ratio_ = update_model_ratio; +} float PSContext::update_model_ratio() const { return update_model_ratio_; } diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index e407c37fdcd..8e91bf13bb0 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -161,12 +161,12 @@ class PSContext { rank_id_(0), worker_num_(0), server_num_(0), - scheduler_host_(""), - scheduler_port_(0), + scheduler_host_("0.0.0.0"), + scheduler_port_(6667), role_(kEnvRoleOfNotPS), server_mode_(""), resetter_round_(ResetterRound::kNoNeedToReset), - fl_server_port_(0), + fl_server_port_(6668), fl_client_enable_(false), fl_name_(""), start_fl_job_threshold_(0), @@ -179,7 +179,7 @@ class PSContext { client_learning_rate_(0.001), secure_aggregation_(false), cluster_config_(nullptr), - scheduler_manage_port_(0), + scheduler_manage_port_(11202), config_file_path_("") {} bool ps_enabled_; bool is_worker_; diff --git a/mindspore/ccsrc/ps/server/common.h b/mindspore/ccsrc/ps/server/common.h index 558794c5981..08772175ea5 100644 --- a/mindspore/ccsrc/ps/server/common.h +++ b/mindspore/ccsrc/ps/server/common.h @@ -63,9 +63,9 @@ using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; using mindspore::kernel::CPUKernel; using FBBuilder = flatbuffers::FlatBufferBuilder; -using TimeOutCb = std::function; +using TimeOutCb = std::function; using StopTimerCb = std::function; -using FinishIterCb = std::function; +using FinishIterCb = std::function; using FinalizeCb = std::function; using MessageCallback = std::function &)>; diff --git a/mindspore/ccsrc/ps/server/distributed_count_service.cc b/mindspore/ccsrc/ps/server/distributed_count_service.cc index 0f3cb9df75a..5f6d7ff7000 100644 --- a/mindspore/ccsrc/ps/server/distributed_count_service.cc +++ b/mindspore/ccsrc/ps/server/distributed_count_service.cc @@ -83,7 +83,10 @@ bool DistributedCountService::Count(const std::string &name, const std::string & MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; global_current_count_[name].insert(id); - TriggerCounterEvent(name); + if (!TriggerCounterEvent(name)) { + MS_LOG(ERROR) << "Leader server trigger count event failed."; + return false; + } } else { // If this server is a follower server, it needs to send CountRequest to the leader server. CountRequest report_count_req; @@ -198,9 +201,14 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptrSendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); return; } @@ -256,20 +264,24 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptrSendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; - return; + return false; } } // Leader server directly calls the callback. counter_handlers_[name].first_count_handler(nullptr); - return; + return true; } -void DistributedCountService::TriggerLastCountEvent(const std::string &name) { +bool DistributedCountService::TriggerLastCountEvent(const std::string &name) { MS_LOG(INFO) << "Activating last count event for " << name; CounterEvent last_count_event; last_count_event.set_type(CounterEventType::LAST_CNT); @@ -297,12 +309,12 @@ void DistributedCountService::TriggerLastCountEvent(const std::string &name) { for (uint32_t i = 1; i < server_num_; i++) { if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; - return; + return false; } } // Leader server directly calls the callback. counter_handlers_[name].last_count_handler(nullptr); - return; + return true; } } // namespace server } // namespace ps diff --git a/mindspore/ccsrc/ps/server/distributed_count_service.h b/mindspore/ccsrc/ps/server/distributed_count_service.h index 1e18dd33c1f..be387718cc3 100644 --- a/mindspore/ccsrc/ps/server/distributed_count_service.h +++ b/mindspore/ccsrc/ps/server/distributed_count_service.h @@ -98,9 +98,9 @@ class DistributedCountService { void HandleCounterEvent(const std::shared_ptr &message); // Call the callbacks when the first/last count event is triggered. - void TriggerCounterEvent(const std::string &name); - void TriggerFirstCountEvent(const std::string &name); - void TriggerLastCountEvent(const std::string &name); + bool TriggerCounterEvent(const std::string &name); + bool TriggerFirstCountEvent(const std::string &name); + bool TriggerLastCountEvent(const std::string &name); // Members for the communication between counting server and other servers. std::shared_ptr server_node_; diff --git a/mindspore/ccsrc/ps/server/iteration.cc b/mindspore/ccsrc/ps/server/iteration.cc index e405663c065..43a1d93e958 100644 --- a/mindspore/ccsrc/ps/server/iteration.cc +++ b/mindspore/ccsrc/ps/server/iteration.cc @@ -20,15 +20,26 @@ #include #include #include "ps/server/model_store.h" +#include "ps/server/server.h" namespace mindspore { namespace ps { namespace server { +class Server; void Iteration::RegisterMessageCallback(const std::shared_ptr &communicator) { MS_EXCEPTION_IF_NULL(communicator); communicator_ = communicator; communicator_->RegisterMsgCallBack("syncIteraion", std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1)); + communicator_->RegisterMsgCallBack( + "notifyLeaderToNextIter", + std::bind(&Iteration::HandleNotifyLeaderMoveToNextIterRequest, this, std::placeholders::_1)); + communicator_->RegisterMsgCallBack( + "prepareForNextIter", std::bind(&Iteration::HandlePrepareForNextIterRequest, this, std::placeholders::_1)); + communicator_->RegisterMsgCallBack("proceedToNextIter", + std::bind(&Iteration::HandleMoveToNextIterRequest, this, std::placeholders::_1)); + communicator_->RegisterMsgCallBack("endLastIter", + std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1)); } void Iteration::RegisterEventCallback(const std::shared_ptr &server_node) { @@ -72,36 +83,33 @@ void Iteration::InitRounds(const std::vectorrank_id() == kLeaderServerRank) { + if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) { + MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed."; + return; + } + if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) { + MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed."; + return; + } + if (!BroadcastEndLastIterRequest(iteration_num_)) { + MS_LOG(ERROR) << "Broadcast end last iteration request failed."; + return; + } } else { - // Store last iteration's model because this iteration is considered as invalid. - const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1); - ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); - MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid."; + // If this server is the follower server, notify leader server to control the cluster to proceed to next iteration. + if (!NotifyLeaderMoveToNextIteration(is_last_iter_valid, reason)) { + MS_LOG(ERROR) << "Server " << server_node_->rank_id() << " notifying the leader server failed."; + return; + } } - - for (auto &round : rounds_) { - round->Reset(); - } - - iteration_num_++; - // After the job is done, reset the iteration to the initial number and reset ModelStore. - if (iteration_num_ > PSContext::instance()->fl_iteration_num()) { - MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed."; - iteration_num_ = 1; - ModelStore::GetInstance().Reset(); - } - - SetIterationCompleted(); - LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); - MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n"; } void Iteration::SetIterationRunning() { @@ -147,7 +155,7 @@ bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) { } for (auto &round : rounds_) { if (!round->ReInitForScaling(server_num)) { - MS_LOG(ERROR) << "Reinitializing round " << round->name() << " for scaling failed."; + MS_LOG(WARNING) << "Reinitializing round " << round->name() << " for scaling failed."; return false; } } @@ -168,6 +176,10 @@ bool Iteration::SyncIteration(uint32_t rank) { MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed."; return false; } + if (sync_iter_rsp_msg == nullptr) { + MS_LOG(ERROR) << "Response from server 0 is empty."; + return false; + } SyncIterationResponse sync_iter_rsp; sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size()); @@ -192,6 +204,239 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptrSendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message); } + +bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) { + std::unique_lock lock(pinned_mtx_); + if (pinned_iter_num_ == iteration_num) { + MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call."; + return true; + } + pinned_iter_num_ = iteration_num; + return false; +} + +bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason) { + MS_LOG(INFO) << "Notify leader server to control the cluster to proceed to next iteration."; + NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req; + notify_leader_to_next_iter_req.set_rank(server_node_->rank_id()); + notify_leader_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid); + notify_leader_to_next_iter_req.set_iter_num(iteration_num_); + notify_leader_to_next_iter_req.set_reason(reason); + if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank, + core::TcpUserCommand::kNotifyLeaderToNextIter)) { + MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed."; + return false; + } + return true; +} + +void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr &message) { + if (message == nullptr) { + return; + } + + NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp; + notify_leader_to_next_iter_rsp.set_result("success"); + communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(), + notify_leader_to_next_iter_rsp.SerializeAsString().size(), message); + + NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req; + notify_leader_to_next_iter_req.ParseFromArray(message->data(), message->len()); + const auto &rank = notify_leader_to_next_iter_req.rank(); + const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid(); + const auto &iter_num = notify_leader_to_next_iter_req.iter_num(); + const auto &reason = notify_leader_to_next_iter_req.reason(); + MS_LOG(INFO) << "Leader server receives NotifyLeaderMoveToNextIterRequest from rank " << rank + << ". Iteration number: " << iter_num << ". Reason: " << reason; + + if (IsMoveToNextIterRequestReentrant(iter_num)) { + return; + } + + if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) { + MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed."; + return; + } + if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) { + MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed."; + return; + } + if (!BroadcastEndLastIterRequest(iteration_num_)) { + MS_LOG(ERROR) << "Broadcast end last iteration request failed."; + return; + } +} + +bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) { + PrepareForNextIter(); + + MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration."; + PrepareForNextIterRequest prepare_next_iter_req; + prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid); + prepare_next_iter_req.set_reason(reason); + + std::vector offline_servers = {}; + for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { + if (!communicator_->SendPbRequest(prepare_next_iter_req, i, core::TcpUserCommand::kPrepareForNextIter)) { + MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later."; + offline_servers.push_back(i); + continue; + } + } + + // Retry sending to offline servers to notify them to prepare. + std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) { + while (!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) { + MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank + << " failed. The server has not recovered yet."; + std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter)); + } + MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success."; + }); + return true; +} + +void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr &message) { + if (message == nullptr) { + return; + } + + PrepareForNextIterRequest prepare_next_iter_req; + prepare_next_iter_req.ParseFromArray(message->data(), message->len()); + const auto &reason = prepare_next_iter_req.reason(); + MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason; + PrepareForNextIter(); + + PrepareForNextIterResponse prepare_next_iter_rsp; + prepare_next_iter_rsp.set_result("success"); + communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(), + prepare_next_iter_rsp.SerializeAsString().size(), message); +} + +void Iteration::PrepareForNextIter() { + MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode."; + Server::GetInstance().SwitchToSafeMode(); +} + +bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) { + MS_LOG(INFO) << "Notify all follower servers to proceed to next iteration. Set last iteration number " + << iteration_num_; + MoveToNextIterRequest proceed_to_next_iter_req; + proceed_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid); + proceed_to_next_iter_req.set_last_iter_num(iteration_num_); + proceed_to_next_iter_req.set_reason(reason); + for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { + if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, core::TcpUserCommand::kProceedToNextIter)) { + MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed."; + continue; + } + } + + Next(is_last_iter_valid, reason); + return true; +} + +void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr &message) { + if (message == nullptr) { + return; + } + + MoveToNextIterResponse proceed_to_next_iter_rsp; + proceed_to_next_iter_rsp.set_result("success"); + communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), + proceed_to_next_iter_rsp.SerializeAsString().size(), message); + + MoveToNextIterRequest proceed_to_next_iter_req; + proceed_to_next_iter_req.ParseFromArray(message->data(), message->len()); + const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid(); + const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num(); + const auto &reason = proceed_to_next_iter_req.reason(); + + MS_LOG(INFO) << "Receive proceeding to next iteration request. This server current iteration is " << iteration_num_ + << ". The iteration number from leader server is " << last_iter_num + << ". Last iteration is valid or not: " << is_last_iter_valid << ". Reason: " << reason; + // Synchronize the iteration number with leader server. + iteration_num_ = last_iter_num; + Next(is_last_iter_valid, reason); +} + +void Iteration::Next(bool is_iteration_valid, const std::string &reason) { + MS_LOG(INFO) << "Prepare for next iteration."; + is_last_iteration_valid_ = is_iteration_valid; + if (is_iteration_valid) { + // Store the model which is successfully aggregated for this iteration. + const auto &model = Executor::GetInstance().GetModel(); + ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; + } else { + // Store last iteration's model because this iteration is considered as invalid. + const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1); + ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; + } + + for (auto &round : rounds_) { + round->Reset(); + } +} + +bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) { + MS_LOG(INFO) << "Notify all follower servers to end last iteration."; + EndLastIterRequest end_last_iter_req; + end_last_iter_req.set_last_iter_num(last_iter_num); + for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { + if (!communicator_->SendPbRequest(end_last_iter_req, i, core::TcpUserCommand::kEndLastIter)) { + MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed."; + continue; + } + } + + EndLastIter(); + return true; +} + +void Iteration::HandleEndLastIterRequest(const std::shared_ptr &message) { + if (message == nullptr) { + return; + } + + EndLastIterRequest end_last_iter_req; + end_last_iter_req.ParseFromArray(message->data(), message->len()); + const auto &last_iter_num = end_last_iter_req.last_iter_num(); + // If the iteration number is not matched, return error. + if (last_iter_num != iteration_num_) { + std::string reason = "The iteration of this server " + std::to_string(server_node_->rank_id()) + " is " + + std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num); + EndLastIterResponse end_last_iter_rsp; + end_last_iter_rsp.set_result(reason); + communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), + end_last_iter_rsp.SerializeAsString().size(), message); + return; + } + + EndLastIter(); + + EndLastIterResponse end_last_iter_rsp; + end_last_iter_rsp.set_result("success"); + communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), + end_last_iter_rsp.SerializeAsString().size(), message); +} + +void Iteration::EndLastIter() { + MS_LOG(INFO) << "End the last iteration " << iteration_num_; + iteration_num_++; + // After the job is done, reset the iteration to the initial number and reset ModelStore. + if (iteration_num_ > PSContext::instance()->fl_iteration_num()) { + MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed."; + iteration_num_ = 1; + ModelStore::GetInstance().Reset(); + } + + Server::GetInstance().CancelSafeMode(); + SetIterationCompleted(); + LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); + MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n"; +} } // namespace server } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/iteration.h b/mindspore/ccsrc/ps/server/iteration.h index e466383aeb2..99e51385044 100644 --- a/mindspore/ccsrc/ps/server/iteration.h +++ b/mindspore/ccsrc/ps/server/iteration.h @@ -19,6 +19,7 @@ #include #include +#include #include "ps/core/communicator/communicator_base.h" #include "ps/server/common.h" #include "ps/server/round.h" @@ -34,6 +35,9 @@ enum class IterationState { kCompleted }; +// The time duration between retrying when sending prepare for next iteration request failed. +constexpr uint32_t kRetryDurationForPrepareForNextIter = 500; + // In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of // Rounds, only after all the rounds are finished, this iteration is considered as completed. class Iteration { @@ -56,9 +60,10 @@ class Iteration { void InitRounds(const std::vector> &communicators, const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb); - // The server proceeds to the next iteration only after the last round finishes or the timer expires. - // If the timer expires, we consider this iteration as invalid. - void ProceedToNextIter(bool is_iteration_valid); + // This method will control servers to proceed to next iteration. + // There's communication between leader and follower servers in this method. + // The server moves to next iteration only after the last round finishes or the time expires. + void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason); // Set current iteration state to running and trigger events about kIterationRunning. void SetIterationRunning(); @@ -84,7 +89,8 @@ class Iteration { communicator_(nullptr), iteration_state_(IterationState::kCompleted), iteration_num_(1), - is_last_iteration_valid_(true) { + is_last_iteration_valid_(true), + pinned_iter_num_(0) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); } ~Iteration() = default; @@ -99,6 +105,32 @@ class Iteration { bool SyncIteration(uint32_t rank); void HandleSyncIterationRequest(const std::shared_ptr &message); + // The request for moving to next iteration is not reentrant. + bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num); + + // The methods for moving to next iteration for all the servers. + // Step 1: follower servers notify leader server that they need to move to next iteration. + bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason); + void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr &message); + + // Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode. + bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason); + void HandlePrepareForNextIterRequest(const std::shared_ptr &message); + // The server prepare for the next iteration. This method will switch the server to safemode. + void PrepareForNextIter(); + + // Step 3: leader server broadcast to all follower servers to move to next iteration. + bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason); + void HandleMoveToNextIterRequest(const std::shared_ptr &message); + // Move to next iteration. Store last iterations model and reset all the rounds. + void Next(bool is_iteration_valid, const std::string &reason); + + // Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode. + bool BroadcastEndLastIterRequest(uint64_t iteration_num); + void HandleEndLastIterRequest(const std::shared_ptr &message); + // The server end the last iteration. This method will increase the iteration number and cancel the safemode. + void EndLastIter(); + std::shared_ptr server_node_; std::shared_ptr communicator_; @@ -113,6 +145,10 @@ class Iteration { // Last iteration is successfully finished. bool is_last_iteration_valid_; + + // To avoid Next method is called multiple times in one iteration, we should mark the iteration number. + uint64_t pinned_iter_num_; + std::mutex pinned_mtx_; }; } // namespace server } // namespace ps diff --git a/mindspore/ccsrc/ps/server/iteration_timer.cc b/mindspore/ccsrc/ps/server/iteration_timer.cc index a3291b56183..19233339575 100644 --- a/mindspore/ccsrc/ps/server/iteration_timer.cc +++ b/mindspore/ccsrc/ps/server/iteration_timer.cc @@ -29,7 +29,7 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) { monitor_thread_ = std::thread([&]() { while (running_.load()) { if (CURRENT_TIME_MILLI > end_time_) { - timeout_callback_(false); + timeout_callback_(false, ""); running_ = false; } // The time tick is 1 millisecond. diff --git a/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc index 405f73e9233..565cb2515de 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc +++ b/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc @@ -62,6 +62,7 @@ bool GetModelKernel::Reset() { } void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr &fbb) { + auto next_req_time = LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp); std::map feature_maps; size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); size_t get_model_iter = static_cast(get_model_req->iteration()); @@ -70,9 +71,11 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons // If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later. if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) { - std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter); + std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) + + ". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" + + "2. Worker has not push all the weights to servers."; BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps, - std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); + std::to_string(next_req_time)); MS_LOG(WARNING) << reason; return; } @@ -88,11 +91,12 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons // If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according // to next_req_time. - auto response_code = - Iteration::GetInstance().is_last_iteration_valid() ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime; + bool last_iter_valid = Iteration::GetInstance().is_last_iteration_valid(); + MS_LOG(INFO) << "GetModel last iteration is valid or not: " << last_iter_valid << ", next request time is " + << next_req_time << ", current iteration is " << current_iter; + auto response_code = last_iter_valid ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime; BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter, - feature_maps, - std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); + feature_maps, std::to_string(next_req_time)); return; } diff --git a/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.cc index af43b5ac56f..8fe72994900 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.cc +++ b/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.cc @@ -48,9 +48,9 @@ bool PushWeightKernel::Launch(const std::vector &inputs, const std:: return false; } - PushWeight(fbb, push_weight_req); + bool ret = PushWeight(fbb, push_weight_req); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return true; + return ret; } bool PushWeightKernel::Reset() { @@ -67,9 +67,9 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req) { +bool PushWeightKernel::PushWeight(std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req) { if (fbb == nullptr || push_weight_req == nullptr) { - return; + return false; } size_t iteration = static_cast(push_weight_req->iteration()); size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); @@ -77,8 +77,8 @@ void PushWeightKernel::PushWeight(std::shared_ptr fbb, const schema:: std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) + ", current iteration:" + std::to_string(current_iter); BuildPushWeightRsp(fbb, schema::ResponseCode_OutOfTime, reason, current_iter); - MS_LOG(ERROR) << reason; - return; + MS_LOG(WARNING) << reason; + return true; } std::map upload_feature_map = ParseFeatureMap(push_weight_req); @@ -86,20 +86,25 @@ void PushWeightKernel::PushWeight(std::shared_ptr fbb, const schema:: std::string reason = "PushWeight feature_map is empty."; BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter); MS_LOG(ERROR) << reason; - return; + return false; } if (!executor_->HandlePushWeight(upload_feature_map)) { std::string reason = "Pushing weight failed."; BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); MS_LOG(ERROR) << reason; - return; + return false; } MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds."; - DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_)); + if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { + std::string reason = "Count for push weight request failed."; + BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); + MS_LOG(ERROR) << reason; + return false; + } BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter); - return; + return true; } std::map PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) { diff --git a/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.h index 0bfff05c104..49c577457ef 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.h +++ b/mindspore/ccsrc/ps/server/kernel/round/push_weight_kernel.h @@ -42,7 +42,7 @@ class PushWeightKernel : public RoundKernel { void OnLastCountEvent(const std::shared_ptr &message) override; private: - void PushWeight(std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req); + bool PushWeight(std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req); std::map ParseFeatureMap(const schema::RequestPushWeight *push_weight_req); void BuildPushWeightRsp(std::shared_ptr fbb, const schema::ResponseCode retcode, const std::string &reason, size_t iteration); diff --git a/mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc index 11a83d86494..b8560a9db85 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc +++ b/mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc @@ -68,7 +68,7 @@ void RoundKernel::StopTimer() { void RoundKernel::FinishIteration() { if (finish_iteration_cb_) { - finish_iteration_cb_(true); + finish_iteration_cb_(true, ""); } return; } diff --git a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc index af3e1eecc19..ec1754334db 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc @@ -61,7 +61,7 @@ bool StartFLJobKernel::Launch(const std::vector &inputs, const std:: if (ReachThresholdForStartFLJob(fbb)) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return false; + return true; } const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot(req_data); @@ -102,7 +102,7 @@ bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr(kCtxIterationNextRequestTimestamp))); - MS_LOG(ERROR) << reason; + MS_LOG(WARNING) << reason; return true; } return false; diff --git a/mindspore/ccsrc/ps/server/round.cc b/mindspore/ccsrc/ps/server/round.cc index b02bf07bc9a..a3cc8a33ce6 100644 --- a/mindspore/ccsrc/ps/server/round.cc +++ b/mindspore/ccsrc/ps/server/round.cc @@ -18,11 +18,13 @@ #include #include #include "ps/server/server.h" +#include "ps/server/iteration.h" namespace mindspore { namespace ps { namespace server { class Server; +class Iteration; Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count, bool server_num_as_threshold) : name_(name), @@ -42,9 +44,9 @@ void Round::Initialize(const std::shared_ptr &communicat name_, [&](std::shared_ptr message) { LaunchRoundKernel(message); }); // Callback when the iteration is finished. - finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid) -> void { - MS_LOG(INFO) << "Round " << name_ << " finished! This iteration is valid. Proceed to next iteration."; - finish_iteration_cb(is_iteration_valid); + finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void { + std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration."; + finish_iteration_cb(is_iteration_valid, reason); }; // Callback for finalizing the server. This can only be called once. @@ -54,9 +56,9 @@ void Round::Initialize(const std::shared_ptr &communicat iter_timer_ = std::make_shared(); // 1.Set the timeout callback for the timer. - iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid) -> void { - MS_LOG(INFO) << "Round " << name_ << " timeout! This iteration is invalid. Proceed to next iteration."; - timeout_cb(is_iteration_valid); + iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void { + std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration."; + timeout_cb(is_iteration_valid, reason); }); // 2.Stopping timer callback which will be set to the round kernel. @@ -89,7 +91,7 @@ bool Round::ReInitForScaling(uint32_t server_num) { } if (kernel_ == nullptr) { - MS_LOG(ERROR) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr."; + MS_LOG(WARNING) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr."; return false; } kernel_->InitKernel(threshold_count_); @@ -129,13 +131,14 @@ void Round::LaunchRoundKernel(const std::shared_ptr &messa communicator_->SendResponse(reason.c_str(), reason.size(), message); return; } + communicator_->SendResponse(output->addr, output->size, message); + kernel_->Release(output); // Must send response back no matter what value Launch method returns. if (!ret) { - MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed."; + std::string reason = "Launching round kernel of round " + name_ + " failed."; + Iteration::GetInstance().MoveToNextIteration(false, reason); } - communicator_->SendResponse(output->addr, output->size, message); - kernel_->Release(output); return; } diff --git a/mindspore/ccsrc/ps/server/server.cc b/mindspore/ccsrc/ps/server/server.cc index 79a6c785475..10179281cf8 100644 --- a/mindspore/ccsrc/ps/server/server.cc +++ b/mindspore/ccsrc/ps/server/server.cc @@ -73,6 +73,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s // InitCipher---->InitExecutor void Server::Run() { signal(SIGINT, SignalHandler); + std::unique_lock lock(scaling_mtx_); InitServerContext(); InitCluster(); InitIteration(); @@ -82,6 +83,7 @@ void Server::Run() { RegisterRoundKernel(); MS_LOG(INFO) << "Server started successfully."; safemode_ = false; + lock.unlock(); // Wait communicators to stop so the main thread is blocked. std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), @@ -91,6 +93,16 @@ void Server::Run() { return; } +void Server::SwitchToSafeMode() { + MS_LOG(INFO) << "Server switch to safemode."; + safemode_ = true; +} + +void Server::CancelSafeMode() { + MS_LOG(INFO) << "Server cancel safemode."; + safemode_ = false; +} + bool Server::IsSafeMode() { return safemode_.load(); } void Server::InitServerContext() { @@ -166,8 +178,10 @@ void Server::InitIteration() { } // 2.Initialize all the rounds. - TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1); - FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1); + TimeOutCb time_out_cb = + std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2); + FinishIterCb finish_iter_cb = + std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2); iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb); return; } @@ -288,28 +302,29 @@ void Server::ProcessBeforeScalingIn() { } void Server::ProcessAfterScalingOut() { + std::unique_lock lock(scaling_mtx_); if (server_node_ == nullptr) { return; } if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed."; + MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; return; } if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed."; + MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; return; } if (!DistributedCountService::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "DistributedCountService reinitializing failed."; + MS_LOG(WARNING) << "DistributedCountService reinitializing failed."; return; } if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) { - MS_LOG(ERROR) << "Iteration reinitializing failed."; + MS_LOG(WARNING) << "Iteration reinitializing failed."; return; } if (!Executor::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "Executor reinitializing failed."; + MS_LOG(WARNING) << "Executor reinitializing failed."; return; } std::this_thread::sleep_for(std::chrono::milliseconds(1000)); @@ -317,6 +332,7 @@ void Server::ProcessAfterScalingOut() { } void Server::ProcessAfterScalingIn() { + std::unique_lock lock(scaling_mtx_); if (server_node_ == nullptr) { return; } @@ -331,23 +347,23 @@ void Server::ProcessAfterScalingIn() { // If the server is not the one to be scaled in, reintialize modules and recover service. if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed."; + MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; return; } if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed."; + MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; return; } if (!DistributedCountService::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "DistributedCountService reinitializing failed."; + MS_LOG(WARNING) << "DistributedCountService reinitializing failed."; return; } if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) { - MS_LOG(ERROR) << "Iteration reinitializing failed."; + MS_LOG(WARNING) << "Iteration reinitializing failed."; return; } if (!Executor::GetInstance().ReInitForScaling()) { - MS_LOG(ERROR) << "Executor reinitializing failed."; + MS_LOG(WARNING) << "Executor reinitializing failed."; return; } std::this_thread::sleep_for(std::chrono::milliseconds(1000)); diff --git a/mindspore/ccsrc/ps/server/server.h b/mindspore/ccsrc/ps/server/server.h index 005f7f8c623..c6708868888 100644 --- a/mindspore/ccsrc/ps/server/server.h +++ b/mindspore/ccsrc/ps/server/server.h @@ -46,6 +46,8 @@ class Server { // func_graph is the frontend graph which will be parse in server's exector and aggregator. void Run(); + void SwitchToSafeMode(); + void CancelSafeMode(); bool IsSafeMode(); private: @@ -134,6 +136,9 @@ class Server { // communicators. std::vector> communicators_with_worker_; + // Mutex for scaling operations. We must wait server's initialization done before handle scaling events. + std::mutex scaling_mtx_; + // Iteration consists of multiple kinds of rounds. Iteration *iteration_; diff --git a/mindspore/ccsrc/ps/worker/fl_worker.cc b/mindspore/ccsrc/ps/worker/fl_worker.cc index d9a076edb6d..70566162f36 100644 --- a/mindspore/ccsrc/ps/worker/fl_worker.cc +++ b/mindspore/ccsrc/ps/worker/fl_worker.cc @@ -67,12 +67,23 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core: } if (output != nullptr) { - do { + while (true) { if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast(command), output)) { MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed."; return false; } - } while (std::string(reinterpret_cast((*output)->data()), (*output)->size()) == kClusterSafeMode); + if (*output == nullptr) { + MS_LOG(WARNING) << "Response from server " << server_rank << " is empty."; + return false; + } + + if (std::string(reinterpret_cast((*output)->data()), (*output)->size()) == kClusterSafeMode) { + MS_LOG(INFO) << "The server " << server_rank << " is in safemode."; + std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode)); + } else { + break; + } + } } else { if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast(command))) { MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed."; @@ -88,6 +99,16 @@ uint32_t FLWorker::worker_num() const { return worker_num_; } uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; } +void FLWorker::SetIterationRunning() { + MS_LOG(INFO) << "Worker iteration starts."; + worker_iteration_state_ = IterationState::kRunning; +} + +void FLWorker::SetIterationCompleted() { + MS_LOG(INFO) << "Worker iteration completes."; + worker_iteration_state_ = IterationState::kCompleted; +} + void FLWorker::InitializeFollowerScaler() { if (!worker_node_->InitFollowerScaler()) { MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed."; @@ -112,21 +133,22 @@ void FLWorker::InitializeFollowerScaler() { } void FLWorker::HandleIterationRunningEvent() { - MS_LOG(INFO) << "Worker iteration starts, safemode is " << safemode_.load(); - iteration_state_ = IterationState::kRunning; + MS_LOG(INFO) << "Server iteration starts, safemode is " << safemode_.load(); + server_iteration_state_ = IterationState::kRunning; if (safemode_.load() == true) { safemode_ = false; } } void FLWorker::HandleIterationCompletedEvent() { - MS_LOG(INFO) << "Worker iteration completes"; - iteration_state_ = IterationState::kCompleted; + MS_LOG(INFO) << "Server iteration completes"; + server_iteration_state_ = IterationState::kCompleted; } void FLWorker::ProcessBeforeScalingOut() { MS_LOG(INFO) << "Starting Worker scaling out barrier."; - while (iteration_state_.load() != IterationState::kCompleted) { + while (server_iteration_state_.load() != IterationState::kCompleted || + worker_iteration_state_.load() != IterationState::kCompleted) { std::this_thread::yield(); } MS_LOG(INFO) << "Ending Worker scaling out barrier. Switch to safemode."; @@ -135,7 +157,8 @@ void FLWorker::ProcessBeforeScalingOut() { void FLWorker::ProcessBeforeScalingIn() { MS_LOG(INFO) << "Starting Worker scaling in barrier."; - while (iteration_state_.load() != IterationState::kCompleted) { + while (server_iteration_state_.load() != IterationState::kCompleted || + worker_iteration_state_.load() != IterationState::kCompleted) { std::this_thread::yield(); } MS_LOG(INFO) << "Ending Worker scaling in barrier. Switch to safemode."; @@ -148,9 +171,6 @@ void FLWorker::ProcessAfterScalingOut() { } MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker."; - while (iteration_state_.load() != IterationState::kCompleted) { - std::this_thread::yield(); - } server_num_ = worker_node_->server_num(); worker_num_ = worker_node_->worker_num(); MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is " @@ -165,9 +185,6 @@ void FLWorker::ProcessAfterScalingIn() { } MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker."; - while (iteration_state_.load() != IterationState::kCompleted) { - std::this_thread::yield(); - } server_num_ = worker_node_->server_num(); worker_num_ = worker_node_->worker_num(); MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_ diff --git a/mindspore/ccsrc/ps/worker/fl_worker.h b/mindspore/ccsrc/ps/worker/fl_worker.h index 17bcf02c352..44890f87d1a 100644 --- a/mindspore/ccsrc/ps/worker/fl_worker.h +++ b/mindspore/ccsrc/ps/worker/fl_worker.h @@ -40,6 +40,9 @@ constexpr uint32_t kTrainEndStepNum = 0; // The worker has to sleep for a while before the networking is completed. constexpr uint32_t kWorkerSleepTimeForNetworking = 1000; +// The time duration between retrying when server is in safemode. +constexpr uint32_t kWorkerRetryDurationForSafeMode = 500; + enum class IterationState { // This iteration is still in process. kRunning, @@ -64,6 +67,10 @@ class FLWorker { uint32_t worker_num() const; uint64_t worker_step_num_per_iteration() const; + // These methods set the worker's iteration state. + void SetIterationRunning(); + void SetIterationCompleted(); + private: FLWorker() : server_num_(0), @@ -72,7 +79,8 @@ class FLWorker { scheduler_port_(0), worker_node_(nullptr), worker_step_num_per_iteration_(1), - iteration_state_(IterationState::kCompleted), + server_iteration_state_(IterationState::kCompleted), + worker_iteration_state_(IterationState::kCompleted), safemode_(false) {} ~FLWorker() = default; FLWorker(const FLWorker &) = delete; @@ -104,9 +112,14 @@ class FLWorker { uint64_t worker_step_num_per_iteration_; // The iteration state is either running or completed. - std::atomic iteration_state_; + // This variable represents the server iteration state and should be changed by events + // kIterationRunning/kIterationCompleted. triggered by server. + std::atomic server_iteration_state_; - // The flag that represents whether worker is in safemode. + // The variable represents the worker iteration state and should be changed by worker training process. + std::atomic worker_iteration_state_; + + // The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state. std::atomic_bool safemode_; }; } // namespace worker diff --git a/mindspore/context.py b/mindspore/context.py index 369e09d074f..ebada9f8396 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -828,18 +828,19 @@ def set_fl_context(**kwargs): Default: 'FEDERATED_LEARNING'. ms_role (string): The process's role in the federated learning mode, which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'. - Default: 'MS_NOT_PS'. - worker_num (int): The number of workers. Default: 0. + Default: 'MS_SERVER'. + worker_num (int): The number of workers. For current version, this must be set to 1 or 0. server_num (int): The number of federated learning servers. Default: 0. - scheduler_ip (string): The scheduler IP. Default: ''. - scheduler_port (int): The scheduler port. Default: 0. + scheduler_ip (string): The scheduler IP. Default: '0.0.0.0'. + scheduler_port (int): The scheduler port. Default: 6667. fl_server_port (int): The http port of the federated learning server. - Normally for each server this should be set to the same value. Default: 0. + Normally for each server this should be set to the same value. Default: 6668. enable_fl_client (bool): Whether this process is federated learning client. Default: False. start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0. start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000. update_model_ratio (float): The ratio for computing the threshold count of updateModel - which will be multiplied by start_fl_job_threshold. Default: 1.0. + which will be multiplied by start_fl_job_threshold. + Must be between 0 and 1.0.Default: 1.0. update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000. fl_name (string): The federated learning job name. Default: ''. fl_iteration_num (int): Iteration number of federeated learning, diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index 58bb611c688..a590f057091 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -15,10 +15,20 @@ """Context for parameter server training mode""" import os +from mindspore._checkparam import Validator from mindspore._c_expression import PSContext _ps_context = None +_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", + "start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window", + "fl_iteration_num", "client_epoch_num", "client_batch_size", "scheduler_manage_port"] + +_check_non_negative_int_keys = ["worker_num"] + +_check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] + +_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"] def ps_context(): """ @@ -181,3 +191,20 @@ def _set_cache_enable(cache_enable): def _set_rank_id(rank_id): ps_context().set_rank_id(rank_id) + +def _check_value(key, value): + """ + Validate the value for parameter server context keys. + """ + if key in _check_positive_int_keys: + Validator.check_positive_int(value, key) + + if key in _check_non_negative_int_keys: + Validator.check_non_negative_int(value, key) + + if key in _check_positive_float_keys: + Validator.check_positive_float(value, key) + + if key in _check_port_keys: + if value < 1 or value > 65535: + raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value)) diff --git a/tests/st/fl/mobile/simulator.py b/tests/st/fl/mobile/simulator.py index 6ab9c32e248..f36fa1ebe1a 100644 --- a/tests/st/fl/mobile/simulator.py +++ b/tests/st/fl/mobile/simulator.py @@ -163,6 +163,9 @@ while True: rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED: x = session.post(url1, data=build_start_fl_job(current_iteration)) + while x.text == "The cluster is in safemode.": + time.sleep(0.2) + x = session.post(url1, data=build_start_fl_job(current_iteration)) rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) print("epoch is", rsp_fl_job.FlPlanConfig().Epochs()) print("iteration is", rsp_fl_job.Iteration()) @@ -173,6 +176,10 @@ while True: print("req update model iteration:", current_iteration, ", id:", args.pid) update_model_buf, update_model_np_data = build_update_model(current_iteration) x = session.post(url2, data=update_model_buf) + while x.text == "The cluster is in safemode.": + time.sleep(0.2) + x = session.post(url1, data=update_model_buf) + print("rsp update model iteration:", current_iteration, ", id:", args.pid) sys.stdout.flush() @@ -227,4 +234,5 @@ while True: # Sleep to the next request timestamp current_ts = datetime_to_timestamp(datetime.datetime.now()) duration = next_req_timestamp - current_ts - time.sleep(duration / 1000) + if duration > 0: + time.sleep(duration / 1000)