forked from mindspore-Ecosystem/mindspore
Synchronize with enter.
This commit is contained in:
parent
ce0dbcdbf5
commit
85b9ee02c0
|
@ -80,7 +80,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
recv_chunk[j] += tmp_recv_chunk[j];
|
||||
}
|
||||
// Step 4: Wait until send is done.
|
||||
if (!server_node_->Wait(send_req_id, 1)) {
|
||||
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -126,7 +126,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
if (!server_node_->Wait(send_req_id, 1)) {
|
||||
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
|||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
|||
} else {
|
||||
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
||||
if (!server_node_->Wait(send_req_id)) {
|
||||
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
|||
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
||||
auto send_req_id =
|
||||
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
||||
if (!server_node_->Wait(send_req_id)) {
|
||||
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
|||
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -29,6 +29,9 @@
|
|||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// The timeout for server collective communication in case of network jitter.
|
||||
constexpr uint32_t kCollectiveCommTimeout = 30;
|
||||
|
||||
// CollectiveOpsImpl is the collective communication API of the server.
|
||||
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
|
||||
// supported for the elastic scaling feature of the server.
|
||||
|
|
|
@ -274,7 +274,7 @@ bool Iteration::DisableServerInstance(std::string *result) {
|
|||
instance_state_ = InstanceState::kDisable;
|
||||
if (!ForciblyMoveToNextIteration()) {
|
||||
*result = "Disabling instance failed. Can't drop current iteration and move to the next.";
|
||||
MS_LOG(ERROR) << result;
|
||||
MS_LOG(ERROR) << *result;
|
||||
return false;
|
||||
}
|
||||
*result = "Disabling FL-Server succeeded.";
|
||||
|
@ -310,6 +310,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
|
|||
iteration_num_ = 1;
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
ModelStore::GetInstance().Reset();
|
||||
if (metrics_ != nullptr) {
|
||||
metrics_->Clear();
|
||||
}
|
||||
|
||||
// Update the hyper-parameters on server and reinitialize rounds.
|
||||
if (!UpdateHyperParams(new_instance_json)) {
|
||||
|
@ -516,13 +519,6 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
|
|||
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
|
||||
MoveToNextIterResponse proceed_to_next_iter_rsp;
|
||||
proceed_to_next_iter_rsp.set_result("success");
|
||||
if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
|
||||
proceed_to_next_iter_rsp.SerializeAsString().size(), message)) {
|
||||
MS_LOG(ERROR) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
|
||||
MoveToNextIterRequest proceed_to_next_iter_req;
|
||||
(void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
|
@ -536,6 +532,14 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::Mess
|
|||
// Synchronize the iteration number with leader server.
|
||||
iteration_num_ = last_iter_num;
|
||||
Next(is_last_iter_valid, reason);
|
||||
|
||||
MoveToNextIterResponse proceed_to_next_iter_rsp;
|
||||
proceed_to_next_iter_rsp.set_result("success");
|
||||
if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
|
||||
proceed_to_next_iter_rsp.SerializeAsString().size(), message)) {
|
||||
MS_LOG(ERROR) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
||||
|
@ -548,7 +552,9 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
} else {
|
||||
// Store last iteration's model because this iteration is considered as invalid.
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
|
||||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
||||
}
|
||||
|
@ -642,7 +648,7 @@ bool Iteration::SummarizeIteration() {
|
|||
|
||||
metrics_->set_fl_name(ps::PSContext::instance()->fl_name());
|
||||
metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num());
|
||||
metrics_->set_cur_iteration_num(iteration_num_ - 1);
|
||||
metrics_->set_cur_iteration_num(iteration_num_);
|
||||
metrics_->set_instance_state(instance_state_.load());
|
||||
metrics_->set_loss(loss_);
|
||||
metrics_->set_accuracy(accuracy_);
|
||||
|
|
|
@ -54,21 +54,20 @@ bool IterationMetrics::Initialize() {
|
|||
}
|
||||
|
||||
// Parse storage file path.
|
||||
std::string metrics_file_path = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
|
||||
auto realpath = Common::GetRealPath(metrics_file_path);
|
||||
metrics_file_path_ = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
|
||||
auto realpath = Common::GetRealPath(metrics_file_path_);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path << " failed.";
|
||||
MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
metrics_file_.open(metrics_file_path, std::ios::ate | std::ios::out);
|
||||
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IterationMetrics::Summarize() {
|
||||
if (!metrics_file_.is_open()) {
|
||||
metrics_file_.clear();
|
||||
MS_LOG(ERROR) << "The metrics file is not opened.";
|
||||
return false;
|
||||
}
|
||||
|
@ -87,7 +86,14 @@ bool IterationMetrics::Summarize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IterationMetrics::Clear() { return true; }
|
||||
bool IterationMetrics::Clear() {
|
||||
if (metrics_file_.is_open()) {
|
||||
MS_LOG(INFO) << "Clear the old metrics file " << metrics_file_path_;
|
||||
metrics_file_.close();
|
||||
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void IterationMetrics::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||
|
||||
|
|
|
@ -96,6 +96,9 @@ class IterationMetrics {
|
|||
// The metrics file object.
|
||||
std::fstream metrics_file_;
|
||||
|
||||
// The metrics file path.
|
||||
std::string metrics_file_path_;
|
||||
|
||||
// Json object of metrics data.
|
||||
nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> js_;
|
||||
|
||||
|
|
|
@ -115,8 +115,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
|
|||
|
||||
if (output != nullptr) {
|
||||
while (true) {
|
||||
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
|
||||
output)) {
|
||||
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output,
|
||||
kWorkerTimeout)) {
|
||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -134,7 +134,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
|
||||
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
|
||||
kWorkerTimeout)) {
|
||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -46,6 +46,9 @@ constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
|
|||
// The rank of the leader server.
|
||||
constexpr uint32_t kLeaderServerRank = 0;
|
||||
|
||||
// The timeout for worker sending message to server in case of network jitter.
|
||||
constexpr uint32_t kWorkerTimeout = 30;
|
||||
|
||||
enum class IterationState {
|
||||
// This iteration is still in process.
|
||||
kRunning,
|
||||
|
|
|
@ -872,6 +872,9 @@ def set_fl_context(**kwargs):
|
|||
Default: 'NOT_ENCRYPT'.
|
||||
config_file_path (string): Configuration file path used by recovery. Default: ''.
|
||||
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
|
||||
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
|
||||
client_password (str): Password to decrypt the secret key stored in the client certificate.
|
||||
server_password (str): Password to decrypt the secret key stored in the server certificate.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in federated learning mode context.
|
||||
|
|
|
@ -92,7 +92,7 @@ from ._quant_ops import *
|
|||
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
ConfusionMatrix, PopulationCount, UpdateState, Load,
|
||||
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
|
||||
StartFLJob, UpdateModel, GetModel, PyFunc)
|
||||
PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc)
|
||||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,
|
||||
|
|
|
@ -778,6 +778,32 @@ class PushWeight(PrimitiveWithInfer):
|
|||
return mstype.float32
|
||||
|
||||
|
||||
class PushMetrics(PrimitiveWithInfer):
|
||||
"""
|
||||
Push metrics like loss and accuracy for federated learning worker.
|
||||
|
||||
Inputs:
|
||||
- **loss** (Tensor) - The loss.
|
||||
- **accuracy** (Tensor) - The accuracy.
|
||||
|
||||
Outputs:
|
||||
None.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize PushMetrics"""
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
self.add_prim_attr("side_effect_mem", True)
|
||||
self.init_prim_io_names(inputs=["loss", "accuracy"], outputs=["result"])
|
||||
|
||||
def infer_shape(self, loss, accuracy):
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self, loss, accuracy):
|
||||
return mstype.float32
|
||||
|
||||
|
||||
class StartFLJob(PrimitiveWithInfer):
|
||||
"""
|
||||
StartFLJob for federated learning worker.
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
# The script runs the process of server's disaster recovery. It will kill the server process and launch it again.
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import subprocess
|
||||
|
@ -28,6 +29,7 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
|||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
#The fl server port of the server which needs to be killed.
|
||||
parser.add_argument("--disaster_recovery_server_port", type=int, default=10976)
|
||||
parser.add_argument("--node_id", type=str, default="")
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
|
@ -49,6 +51,7 @@ parser.add_argument("--dp_eps", type=float, default=50.0)
|
|||
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
|
@ -62,6 +65,7 @@ server_num = args.server_num
|
|||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
disaster_recovery_server_port = args.disaster_recovery_server_port
|
||||
node_id = args.node_id
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
|
@ -72,6 +76,7 @@ client_epoch_num = args.client_epoch_num
|
|||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
root_first_ca_path = args.root_first_ca_path
|
||||
root_second_ca_path = args.root_second_ca_path
|
||||
pki_verify = args.pki_verify
|
||||
|
@ -94,11 +99,12 @@ offline_cmd = "ps_demo_id=`ps -ef | grep " + str(disaster_recovery_server_port)
|
|||
offline_cmd += " && for id in $ps_demo_id; do kill -9 $id && echo \"Killed server process: $id\"; done"
|
||||
subprocess.call(['bash', '-c', offline_cmd])
|
||||
|
||||
#Step 2: Wait 35 seconds for recovery.
|
||||
wait_cmd = "echo \"Start to sleep for 35 seconds\" && sleep 35"
|
||||
#Step 2: Wait 3 seconds for recovery.
|
||||
wait_cmd = "echo \"Start to sleep for 3 seconds\" && sleep 3"
|
||||
subprocess.call(['bash', '-c', wait_cmd])
|
||||
|
||||
#Step 3: Launch the server again with the same fl server port.
|
||||
os.environ['MS_NODE_ID'] = str(node_id)
|
||||
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_server += "rm -rf ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
|
||||
cmd_server += "mkdir ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
|
||||
|
@ -126,6 +132,7 @@ cmd_server += " --dp_eps=" + str(dp_eps)
|
|||
cmd_server += " --dp_delta=" + str(dp_delta)
|
||||
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --config_file_path=" + str(config_file_path)
|
||||
cmd_server += " --root_first_ca_path=" + str(root_first_ca_path)
|
||||
cmd_server += " --root_second_ca_path=" + str(root_second_ca_path)
|
||||
cmd_server += " --pki_verify=" + str(pki_verify)
|
||||
|
|
|
@ -41,6 +41,9 @@ server_num = args.server_num
|
|||
|
||||
str_fl_id = 'fl_lenet_' + str(pid)
|
||||
|
||||
server_not_available_rsp = ["The cluster is in safemode.",
|
||||
"The server's training job is disabled or finished."]
|
||||
|
||||
def generate_port():
|
||||
if not use_elb:
|
||||
return http_port
|
||||
|
@ -156,7 +159,7 @@ def start_fl_job():
|
|||
print("Start fl job url is ", url)
|
||||
|
||||
x = session.post(url, data=build_start_fl_job())
|
||||
if x.text == "The cluster is in safemode.":
|
||||
if x.text in server_not_available_rsp:
|
||||
start_fl_job_result['reason'] = "Restart iteration."
|
||||
start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
|
||||
print("Start fl job when safemode.")
|
||||
|
@ -185,7 +188,7 @@ def update_model(iteration):
|
|||
print("Update model url:", url, ", iteration:", iteration)
|
||||
update_model_buf, update_model_np_data = build_update_model(iteration)
|
||||
x = session.post(url, data=update_model_buf)
|
||||
if x.text == "The cluster is in safemode.":
|
||||
if x.text in server_not_available_rsp:
|
||||
update_model_result['reason'] = "Restart iteration."
|
||||
update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
|
||||
print("Update model when safemode.")
|
||||
|
@ -214,7 +217,7 @@ def get_model(iteration, update_model_data):
|
|||
|
||||
while True:
|
||||
x = session.post(url, data=build_get_model(iteration))
|
||||
if x.text == "The cluster is in safemode.":
|
||||
if x.text in server_not_available_rsp:
|
||||
print("Get model when safemode.")
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
@ -271,8 +274,5 @@ while True:
|
|||
time.sleep(duration / 1000)
|
||||
continue
|
||||
|
||||
if current_iteration == 1:
|
||||
time.sleep(2)
|
||||
|
||||
print("")
|
||||
sys.stdout.flush()
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
|
@ -70,3 +71,12 @@ class LeNet5(nn.Cell):
|
|||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
class PushMetrics(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PushMetrics, self).__init__()
|
||||
self.push_metrics = P.PushMetrics()
|
||||
|
||||
def construct(self, loss, acc):
|
||||
x = self.push_metrics(loss, acc)
|
||||
return x
|
||||
|
|
|
@ -20,9 +20,10 @@ import numpy as np
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import WithLossCell
|
||||
from src.cell_wrapper import TrainOneStepCellWithServerCommunicator
|
||||
from src.model import LeNet5
|
||||
from src.model import LeNet5, PushMetrics
|
||||
# from src.adam import AdamWeightDecayOp
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_hybrid_train_lenet")
|
||||
|
@ -131,6 +132,7 @@ if __name__ == "__main__":
|
|||
epoch = 50000
|
||||
np.random.seed(0)
|
||||
network = LeNet5(62)
|
||||
push_metrics = PushMetrics()
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
# Please do not freeze layers if you want to both get and overwrite these layers to servers, which is meaningless.
|
||||
network.conv1.weight.requires_grad = False
|
||||
|
@ -153,9 +155,12 @@ if __name__ == "__main__":
|
|||
train_network.set_train()
|
||||
losses = []
|
||||
|
||||
for _ in range(epoch):
|
||||
for i in range(epoch):
|
||||
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
|
||||
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
|
||||
loss = train_network(data, label).asnumpy()
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
if (i + 1) % worker_step_num_per_iteration == 0:
|
||||
push_metrics(Tensor(loss, mstype.float32), Tensor(loss, mstype.float32))
|
||||
losses.append(loss)
|
||||
print(losses)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
# The script runs the process of server's disaster recovery. It will kill the server process and launch it again.
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import subprocess
|
||||
|
@ -28,6 +29,7 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
|||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
#The fl server port of the server which needs to be killed.
|
||||
parser.add_argument("--disaster_recovery_server_port", type=int, default=10976)
|
||||
parser.add_argument("--node_id", type=str, default="")
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
|
@ -61,6 +63,7 @@ server_num = args.server_num
|
|||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
disaster_recovery_server_port = args.disaster_recovery_server_port
|
||||
node_id = args.node_id
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
|
@ -91,11 +94,12 @@ offline_cmd = "ps_demo_id=`ps -ef | grep " + str(disaster_recovery_server_port)
|
|||
offline_cmd += " && for id in $ps_demo_id; do kill -9 $id && echo \"Killed server process: $id\"; done"
|
||||
subprocess.call(['bash', '-c', offline_cmd])
|
||||
|
||||
#Step 2: Wait 35 seconds for recovery.
|
||||
wait_cmd = "echo \"Start to sleep for 35 seconds\" && sleep 35"
|
||||
#Step 2: Wait 3 seconds for recovery.
|
||||
wait_cmd = "echo \"Start to sleep for 3 seconds\" && sleep 3"
|
||||
subprocess.call(['bash', '-c', wait_cmd])
|
||||
|
||||
#Step 3: Launch the server again with the same fl server port.
|
||||
os.environ['MS_NODE_ID'] = str(node_id)
|
||||
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_server += "rm -rf ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
|
||||
cmd_server += "mkdir ${execute_path}/disaster_recovery_server_" + str(disaster_recovery_server_port) + "/ &&"
|
||||
|
|
|
@ -41,6 +41,9 @@ server_num = args.server_num
|
|||
|
||||
str_fl_id = 'fl_lenet_' + str(pid)
|
||||
|
||||
server_not_available_rsp = ["The cluster is in safemode.",
|
||||
"The server's training job is disabled or finished."]
|
||||
|
||||
def generate_port():
|
||||
if not use_elb:
|
||||
return http_port
|
||||
|
@ -156,7 +159,7 @@ def start_fl_job():
|
|||
print("Start fl job url is ", url)
|
||||
|
||||
x = session.post(url, data=build_start_fl_job())
|
||||
if x.text == "The cluster is in safemode.":
|
||||
if x.text in server_not_available_rsp:
|
||||
start_fl_job_result['reason'] = "Restart iteration."
|
||||
start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
|
||||
print("Start fl job when safemode.")
|
||||
|
@ -185,7 +188,7 @@ def update_model(iteration):
|
|||
print("Update model url:", url, ", iteration:", iteration)
|
||||
update_model_buf, update_model_np_data = build_update_model(iteration)
|
||||
x = session.post(url, data=update_model_buf)
|
||||
if x.text == "The cluster is in safemode.":
|
||||
if x.text in server_not_available_rsp:
|
||||
update_model_result['reason'] = "Restart iteration."
|
||||
update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
|
||||
print("Update model when safemode.")
|
||||
|
@ -214,7 +217,7 @@ def get_model(iteration, update_model_data):
|
|||
|
||||
while True:
|
||||
x = session.post(url, data=build_get_model(iteration))
|
||||
if x.text == "The cluster is in safemode.":
|
||||
if x.text in server_not_available_rsp:
|
||||
print("Get model when safemode.")
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
@ -271,8 +274,5 @@ while True:
|
|||
time.sleep(duration / 1000)
|
||||
continue
|
||||
|
||||
if current_iteration == 1:
|
||||
time.sleep(2)
|
||||
|
||||
print("")
|
||||
sys.stdout.flush()
|
||||
|
|
Loading…
Reference in New Issue