!18594 Sync from enterprise

Merge pull request !18594 from ZPaC/sync-from-enter
This commit is contained in:
i-robot 2021-06-22 02:44:08 +00:00 committed by Gitee
commit edc1c8bf58
25 changed files with 575 additions and 129 deletions

View File

@ -107,6 +107,7 @@ class FusedPullWeightKernel : public CPUKernel {
}
}
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
ps::worker::FLWorker::GetInstance().SetIterationRunning();
return true;
}

View File

@ -68,8 +68,8 @@ class FusedPushWeightKernel : public CPUKernel {
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
if (!ps::worker::FLWorker::GetInstance().SendToServer(
i, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPushWeight, &push_weight_rsp_msg)) {
MS_LOG(EXCEPTION) << "Sending request for FusedPushWeight to server " << i << " failed.";
return false;
MS_LOG(ERROR) << "Sending request for FusedPushWeight to server " << i << " failed.";
continue;
}
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
@ -83,6 +83,7 @@ class FusedPushWeightKernel : public CPUKernel {
}
}
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
return true;
}

View File

@ -47,7 +47,11 @@ enum class TcpUserCommand {
kCounterEvent,
kPullWeight,
kPushWeight,
kSyncIteration
kSyncIteration,
kNotifyLeaderToNextIter,
kPrepareForNextIter,
kProceedToNextIter,
kEndLastIter
};
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
@ -61,7 +65,11 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kCounterEvent, "counterEvent"},
{TcpUserCommand::kPullWeight, "pullWeight"},
{TcpUserCommand::kPushWeight, "pushWeight"},
{TcpUserCommand::kSyncIteration, "syncIteration"}};
{TcpUserCommand::kSyncIteration, "syncIteration"},
{TcpUserCommand::kNotifyLeaderToNextIter, "notifyLeaderToNextIter"},
{TcpUserCommand::kPrepareForNextIter, "prepareForNextIter"},
{TcpUserCommand::kProceedToNextIter, "proceedToNextIter"},
{TcpUserCommand::kEndLastIter, "endLastIter"}};
class TcpCommunicator : public CommunicatorBase {
public:

View File

@ -163,3 +163,41 @@ message SyncIterationResponse {
// The current iteration number.
uint64 iteration = 1;
}
message PrepareForNextIterRequest {
bool is_last_iter_valid = 1;
string reason = 2;
}
message PrepareForNextIterResponse {
string result = 1;
}
message NotifyLeaderMoveToNextIterRequest {
uint32 rank = 1;
bool is_last_iter_valid = 2;
uint64 iter_num = 3;
string reason = 4;
}
message NotifyLeaderMoveToNextIterResponse {
string result = 1;
}
message MoveToNextIterRequest {
bool is_last_iter_valid = 1;
uint64 last_iter_num = 2;
string reason = 3;
}
message MoveToNextIterResponse {
string result = 1;
}
message EndLastIterRequest {
uint64 last_iter_num = 1;
}
message EndLastIterResponse {
string result = 1;
}

View File

@ -45,7 +45,7 @@ void PSContext::SetPSEnable(bool enabled) {
} else if (ms_role == kEnvRoleOfScheduler) {
is_sched_ = true;
} else {
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
MS_LOG(INFO) << "MS_ROLE is " << ms_role;
}
worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
@ -273,7 +273,13 @@ void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window)
uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
void PSContext::set_update_model_ratio(float update_model_ratio) { update_model_ratio_ = update_model_ratio; }
void PSContext::set_update_model_ratio(float update_model_ratio) {
if (update_model_ratio > 1.0) {
MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1.";
return;
}
update_model_ratio_ = update_model_ratio;
}
float PSContext::update_model_ratio() const { return update_model_ratio_; }

View File

@ -161,12 +161,12 @@ class PSContext {
rank_id_(0),
worker_num_(0),
server_num_(0),
scheduler_host_(""),
scheduler_port_(0),
scheduler_host_("0.0.0.0"),
scheduler_port_(6667),
role_(kEnvRoleOfNotPS),
server_mode_(""),
resetter_round_(ResetterRound::kNoNeedToReset),
fl_server_port_(0),
fl_server_port_(6668),
fl_client_enable_(false),
fl_name_(""),
start_fl_job_threshold_(0),
@ -179,7 +179,7 @@ class PSContext {
client_learning_rate_(0.001),
secure_aggregation_(false),
cluster_config_(nullptr),
scheduler_manage_port_(0),
scheduler_manage_port_(11202),
config_file_path_("") {}
bool ps_enabled_;
bool is_worker_;

View File

@ -63,9 +63,9 @@ using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
using mindspore::kernel::CPUKernel;
using FBBuilder = flatbuffers::FlatBufferBuilder;
using TimeOutCb = std::function<void(bool)>;
using TimeOutCb = std::function<void(bool, const std::string &)>;
using StopTimerCb = std::function<void(void)>;
using FinishIterCb = std::function<void(bool)>;
using FinishIterCb = std::function<void(bool, const std::string &)>;
using FinalizeCb = std::function<void(void)>;
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;

View File

@ -83,7 +83,10 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
global_current_count_[name].insert(id);
TriggerCounterEvent(name);
if (!TriggerCounterEvent(name)) {
MS_LOG(ERROR) << "Leader server trigger count event failed.";
return false;
}
} else {
// If this server is a follower server, it needs to send CountRequest to the leader server.
CountRequest report_count_req;
@ -198,9 +201,14 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::Mes
// Insert the id for the counter, which means the count for the name is increased.
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
global_current_count_[name].insert(id);
TriggerCounterEvent(name);
count_rsp.set_result(true);
count_rsp.set_reason("success");
if (!TriggerCounterEvent(name)) {
std::string reason = "Trigger count event for " + name + " of " + id + " failed.";
count_rsp.set_result(false);
count_rsp.set_reason(reason);
} else {
count_rsp.set_result(true);
count_rsp.set_reason("success");
}
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
return;
}
@ -256,20 +264,24 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::Mes
return;
}
void DistributedCountService::TriggerCounterEvent(const std::string &name) {
bool DistributedCountService::TriggerCounterEvent(const std::string &name) {
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
<< ", threshold count is " << global_threshold_count_[name];
// The threshold count may be 1 so the first and last count event should be both activated.
if (global_current_count_[name].size() == 1) {
TriggerFirstCountEvent(name);
if (!TriggerFirstCountEvent(name)) {
return false;
}
}
if (global_current_count_[name].size() == global_threshold_count_[name]) {
TriggerLastCountEvent(name);
if (!TriggerLastCountEvent(name)) {
return false;
}
}
return;
return true;
}
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
MS_LOG(DEBUG) << "Activating first count event for " << name;
CounterEvent first_count_event;
first_count_event.set_type(CounterEventType::FIRST_CNT);
@ -279,15 +291,15 @@ void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
return;
return false;
}
}
// Leader server directly calls the callback.
counter_handlers_[name].first_count_handler(nullptr);
return;
return true;
}
void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
MS_LOG(INFO) << "Activating last count event for " << name;
CounterEvent last_count_event;
last_count_event.set_type(CounterEventType::LAST_CNT);
@ -297,12 +309,12 @@ void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
return;
return false;
}
}
// Leader server directly calls the callback.
counter_handlers_[name].last_count_handler(nullptr);
return;
return true;
}
} // namespace server
} // namespace ps

View File

@ -98,9 +98,9 @@ class DistributedCountService {
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
// Call the callbacks when the first/last count event is triggered.
void TriggerCounterEvent(const std::string &name);
void TriggerFirstCountEvent(const std::string &name);
void TriggerLastCountEvent(const std::string &name);
bool TriggerCounterEvent(const std::string &name);
bool TriggerFirstCountEvent(const std::string &name);
bool TriggerLastCountEvent(const std::string &name);
// Members for the communication between counting server and other servers.
std::shared_ptr<core::ServerNode> server_node_;

View File

@ -20,15 +20,26 @@
#include <vector>
#include <numeric>
#include "ps/server/model_store.h"
#include "ps/server/server.h"
namespace mindspore {
namespace ps {
namespace server {
class Server;
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
communicator_->RegisterMsgCallBack("syncIteraion",
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
"notifyLeaderToNextIter",
std::bind(&Iteration::HandleNotifyLeaderMoveToNextIterRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
"prepareForNextIter", std::bind(&Iteration::HandlePrepareForNextIterRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack("proceedToNextIter",
std::bind(&Iteration::HandleMoveToNextIterRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack("endLastIter",
std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
}
void Iteration::RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node) {
@ -72,36 +83,33 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
return;
}
void Iteration::ProceedToNextIter(bool is_iteration_valid) {
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
is_last_iteration_valid_ = is_iteration_valid;
if (is_iteration_valid) {
// Store the model which is successfully aggregated for this iteration.
const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_
<< " validation is " << is_last_iter_valid << ". Reason: " << reason;
if (IsMoveToNextIterRequestReentrant(iteration_num_)) {
return;
}
if (server_node_->rank_id() == kLeaderServerRank) {
if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
return;
}
if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
return;
}
if (!BroadcastEndLastIterRequest(iteration_num_)) {
MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
return;
}
} else {
// Store last iteration's model because this iteration is considered as invalid.
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid.";
// If this server is the follower server, notify leader server to control the cluster to proceed to next iteration.
if (!NotifyLeaderMoveToNextIteration(is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Server " << server_node_->rank_id() << " notifying the leader server failed.";
return;
}
}
for (auto &round : rounds_) {
round->Reset();
}
iteration_num_++;
// After the job is done, reset the iteration to the initial number and reset ModelStore.
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
iteration_num_ = 1;
ModelStore::GetInstance().Reset();
}
SetIterationCompleted();
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
}
void Iteration::SetIterationRunning() {
@ -147,7 +155,7 @@ bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
}
for (auto &round : rounds_) {
if (!round->ReInitForScaling(server_num)) {
MS_LOG(ERROR) << "Reinitializing round " << round->name() << " for scaling failed.";
MS_LOG(WARNING) << "Reinitializing round " << round->name() << " for scaling failed.";
return false;
}
}
@ -168,6 +176,10 @@ bool Iteration::SyncIteration(uint32_t rank) {
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
return false;
}
if (sync_iter_rsp_msg == nullptr) {
MS_LOG(ERROR) << "Response from server 0 is empty.";
return false;
}
SyncIterationResponse sync_iter_rsp;
sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size());
@ -192,6 +204,239 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHa
std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString();
communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message);
}
bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
std::unique_lock<std::mutex> lock(pinned_mtx_);
if (pinned_iter_num_ == iteration_num) {
MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call.";
return true;
}
pinned_iter_num_ = iteration_num;
return false;
}
bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
MS_LOG(INFO) << "Notify leader server to control the cluster to proceed to next iteration.";
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
notify_leader_to_next_iter_req.set_rank(server_node_->rank_id());
notify_leader_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
notify_leader_to_next_iter_req.set_reason(reason);
if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
core::TcpUserCommand::kNotifyLeaderToNextIter)) {
MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
return false;
}
return true;
}
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp;
notify_leader_to_next_iter_rsp.set_result("success");
communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
notify_leader_to_next_iter_req.ParseFromArray(message->data(), message->len());
const auto &rank = notify_leader_to_next_iter_req.rank();
const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
const auto &reason = notify_leader_to_next_iter_req.reason();
MS_LOG(INFO) << "Leader server receives NotifyLeaderMoveToNextIterRequest from rank " << rank
<< ". Iteration number: " << iter_num << ". Reason: " << reason;
if (IsMoveToNextIterRequestReentrant(iter_num)) {
return;
}
if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
return;
}
if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
return;
}
if (!BroadcastEndLastIterRequest(iteration_num_)) {
MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
return;
}
}
bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
PrepareForNextIter();
MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration.";
PrepareForNextIterRequest prepare_next_iter_req;
prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
prepare_next_iter_req.set_reason(reason);
std::vector<uint32_t> offline_servers = {};
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, core::TcpUserCommand::kPrepareForNextIter)) {
MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
offline_servers.push_back(i);
continue;
}
}
// Retry sending to offline servers to notify them to prepare.
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
while (!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) {
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
<< " failed. The server has not recovered yet.";
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
}
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
});
return true;
}
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
PrepareForNextIterRequest prepare_next_iter_req;
prepare_next_iter_req.ParseFromArray(message->data(), message->len());
const auto &reason = prepare_next_iter_req.reason();
MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason;
PrepareForNextIter();
PrepareForNextIterResponse prepare_next_iter_rsp;
prepare_next_iter_rsp.set_result("success");
communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
prepare_next_iter_rsp.SerializeAsString().size(), message);
}
void Iteration::PrepareForNextIter() {
MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
Server::GetInstance().SwitchToSafeMode();
}
bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
MS_LOG(INFO) << "Notify all follower servers to proceed to next iteration. Set last iteration number "
<< iteration_num_;
MoveToNextIterRequest proceed_to_next_iter_req;
proceed_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
proceed_to_next_iter_req.set_reason(reason);
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, core::TcpUserCommand::kProceedToNextIter)) {
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
continue;
}
}
Next(is_last_iter_valid, reason);
return true;
}
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
MoveToNextIterResponse proceed_to_next_iter_rsp;
proceed_to_next_iter_rsp.set_result("success");
communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
MoveToNextIterRequest proceed_to_next_iter_req;
proceed_to_next_iter_req.ParseFromArray(message->data(), message->len());
const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
const auto &reason = proceed_to_next_iter_req.reason();
MS_LOG(INFO) << "Receive proceeding to next iteration request. This server current iteration is " << iteration_num_
<< ". The iteration number from leader server is " << last_iter_num
<< ". Last iteration is valid or not: " << is_last_iter_valid << ". Reason: " << reason;
// Synchronize the iteration number with leader server.
iteration_num_ = last_iter_num;
Next(is_last_iter_valid, reason);
}
void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
MS_LOG(INFO) << "Prepare for next iteration.";
is_last_iteration_valid_ = is_iteration_valid;
if (is_iteration_valid) {
// Store the model which is successfully aggregated for this iteration.
const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else {
// Store last iteration's model because this iteration is considered as invalid.
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
}
for (auto &round : rounds_) {
round->Reset();
}
}
bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
MS_LOG(INFO) << "Notify all follower servers to end last iteration.";
EndLastIterRequest end_last_iter_req;
end_last_iter_req.set_last_iter_num(last_iter_num);
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(end_last_iter_req, i, core::TcpUserCommand::kEndLastIter)) {
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
continue;
}
}
EndLastIter();
return true;
}
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
EndLastIterRequest end_last_iter_req;
end_last_iter_req.ParseFromArray(message->data(), message->len());
const auto &last_iter_num = end_last_iter_req.last_iter_num();
// If the iteration number is not matched, return error.
if (last_iter_num != iteration_num_) {
std::string reason = "The iteration of this server " + std::to_string(server_node_->rank_id()) + " is " +
std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num);
EndLastIterResponse end_last_iter_rsp;
end_last_iter_rsp.set_result(reason);
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
end_last_iter_rsp.SerializeAsString().size(), message);
return;
}
EndLastIter();
EndLastIterResponse end_last_iter_rsp;
end_last_iter_rsp.set_result("success");
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
end_last_iter_rsp.SerializeAsString().size(), message);
}
void Iteration::EndLastIter() {
MS_LOG(INFO) << "End the last iteration " << iteration_num_;
iteration_num_++;
// After the job is done, reset the iteration to the initial number and reset ModelStore.
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
iteration_num_ = 1;
ModelStore::GetInstance().Reset();
}
Server::GetInstance().CancelSafeMode();
SetIterationCompleted();
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
}
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include <string>
#include "ps/core/communicator/communicator_base.h"
#include "ps/server/common.h"
#include "ps/server/round.h"
@ -34,6 +35,9 @@ enum class IterationState {
kCompleted
};
// The time duration between retrying when sending prepare for next iteration request failed.
constexpr uint32_t kRetryDurationForPrepareForNextIter = 500;
// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
class Iteration {
@ -56,9 +60,10 @@ class Iteration {
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
// The server proceeds to the next iteration only after the last round finishes or the timer expires.
// If the timer expires, we consider this iteration as invalid.
void ProceedToNextIter(bool is_iteration_valid);
// This method will control servers to proceed to next iteration.
// There's communication between leader and follower servers in this method.
// The server moves to next iteration only after the last round finishes or the time expires.
void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
// Set current iteration state to running and trigger events about kIterationRunning.
void SetIterationRunning();
@ -84,7 +89,8 @@ class Iteration {
communicator_(nullptr),
iteration_state_(IterationState::kCompleted),
iteration_num_(1),
is_last_iteration_valid_(true) {
is_last_iteration_valid_(true),
pinned_iter_num_(0) {
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
}
~Iteration() = default;
@ -99,6 +105,32 @@ class Iteration {
bool SyncIteration(uint32_t rank);
void HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message);
// The request for moving to next iteration is not reentrant.
bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
// The methods for moving to next iteration for all the servers.
// Step 1: follower servers notify leader server that they need to move to next iteration.
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
// The server prepare for the next iteration. This method will switch the server to safemode.
void PrepareForNextIter();
// Step 3: leader server broadcast to all follower servers to move to next iteration.
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
// Move to next iteration. Store last iterations model and reset all the rounds.
void Next(bool is_iteration_valid, const std::string &reason);
// Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
bool BroadcastEndLastIterRequest(uint64_t iteration_num);
void HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message);
// The server end the last iteration. This method will increase the iteration number and cancel the safemode.
void EndLastIter();
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_;
@ -113,6 +145,10 @@ class Iteration {
// Last iteration is successfully finished.
bool is_last_iteration_valid_;
// To avoid Next method is called multiple times in one iteration, we should mark the iteration number.
uint64_t pinned_iter_num_;
std::mutex pinned_mtx_;
};
} // namespace server
} // namespace ps

View File

@ -29,7 +29,7 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
monitor_thread_ = std::thread([&]() {
while (running_.load()) {
if (CURRENT_TIME_MILLI > end_time_) {
timeout_callback_(false);
timeout_callback_(false, "");
running_ = false;
}
// The time tick is 1 millisecond.

View File

@ -62,6 +62,7 @@ bool GetModelKernel::Reset() {
}
void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) {
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
std::map<std::string, AddressPtr> feature_maps;
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
@ -70,9 +71,11 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" +
"2. Worker has not push all the weights to servers.";
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
std::to_string(next_req_time));
MS_LOG(WARNING) << reason;
return;
}
@ -88,11 +91,12 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
// If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according
// to next_req_time.
auto response_code =
Iteration::GetInstance().is_last_iteration_valid() ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
bool last_iter_valid = Iteration::GetInstance().is_last_iteration_valid();
MS_LOG(INFO) << "GetModel last iteration is valid or not: " << last_iter_valid << ", next request time is "
<< next_req_time << ", current iteration is " << current_iter;
auto response_code = last_iter_valid ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter,
feature_maps,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
feature_maps, std::to_string(next_req_time));
return;
}

View File

@ -48,9 +48,9 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return false;
}
PushWeight(fbb, push_weight_req);
bool ret = PushWeight(fbb, push_weight_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
return ret;
}
bool PushWeightKernel::Reset() {
@ -67,9 +67,9 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandl
return;
}
void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
if (fbb == nullptr || push_weight_req == nullptr) {
return;
return false;
}
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
@ -77,8 +77,8 @@ void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(current_iter);
BuildPushWeightRsp(fbb, schema::ResponseCode_OutOfTime, reason, current_iter);
MS_LOG(ERROR) << reason;
return;
MS_LOG(WARNING) << reason;
return true;
}
std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
@ -86,20 +86,25 @@ void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
std::string reason = "PushWeight feature_map is empty.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
MS_LOG(ERROR) << reason;
return;
return false;
}
if (!executor_->HandlePushWeight(upload_feature_map)) {
std::string reason = "Pushing weight failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
MS_LOG(ERROR) << reason;
return;
return false;
}
MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";
DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_));
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) {
std::string reason = "Count for push weight request failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
MS_LOG(ERROR) << reason;
return false;
}
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
return;
return true;
}
std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {

View File

@ -42,7 +42,7 @@ class PushWeightKernel : public RoundKernel {
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
private:
void PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
size_t iteration);

View File

@ -68,7 +68,7 @@ void RoundKernel::StopTimer() {
void RoundKernel::FinishIteration() {
if (finish_iteration_cb_) {
finish_iteration_cb_(true);
finish_iteration_cb_(true, "");
}
return;
}

View File

@ -61,7 +61,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
if (ReachThresholdForStartFLJob(fbb)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return true;
}
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
@ -102,7 +102,7 @@ bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuild
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
return true;
}
return false;

View File

@ -18,11 +18,13 @@
#include <memory>
#include <string>
#include "ps/server/server.h"
#include "ps/server/iteration.h"
namespace mindspore {
namespace ps {
namespace server {
class Server;
class Iteration;
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
bool server_num_as_threshold)
: name_(name),
@ -42,9 +44,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
// Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid) -> void {
MS_LOG(INFO) << "Round " << name_ << " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(is_iteration_valid);
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(is_iteration_valid, reason);
};
// Callback for finalizing the server. This can only be called once.
@ -54,9 +56,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
iter_timer_ = std::make_shared<IterationTimer>();
// 1.Set the timeout callback for the timer.
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid) -> void {
MS_LOG(INFO) << "Round " << name_ << " timeout! This iteration is invalid. Proceed to next iteration.";
timeout_cb(is_iteration_valid);
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
timeout_cb(is_iteration_valid, reason);
});
// 2.Stopping timer callback which will be set to the round kernel.
@ -89,7 +91,7 @@ bool Round::ReInitForScaling(uint32_t server_num) {
}
if (kernel_ == nullptr) {
MS_LOG(ERROR) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
MS_LOG(WARNING) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
return false;
}
kernel_->InitKernel(threshold_count_);
@ -129,13 +131,14 @@ void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &messa
communicator_->SendResponse(reason.c_str(), reason.size(), message);
return;
}
communicator_->SendResponse(output->addr, output->size, message);
kernel_->Release(output);
// Must send response back no matter what value Launch method returns.
if (!ret) {
MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed.";
std::string reason = "Launching round kernel of round " + name_ + " failed.";
Iteration::GetInstance().MoveToNextIteration(false, reason);
}
communicator_->SendResponse(output->addr, output->size, message);
kernel_->Release(output);
return;
}

View File

@ -73,6 +73,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
// InitCipher---->InitExecutor
void Server::Run() {
signal(SIGINT, SignalHandler);
std::unique_lock<std::mutex> lock(scaling_mtx_);
InitServerContext();
InitCluster();
InitIteration();
@ -82,6 +83,7 @@ void Server::Run() {
RegisterRoundKernel();
MS_LOG(INFO) << "Server started successfully.";
safemode_ = false;
lock.unlock();
// Wait communicators to stop so the main thread is blocked.
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
@ -91,6 +93,16 @@ void Server::Run() {
return;
}
void Server::SwitchToSafeMode() {
MS_LOG(INFO) << "Server switch to safemode.";
safemode_ = true;
}
void Server::CancelSafeMode() {
MS_LOG(INFO) << "Server cancel safemode.";
safemode_ = false;
}
bool Server::IsSafeMode() { return safemode_.load(); }
void Server::InitServerContext() {
@ -166,8 +178,10 @@ void Server::InitIteration() {
}
// 2.Initialize all the rounds.
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
TimeOutCb time_out_cb =
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
FinishIterCb finish_iter_cb =
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
return;
}
@ -288,28 +302,29 @@ void Server::ProcessBeforeScalingIn() {
}
void Server::ProcessAfterScalingOut() {
std::unique_lock<std::mutex> lock(scaling_mtx_);
if (server_node_ == nullptr) {
return;
}
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
return;
}
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
return;
}
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
return;
}
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
MS_LOG(ERROR) << "Iteration reinitializing failed.";
MS_LOG(WARNING) << "Iteration reinitializing failed.";
return;
}
if (!Executor::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "Executor reinitializing failed.";
MS_LOG(WARNING) << "Executor reinitializing failed.";
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
@ -317,6 +332,7 @@ void Server::ProcessAfterScalingOut() {
}
void Server::ProcessAfterScalingIn() {
std::unique_lock<std::mutex> lock(scaling_mtx_);
if (server_node_ == nullptr) {
return;
}
@ -331,23 +347,23 @@ void Server::ProcessAfterScalingIn() {
// If the server is not the one to be scaled in, reintialize modules and recover service.
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
return;
}
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
return;
}
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
return;
}
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
MS_LOG(ERROR) << "Iteration reinitializing failed.";
MS_LOG(WARNING) << "Iteration reinitializing failed.";
return;
}
if (!Executor::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "Executor reinitializing failed.";
MS_LOG(WARNING) << "Executor reinitializing failed.";
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));

View File

@ -46,6 +46,8 @@ class Server {
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
void Run();
void SwitchToSafeMode();
void CancelSafeMode();
bool IsSafeMode();
private:
@ -134,6 +136,9 @@ class Server {
// communicators.
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
std::mutex scaling_mtx_;
// Iteration consists of multiple kinds of rounds.
Iteration *iteration_;

View File

@ -67,12 +67,23 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core:
}
if (output != nullptr) {
do {
while (true) {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false;
}
} while (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode);
if (*output == nullptr) {
MS_LOG(WARNING) << "Response from server " << server_rank << " is empty.";
return false;
}
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode) {
MS_LOG(INFO) << "The server " << server_rank << " is in safemode.";
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
} else {
break;
}
}
} else {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
@ -88,6 +99,16 @@ uint32_t FLWorker::worker_num() const { return worker_num_; }
uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
void FLWorker::SetIterationRunning() {
MS_LOG(INFO) << "Worker iteration starts.";
worker_iteration_state_ = IterationState::kRunning;
}
void FLWorker::SetIterationCompleted() {
MS_LOG(INFO) << "Worker iteration completes.";
worker_iteration_state_ = IterationState::kCompleted;
}
void FLWorker::InitializeFollowerScaler() {
if (!worker_node_->InitFollowerScaler()) {
MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
@ -112,21 +133,22 @@ void FLWorker::InitializeFollowerScaler() {
}
void FLWorker::HandleIterationRunningEvent() {
MS_LOG(INFO) << "Worker iteration starts, safemode is " << safemode_.load();
iteration_state_ = IterationState::kRunning;
MS_LOG(INFO) << "Server iteration starts, safemode is " << safemode_.load();
server_iteration_state_ = IterationState::kRunning;
if (safemode_.load() == true) {
safemode_ = false;
}
}
void FLWorker::HandleIterationCompletedEvent() {
MS_LOG(INFO) << "Worker iteration completes";
iteration_state_ = IterationState::kCompleted;
MS_LOG(INFO) << "Server iteration completes";
server_iteration_state_ = IterationState::kCompleted;
}
void FLWorker::ProcessBeforeScalingOut() {
MS_LOG(INFO) << "Starting Worker scaling out barrier.";
while (iteration_state_.load() != IterationState::kCompleted) {
while (server_iteration_state_.load() != IterationState::kCompleted ||
worker_iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
MS_LOG(INFO) << "Ending Worker scaling out barrier. Switch to safemode.";
@ -135,7 +157,8 @@ void FLWorker::ProcessBeforeScalingOut() {
void FLWorker::ProcessBeforeScalingIn() {
MS_LOG(INFO) << "Starting Worker scaling in barrier.";
while (iteration_state_.load() != IterationState::kCompleted) {
while (server_iteration_state_.load() != IterationState::kCompleted ||
worker_iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
MS_LOG(INFO) << "Ending Worker scaling in barrier. Switch to safemode.";
@ -148,9 +171,6 @@ void FLWorker::ProcessAfterScalingOut() {
}
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
while (iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
server_num_ = worker_node_->server_num();
worker_num_ = worker_node_->worker_num();
MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is "
@ -165,9 +185,6 @@ void FLWorker::ProcessAfterScalingIn() {
}
MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
while (iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
server_num_ = worker_node_->server_num();
worker_num_ = worker_node_->worker_num();
MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_

View File

@ -40,6 +40,9 @@ constexpr uint32_t kTrainEndStepNum = 0;
// The worker has to sleep for a while before the networking is completed.
constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
// The time duration between retrying when server is in safemode.
constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
enum class IterationState {
// This iteration is still in process.
kRunning,
@ -64,6 +67,10 @@ class FLWorker {
uint32_t worker_num() const;
uint64_t worker_step_num_per_iteration() const;
// These methods set the worker's iteration state.
void SetIterationRunning();
void SetIterationCompleted();
private:
FLWorker()
: server_num_(0),
@ -72,7 +79,8 @@ class FLWorker {
scheduler_port_(0),
worker_node_(nullptr),
worker_step_num_per_iteration_(1),
iteration_state_(IterationState::kCompleted),
server_iteration_state_(IterationState::kCompleted),
worker_iteration_state_(IterationState::kCompleted),
safemode_(false) {}
~FLWorker() = default;
FLWorker(const FLWorker &) = delete;
@ -104,9 +112,14 @@ class FLWorker {
uint64_t worker_step_num_per_iteration_;
// The iteration state is either running or completed.
std::atomic<IterationState> iteration_state_;
// This variable represents the server iteration state and should be changed by events
// kIterationRunning/kIterationCompleted. triggered by server.
std::atomic<IterationState> server_iteration_state_;
// The flag that represents whether worker is in safemode.
// The variable represents the worker iteration state and should be changed by worker training process.
std::atomic<IterationState> worker_iteration_state_;
// The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state.
std::atomic_bool safemode_;
};
} // namespace worker

View File

@ -828,18 +828,19 @@ def set_fl_context(**kwargs):
Default: 'FEDERATED_LEARNING'.
ms_role (string): The process's role in the federated learning mode,
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
Default: 'MS_NOT_PS'.
worker_num (int): The number of workers. Default: 0.
Default: 'MS_SERVER'.
worker_num (int): The number of workers. For current version, this must be set to 1 or 0.
server_num (int): The number of federated learning servers. Default: 0.
scheduler_ip (string): The scheduler IP. Default: ''.
scheduler_port (int): The scheduler port. Default: 0.
scheduler_ip (string): The scheduler IP. Default: '0.0.0.0'.
scheduler_port (int): The scheduler port. Default: 6667.
fl_server_port (int): The http port of the federated learning server.
Normally for each server this should be set to the same value. Default: 0.
Normally for each server this should be set to the same value. Default: 6668.
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
update_model_ratio (float): The ratio for computing the threshold count of updateModel
which will be multiplied by start_fl_job_threshold. Default: 1.0.
which will be multiplied by start_fl_job_threshold.
Must be between 0 and 1.0.Default: 1.0.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
fl_name (string): The federated learning job name. Default: ''.
fl_iteration_num (int): Iteration number of federeated learning,

View File

@ -15,10 +15,20 @@
"""Context for parameter server training mode"""
import os
from mindspore._checkparam import Validator
from mindspore._c_expression import PSContext
_ps_context = None
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
"start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window",
"fl_iteration_num", "client_epoch_num", "client_batch_size", "scheduler_manage_port"]
_check_non_negative_int_keys = ["worker_num"]
_check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"]
def ps_context():
"""
@ -181,3 +191,20 @@ def _set_cache_enable(cache_enable):
def _set_rank_id(rank_id):
ps_context().set_rank_id(rank_id)
def _check_value(key, value):
"""
Validate the value for parameter server context keys.
"""
if key in _check_positive_int_keys:
Validator.check_positive_int(value, key)
if key in _check_non_negative_int_keys:
Validator.check_non_negative_int(value, key)
if key in _check_positive_float_keys:
Validator.check_positive_float(value, key)
if key in _check_port_keys:
if value < 1 or value > 65535:
raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))

View File

@ -163,6 +163,9 @@ while True:
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
x = session.post(url1, data=build_start_fl_job(current_iteration))
while x.text == "The cluster is in safemode.":
time.sleep(0.2)
x = session.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
print("iteration is", rsp_fl_job.Iteration())
@ -173,6 +176,10 @@ while True:
print("req update model iteration:", current_iteration, ", id:", args.pid)
update_model_buf, update_model_np_data = build_update_model(current_iteration)
x = session.post(url2, data=update_model_buf)
while x.text == "The cluster is in safemode.":
time.sleep(0.2)
x = session.post(url1, data=update_model_buf)
print("rsp update model iteration:", current_iteration, ", id:", args.pid)
sys.stdout.flush()
@ -227,4 +234,5 @@ while True:
# Sleep to the next request timestamp
current_ts = datetime_to_timestamp(datetime.datetime.now())
duration = next_req_timestamp - current_ts
time.sleep(duration / 1000)
if duration > 0:
time.sleep(duration / 1000)