Synchronize with enter.

This commit is contained in:
ZPaC 2021-09-02 15:47:07 +08:00
parent ce0dbcdbf5
commit 85b9ee02c0
16 changed files with 123 additions and 46 deletions

View File

@ -80,7 +80,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
@ -95,7 +95,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
recv_chunk[j] += tmp_recv_chunk[j];
}
// Step 4: Wait until send is done.
if (!server_node_->Wait(send_req_id, 1)) {
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
@ -117,7 +117,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
@ -126,7 +126,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return false;
}
if (!server_node_->Wait(send_req_id, 1)) {
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
@ -157,7 +157,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
std::shared_ptr<std::vector<unsigned char>> recv_str;
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
@ -173,7 +173,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
} else {
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
if (!server_node_->Wait(send_req_id)) {
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
@ -187,7 +187,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
MS_LOG(DEBUG) << "Broadcast data to process " << i;
auto send_req_id =
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
if (!server_node_->Wait(send_req_id)) {
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
@ -196,7 +196,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}

View File

@ -29,6 +29,9 @@
namespace mindspore {
namespace fl {
namespace server {
// The timeout for server collective communication in case of network jitter.
constexpr uint32_t kCollectiveCommTimeout = 30;
// CollectiveOpsImpl is the collective communication API of the server.
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
// supported for the elastic scaling feature of the server.

View File

@ -274,7 +274,7 @@ bool Iteration::DisableServerInstance(std::string *result) {
instance_state_ = InstanceState::kDisable;
if (!ForciblyMoveToNextIteration()) {
*result = "Disabling instance failed. Can't drop current iteration and move to the next.";
MS_LOG(ERROR) << result;
MS_LOG(ERROR) << *result;
return false;
}
*result = "Disabling FL-Server succeeded.";
@ -310,6 +310,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
iteration_num_ = 1;
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
ModelStore::GetInstance().Reset();
if (metrics_ != nullptr) {
metrics_->Clear();
}
// Update the hyper-parameters on server and reinitialize rounds.
if (!UpdateHyperParams(new_instance_json)) {
@ -516,13 +519,6 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
MoveToNextIterResponse proceed_to_next_iter_rsp;
proceed_to_next_iter_rsp.set_result("success");
if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
proceed_to_next_iter_rsp.SerializeAsString().size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
return;
}
MoveToNextIterRequest proceed_to_next_iter_req;
(void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
@ -536,6 +532,14 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::Mess
// Synchronize the iteration number with leader server.
iteration_num_ = last_iter_num;
Next(is_last_iter_valid, reason);
MoveToNextIterResponse proceed_to_next_iter_rsp;
proceed_to_next_iter_rsp.set_result("success");
if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
proceed_to_next_iter_rsp.SerializeAsString().size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
return;
}
}
void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
@ -548,7 +552,9 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else {
// Store last iteration's model because this iteration is considered as invalid.
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first;
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
}
@ -642,7 +648,7 @@ bool Iteration::SummarizeIteration() {
metrics_->set_fl_name(ps::PSContext::instance()->fl_name());
metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num());
metrics_->set_cur_iteration_num(iteration_num_ - 1);
metrics_->set_cur_iteration_num(iteration_num_);
metrics_->set_instance_state(instance_state_.load());
metrics_->set_loss(loss_);
metrics_->set_accuracy(accuracy_);

View File

@ -54,21 +54,20 @@ bool IterationMetrics::Initialize() {
}
// Parse storage file path.
std::string metrics_file_path = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
auto realpath = Common::GetRealPath(metrics_file_path);
metrics_file_path_ = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
auto realpath = Common::GetRealPath(metrics_file_path_);
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path << " failed.";
MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path_ << " failed.";
return false;
}
metrics_file_.open(metrics_file_path, std::ios::ate | std::ios::out);
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
}
return true;
}
bool IterationMetrics::Summarize() {
if (!metrics_file_.is_open()) {
metrics_file_.clear();
MS_LOG(ERROR) << "The metrics file is not opened.";
return false;
}
@ -87,7 +86,14 @@ bool IterationMetrics::Summarize() {
return true;
}
bool IterationMetrics::Clear() { return true; }
bool IterationMetrics::Clear() {
if (metrics_file_.is_open()) {
MS_LOG(INFO) << "Clear the old metrics file " << metrics_file_path_;
metrics_file_.close();
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
}
return true;
}
void IterationMetrics::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }

View File

@ -96,6 +96,9 @@ class IterationMetrics {
// The metrics file object.
std::fstream metrics_file_;
// The metrics file path.
std::string metrics_file_path_;
// Json object of metrics data.
nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> js_;

View File

@ -115,8 +115,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
if (output != nullptr) {
while (true) {
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
output)) {
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output,
kWorkerTimeout)) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false;
}
@ -134,7 +134,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
}
}
} else {
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
kWorkerTimeout)) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false;
}

View File

@ -46,6 +46,9 @@ constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
// The rank of the leader server.
constexpr uint32_t kLeaderServerRank = 0;
// The timeout for worker sending message to server in case of network jitter.
constexpr uint32_t kWorkerTimeout = 30;
enum class IterationState {
// This iteration is still in process.
kRunning,

View File

@ -872,6 +872,9 @@ def set_fl_context(**kwargs):
Default: 'NOT_ENCRYPT'.
config_file_path (string): Configuration file path used by recovery. Default: ''.
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
client_password (str): Password to decrypt the secret key stored in the client certificate.
server_password (str): Password to decrypt the secret key stored in the server certificate.
Raises:
ValueError: If input key is not the attribute in federated learning mode context.

View File

@ -92,7 +92,7 @@ from ._quant_ops import *
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, PopulationCount, UpdateState, Load,
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
StartFLJob, UpdateModel, GetModel, PyFunc)
PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,

View File

@ -778,6 +778,32 @@ class PushWeight(PrimitiveWithInfer):
return mstype.float32
class PushMetrics(PrimitiveWithInfer):
"""
Push metrics like loss and accuracy for federated learning worker.
Inputs:
- **loss** (Tensor) - The loss.
- **accuracy** (Tensor) - The accuracy.
Outputs:
None.
"""
@prim_attr_register
def __init__(self):
"""Initialize PushMetrics"""
self.add_prim_attr("primitive_target", "CPU")
self.add_prim_attr("side_effect_mem", True)
self.init_prim_io_names(inputs=["loss", "accuracy"], outputs=["result"])
def infer_shape(self, loss, accuracy):
return [1]
def infer_dtype(self, loss, accuracy):
return mstype.float32
class StartFLJob(PrimitiveWithInfer):
"""
StartFLJob for federated learning worker.

View File

@ -15,6 +15,7 @@
# The script runs the process of server's disaster recovery. It will kill the server process and launch it again.
import os
import ast
import argparse
import subprocess
@ -28,6 +29,7 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
#The fl server port of the server which needs to be killed.
parser.add_argument("--disaster_recovery_server_port", type=int, default=10976)
parser.add_argument("--node_id", type=str, default="")
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
@ -49,6 +51,7 @@ parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
@ -62,6 +65,7 @@ server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
disaster_recovery_server_port = args.disaster_recovery_server_port
node_id = args.node_id
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
@ -72,6 +76,7 @@ client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
root_first_ca_path = args.root_first_ca_path
root_second_ca_path = args.root_second_ca_path
pki_verify = args.pki_verify
@ -94,11 +99,12 @@ offline_cmd = "ps_demo_id=`ps -ef | grep " + str(disaster_recovery_server_port)
offline_cmd += " && for id in $ps_demo_id; do kill -9 $id && echo \"Killed server process: $id\"; done"
subprocess.call(['bash', '-c', offline_cmd])
#Step 2: Wait 35 seconds for recovery.
wait_cmd = "echo \"Start to sleep for 35 seconds\" && sleep 35"
#Step 2: Wait 3 seconds for recovery.
wait_cmd = "echo \"Start to sleep for 3 seconds\" && sleep 3"
subprocess.call(['bash', '-c', wait_cmd])
#Step 3: Launch the server again with the same fl server port.
os.environ['MS_NODE_ID'] = str(node_id)
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_server += "rm -rf ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
cmd_server += "mkdir ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
@ -126,6 +132,7 @@ cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --config_file_path=" + str(config_file_path)
cmd_server += " --root_first_ca_path=" + str(root_first_ca_path)
cmd_server += " --root_second_ca_path=" + str(root_second_ca_path)
cmd_server += " --pki_verify=" + str(pki_verify)

View File

@ -41,6 +41,9 @@ server_num = args.server_num
str_fl_id = 'fl_lenet_' + str(pid)
server_not_available_rsp = ["The cluster is in safemode.",
"The server's training job is disabled or finished."]
def generate_port():
if not use_elb:
return http_port
@ -156,7 +159,7 @@ def start_fl_job():
print("Start fl job url is ", url)
x = session.post(url, data=build_start_fl_job())
if x.text == "The cluster is in safemode.":
if x.text in server_not_available_rsp:
start_fl_job_result['reason'] = "Restart iteration."
start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
print("Start fl job when safemode.")
@ -185,7 +188,7 @@ def update_model(iteration):
print("Update model url:", url, ", iteration:", iteration)
update_model_buf, update_model_np_data = build_update_model(iteration)
x = session.post(url, data=update_model_buf)
if x.text == "The cluster is in safemode.":
if x.text in server_not_available_rsp:
update_model_result['reason'] = "Restart iteration."
update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
print("Update model when safemode.")
@ -214,7 +217,7 @@ def get_model(iteration, update_model_data):
while True:
x = session.post(url, data=build_get_model(iteration))
if x.text == "The cluster is in safemode.":
if x.text in server_not_available_rsp:
print("Get model when safemode.")
time.sleep(0.5)
continue
@ -271,8 +274,5 @@ while True:
time.sleep(duration / 1000)
continue
if current_iteration == 1:
time.sleep(2)
print("")
sys.stdout.flush()

View File

@ -14,6 +14,7 @@
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
@ -70,3 +71,12 @@ class LeNet5(nn.Cell):
x = self.relu(x)
x = self.fc3(x)
return x
class PushMetrics(nn.Cell):
def __init__(self):
super(PushMetrics, self).__init__()
self.push_metrics = P.PushMetrics()
def construct(self, loss, acc):
x = self.push_metrics(loss, acc)
return x

View File

@ -20,9 +20,10 @@ import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import WithLossCell
from src.cell_wrapper import TrainOneStepCellWithServerCommunicator
from src.model import LeNet5
from src.model import LeNet5, PushMetrics
# from src.adam import AdamWeightDecayOp
parser = argparse.ArgumentParser(description="test_hybrid_train_lenet")
@ -131,6 +132,7 @@ if __name__ == "__main__":
epoch = 50000
np.random.seed(0)
network = LeNet5(62)
push_metrics = PushMetrics()
if context.get_fl_context("ms_role") == "MS_WORKER":
# Please do not freeze layers if you want to both get and overwrite these layers to servers, which is meaningless.
network.conv1.weight.requires_grad = False
@ -153,9 +155,12 @@ if __name__ == "__main__":
train_network.set_train()
losses = []
for _ in range(epoch):
for i in range(epoch):
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
loss = train_network(data, label).asnumpy()
if context.get_fl_context("ms_role") == "MS_WORKER":
if (i + 1) % worker_step_num_per_iteration == 0:
push_metrics(Tensor(loss, mstype.float32), Tensor(loss, mstype.float32))
losses.append(loss)
print(losses)

View File

@ -15,6 +15,7 @@
# The script runs the process of server's disaster recovery. It will kill the server process and launch it again.
import os
import ast
import argparse
import subprocess
@ -28,6 +29,7 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
#The fl server port of the server which needs to be killed.
parser.add_argument("--disaster_recovery_server_port", type=int, default=10976)
parser.add_argument("--node_id", type=str, default="")
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
@ -61,6 +63,7 @@ server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
disaster_recovery_server_port = args.disaster_recovery_server_port
node_id = args.node_id
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
@ -91,11 +94,12 @@ offline_cmd = "ps_demo_id=`ps -ef | grep " + str(disaster_recovery_server_port)
offline_cmd += " && for id in $ps_demo_id; do kill -9 $id && echo \"Killed server process: $id\"; done"
subprocess.call(['bash', '-c', offline_cmd])
#Step 2: Wait 35 seconds for recovery.
wait_cmd = "echo \"Start to sleep for 35 seconds\" && sleep 35"
#Step 2: Wait 3 seconds for recovery.
wait_cmd = "echo \"Start to sleep for 3 seconds\" && sleep 3"
subprocess.call(['bash', '-c', wait_cmd])
#Step 3: Launch the server again with the same fl server port.
os.environ['MS_NODE_ID'] = str(node_id)
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_server += "rm -rf ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
cmd_server += "mkdir ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"

View File

@ -41,6 +41,9 @@ server_num = args.server_num
str_fl_id = 'fl_lenet_' + str(pid)
server_not_available_rsp = ["The cluster is in safemode.",
"The server's training job is disabled or finished."]
def generate_port():
if not use_elb:
return http_port
@ -156,7 +159,7 @@ def start_fl_job():
print("Start fl job url is ", url)
x = session.post(url, data=build_start_fl_job())
if x.text == "The cluster is in safemode.":
if x.text in server_not_available_rsp:
start_fl_job_result['reason'] = "Restart iteration."
start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
print("Start fl job when safemode.")
@ -185,7 +188,7 @@ def update_model(iteration):
print("Update model url:", url, ", iteration:", iteration)
update_model_buf, update_model_np_data = build_update_model(iteration)
x = session.post(url, data=update_model_buf)
if x.text == "The cluster is in safemode.":
if x.text in server_not_available_rsp:
update_model_result['reason'] = "Restart iteration."
update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
print("Update model when safemode.")
@ -214,7 +217,7 @@ def get_model(iteration, update_model_data):
while True:
x = session.post(url, data=build_get_model(iteration))
if x.text == "The cluster is in safemode.":
if x.text in server_not_available_rsp:
print("Get model when safemode.")
time.sleep(0.5)
continue
@ -271,8 +274,5 @@ while True:
time.sleep(duration / 1000)
continue
if current_iteration == 1:
time.sleep(2)
print("")
sys.stdout.flush()