!18803 Sync enter bugfix for fail over

Merge pull request !18803 from ZPaC/sync-from-enter
This commit is contained in:
i-robot 2021-06-24 12:25:52 +00:00 committed by Gitee
commit 26e34dec80
21 changed files with 537 additions and 292 deletions

View File

@ -73,8 +73,9 @@ class FusedPullWeightKernel : public CPUKernel {
while (retcode == schema::ResponseCode_SucNotReady) {
if (!ps::worker::FLWorker::GetInstance().SendToServer(
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
MS_LOG(EXCEPTION) << "Sending request for FusedPullWeight to server 0 failed.";
return false;
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped.";
ps::worker::FLWorker::GetInstance().SetIterationRunning();
return true;
}
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
@ -82,6 +83,13 @@ class FusedPullWeightKernel : public CPUKernel {
retcode = pull_weight_rsp->retcode();
if (retcode == schema::ResponseCode_SucNotReady) {
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights));
fl_iteration_ = pull_weight_rsp->iteration();
MS_LOG(DEBUG) << "Server is not ready for downloading yet. Reason: " << pull_weight_rsp->reason()->str()
<< ". Retry later.";
if (!BuildPullWeightReq(fbb)) {
MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed.";
return false;
}
continue;
} else if (retcode != schema::ResponseCode_SUCCEED) {
MS_LOG(EXCEPTION) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode()

View File

@ -28,6 +28,8 @@
namespace mindspore {
namespace kernel {
// The duration between two uploading requests when return code is ResponseCode_SucNotReady.
constexpr int kRetryDurationOfPushWeights = 200;
template <typename T>
class FusedPushWeightKernel : public CPUKernel {
public:
@ -66,22 +68,41 @@ class FusedPushWeightKernel : public CPUKernel {
// The server number may change after scaling in/out.
for (uint32_t i = 0; i < ps::worker::FLWorker::GetInstance().server_num(); i++) {
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
if (!ps::worker::FLWorker::GetInstance().SendToServer(
i, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPushWeight, &push_weight_rsp_msg)) {
MS_LOG(ERROR) << "Sending request for FusedPushWeight to server " << i << " failed.";
continue;
}
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) {
if (!ps::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
ps::core::TcpUserCommand::kPushWeight,
&push_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i
<< " failed. This iteration is dropped.";
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
return true;
}
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
const schema::ResponsePushWeight *push_weight_rsp =
flatbuffers::GetRoot<schema::ResponsePushWeight>(push_weight_rsp_msg->data());
auto retcode = push_weight_rsp->retcode();
if (retcode != schema::ResponseCode_SUCCEED) {
MS_LOG(EXCEPTION) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode()
<< ", reason: " << push_weight_rsp->reason()->str();
return false;
push_weight_rsp = flatbuffers::GetRoot<schema::ResponsePushWeight>(push_weight_rsp_msg->data());
retcode = push_weight_rsp->retcode();
if (retcode == schema::ResponseCode_SucNotReady) {
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights));
fl_iteration_ = push_weight_rsp->iteration();
MS_LOG(DEBUG) << "Server is not ready for pushing weight yet. Reason: " << push_weight_rsp->reason()->str()
<< ". Retry later.";
if (!BuildPushWeightReq(fbb, inputs)) {
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
return false;
}
continue;
} else if (retcode != schema::ResponseCode_SUCCEED) {
MS_LOG(EXCEPTION) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode()
<< ", reason: " << push_weight_rsp->reason()->str();
return false;
} else {
MS_LOG(DEBUG) << "FusedPushWeight succeed.";
}
}
}
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
return true;

View File

@ -60,6 +60,7 @@
#include "ps/constants.h"
#include "ps/util.h"
#include "ps/worker.h"
#include "ps/worker/fl_worker.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/server/server.h"
@ -1254,8 +1255,13 @@ void ClearResAtexit() {
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.Finalize();
}
MS_LOG(INFO) << "ps::worker.Finalize";
ps::Worker::GetInstance().Finalize();
MS_LOG(INFO) << "Start finalizing worker.";
const std::string &server_mode = ps::PSContext::instance()->server_mode();
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
ps::worker::FLWorker::GetInstance().Finalize();
} else {
ps::Worker::GetInstance().Finalize();
}
}
#endif
ad::g_k_prims.clear();

View File

@ -41,6 +41,8 @@ void CommunicatorBase::Join() {
running_thread_.join();
return;
}
bool CommunicatorBase::running() const { return running_; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -56,6 +56,8 @@ class CommunicatorBase {
bool SendResponse(const void *rsp_data, size_t rsp_len, std::shared_ptr<MessageHandler> msg_handler);
bool running() const;
protected:
std::unordered_map<std::string, MessageCallback> msg_callbacks_;
std::thread running_thread_;

View File

@ -104,12 +104,12 @@ class TcpCommunicator : public CommunicatorBase {
if (output != nullptr) {
if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command), output)) {
MS_LOG(ERROR) << "Query leader server whether count is enough failed.";
MS_LOG(ERROR) << "Sending protobuffer message to server " << rank_id << " failed.";
return false;
}
} else {
if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command))) {
MS_LOG(ERROR) << "Query leader server whether count is enough failed.";
MS_LOG(ERROR) << "Sending protobuffer message to server " << rank_id << " failed.";
return false;
}
}

View File

@ -51,7 +51,7 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl
return;
}
if (global_threshold_count_.count(name) != 0) {
MS_LOG(WARNING) << "Counter for " << name << " is already set.";
MS_LOG(INFO) << "Counter for " << name << " is already set.";
return;
}

View File

@ -106,9 +106,10 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
return false;
}
std::string update_meta_rsp = reinterpret_cast<const char *>(update_meta_rsp_msg->data());
std::string update_meta_rsp =
std::string(reinterpret_cast<char *>(update_meta_rsp_msg->data()), update_meta_rsp_msg->size());
if (update_meta_rsp != kSuccess) {
MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed.";
MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed. " << update_meta_rsp;
return false;
}
}

View File

@ -207,7 +207,7 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHa
bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
std::unique_lock<std::mutex> lock(pinned_mtx_);
if (pinned_iter_num_ == iteration_num) {
if (pinned_iter_num_ >= iteration_num) {
MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call.";
return true;
}
@ -286,7 +286,9 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
// Retry sending to offline servers to notify them to prepare.
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
while (!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) {
// Should avoid endless loop if the server communicator is stopped.
while (communicator_->running() &&
!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) {
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
<< " failed. The server has not recovered yet.";
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));

View File

@ -68,7 +68,6 @@ class FedAvgKernel : public AggregationKernel {
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
MS_EXCEPTION_IF_NULL(weight_node);
name_ = cnode_name + "." + weight_node->fullname_with_scope();
MS_LOG(INFO) << "Register counter for " << name_;
first_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
std::unique_lock<std::mutex> lock(weight_mutex_);
if (!participated_) {
@ -123,9 +122,8 @@ class FedAvgKernel : public AggregationKernel {
accum_count_++;
participated_ = true;
DistributedCountService::GetInstance().Count(
return DistributedCountService::GetInstance().Count(
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
return true;
}
void Reset() override {

View File

@ -89,14 +89,10 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
}
// If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according
// to next_req_time.
bool last_iter_valid = Iteration::GetInstance().is_last_iteration_valid();
MS_LOG(INFO) << "GetModel last iteration is valid or not: " << last_iter_valid << ", next request time is "
<< next_req_time << ", current iteration is " << current_iter;
auto response_code = last_iter_valid ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter,
feature_maps, std::to_string(next_req_time));
MS_LOG(INFO) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
current_iter, feature_maps, std::to_string(next_req_time));
return;
}

View File

@ -59,13 +59,23 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return false;
}
if (ReachThresholdForStartFLJob(fbb)) {
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestFLJob>()) {
std::string reason = "The schema of startFLJob is invalid.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, "");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
if (ReachThresholdForStartFLJob(fbb)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!ReadyForStartFLJob(fbb, device_meta)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
@ -140,7 +150,7 @@ bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
std::string reason = "startFLJob counting failed.";
std::string reason = "Counting start fl job request failed. Please retry later.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));

View File

@ -56,9 +56,9 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
}
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
if (!ReachThresholdForUpdateModel(fbb)) {
if (ReachThresholdForUpdateModel(fbb)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return true;
}
const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
@ -103,14 +103,14 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for updateModel is enough.";
std::string reason = "Current amount for updateModel is enough. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
MS_LOG(WARNING) << reason;
return true;
}
return true;
return false;
}
bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
@ -118,19 +118,20 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
size_t iteration = static_cast<size_t>(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
". Retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
MS_LOG(WARNING) << reason;
return true;
}
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
std::string update_model_fl_id = update_model_req->fl_id()->str();
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
@ -193,7 +194,7 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
std::string reason = "UpdateModel counting failed.";
std::string reason = "Counting for update model request failed. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));

View File

@ -85,7 +85,7 @@ bool ParameterAggregator::LaunchAggregators() {
bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs);
if (!ret) {
MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed.";
continue;
return false;
}
}
return true;

View File

@ -222,6 +222,7 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommu
MS_EXCEPTION_IF_NULL(communicator);
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
@ -231,6 +232,7 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommu
MS_LOG(ERROR)
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
"network building phase.";
safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
@ -285,6 +287,7 @@ void Server::StartCommunicator() {
DistributedMetadataStore::GetInstance().Initialize(server_node_);
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
MS_LOG(INFO) << "Start communicator with worker.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),

View File

@ -19,6 +19,7 @@
#include <vector>
#include <utility>
#include "ps/worker/fl_worker.h"
#include "utils/ms_exception.h"
namespace mindspore {
namespace ps {
@ -39,12 +40,38 @@ void FLWorker::Run() {
worker_node_ = std::make_shared<core::WorkerNode>();
MS_EXCEPTION_IF_NULL(worker_node_);
worker_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
Finalize();
try {
MS_LOG(EXCEPTION)
<< "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
} catch (std::exception &e) {
MsException::Instance().SetException();
}
});
worker_node_->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
Finalize();
try {
MS_LOG(EXCEPTION)
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
"network building phase.";
} catch (std::exception &e) {
MsException::Instance().SetException();
}
});
InitializeFollowerScaler();
worker_node_->Start();
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
return;
}
void FLWorker::Finalize() {
MS_EXCEPTION_IF_NULL(worker_node_);
worker_node_->Finish();
worker_node_->Stop();
}
bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output) {
// If the worker is in safemode, do not communicate with server.

View File

@ -60,6 +60,7 @@ class FLWorker {
return instance;
}
void Run();
void Finalize();
bool SendToServer(uint32_t server_rank, void *data, size_t size, core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output = nullptr);

View File

@ -824,14 +824,14 @@ def set_fl_context(**kwargs):
Args:
enable_fl (bool): Whether to enable federated learning training mode.
Default: False.
server_mode (string): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
server_mode (str): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
Default: 'FEDERATED_LEARNING'.
ms_role (string): The process's role in the federated learning mode,
ms_role (str): The process's role in the federated learning mode,
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
Default: 'MS_SERVER'.
worker_num (int): The number of workers. For current version, this must be set to 1 or 0.
server_num (int): The number of federated learning servers. Default: 0.
scheduler_ip (string): The scheduler IP. Default: '0.0.0.0'.
scheduler_ip (str): The scheduler IP. Default: '0.0.0.0'.
scheduler_port (int): The scheduler port. Default: 6667.
fl_server_port (int): The http port of the federated learning server.
Normally for each server this should be set to the same value. Default: 6668.
@ -842,7 +842,7 @@ def set_fl_context(**kwargs):
which will be multiplied by start_fl_job_threshold.
Must be between 0 and 1.0.Default: 1.0.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
fl_name (string): The federated learning job name. Default: ''.
fl_name (str): The federated learning job name. Default: ''.
fl_iteration_num (int): Iteration number of federeated learning,
which is the number of interactions between client and server. Default: 20.
client_epoch_num (int): Client training epoch number. Default: 25.

View File

@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# The script runs the process of server's disaster recovery. It will kill the server process and launch it again.
import ast
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="HYBRID_TRAINING")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
#The fl server port of the server which needs to be killed.
parser.add_argument("--disaster_recovery_server_port", type=int, default=10976)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--root_first_ca_path", type=str, default="")
parser.add_argument("--root_second_ca_path", type=str, default="")
parser.add_argument("--pki_verify", type=ast.literal_eval, default=False)
parser.add_argument("--root_first_crl_path", type=str, default="")
parser.add_argument("--root_second_crl_path", type=str, default="")
parser.add_argument("--sts_jar_path", type=str, default="")
parser.add_argument("--sts_properties_path", type=str, default="")
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
disaster_recovery_server_port = args.disaster_recovery_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
root_first_ca_path = args.root_first_ca_path
root_second_ca_path = args.root_second_ca_path
pki_verify = args.pki_verify
root_first_crl_path = args.root_first_crl_path
root_second_crl_path = args.root_second_crl_path
sts_jar_path = args.sts_jar_path
sts_properties_path = args.sts_properties_path
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type
enable_ssl = args.enable_ssl
#Step 1: make the server offline.
offline_cmd = "ps_demo_id=`ps -ef | grep " + str(disaster_recovery_server_port) \
+ "|grep -v cd | grep -v grep | grep -v run_server_disaster_recovery | awk '{print $2}'`"
offline_cmd += " && for id in $ps_demo_id; do kill -9 $id && echo \"Killed server process: $id\"; done"
subprocess.call(['bash', '-c', offline_cmd])
#Step 2: Wait 35 seconds for recovery.
wait_cmd = "echo \"Start to sleep for 35 seconds\" && sleep 35"
subprocess.call(['bash', '-c', wait_cmd])
#Step 3: Launch the server again with the same fl server port.
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_server += "rm -rf ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
cmd_server += "mkdir ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
cmd_server += "cd ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) \
+ "/ || exit && export GLOG_v=1 &&"
cmd_server += "python ${self_path}/../test_mobile_lenet.py"
cmd_server += " --device_target=" + device_target
cmd_server += " --server_mode=" + server_mode
cmd_server += " --ms_role=MS_SERVER"
cmd_server += " --worker_num=" + str(worker_num)
cmd_server += " --server_num=" + str(server_num)
cmd_server += " --scheduler_ip=" + scheduler_ip
cmd_server += " --scheduler_port=" + str(scheduler_port)
cmd_server += " --fl_server_port=" + str(disaster_recovery_server_port)
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
cmd_server += " --enable_ssl=" + str(enable_ssl)
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --root_first_ca_path=" + str(root_first_ca_path)
cmd_server += " --root_second_ca_path=" + str(root_second_ca_path)
cmd_server += " --pki_verify=" + str(pki_verify)
cmd_server += " --root_first_crl_path=" + str(root_first_crl_path)
cmd_server += " --root_second_crl_path=" + str(root_second_crl_path)
cmd_server += " --root_second_crl_path=" + str(root_second_crl_path)
cmd_server += " --sts_jar_path=" + str(sts_jar_path)
cmd_server += " --sts_properties_path=" + str(sts_properties_path)
cmd_server += " > server.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_server])

View File

@ -1,238 +1,272 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import time
import datetime
import random
import sys
import requests
import flatbuffers
import numpy as np
from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
RequestUpdateModel, FeatureMap, RequestGetModel, ResponseGetModel)
parser = argparse.ArgumentParser()
parser.add_argument("--pid", type=int, default=0)
parser.add_argument("--http_ip", type=str, default="10.113.216.106")
parser.add_argument("--http_port", type=int, default=6666)
parser.add_argument("--use_elb", type=bool, default=False)
parser.add_argument("--server_num", type=int, default=1)
args, _ = parser.parse_known_args()
pid = args.pid
http_ip = args.http_ip
http_port = args.http_port
use_elb = args.use_elb
server_num = args.server_num
str_fl_id = 'fl_lenet_' + str(pid)
def generate_port():
if not use_elb:
return http_port
port = random.randint(0, 100000) % server_num + http_port
return port
def build_start_fl_job(iteration):
start_fl_job_builder = flatbuffers.Builder(1024)
fl_name = start_fl_job_builder.CreateString('fl_test_job')
fl_id = start_fl_job_builder.CreateString(str_fl_id)
data_size = 32
timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
RequestFLJob.RequestFLJobStart(start_fl_job_builder)
RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
RequestFLJob.RequestFLJobAddIteration(start_fl_job_builder, iteration)
RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
start_fl_job_builder.Finish(fl_job_req)
buf = start_fl_job_builder.Output()
return buf
def build_feature_map(builder, names, lengths):
if len(names) != len(lengths):
return None
feature_maps = []
np_data = []
for j, _ in enumerate(names):
name = names[j]
length = lengths[j]
weight_full_name = builder.CreateString(name)
FeatureMap.FeatureMapStartDataVector(builder, length)
weight = np.random.rand(length) * 32
np_data.append(weight)
for idx in range(length - 1, -1, -1):
builder.PrependFloat32(weight[idx])
data = builder.EndVector(length)
FeatureMap.FeatureMapStart(builder)
FeatureMap.FeatureMapAddData(builder, data)
FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
feature_map = FeatureMap.FeatureMapEnd(builder)
feature_maps.append(feature_map)
return feature_maps, np_data
def build_update_model(iteration):
builder_update_model = flatbuffers.Builder(1)
fl_name = builder_update_model.CreateString('fl_test_job')
fl_id = builder_update_model.CreateString(str_fl_id)
timestamp = builder_update_model.CreateString('2020/11/16/19/18')
feature_maps, np_data = build_feature_map(builder_update_model,
["conv1.weight", "conv2.weight", "fc1.weight",
"fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
[450, 2400, 48000, 10080, 5208, 120, 84, 62])
RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
for single_feature_map in feature_maps:
builder_update_model.PrependUOffsetTRelative(single_feature_map)
feature_map = builder_update_model.EndVector(len(feature_maps))
RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
builder_update_model.Finish(req_update_model)
buf = builder_update_model.Output()
return buf, np_data
def build_get_model(iteration):
builder_get_model = flatbuffers.Builder(1)
fl_name = builder_get_model.CreateString('fl_test_job')
timestamp = builder_get_model.CreateString('2020/12/16/19/18')
RequestGetModel.RequestGetModelStart(builder_get_model)
RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
builder_get_model.Finish(req_get_model)
buf = builder_get_model.Output()
return buf
def datetime_to_timestamp(datetime_obj):
"""将本地(local) datetime 格式的时间 (含毫秒) 转为毫秒时间戳
:param datetime_obj: {datetime}2016-02-25 20:21:04.242000
:return: 13 位的毫秒时间戳 1456402864242
"""
local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
return local_timestamp
weight_to_idx = {
"conv1.weight": 0,
"conv2.weight": 1,
"fc1.weight": 2,
"fc2.weight": 3,
"fc3.weight": 4,
"fc1.bias": 5,
"fc2.bias": 6,
"fc3.bias": 7
}
session = requests.Session()
current_iteration = 1
url = "http://" + http_ip + ":" + str(generate_port())
np.random.seed(0)
while True:
url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
print("start url is ", url1)
x = session.post(url1, data=build_start_fl_job(current_iteration))
while x.text == "The cluster is in safemode.":
x = session.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
x = session.post(url1, data=build_start_fl_job(current_iteration))
while x.text == "The cluster is in safemode.":
time.sleep(0.2)
x = session.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
print("iteration is", rsp_fl_job.Iteration())
current_iteration = rsp_fl_job.Iteration()
sys.stdout.flush()
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
print("req update model iteration:", current_iteration, ", id:", args.pid)
update_model_buf, update_model_np_data = build_update_model(current_iteration)
x = session.post(url2, data=update_model_buf)
while x.text == "The cluster is in safemode.":
time.sleep(0.2)
x = session.post(url1, data=update_model_buf)
print("rsp update model iteration:", current_iteration, ", id:", args.pid)
sys.stdout.flush()
url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
print("req get model iteration:", current_iteration, ", id:", args.pid)
x = session.post(url3, data=build_get_model(current_iteration))
while x.text == "The cluster is in safemode.":
time.sleep(0.2)
x = session.post(url3, data=build_get_model(current_iteration))
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
sys.stdout.flush()
next_req_timestamp = 0
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
sys.stdout.flush()
elif rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
repeat_time = 0
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
time.sleep(0.2)
x = session.post(url3, data=build_get_model(current_iteration))
while x.text == "The cluster is in safemode.":
time.sleep(0.2)
x = session.post(url3, data=build_get_model(current_iteration))
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
sys.stdout.flush()
break
repeat_time += 1
if repeat_time > 1000:
print("GetModel try timeout ", args.pid)
sys.exit(0)
else:
pass
if next_req_timestamp == 0:
for i in range(0, 1):
print(rsp_get_model.FeatureMap(i).WeightFullname())
origin = update_model_np_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
print("Before update model", args.pid, origin[0:10])
print("After get model", args.pid, after[0:10])
sys.stdout.flush()
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
else:
# Sleep to the next request timestamp
current_ts = datetime_to_timestamp(datetime.datetime.now())
duration = next_req_timestamp - current_ts
if duration > 0:
time.sleep(duration / 1000)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import time
import datetime
import random
import sys
import requests
import flatbuffers
import numpy as np
from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
RequestUpdateModel, ResponseUpdateModel,
FeatureMap, RequestGetModel, ResponseGetModel)
parser = argparse.ArgumentParser()
parser.add_argument("--pid", type=int, default=0)
parser.add_argument("--http_ip", type=str, default="10.113.216.106")
parser.add_argument("--http_port", type=int, default=6666)
parser.add_argument("--use_elb", type=bool, default=False)
parser.add_argument("--server_num", type=int, default=1)
args, _ = parser.parse_known_args()
pid = args.pid
http_ip = args.http_ip
http_port = args.http_port
use_elb = args.use_elb
server_num = args.server_num
str_fl_id = 'fl_lenet_' + str(pid)
def generate_port():
if not use_elb:
return http_port
port = random.randint(0, 100000) % server_num + http_port
return port
def build_start_fl_job():
start_fl_job_builder = flatbuffers.Builder(1024)
fl_name = start_fl_job_builder.CreateString('fl_test_job')
fl_id = start_fl_job_builder.CreateString(str_fl_id)
data_size = 32
timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
RequestFLJob.RequestFLJobStart(start_fl_job_builder)
RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
start_fl_job_builder.Finish(fl_job_req)
buf = start_fl_job_builder.Output()
return buf
def build_feature_map(builder, names, lengths):
if len(names) != len(lengths):
return None
feature_maps = []
np_data = []
for j, _ in enumerate(names):
name = names[j]
length = lengths[j]
weight_full_name = builder.CreateString(name)
FeatureMap.FeatureMapStartDataVector(builder, length)
weight = np.random.rand(length) * 32
np_data.append(weight)
for idx in range(length - 1, -1, -1):
builder.PrependFloat32(weight[idx])
data = builder.EndVector(length)
FeatureMap.FeatureMapStart(builder)
FeatureMap.FeatureMapAddData(builder, data)
FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
feature_map = FeatureMap.FeatureMapEnd(builder)
feature_maps.append(feature_map)
return feature_maps, np_data
def build_update_model(iteration):
builder_update_model = flatbuffers.Builder(1)
fl_name = builder_update_model.CreateString('fl_test_job')
fl_id = builder_update_model.CreateString(str_fl_id)
timestamp = builder_update_model.CreateString('2020/11/16/19/18')
feature_maps, np_data = build_feature_map(builder_update_model,
["conv1.weight", "conv2.weight", "fc1.weight",
"fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
[450, 2400, 48000, 10080, 5208, 120, 84, 62])
RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
for single_feature_map in feature_maps:
builder_update_model.PrependUOffsetTRelative(single_feature_map)
feature_map = builder_update_model.EndVector(len(feature_maps))
RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
builder_update_model.Finish(req_update_model)
buf = builder_update_model.Output()
return buf, np_data
def build_get_model(iteration):
builder_get_model = flatbuffers.Builder(1)
fl_name = builder_get_model.CreateString('fl_test_job')
timestamp = builder_get_model.CreateString('2020/12/16/19/18')
RequestGetModel.RequestGetModelStart(builder_get_model)
RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
builder_get_model.Finish(req_get_model)
buf = builder_get_model.Output()
return buf
def datetime_to_timestamp(datetime_obj):
local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
return local_timestamp
weight_to_idx = {
"conv1.weight": 0,
"conv2.weight": 1,
"fc1.weight": 2,
"fc2.weight": 3,
"fc3.weight": 4,
"fc1.bias": 5,
"fc2.bias": 6,
"fc3.bias": 7
}
session = requests.Session()
current_iteration = 1
np.random.seed(0)
def start_fl_job():
start_fl_job_result = {}
iteration = 0
url = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
print("Start fl job url is ", url)
x = session.post(url, data=build_start_fl_job())
if x.text == "The cluster is in safemode.":
start_fl_job_result['reason'] = "Restart iteration."
start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
print("Start fl job when safemode.")
return start_fl_job_result, iteration
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
iteration = rsp_fl_job.Iteration()
if rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
if rsp_fl_job.Retcode() == ResponseCode.ResponseCode.OutOfTime:
start_fl_job_result['reason'] = "Restart iteration."
start_fl_job_result['next_ts'] = int(rsp_fl_job.NextReqTime().decode('utf-8'))
print("Start fl job out of time. Next request at ",
start_fl_job_result['next_ts'], "reason:", rsp_fl_job.Reason())
else:
print("Start fl job failed, return code is ", rsp_fl_job.Retcode())
sys.exit()
else:
start_fl_job_result['reason'] = "Success"
start_fl_job_result['next_ts'] = 0
return start_fl_job_result, iteration
def update_model(iteration):
update_model_result = {}
url = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
print("Update model url:", url, ", iteration:", iteration)
update_model_buf, update_model_np_data = build_update_model(iteration)
x = session.post(url, data=update_model_buf)
if x.text == "The cluster is in safemode.":
update_model_result['reason'] = "Restart iteration."
update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
print("Update model when safemode.")
return update_model_result, update_model_np_data
rsp_update_model = ResponseUpdateModel.ResponseUpdateModel.GetRootAsResponseUpdateModel(x.content, 0)
if rsp_update_model.Retcode() != ResponseCode.ResponseCode.SUCCEED:
if rsp_update_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
update_model_result['reason'] = "Restart iteration."
update_model_result['next_ts'] = int(rsp_update_model.NextReqTime().decode('utf-8'))
print("Update model out of time. Next request at ",
update_model_result['next_ts'], "reason:", rsp_update_model.Reason())
else:
print("Update model failed, return code is ", rsp_update_model.Retcode())
sys.exit()
else:
update_model_result['reason'] = "Success"
update_model_result['next_ts'] = 0
return update_model_result, update_model_np_data
def get_model(iteration, update_model_data):
get_model_result = {}
url = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
print("Get model url:", url, ", iteration:", iteration)
while True:
x = session.post(url, data=build_get_model(iteration))
if x.text == "The cluster is in safemode.":
print("Get model when safemode.")
time.sleep(0.5)
continue
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
ret_code = rsp_get_model.Retcode()
if ret_code == ResponseCode.ResponseCode.SUCCEED:
break
elif ret_code == ResponseCode.ResponseCode.SucNotReady:
time.sleep(0.5)
continue
else:
print("Get model failed, return code is ", rsp_get_model.Retcode())
sys.exit()
for i in range(0, 1):
print(rsp_get_model.FeatureMap(i).WeightFullname())
origin = update_model_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
print("Before update model", args.pid, origin[0:10])
print("After get model", args.pid, after[0:10])
sys.stdout.flush()
get_model_result['reason'] = "Success"
get_model_result['next_ts'] = 0
return get_model_result
while True:
result, current_iteration = start_fl_job()
sys.stdout.flush()
if result['reason'] == "Restart iteration.":
current_ts = datetime_to_timestamp(datetime.datetime.now())
duration = result['next_ts'] - current_ts
if duration >= 0:
time.sleep(duration / 1000)
continue
result, update_data = update_model(current_iteration)
sys.stdout.flush()
if result['reason'] == "Restart iteration.":
current_ts = datetime_to_timestamp(datetime.datetime.now())
duration = result['next_ts'] - current_ts
if duration >= 0:
time.sleep(duration / 1000)
continue
result = get_model(current_iteration, update_data)
sys.stdout.flush()
if result['reason'] == "Restart iteration.":
current_ts = datetime_to_timestamp(datetime.datetime.now())
duration = result['next_ts'] - current_ts
if duration >= 0:
time.sleep(duration / 1000)
continue

View File

@ -135,4 +135,4 @@ if __name__ == "__main__":
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("Accuracy:", acc['Accuracy'])
assert acc['Accuracy'] > 0.90
assert acc['Accuracy'] > 0.83