forked from mindspore-Ecosystem/mindspore
!18594 Sync from enterprise
Merge pull request !18594 from ZPaC/sync-from-enter
This commit is contained in:
commit
edc1c8bf58
|
@ -107,6 +107,7 @@ class FusedPullWeightKernel : public CPUKernel {
|
|||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||
ps::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -68,8 +68,8 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
|
||||
if (!ps::worker::FLWorker::GetInstance().SendToServer(
|
||||
i, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPushWeight, &push_weight_rsp_msg)) {
|
||||
MS_LOG(EXCEPTION) << "Sending request for FusedPushWeight to server " << i << " failed.";
|
||||
return false;
|
||||
MS_LOG(ERROR) << "Sending request for FusedPushWeight to server " << i << " failed.";
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
||||
|
||||
|
@ -83,6 +83,7 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,11 @@ enum class TcpUserCommand {
|
|||
kCounterEvent,
|
||||
kPullWeight,
|
||||
kPushWeight,
|
||||
kSyncIteration
|
||||
kSyncIteration,
|
||||
kNotifyLeaderToNextIter,
|
||||
kPrepareForNextIter,
|
||||
kProceedToNextIter,
|
||||
kEndLastIter
|
||||
};
|
||||
|
||||
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
||||
|
@ -61,7 +65,11 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
|||
{TcpUserCommand::kCounterEvent, "counterEvent"},
|
||||
{TcpUserCommand::kPullWeight, "pullWeight"},
|
||||
{TcpUserCommand::kPushWeight, "pushWeight"},
|
||||
{TcpUserCommand::kSyncIteration, "syncIteration"}};
|
||||
{TcpUserCommand::kSyncIteration, "syncIteration"},
|
||||
{TcpUserCommand::kNotifyLeaderToNextIter, "notifyLeaderToNextIter"},
|
||||
{TcpUserCommand::kPrepareForNextIter, "prepareForNextIter"},
|
||||
{TcpUserCommand::kProceedToNextIter, "proceedToNextIter"},
|
||||
{TcpUserCommand::kEndLastIter, "endLastIter"}};
|
||||
|
||||
class TcpCommunicator : public CommunicatorBase {
|
||||
public:
|
||||
|
|
|
@ -163,3 +163,41 @@ message SyncIterationResponse {
|
|||
// The current iteration number.
|
||||
uint64 iteration = 1;
|
||||
}
|
||||
|
||||
message PrepareForNextIterRequest {
|
||||
bool is_last_iter_valid = 1;
|
||||
string reason = 2;
|
||||
}
|
||||
|
||||
message PrepareForNextIterResponse {
|
||||
string result = 1;
|
||||
}
|
||||
|
||||
message NotifyLeaderMoveToNextIterRequest {
|
||||
uint32 rank = 1;
|
||||
bool is_last_iter_valid = 2;
|
||||
uint64 iter_num = 3;
|
||||
string reason = 4;
|
||||
}
|
||||
|
||||
message NotifyLeaderMoveToNextIterResponse {
|
||||
string result = 1;
|
||||
}
|
||||
|
||||
message MoveToNextIterRequest {
|
||||
bool is_last_iter_valid = 1;
|
||||
uint64 last_iter_num = 2;
|
||||
string reason = 3;
|
||||
}
|
||||
|
||||
message MoveToNextIterResponse {
|
||||
string result = 1;
|
||||
}
|
||||
|
||||
message EndLastIterRequest {
|
||||
uint64 last_iter_num = 1;
|
||||
}
|
||||
|
||||
message EndLastIterResponse {
|
||||
string result = 1;
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
} else if (ms_role == kEnvRoleOfScheduler) {
|
||||
is_sched_ = true;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
|
||||
MS_LOG(INFO) << "MS_ROLE is " << ms_role;
|
||||
}
|
||||
|
||||
worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
|
||||
|
@ -273,7 +273,13 @@ void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window)
|
|||
|
||||
uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
|
||||
|
||||
void PSContext::set_update_model_ratio(float update_model_ratio) { update_model_ratio_ = update_model_ratio; }
|
||||
void PSContext::set_update_model_ratio(float update_model_ratio) {
|
||||
if (update_model_ratio > 1.0) {
|
||||
MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1.";
|
||||
return;
|
||||
}
|
||||
update_model_ratio_ = update_model_ratio;
|
||||
}
|
||||
|
||||
float PSContext::update_model_ratio() const { return update_model_ratio_; }
|
||||
|
||||
|
|
|
@ -161,12 +161,12 @@ class PSContext {
|
|||
rank_id_(0),
|
||||
worker_num_(0),
|
||||
server_num_(0),
|
||||
scheduler_host_(""),
|
||||
scheduler_port_(0),
|
||||
scheduler_host_("0.0.0.0"),
|
||||
scheduler_port_(6667),
|
||||
role_(kEnvRoleOfNotPS),
|
||||
server_mode_(""),
|
||||
resetter_round_(ResetterRound::kNoNeedToReset),
|
||||
fl_server_port_(0),
|
||||
fl_server_port_(6668),
|
||||
fl_client_enable_(false),
|
||||
fl_name_(""),
|
||||
start_fl_job_threshold_(0),
|
||||
|
@ -179,7 +179,7 @@ class PSContext {
|
|||
client_learning_rate_(0.001),
|
||||
secure_aggregation_(false),
|
||||
cluster_config_(nullptr),
|
||||
scheduler_manage_port_(0),
|
||||
scheduler_manage_port_(11202),
|
||||
config_file_path_("") {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
|
|
|
@ -63,9 +63,9 @@ using mindspore::kernel::Address;
|
|||
using mindspore::kernel::AddressPtr;
|
||||
using mindspore::kernel::CPUKernel;
|
||||
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||
using TimeOutCb = std::function<void(bool)>;
|
||||
using TimeOutCb = std::function<void(bool, const std::string &)>;
|
||||
using StopTimerCb = std::function<void(void)>;
|
||||
using FinishIterCb = std::function<void(bool)>;
|
||||
using FinishIterCb = std::function<void(bool, const std::string &)>;
|
||||
using FinalizeCb = std::function<void(void)>;
|
||||
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
||||
|
||||
|
|
|
@ -83,7 +83,10 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
|||
|
||||
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||
global_current_count_[name].insert(id);
|
||||
TriggerCounterEvent(name);
|
||||
if (!TriggerCounterEvent(name)) {
|
||||
MS_LOG(ERROR) << "Leader server trigger count event failed.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// If this server is a follower server, it needs to send CountRequest to the leader server.
|
||||
CountRequest report_count_req;
|
||||
|
@ -198,9 +201,14 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::Mes
|
|||
// Insert the id for the counter, which means the count for the name is increased.
|
||||
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||
global_current_count_[name].insert(id);
|
||||
TriggerCounterEvent(name);
|
||||
count_rsp.set_result(true);
|
||||
count_rsp.set_reason("success");
|
||||
if (!TriggerCounterEvent(name)) {
|
||||
std::string reason = "Trigger count event for " + name + " of " + id + " failed.";
|
||||
count_rsp.set_result(false);
|
||||
count_rsp.set_reason(reason);
|
||||
} else {
|
||||
count_rsp.set_result(true);
|
||||
count_rsp.set_reason("success");
|
||||
}
|
||||
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
@ -256,20 +264,24 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::Mes
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::TriggerCounterEvent(const std::string &name) {
|
||||
bool DistributedCountService::TriggerCounterEvent(const std::string &name) {
|
||||
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
|
||||
<< ", threshold count is " << global_threshold_count_[name];
|
||||
// The threshold count may be 1 so the first and last count event should be both activated.
|
||||
if (global_current_count_[name].size() == 1) {
|
||||
TriggerFirstCountEvent(name);
|
||||
if (!TriggerFirstCountEvent(name)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (global_current_count_[name].size() == global_threshold_count_[name]) {
|
||||
TriggerLastCountEvent(name);
|
||||
if (!TriggerLastCountEvent(name)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||
bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||
MS_LOG(DEBUG) << "Activating first count event for " << name;
|
||||
CounterEvent first_count_event;
|
||||
first_count_event.set_type(CounterEventType::FIRST_CNT);
|
||||
|
@ -279,15 +291,15 @@ void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
|||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Leader server directly calls the callback.
|
||||
counter_handlers_[name].first_count_handler(nullptr);
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
||||
bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
||||
MS_LOG(INFO) << "Activating last count event for " << name;
|
||||
CounterEvent last_count_event;
|
||||
last_count_event.set_type(CounterEventType::LAST_CNT);
|
||||
|
@ -297,12 +309,12 @@ void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
|||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Leader server directly calls the callback.
|
||||
counter_handlers_[name].last_count_handler(nullptr);
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
|
|
|
@ -98,9 +98,9 @@ class DistributedCountService {
|
|||
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Call the callbacks when the first/last count event is triggered.
|
||||
void TriggerCounterEvent(const std::string &name);
|
||||
void TriggerFirstCountEvent(const std::string &name);
|
||||
void TriggerLastCountEvent(const std::string &name);
|
||||
bool TriggerCounterEvent(const std::string &name);
|
||||
bool TriggerFirstCountEvent(const std::string &name);
|
||||
bool TriggerLastCountEvent(const std::string &name);
|
||||
|
||||
// Members for the communication between counting server and other servers.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
|
|
|
@ -20,15 +20,26 @@
|
|||
#include <vector>
|
||||
#include <numeric>
|
||||
#include "ps/server/model_store.h"
|
||||
#include "ps/server/server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
class Server;
|
||||
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
communicator_->RegisterMsgCallBack("syncIteraion",
|
||||
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"notifyLeaderToNextIter",
|
||||
std::bind(&Iteration::HandleNotifyLeaderMoveToNextIterRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"prepareForNextIter", std::bind(&Iteration::HandlePrepareForNextIterRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack("proceedToNextIter",
|
||||
std::bind(&Iteration::HandleMoveToNextIterRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack("endLastIter",
|
||||
std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
|
||||
}
|
||||
|
||||
void Iteration::RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||
|
@ -72,36 +83,33 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
|
|||
return;
|
||||
}
|
||||
|
||||
void Iteration::ProceedToNextIter(bool is_iteration_valid) {
|
||||
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
is_last_iteration_valid_ = is_iteration_valid;
|
||||
if (is_iteration_valid) {
|
||||
// Store the model which is successfully aggregated for this iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
|
||||
MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_
|
||||
<< " validation is " << is_last_iter_valid << ". Reason: " << reason;
|
||||
if (IsMoveToNextIterRequestReentrant(iteration_num_)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||
if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
|
||||
MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
|
||||
return;
|
||||
}
|
||||
if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
|
||||
MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
|
||||
return;
|
||||
}
|
||||
if (!BroadcastEndLastIterRequest(iteration_num_)) {
|
||||
MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// Store last iteration's model because this iteration is considered as invalid.
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid.";
|
||||
// If this server is the follower server, notify leader server to control the cluster to proceed to next iteration.
|
||||
if (!NotifyLeaderMoveToNextIteration(is_last_iter_valid, reason)) {
|
||||
MS_LOG(ERROR) << "Server " << server_node_->rank_id() << " notifying the leader server failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &round : rounds_) {
|
||||
round->Reset();
|
||||
}
|
||||
|
||||
iteration_num_++;
|
||||
// After the job is done, reset the iteration to the initial number and reset ModelStore.
|
||||
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
|
||||
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||
iteration_num_ = 1;
|
||||
ModelStore::GetInstance().Reset();
|
||||
}
|
||||
|
||||
SetIterationCompleted();
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
|
||||
void Iteration::SetIterationRunning() {
|
||||
|
@ -147,7 +155,7 @@ bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
|
|||
}
|
||||
for (auto &round : rounds_) {
|
||||
if (!round->ReInitForScaling(server_num)) {
|
||||
MS_LOG(ERROR) << "Reinitializing round " << round->name() << " for scaling failed.";
|
||||
MS_LOG(WARNING) << "Reinitializing round " << round->name() << " for scaling failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -168,6 +176,10 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
|||
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
|
||||
return false;
|
||||
}
|
||||
if (sync_iter_rsp_msg == nullptr) {
|
||||
MS_LOG(ERROR) << "Response from server 0 is empty.";
|
||||
return false;
|
||||
}
|
||||
|
||||
SyncIterationResponse sync_iter_rsp;
|
||||
sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size());
|
||||
|
@ -192,6 +204,239 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHa
|
|||
std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString();
|
||||
communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message);
|
||||
}
|
||||
|
||||
bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
|
||||
std::unique_lock<std::mutex> lock(pinned_mtx_);
|
||||
if (pinned_iter_num_ == iteration_num) {
|
||||
MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call.";
|
||||
return true;
|
||||
}
|
||||
pinned_iter_num_ = iteration_num;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
|
||||
MS_LOG(INFO) << "Notify leader server to control the cluster to proceed to next iteration.";
|
||||
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
|
||||
notify_leader_to_next_iter_req.set_rank(server_node_->rank_id());
|
||||
notify_leader_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
|
||||
notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
|
||||
notify_leader_to_next_iter_req.set_reason(reason);
|
||||
if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
|
||||
core::TcpUserCommand::kNotifyLeaderToNextIter)) {
|
||||
MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp;
|
||||
notify_leader_to_next_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
|
||||
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
|
||||
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
|
||||
notify_leader_to_next_iter_req.ParseFromArray(message->data(), message->len());
|
||||
const auto &rank = notify_leader_to_next_iter_req.rank();
|
||||
const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
|
||||
const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
|
||||
const auto &reason = notify_leader_to_next_iter_req.reason();
|
||||
MS_LOG(INFO) << "Leader server receives NotifyLeaderMoveToNextIterRequest from rank " << rank
|
||||
<< ". Iteration number: " << iter_num << ". Reason: " << reason;
|
||||
|
||||
if (IsMoveToNextIterRequestReentrant(iter_num)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
|
||||
MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
|
||||
return;
|
||||
}
|
||||
if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
|
||||
MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
|
||||
return;
|
||||
}
|
||||
if (!BroadcastEndLastIterRequest(iteration_num_)) {
|
||||
MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
|
||||
PrepareForNextIter();
|
||||
|
||||
MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration.";
|
||||
PrepareForNextIterRequest prepare_next_iter_req;
|
||||
prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
|
||||
prepare_next_iter_req.set_reason(reason);
|
||||
|
||||
std::vector<uint32_t> offline_servers = {};
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, core::TcpUserCommand::kPrepareForNextIter)) {
|
||||
MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
|
||||
offline_servers.push_back(i);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Retry sending to offline servers to notify them to prepare.
|
||||
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
|
||||
while (!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) {
|
||||
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
|
||||
<< " failed. The server has not recovered yet.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
|
||||
}
|
||||
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
PrepareForNextIterRequest prepare_next_iter_req;
|
||||
prepare_next_iter_req.ParseFromArray(message->data(), message->len());
|
||||
const auto &reason = prepare_next_iter_req.reason();
|
||||
MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason;
|
||||
PrepareForNextIter();
|
||||
|
||||
PrepareForNextIterResponse prepare_next_iter_rsp;
|
||||
prepare_next_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
|
||||
prepare_next_iter_rsp.SerializeAsString().size(), message);
|
||||
}
|
||||
|
||||
void Iteration::PrepareForNextIter() {
|
||||
MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
|
||||
Server::GetInstance().SwitchToSafeMode();
|
||||
}
|
||||
|
||||
bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
|
||||
MS_LOG(INFO) << "Notify all follower servers to proceed to next iteration. Set last iteration number "
|
||||
<< iteration_num_;
|
||||
MoveToNextIterRequest proceed_to_next_iter_req;
|
||||
proceed_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
|
||||
proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
|
||||
proceed_to_next_iter_req.set_reason(reason);
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, core::TcpUserCommand::kProceedToNextIter)) {
|
||||
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Next(is_last_iter_valid, reason);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
MoveToNextIterResponse proceed_to_next_iter_rsp;
|
||||
proceed_to_next_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
|
||||
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
|
||||
MoveToNextIterRequest proceed_to_next_iter_req;
|
||||
proceed_to_next_iter_req.ParseFromArray(message->data(), message->len());
|
||||
const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
|
||||
const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
|
||||
const auto &reason = proceed_to_next_iter_req.reason();
|
||||
|
||||
MS_LOG(INFO) << "Receive proceeding to next iteration request. This server current iteration is " << iteration_num_
|
||||
<< ". The iteration number from leader server is " << last_iter_num
|
||||
<< ". Last iteration is valid or not: " << is_last_iter_valid << ". Reason: " << reason;
|
||||
// Synchronize the iteration number with leader server.
|
||||
iteration_num_ = last_iter_num;
|
||||
Next(is_last_iter_valid, reason);
|
||||
}
|
||||
|
||||
void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
||||
MS_LOG(INFO) << "Prepare for next iteration.";
|
||||
is_last_iteration_valid_ = is_iteration_valid;
|
||||
if (is_iteration_valid) {
|
||||
// Store the model which is successfully aggregated for this iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
} else {
|
||||
// Store last iteration's model because this iteration is considered as invalid.
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
||||
}
|
||||
|
||||
for (auto &round : rounds_) {
|
||||
round->Reset();
|
||||
}
|
||||
}
|
||||
|
||||
bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
||||
MS_LOG(INFO) << "Notify all follower servers to end last iteration.";
|
||||
EndLastIterRequest end_last_iter_req;
|
||||
end_last_iter_req.set_last_iter_num(last_iter_num);
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(end_last_iter_req, i, core::TcpUserCommand::kEndLastIter)) {
|
||||
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
EndLastIter();
|
||||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
EndLastIterRequest end_last_iter_req;
|
||||
end_last_iter_req.ParseFromArray(message->data(), message->len());
|
||||
const auto &last_iter_num = end_last_iter_req.last_iter_num();
|
||||
// If the iteration number is not matched, return error.
|
||||
if (last_iter_num != iteration_num_) {
|
||||
std::string reason = "The iteration of this server " + std::to_string(server_node_->rank_id()) + " is " +
|
||||
std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num);
|
||||
EndLastIterResponse end_last_iter_rsp;
|
||||
end_last_iter_rsp.set_result(reason);
|
||||
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
|
||||
end_last_iter_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
EndLastIter();
|
||||
|
||||
EndLastIterResponse end_last_iter_rsp;
|
||||
end_last_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
|
||||
end_last_iter_rsp.SerializeAsString().size(), message);
|
||||
}
|
||||
|
||||
void Iteration::EndLastIter() {
|
||||
MS_LOG(INFO) << "End the last iteration " << iteration_num_;
|
||||
iteration_num_++;
|
||||
// After the job is done, reset the iteration to the initial number and reset ModelStore.
|
||||
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
|
||||
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||
iteration_num_ = 1;
|
||||
ModelStore::GetInstance().Reset();
|
||||
}
|
||||
|
||||
Server::GetInstance().CancelSafeMode();
|
||||
SetIterationCompleted();
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/round.h"
|
||||
|
@ -34,6 +35,9 @@ enum class IterationState {
|
|||
kCompleted
|
||||
};
|
||||
|
||||
// The time duration between retrying when sending prepare for next iteration request failed.
|
||||
constexpr uint32_t kRetryDurationForPrepareForNextIter = 500;
|
||||
|
||||
// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
|
||||
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
|
||||
class Iteration {
|
||||
|
@ -56,9 +60,10 @@ class Iteration {
|
|||
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
// The server proceeds to the next iteration only after the last round finishes or the timer expires.
|
||||
// If the timer expires, we consider this iteration as invalid.
|
||||
void ProceedToNextIter(bool is_iteration_valid);
|
||||
// This method will control servers to proceed to next iteration.
|
||||
// There's communication between leader and follower servers in this method.
|
||||
// The server moves to next iteration only after the last round finishes or the time expires.
|
||||
void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
||||
|
||||
// Set current iteration state to running and trigger events about kIterationRunning.
|
||||
void SetIterationRunning();
|
||||
|
@ -84,7 +89,8 @@ class Iteration {
|
|||
communicator_(nullptr),
|
||||
iteration_state_(IterationState::kCompleted),
|
||||
iteration_num_(1),
|
||||
is_last_iteration_valid_(true) {
|
||||
is_last_iteration_valid_(true),
|
||||
pinned_iter_num_(0) {
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
}
|
||||
~Iteration() = default;
|
||||
|
@ -99,6 +105,32 @@ class Iteration {
|
|||
bool SyncIteration(uint32_t rank);
|
||||
void HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// The request for moving to next iteration is not reentrant.
|
||||
bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
|
||||
|
||||
// The methods for moving to next iteration for all the servers.
|
||||
// Step 1: follower servers notify leader server that they need to move to next iteration.
|
||||
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
||||
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
|
||||
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||
void HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
// The server prepare for the next iteration. This method will switch the server to safemode.
|
||||
void PrepareForNextIter();
|
||||
|
||||
// Step 3: leader server broadcast to all follower servers to move to next iteration.
|
||||
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||
void HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
// Move to next iteration. Store last iterations model and reset all the rounds.
|
||||
void Next(bool is_iteration_valid, const std::string &reason);
|
||||
|
||||
// Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
|
||||
bool BroadcastEndLastIterRequest(uint64_t iteration_num);
|
||||
void HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
// The server end the last iteration. This method will increase the iteration number and cancel the safemode.
|
||||
void EndLastIter();
|
||||
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||
|
||||
|
@ -113,6 +145,10 @@ class Iteration {
|
|||
|
||||
// Last iteration is successfully finished.
|
||||
bool is_last_iteration_valid_;
|
||||
|
||||
// To avoid Next method is called multiple times in one iteration, we should mark the iteration number.
|
||||
uint64_t pinned_iter_num_;
|
||||
std::mutex pinned_mtx_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
|
|
|
@ -29,7 +29,7 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
|||
monitor_thread_ = std::thread([&]() {
|
||||
while (running_.load()) {
|
||||
if (CURRENT_TIME_MILLI > end_time_) {
|
||||
timeout_callback_(false);
|
||||
timeout_callback_(false, "");
|
||||
running_ = false;
|
||||
}
|
||||
// The time tick is 1 millisecond.
|
||||
|
|
|
@ -62,6 +62,7 @@ bool GetModelKernel::Reset() {
|
|||
}
|
||||
|
||||
void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) {
|
||||
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
||||
std::map<std::string, AddressPtr> feature_maps;
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
|
||||
|
@ -70,9 +71,11 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
|||
|
||||
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
|
||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
|
||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
|
||||
". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" +
|
||||
"2. Worker has not push all the weights to servers.";
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
std::to_string(next_req_time));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return;
|
||||
}
|
||||
|
@ -88,11 +91,12 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
|||
|
||||
// If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according
|
||||
// to next_req_time.
|
||||
auto response_code =
|
||||
Iteration::GetInstance().is_last_iteration_valid() ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
|
||||
bool last_iter_valid = Iteration::GetInstance().is_last_iteration_valid();
|
||||
MS_LOG(INFO) << "GetModel last iteration is valid or not: " << last_iter_valid << ", next request time is "
|
||||
<< next_req_time << ", current iteration is " << current_iter;
|
||||
auto response_code = last_iter_valid ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
|
||||
BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter,
|
||||
feature_maps,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
feature_maps, std::to_string(next_req_time));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -48,9 +48,9 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
return false;
|
||||
}
|
||||
|
||||
PushWeight(fbb, push_weight_req);
|
||||
bool ret = PushWeight(fbb, push_weight_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool PushWeightKernel::Reset() {
|
||||
|
@ -67,9 +67,9 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandl
|
|||
return;
|
||||
}
|
||||
|
||||
void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
|
||||
bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
|
||||
if (fbb == nullptr || push_weight_req == nullptr) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
|
@ -77,8 +77,8 @@ void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
|
|||
std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
|
||||
", current iteration:" + std::to_string(current_iter);
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_OutOfTime, reason, current_iter);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
MS_LOG(WARNING) << reason;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
|
||||
|
@ -86,20 +86,25 @@ void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
|
|||
std::string reason = "PushWeight feature_map is empty.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!executor_->HandlePushWeight(upload_feature_map)) {
|
||||
std::string reason = "Pushing weight failed.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";
|
||||
|
||||
DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_));
|
||||
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) {
|
||||
std::string reason = "Count for push weight request failed.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {
|
||||
|
|
|
@ -42,7 +42,7 @@ class PushWeightKernel : public RoundKernel {
|
|||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
void PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
|
||||
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
|
||||
size_t iteration);
|
||||
|
|
|
@ -68,7 +68,7 @@ void RoundKernel::StopTimer() {
|
|||
|
||||
void RoundKernel::FinishIteration() {
|
||||
if (finish_iteration_cb_) {
|
||||
finish_iteration_cb_(true);
|
||||
finish_iteration_cb_(true, "");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
|
||||
if (ReachThresholdForStartFLJob(fbb)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
|
||||
|
@ -102,7 +102,7 @@ bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuild
|
|||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
MS_LOG(WARNING) << reason;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -18,11 +18,13 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include "ps/server/server.h"
|
||||
#include "ps/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
class Server;
|
||||
class Iteration;
|
||||
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
|
||||
bool server_num_as_threshold)
|
||||
: name_(name),
|
||||
|
@ -42,9 +44,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
|
|||
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||
|
||||
// Callback when the iteration is finished.
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " finished! This iteration is valid. Proceed to next iteration.";
|
||||
finish_iteration_cb(is_iteration_valid);
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
|
||||
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
|
||||
finish_iteration_cb(is_iteration_valid, reason);
|
||||
};
|
||||
|
||||
// Callback for finalizing the server. This can only be called once.
|
||||
|
@ -54,9 +56,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
|
|||
iter_timer_ = std::make_shared<IterationTimer>();
|
||||
|
||||
// 1.Set the timeout callback for the timer.
|
||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " timeout! This iteration is invalid. Proceed to next iteration.";
|
||||
timeout_cb(is_iteration_valid);
|
||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
|
||||
std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
|
||||
timeout_cb(is_iteration_valid, reason);
|
||||
});
|
||||
|
||||
// 2.Stopping timer callback which will be set to the round kernel.
|
||||
|
@ -89,7 +91,7 @@ bool Round::ReInitForScaling(uint32_t server_num) {
|
|||
}
|
||||
|
||||
if (kernel_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
|
||||
MS_LOG(WARNING) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
|
||||
return false;
|
||||
}
|
||||
kernel_->InitKernel(threshold_count_);
|
||||
|
@ -129,13 +131,14 @@ void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &messa
|
|||
communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||
return;
|
||||
}
|
||||
communicator_->SendResponse(output->addr, output->size, message);
|
||||
kernel_->Release(output);
|
||||
|
||||
// Must send response back no matter what value Launch method returns.
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed.";
|
||||
std::string reason = "Launching round kernel of round " + name_ + " failed.";
|
||||
Iteration::GetInstance().MoveToNextIteration(false, reason);
|
||||
}
|
||||
communicator_->SendResponse(output->addr, output->size, message);
|
||||
kernel_->Release(output);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
|
|||
// InitCipher---->InitExecutor
|
||||
void Server::Run() {
|
||||
signal(SIGINT, SignalHandler);
|
||||
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
||||
InitServerContext();
|
||||
InitCluster();
|
||||
InitIteration();
|
||||
|
@ -82,6 +83,7 @@ void Server::Run() {
|
|||
RegisterRoundKernel();
|
||||
MS_LOG(INFO) << "Server started successfully.";
|
||||
safemode_ = false;
|
||||
lock.unlock();
|
||||
|
||||
// Wait communicators to stop so the main thread is blocked.
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
|
@ -91,6 +93,16 @@ void Server::Run() {
|
|||
return;
|
||||
}
|
||||
|
||||
void Server::SwitchToSafeMode() {
|
||||
MS_LOG(INFO) << "Server switch to safemode.";
|
||||
safemode_ = true;
|
||||
}
|
||||
|
||||
void Server::CancelSafeMode() {
|
||||
MS_LOG(INFO) << "Server cancel safemode.";
|
||||
safemode_ = false;
|
||||
}
|
||||
|
||||
bool Server::IsSafeMode() { return safemode_.load(); }
|
||||
|
||||
void Server::InitServerContext() {
|
||||
|
@ -166,8 +178,10 @@ void Server::InitIteration() {
|
|||
}
|
||||
|
||||
// 2.Initialize all the rounds.
|
||||
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
|
||||
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
|
||||
TimeOutCb time_out_cb =
|
||||
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
FinishIterCb finish_iter_cb =
|
||||
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||
return;
|
||||
}
|
||||
|
@ -288,28 +302,29 @@ void Server::ProcessBeforeScalingIn() {
|
|||
}
|
||||
|
||||
void Server::ProcessAfterScalingOut() {
|
||||
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
||||
if (server_node_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
|
||||
MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
|
||||
MS_LOG(ERROR) << "Iteration reinitializing failed.";
|
||||
MS_LOG(WARNING) << "Iteration reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "Executor reinitializing failed.";
|
||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
|
@ -317,6 +332,7 @@ void Server::ProcessAfterScalingOut() {
|
|||
}
|
||||
|
||||
void Server::ProcessAfterScalingIn() {
|
||||
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
||||
if (server_node_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -331,23 +347,23 @@ void Server::ProcessAfterScalingIn() {
|
|||
|
||||
// If the server is not the one to be scaled in, reintialize modules and recover service.
|
||||
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
|
||||
MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
|
||||
MS_LOG(ERROR) << "Iteration reinitializing failed.";
|
||||
MS_LOG(WARNING) << "Iteration reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "Executor reinitializing failed.";
|
||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
|
|
|
@ -46,6 +46,8 @@ class Server {
|
|||
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
|
||||
void Run();
|
||||
|
||||
void SwitchToSafeMode();
|
||||
void CancelSafeMode();
|
||||
bool IsSafeMode();
|
||||
|
||||
private:
|
||||
|
@ -134,6 +136,9 @@ class Server {
|
|||
// communicators.
|
||||
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
||||
|
||||
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
|
||||
std::mutex scaling_mtx_;
|
||||
|
||||
// Iteration consists of multiple kinds of rounds.
|
||||
Iteration *iteration_;
|
||||
|
||||
|
|
|
@ -67,12 +67,23 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core:
|
|||
}
|
||||
|
||||
if (output != nullptr) {
|
||||
do {
|
||||
while (true) {
|
||||
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) {
|
||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||
return false;
|
||||
}
|
||||
} while (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode);
|
||||
if (*output == nullptr) {
|
||||
MS_LOG(WARNING) << "Response from server " << server_rank << " is empty.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode) {
|
||||
MS_LOG(INFO) << "The server " << server_rank << " is in safemode.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
|
||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||
|
@ -88,6 +99,16 @@ uint32_t FLWorker::worker_num() const { return worker_num_; }
|
|||
|
||||
uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
|
||||
|
||||
void FLWorker::SetIterationRunning() {
|
||||
MS_LOG(INFO) << "Worker iteration starts.";
|
||||
worker_iteration_state_ = IterationState::kRunning;
|
||||
}
|
||||
|
||||
void FLWorker::SetIterationCompleted() {
|
||||
MS_LOG(INFO) << "Worker iteration completes.";
|
||||
worker_iteration_state_ = IterationState::kCompleted;
|
||||
}
|
||||
|
||||
void FLWorker::InitializeFollowerScaler() {
|
||||
if (!worker_node_->InitFollowerScaler()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
|
||||
|
@ -112,21 +133,22 @@ void FLWorker::InitializeFollowerScaler() {
|
|||
}
|
||||
|
||||
void FLWorker::HandleIterationRunningEvent() {
|
||||
MS_LOG(INFO) << "Worker iteration starts, safemode is " << safemode_.load();
|
||||
iteration_state_ = IterationState::kRunning;
|
||||
MS_LOG(INFO) << "Server iteration starts, safemode is " << safemode_.load();
|
||||
server_iteration_state_ = IterationState::kRunning;
|
||||
if (safemode_.load() == true) {
|
||||
safemode_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
void FLWorker::HandleIterationCompletedEvent() {
|
||||
MS_LOG(INFO) << "Worker iteration completes";
|
||||
iteration_state_ = IterationState::kCompleted;
|
||||
MS_LOG(INFO) << "Server iteration completes";
|
||||
server_iteration_state_ = IterationState::kCompleted;
|
||||
}
|
||||
|
||||
void FLWorker::ProcessBeforeScalingOut() {
|
||||
MS_LOG(INFO) << "Starting Worker scaling out barrier.";
|
||||
while (iteration_state_.load() != IterationState::kCompleted) {
|
||||
while (server_iteration_state_.load() != IterationState::kCompleted ||
|
||||
worker_iteration_state_.load() != IterationState::kCompleted) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
MS_LOG(INFO) << "Ending Worker scaling out barrier. Switch to safemode.";
|
||||
|
@ -135,7 +157,8 @@ void FLWorker::ProcessBeforeScalingOut() {
|
|||
|
||||
void FLWorker::ProcessBeforeScalingIn() {
|
||||
MS_LOG(INFO) << "Starting Worker scaling in barrier.";
|
||||
while (iteration_state_.load() != IterationState::kCompleted) {
|
||||
while (server_iteration_state_.load() != IterationState::kCompleted ||
|
||||
worker_iteration_state_.load() != IterationState::kCompleted) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
MS_LOG(INFO) << "Ending Worker scaling in barrier. Switch to safemode.";
|
||||
|
@ -148,9 +171,6 @@ void FLWorker::ProcessAfterScalingOut() {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
|
||||
while (iteration_state_.load() != IterationState::kCompleted) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
server_num_ = worker_node_->server_num();
|
||||
worker_num_ = worker_node_->worker_num();
|
||||
MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is "
|
||||
|
@ -165,9 +185,6 @@ void FLWorker::ProcessAfterScalingIn() {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
|
||||
while (iteration_state_.load() != IterationState::kCompleted) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
server_num_ = worker_node_->server_num();
|
||||
worker_num_ = worker_node_->worker_num();
|
||||
MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_
|
||||
|
|
|
@ -40,6 +40,9 @@ constexpr uint32_t kTrainEndStepNum = 0;
|
|||
// The worker has to sleep for a while before the networking is completed.
|
||||
constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
|
||||
|
||||
// The time duration between retrying when server is in safemode.
|
||||
constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
|
||||
|
||||
enum class IterationState {
|
||||
// This iteration is still in process.
|
||||
kRunning,
|
||||
|
@ -64,6 +67,10 @@ class FLWorker {
|
|||
uint32_t worker_num() const;
|
||||
uint64_t worker_step_num_per_iteration() const;
|
||||
|
||||
// These methods set the worker's iteration state.
|
||||
void SetIterationRunning();
|
||||
void SetIterationCompleted();
|
||||
|
||||
private:
|
||||
FLWorker()
|
||||
: server_num_(0),
|
||||
|
@ -72,7 +79,8 @@ class FLWorker {
|
|||
scheduler_port_(0),
|
||||
worker_node_(nullptr),
|
||||
worker_step_num_per_iteration_(1),
|
||||
iteration_state_(IterationState::kCompleted),
|
||||
server_iteration_state_(IterationState::kCompleted),
|
||||
worker_iteration_state_(IterationState::kCompleted),
|
||||
safemode_(false) {}
|
||||
~FLWorker() = default;
|
||||
FLWorker(const FLWorker &) = delete;
|
||||
|
@ -104,9 +112,14 @@ class FLWorker {
|
|||
uint64_t worker_step_num_per_iteration_;
|
||||
|
||||
// The iteration state is either running or completed.
|
||||
std::atomic<IterationState> iteration_state_;
|
||||
// This variable represents the server iteration state and should be changed by events
|
||||
// kIterationRunning/kIterationCompleted. triggered by server.
|
||||
std::atomic<IterationState> server_iteration_state_;
|
||||
|
||||
// The flag that represents whether worker is in safemode.
|
||||
// The variable represents the worker iteration state and should be changed by worker training process.
|
||||
std::atomic<IterationState> worker_iteration_state_;
|
||||
|
||||
// The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state.
|
||||
std::atomic_bool safemode_;
|
||||
};
|
||||
} // namespace worker
|
||||
|
|
|
@ -828,18 +828,19 @@ def set_fl_context(**kwargs):
|
|||
Default: 'FEDERATED_LEARNING'.
|
||||
ms_role (string): The process's role in the federated learning mode,
|
||||
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
|
||||
Default: 'MS_NOT_PS'.
|
||||
worker_num (int): The number of workers. Default: 0.
|
||||
Default: 'MS_SERVER'.
|
||||
worker_num (int): The number of workers. For current version, this must be set to 1 or 0.
|
||||
server_num (int): The number of federated learning servers. Default: 0.
|
||||
scheduler_ip (string): The scheduler IP. Default: ''.
|
||||
scheduler_port (int): The scheduler port. Default: 0.
|
||||
scheduler_ip (string): The scheduler IP. Default: '0.0.0.0'.
|
||||
scheduler_port (int): The scheduler port. Default: 6667.
|
||||
fl_server_port (int): The http port of the federated learning server.
|
||||
Normally for each server this should be set to the same value. Default: 0.
|
||||
Normally for each server this should be set to the same value. Default: 6668.
|
||||
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
|
||||
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
|
||||
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
|
||||
update_model_ratio (float): The ratio for computing the threshold count of updateModel
|
||||
which will be multiplied by start_fl_job_threshold. Default: 1.0.
|
||||
which will be multiplied by start_fl_job_threshold.
|
||||
Must be between 0 and 1.0.Default: 1.0.
|
||||
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
|
||||
fl_name (string): The federated learning job name. Default: ''.
|
||||
fl_iteration_num (int): Iteration number of federeated learning,
|
||||
|
|
|
@ -15,10 +15,20 @@
|
|||
"""Context for parameter server training mode"""
|
||||
|
||||
import os
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore._c_expression import PSContext
|
||||
|
||||
_ps_context = None
|
||||
|
||||
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
|
||||
"start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window",
|
||||
"fl_iteration_num", "client_epoch_num", "client_batch_size", "scheduler_manage_port"]
|
||||
|
||||
_check_non_negative_int_keys = ["worker_num"]
|
||||
|
||||
_check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
|
||||
|
||||
_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"]
|
||||
|
||||
def ps_context():
|
||||
"""
|
||||
|
@ -181,3 +191,20 @@ def _set_cache_enable(cache_enable):
|
|||
|
||||
def _set_rank_id(rank_id):
|
||||
ps_context().set_rank_id(rank_id)
|
||||
|
||||
def _check_value(key, value):
|
||||
"""
|
||||
Validate the value for parameter server context keys.
|
||||
"""
|
||||
if key in _check_positive_int_keys:
|
||||
Validator.check_positive_int(value, key)
|
||||
|
||||
if key in _check_non_negative_int_keys:
|
||||
Validator.check_non_negative_int(value, key)
|
||||
|
||||
if key in _check_positive_float_keys:
|
||||
Validator.check_positive_float(value, key)
|
||||
|
||||
if key in _check_port_keys:
|
||||
if value < 1 or value > 65535:
|
||||
raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))
|
||||
|
|
|
@ -163,6 +163,9 @@ while True:
|
|||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
|
||||
x = session.post(url1, data=build_start_fl_job(current_iteration))
|
||||
while x.text == "The cluster is in safemode.":
|
||||
time.sleep(0.2)
|
||||
x = session.post(url1, data=build_start_fl_job(current_iteration))
|
||||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
|
||||
print("iteration is", rsp_fl_job.Iteration())
|
||||
|
@ -173,6 +176,10 @@ while True:
|
|||
print("req update model iteration:", current_iteration, ", id:", args.pid)
|
||||
update_model_buf, update_model_np_data = build_update_model(current_iteration)
|
||||
x = session.post(url2, data=update_model_buf)
|
||||
while x.text == "The cluster is in safemode.":
|
||||
time.sleep(0.2)
|
||||
x = session.post(url1, data=update_model_buf)
|
||||
|
||||
print("rsp update model iteration:", current_iteration, ", id:", args.pid)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
@ -227,4 +234,5 @@ while True:
|
|||
# Sleep to the next request timestamp
|
||||
current_ts = datetime_to_timestamp(datetime.datetime.now())
|
||||
duration = next_req_timestamp - current_ts
|
||||
time.sleep(duration / 1000)
|
||||
if duration > 0:
|
||||
time.sleep(duration / 1000)
|
||||
|
|
Loading…
Reference in New Issue