diff --git a/mindspore/ccsrc/fl/server/collective_ops_impl.cc b/mindspore/ccsrc/fl/server/collective_ops_impl.cc index f9f5991216e..20527521aff 100644 --- a/mindspore/ccsrc/fl/server/collective_ops_impl.cc +++ b/mindspore/ccsrc/fl/server/collective_ops_impl.cc @@ -80,7 +80,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size std::shared_ptr> recv_str; auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str); - if (!server_node_->CollectiveWait(recv_req_id)) { + if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; return false; } @@ -95,7 +95,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size recv_chunk[j] += tmp_recv_chunk[j]; } // Step 4: Wait until send is done. - if (!server_node_->Wait(send_req_id, 1)) { + if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; return false; } @@ -117,7 +117,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size std::shared_ptr> recv_str; auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str); - if (!server_node_->CollectiveWait(recv_req_id)) { + if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; return false; } @@ -126,7 +126,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; return false; } - if (!server_node_->Wait(send_req_id, 1)) { + if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; return false; } @@ -157,7 +157,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec std::shared_ptr> recv_str; MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str); - if (!server_node_->CollectiveWait(recv_req_id)) { + if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; return false; } @@ -173,7 +173,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec } else { MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); - if (!server_node_->Wait(send_req_id)) { + if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; return false; } @@ -187,7 +187,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec MS_LOG(DEBUG) << "Broadcast data to process " << i; auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); - if (!server_node_->Wait(send_req_id)) { + if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; return false; } @@ -196,7 +196,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec MS_LOG(DEBUG) << "Broadcast receive from rank 0."; std::shared_ptr> recv_str; auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str); - if (!server_node_->CollectiveWait(recv_req_id)) { + if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; return false; } diff --git a/mindspore/ccsrc/fl/server/collective_ops_impl.h b/mindspore/ccsrc/fl/server/collective_ops_impl.h index 1e77d9da018..4d8f2d99585 100644 --- a/mindspore/ccsrc/fl/server/collective_ops_impl.h +++ b/mindspore/ccsrc/fl/server/collective_ops_impl.h @@ -29,6 +29,9 @@ namespace mindspore { namespace fl { namespace server { +// The timeout for server collective communication in case of network jitter. +constexpr uint32_t kCollectiveCommTimeout = 30; + // CollectiveOpsImpl is the collective communication API of the server. // For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also // supported for the elastic scaling feature of the server. diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index 0ac831f3d22..da99f82dcff 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -274,7 +274,7 @@ bool Iteration::DisableServerInstance(std::string *result) { instance_state_ = InstanceState::kDisable; if (!ForciblyMoveToNextIteration()) { *result = "Disabling instance failed. Can't drop current iteration and move to the next."; - MS_LOG(ERROR) << result; + MS_LOG(ERROR) << *result; return false; } *result = "Disabling FL-Server succeeded."; @@ -310,6 +310,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string iteration_num_ = 1; LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); ModelStore::GetInstance().Reset(); + if (metrics_ != nullptr) { + metrics_->Clear(); + } // Update the hyper-parameters on server and reinitialize rounds. if (!UpdateHyperParams(new_instance_json)) { @@ -516,13 +519,6 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr &message) { MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); - MoveToNextIterResponse proceed_to_next_iter_rsp; - proceed_to_next_iter_rsp.set_result("success"); - if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), - proceed_to_next_iter_rsp.SerializeAsString().size(), message)) { - MS_LOG(ERROR) << "Sending response failed."; - return; - } MoveToNextIterRequest proceed_to_next_iter_req; (void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); @@ -536,6 +532,14 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptrSendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), + proceed_to_next_iter_rsp.SerializeAsString().size(), message)) { + MS_LOG(ERROR) << "Sending response failed."; + return; + } } void Iteration::Next(bool is_iteration_valid, const std::string &reason) { @@ -548,7 +552,9 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; } else { // Store last iteration's model because this iteration is considered as invalid. - const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1); + const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); + size_t latest_iter_num = iter_to_model.rbegin()->first; + const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num); ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; } @@ -642,7 +648,7 @@ bool Iteration::SummarizeIteration() { metrics_->set_fl_name(ps::PSContext::instance()->fl_name()); metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num()); - metrics_->set_cur_iteration_num(iteration_num_ - 1); + metrics_->set_cur_iteration_num(iteration_num_); metrics_->set_instance_state(instance_state_.load()); metrics_->set_loss(loss_); metrics_->set_accuracy(accuracy_); diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.cc b/mindspore/ccsrc/fl/server/iteration_metrics.cc index b77d196477d..33c8fda5180 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.cc +++ b/mindspore/ccsrc/fl/server/iteration_metrics.cc @@ -54,21 +54,20 @@ bool IterationMetrics::Initialize() { } // Parse storage file path. - std::string metrics_file_path = JsonGetKeyWithException(value_json, ps::kStoreFilePath); - auto realpath = Common::GetRealPath(metrics_file_path); + metrics_file_path_ = JsonGetKeyWithException(value_json, ps::kStoreFilePath); + auto realpath = Common::GetRealPath(metrics_file_path_); if (!realpath.has_value()) { - MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path << " failed."; + MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path_ << " failed."; return false; } - metrics_file_.open(metrics_file_path, std::ios::ate | std::ios::out); + metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out); } return true; } bool IterationMetrics::Summarize() { if (!metrics_file_.is_open()) { - metrics_file_.clear(); MS_LOG(ERROR) << "The metrics file is not opened."; return false; } @@ -87,7 +86,14 @@ bool IterationMetrics::Summarize() { return true; } -bool IterationMetrics::Clear() { return true; } +bool IterationMetrics::Clear() { + if (metrics_file_.is_open()) { + MS_LOG(INFO) << "Clear the old metrics file " << metrics_file_path_; + metrics_file_.close(); + metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out); + } + return true; +} void IterationMetrics::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.h b/mindspore/ccsrc/fl/server/iteration_metrics.h index 596f352771e..de011ed4ae5 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.h +++ b/mindspore/ccsrc/fl/server/iteration_metrics.h @@ -96,6 +96,9 @@ class IterationMetrics { // The metrics file object. std::fstream metrics_file_; + // The metrics file path. + std::string metrics_file_path_; + // Json object of metrics data. nlohmann::basic_json js_; diff --git a/mindspore/ccsrc/fl/worker/fl_worker.cc b/mindspore/ccsrc/fl/worker/fl_worker.cc index 8acdf15b455..412d88133c4 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.cc +++ b/mindspore/ccsrc/fl/worker/fl_worker.cc @@ -115,8 +115,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, if (output != nullptr) { while (true) { - if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast(command), - output)) { + if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast(command), output, + kWorkerTimeout)) { MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed."; return false; } @@ -134,7 +134,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, } } } else { - if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast(command))) { + if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast(command), + kWorkerTimeout)) { MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed."; return false; } diff --git a/mindspore/ccsrc/fl/worker/fl_worker.h b/mindspore/ccsrc/fl/worker/fl_worker.h index 4b0fc9e2fde..04f7d12a7dc 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.h +++ b/mindspore/ccsrc/fl/worker/fl_worker.h @@ -46,6 +46,9 @@ constexpr uint32_t kWorkerRetryDurationForSafeMode = 500; // The rank of the leader server. constexpr uint32_t kLeaderServerRank = 0; +// The timeout for worker sending message to server in case of network jitter. +constexpr uint32_t kWorkerTimeout = 30; + enum class IterationState { // This iteration is still in process. kRunning, diff --git a/mindspore/context.py b/mindspore/context.py index c37cead87bf..698b4379855 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -872,6 +872,9 @@ def set_fl_context(**kwargs): Default: 'NOT_ENCRYPT'. config_file_path (string): Configuration file path used by recovery. Default: ''. scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202. + enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true. + client_password (str): Password to decrypt the secret key stored in the client certificate. + server_password (str): Password to decrypt the secret key stored in the server certificate. Raises: ValueError: If input key is not the attribute in federated learning mode context. diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 2da62433ea3..afbb3aa1f96 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -92,7 +92,7 @@ from ._quant_ops import * from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, ConfusionMatrix, PopulationCount, UpdateState, Load, CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight, - StartFLJob, UpdateModel, GetModel, PyFunc) + PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc) from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col, diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index a5aa95bf458..1c639c163c0 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -778,6 +778,32 @@ class PushWeight(PrimitiveWithInfer): return mstype.float32 +class PushMetrics(PrimitiveWithInfer): + """ + Push metrics like loss and accuracy for federated learning worker. + + Inputs: + - **loss** (Tensor) - The loss. + - **accuracy** (Tensor) - The accuracy. + + Outputs: + None. + """ + + @prim_attr_register + def __init__(self): + """Initialize PushMetrics""" + self.add_prim_attr("primitive_target", "CPU") + self.add_prim_attr("side_effect_mem", True) + self.init_prim_io_names(inputs=["loss", "accuracy"], outputs=["result"]) + + def infer_shape(self, loss, accuracy): + return [1] + + def infer_dtype(self, loss, accuracy): + return mstype.float32 + + class StartFLJob(PrimitiveWithInfer): """ StartFLJob for federated learning worker. diff --git a/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py b/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py index 60d65aab301..476000db53c 100644 --- a/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py +++ b/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py @@ -15,6 +15,7 @@ # The script runs the process of server's disaster recovery. It will kill the server process and launch it again. +import os import ast import argparse import subprocess @@ -28,6 +29,7 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") parser.add_argument("--scheduler_port", type=int, default=8113) #The fl server port of the server which needs to be killed. parser.add_argument("--disaster_recovery_server_port", type=int, default=10976) +parser.add_argument("--node_id", type=str, default="") parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) @@ -49,6 +51,7 @@ parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +parser.add_argument("--config_file_path", type=str, default="") parser.add_argument("--client_password", type=str, default="") parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) @@ -62,6 +65,7 @@ server_num = args.server_num scheduler_ip = args.scheduler_ip scheduler_port = args.scheduler_port disaster_recovery_server_port = args.disaster_recovery_server_port +node_id = args.node_id start_fl_job_threshold = args.start_fl_job_threshold start_fl_job_time_window = args.start_fl_job_time_window update_model_ratio = args.update_model_ratio @@ -72,6 +76,7 @@ client_epoch_num = args.client_epoch_num client_batch_size = args.client_batch_size client_learning_rate = args.client_learning_rate local_server_num = args.local_server_num +config_file_path = args.config_file_path root_first_ca_path = args.root_first_ca_path root_second_ca_path = args.root_second_ca_path pki_verify = args.pki_verify @@ -94,11 +99,12 @@ offline_cmd = "ps_demo_id=`ps -ef | grep " + str(disaster_recovery_server_port) offline_cmd += " && for id in $ps_demo_id; do kill -9 $id && echo \"Killed server process: $id\"; done" subprocess.call(['bash', '-c', offline_cmd]) -#Step 2: Wait 35 seconds for recovery. -wait_cmd = "echo \"Start to sleep for 35 seconds\" && sleep 35" +#Step 2: Wait 3 seconds for recovery. +wait_cmd = "echo \"Start to sleep for 3 seconds\" && sleep 3" subprocess.call(['bash', '-c', wait_cmd]) #Step 3: Launch the server again with the same fl server port. +os.environ['MS_NODE_ID'] = str(node_id) cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " cmd_server += "rm -rf ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&" cmd_server += "mkdir ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&" @@ -126,6 +132,7 @@ cmd_server += " --dp_eps=" + str(dp_eps) cmd_server += " --dp_delta=" + str(dp_delta) cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) cmd_server += " --encrypt_type=" + str(encrypt_type) +cmd_server += " --config_file_path=" + str(config_file_path) cmd_server += " --root_first_ca_path=" + str(root_first_ca_path) cmd_server += " --root_second_ca_path=" + str(root_second_ca_path) cmd_server += " --pki_verify=" + str(pki_verify) diff --git a/tests/st/fl/hybrid_lenet/simulator.py b/tests/st/fl/hybrid_lenet/simulator.py index 976bf1361db..30817e48706 100644 --- a/tests/st/fl/hybrid_lenet/simulator.py +++ b/tests/st/fl/hybrid_lenet/simulator.py @@ -41,6 +41,9 @@ server_num = args.server_num str_fl_id = 'fl_lenet_' + str(pid) +server_not_available_rsp = ["The cluster is in safemode.", + "The server's training job is disabled or finished."] + def generate_port(): if not use_elb: return http_port @@ -156,7 +159,7 @@ def start_fl_job(): print("Start fl job url is ", url) x = session.post(url, data=build_start_fl_job()) - if x.text == "The cluster is in safemode.": + if x.text in server_not_available_rsp: start_fl_job_result['reason'] = "Restart iteration." start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500 print("Start fl job when safemode.") @@ -185,7 +188,7 @@ def update_model(iteration): print("Update model url:", url, ", iteration:", iteration) update_model_buf, update_model_np_data = build_update_model(iteration) x = session.post(url, data=update_model_buf) - if x.text == "The cluster is in safemode.": + if x.text in server_not_available_rsp: update_model_result['reason'] = "Restart iteration." update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500 print("Update model when safemode.") @@ -214,7 +217,7 @@ def get_model(iteration, update_model_data): while True: x = session.post(url, data=build_get_model(iteration)) - if x.text == "The cluster is in safemode.": + if x.text in server_not_available_rsp: print("Get model when safemode.") time.sleep(0.5) continue @@ -271,8 +274,5 @@ while True: time.sleep(duration / 1000) continue - if current_iteration == 1: - time.sleep(2) - print("") sys.stdout.flush() diff --git a/tests/st/fl/hybrid_lenet/src/model.py b/tests/st/fl/hybrid_lenet/src/model.py index aba4940499e..352118736d4 100644 --- a/tests/st/fl/hybrid_lenet/src/model.py +++ b/tests/st/fl/hybrid_lenet/src/model.py @@ -14,6 +14,7 @@ # ============================================================================ import mindspore.nn as nn +from mindspore.ops import operations as P from mindspore.common.initializer import TruncatedNormal def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): @@ -70,3 +71,12 @@ class LeNet5(nn.Cell): x = self.relu(x) x = self.fc3(x) return x + +class PushMetrics(nn.Cell): + def __init__(self): + super(PushMetrics, self).__init__() + self.push_metrics = P.PushMetrics() + + def construct(self, loss, acc): + x = self.push_metrics(loss, acc) + return x diff --git a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py index ab5c8bf6fa6..744c36efa61 100644 --- a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py +++ b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py @@ -20,9 +20,10 @@ import numpy as np import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor +from mindspore.common import dtype as mstype from mindspore.nn import WithLossCell from src.cell_wrapper import TrainOneStepCellWithServerCommunicator -from src.model import LeNet5 +from src.model import LeNet5, PushMetrics # from src.adam import AdamWeightDecayOp parser = argparse.ArgumentParser(description="test_hybrid_train_lenet") @@ -131,6 +132,7 @@ if __name__ == "__main__": epoch = 50000 np.random.seed(0) network = LeNet5(62) + push_metrics = PushMetrics() if context.get_fl_context("ms_role") == "MS_WORKER": # Please do not freeze layers if you want to both get and overwrite these layers to servers, which is meaningless. network.conv1.weight.requires_grad = False @@ -153,9 +155,12 @@ if __name__ == "__main__": train_network.set_train() losses = [] - for _ in range(epoch): + for i in range(epoch): data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32)) loss = train_network(data, label).asnumpy() + if context.get_fl_context("ms_role") == "MS_WORKER": + if (i + 1) % worker_step_num_per_iteration == 0: + push_metrics(Tensor(loss, mstype.float32), Tensor(loss, mstype.float32)) losses.append(loss) print(losses) diff --git a/tests/st/fl/mobile/run_server_disaster_recovery.py b/tests/st/fl/mobile/run_server_disaster_recovery.py index 2bf0f80ad20..09664276884 100644 --- a/tests/st/fl/mobile/run_server_disaster_recovery.py +++ b/tests/st/fl/mobile/run_server_disaster_recovery.py @@ -15,6 +15,7 @@ # The script runs the process of server's disaster recovery. It will kill the server process and launch it again. +import os import ast import argparse import subprocess @@ -28,6 +29,7 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") parser.add_argument("--scheduler_port", type=int, default=8113) #The fl server port of the server which needs to be killed. parser.add_argument("--disaster_recovery_server_port", type=int, default=10976) +parser.add_argument("--node_id", type=str, default="") parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) @@ -61,6 +63,7 @@ server_num = args.server_num scheduler_ip = args.scheduler_ip scheduler_port = args.scheduler_port disaster_recovery_server_port = args.disaster_recovery_server_port +node_id = args.node_id start_fl_job_threshold = args.start_fl_job_threshold start_fl_job_time_window = args.start_fl_job_time_window update_model_ratio = args.update_model_ratio @@ -91,11 +94,12 @@ offline_cmd = "ps_demo_id=`ps -ef | grep " + str(disaster_recovery_server_port) offline_cmd += " && for id in $ps_demo_id; do kill -9 $id && echo \"Killed server process: $id\"; done" subprocess.call(['bash', '-c', offline_cmd]) -#Step 2: Wait 35 seconds for recovery. -wait_cmd = "echo \"Start to sleep for 35 seconds\" && sleep 35" +#Step 2: Wait 3 seconds for recovery. +wait_cmd = "echo \"Start to sleep for 3 seconds\" && sleep 3" subprocess.call(['bash', '-c', wait_cmd]) #Step 3: Launch the server again with the same fl server port. +os.environ['MS_NODE_ID'] = str(node_id) cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " cmd_server += "rm -rf ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&" cmd_server += "mkdir ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&" diff --git a/tests/st/fl/mobile/simulator.py b/tests/st/fl/mobile/simulator.py index 976bf1361db..30817e48706 100644 --- a/tests/st/fl/mobile/simulator.py +++ b/tests/st/fl/mobile/simulator.py @@ -41,6 +41,9 @@ server_num = args.server_num str_fl_id = 'fl_lenet_' + str(pid) +server_not_available_rsp = ["The cluster is in safemode.", + "The server's training job is disabled or finished."] + def generate_port(): if not use_elb: return http_port @@ -156,7 +159,7 @@ def start_fl_job(): print("Start fl job url is ", url) x = session.post(url, data=build_start_fl_job()) - if x.text == "The cluster is in safemode.": + if x.text in server_not_available_rsp: start_fl_job_result['reason'] = "Restart iteration." start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500 print("Start fl job when safemode.") @@ -185,7 +188,7 @@ def update_model(iteration): print("Update model url:", url, ", iteration:", iteration) update_model_buf, update_model_np_data = build_update_model(iteration) x = session.post(url, data=update_model_buf) - if x.text == "The cluster is in safemode.": + if x.text in server_not_available_rsp: update_model_result['reason'] = "Restart iteration." update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500 print("Update model when safemode.") @@ -214,7 +217,7 @@ def get_model(iteration, update_model_data): while True: x = session.post(url, data=build_get_model(iteration)) - if x.text == "The cluster is in safemode.": + if x.text in server_not_available_rsp: print("Get model when safemode.") time.sleep(0.5) continue @@ -271,8 +274,5 @@ while True: time.sleep(duration / 1000) continue - if current_iteration == 1: - time.sleep(2) - print("") sys.stdout.flush()