!32365 fix issue I4Z7WC、I502P8、I502TN、I5031D、I503MS、I503SO、I502L2
Merge pull request !32365 from tan-wei-cheng-3260/r1.6-develop3
This commit is contained in:
commit
613939605c
|
@ -32,7 +32,7 @@ void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &
|
||||||
MS_EXCEPTION_IF_NULL(server_node);
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
rank_id_ = server_node_->rank_id();
|
rank_id_ = server_node_->rank_id();
|
||||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
server_num_ = server_node->server_num();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerN
|
||||||
MS_EXCEPTION_IF_NULL(server_node);
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
local_rank_ = server_node_->rank_id();
|
local_rank_ = server_node_->rank_id();
|
||||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
server_num_ = server_node->server_num();
|
||||||
counting_server_rank_ = counting_server_rank;
|
counting_server_rank_ = counting_server_rank;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -109,6 +109,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
||||||
}
|
}
|
||||||
if (!TriggerCounterEvent(name, reason)) {
|
if (!TriggerCounterEvent(name, reason)) {
|
||||||
MS_LOG(WARNING) << "Leader server trigger count event failed.";
|
MS_LOG(WARNING) << "Leader server trigger count event failed.";
|
||||||
|
Iteration::GetInstance().NotifyNext(false, *reason);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "fl/server/iteration.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -26,7 +27,7 @@ void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::Server
|
||||||
MS_EXCEPTION_IF_NULL(server_node);
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
local_rank_ = server_node_->rank_id();
|
local_rank_ = server_node_->rank_id();
|
||||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
server_num_ = server_node->server_num();
|
||||||
InitHashRing();
|
InitHashRing();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -109,6 +110,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
|
||||||
if (reason != nullptr) {
|
if (reason != nullptr) {
|
||||||
*reason = kNetworkError;
|
*reason = kNetworkError;
|
||||||
}
|
}
|
||||||
|
Iteration::GetInstance().NotifyNext(false, *reason);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -323,8 +323,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
|
||||||
}
|
}
|
||||||
|
|
||||||
if (iteration_num_ == 1) {
|
if (iteration_num_ == 1) {
|
||||||
MS_LOG(INFO) << "This is just the first iteration.";
|
*result = "This is just the first iteration, do not need to new instance.";
|
||||||
return true;
|
MS_LOG(WARNING) << *result;
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start new server instance.
|
// Start new server instance.
|
||||||
|
|
|
@ -25,7 +25,10 @@ namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
class Server;
|
class Server;
|
||||||
class Iteration;
|
class Iteration;
|
||||||
std::atomic<uint32_t> kPrintTimes = 0;
|
std::atomic<uint32_t> kJobNotReadyPrintTimes = 0;
|
||||||
|
std::atomic<uint32_t> kJobNotAvailablePrintTimes = 0;
|
||||||
|
std::atomic<uint32_t> kClusterSafeModePrintTimes = 0;
|
||||||
|
|
||||||
const uint32_t kPrintTimesThreshold = 3000;
|
const uint32_t kPrintTimesThreshold = 3000;
|
||||||
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
|
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
|
||||||
bool server_num_as_threshold)
|
bool server_num_as_threshold)
|
||||||
|
@ -133,8 +136,6 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
|
||||||
|
|
||||||
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
|
||||||
|
|
||||||
std::string reason = "";
|
std::string reason = "";
|
||||||
if (!IsServerAvailable(&reason)) {
|
if (!IsServerAvailable(&reason)) {
|
||||||
if (!message->SendResponse(reason.c_str(), reason.size())) {
|
if (!message->SendResponse(reason.c_str(), reason.size())) {
|
||||||
|
@ -143,6 +144,8 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
||||||
(void)(Iteration::GetInstance().running_round_num_++);
|
(void)(Iteration::GetInstance().running_round_num_++);
|
||||||
bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message);
|
bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message);
|
||||||
// Must send response back no matter what value Launch method returns.
|
// Must send response back no matter what value Launch method returns.
|
||||||
|
@ -201,25 +204,35 @@ bool Round::IsServerAvailable(std::string *reason) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!Server::GetInstance().IsReady()) {
|
||||||
|
if (kJobNotReadyPrintTimes % kPrintTimesThreshold == 0) {
|
||||||
|
MS_LOG(WARNING) << "The server's training job is not ready, please retry " + name_ + " later.";
|
||||||
|
kJobNotReadyPrintTimes = 0;
|
||||||
|
}
|
||||||
|
kJobNotReadyPrintTimes += 1;
|
||||||
|
*reason = ps::kJobNotReady;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// If the server state is Disable or Finish, refuse the request.
|
// If the server state is Disable or Finish, refuse the request.
|
||||||
if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
|
if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
|
||||||
Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
|
Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
|
||||||
if (kPrintTimes % kPrintTimesThreshold == 0) {
|
if (kJobNotAvailablePrintTimes % kPrintTimesThreshold == 0) {
|
||||||
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
|
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
|
||||||
kPrintTimes = 0;
|
kJobNotAvailablePrintTimes = 0;
|
||||||
}
|
}
|
||||||
kPrintTimes += 1;
|
kJobNotAvailablePrintTimes += 1;
|
||||||
*reason = ps::kJobNotAvailable;
|
*reason = ps::kJobNotAvailable;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the server is still in safemode, reject the request.
|
// If the server is still in safemode, reject the request.
|
||||||
if (Server::GetInstance().IsSafeMode()) {
|
if (Server::GetInstance().IsSafeMode()) {
|
||||||
if (kPrintTimes % kPrintTimesThreshold == 0) {
|
if (kClusterSafeModePrintTimes % kPrintTimesThreshold == 0) {
|
||||||
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
|
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
|
||||||
kPrintTimes = 0;
|
kClusterSafeModePrintTimes = 0;
|
||||||
}
|
}
|
||||||
kPrintTimes += 1;
|
kClusterSafeModePrintTimes += 1;
|
||||||
*reason = ps::kClusterSafeMode;
|
*reason = ps::kClusterSafeMode;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,6 +89,7 @@ void Server::Run() {
|
||||||
Recover();
|
Recover();
|
||||||
MS_LOG(INFO) << "Server started successfully.";
|
MS_LOG(INFO) << "Server started successfully.";
|
||||||
safemode_ = false;
|
safemode_ = false;
|
||||||
|
is_ready_ = true;
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
|
||||||
// Wait communicators to stop so the main thread is blocked.
|
// Wait communicators to stop so the main thread is blocked.
|
||||||
|
@ -461,6 +462,17 @@ void Server::StartCommunicator() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Start communicator with worker.";
|
||||||
|
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
|
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
||||||
|
if (typeid(*communicator.get()) != typeid(ps::core::TcpCommunicator)) {
|
||||||
|
if (!communicator->Start()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(server_node_);
|
MS_EXCEPTION_IF_NULL(server_node_);
|
||||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||||
MS_LOG(INFO) << "Start communicator with server.";
|
MS_LOG(INFO) << "Start communicator with server.";
|
||||||
|
@ -472,15 +484,6 @@ void Server::StartCommunicator() {
|
||||||
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
|
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
|
||||||
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
|
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
|
||||||
MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
|
MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
|
||||||
|
|
||||||
MS_LOG(INFO) << "Start communicator with worker.";
|
|
||||||
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
|
||||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
|
||||||
if (!communicator->Start()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::Recover() {
|
void Server::Recover() {
|
||||||
|
@ -695,6 +698,8 @@ void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::Mess
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Server::IsReady() const { return is_ready_.load(); }
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -78,6 +78,8 @@ class Server {
|
||||||
|
|
||||||
bool SubmitTask(std::function<void()> &&task);
|
bool SubmitTask(std::function<void()> &&task);
|
||||||
|
|
||||||
|
bool IsReady() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Server()
|
Server()
|
||||||
: server_node_(nullptr),
|
: server_node_(nullptr),
|
||||||
|
@ -111,7 +113,8 @@ class Server {
|
||||||
cipher_get_list_sign_cnt_(0),
|
cipher_get_list_sign_cnt_(0),
|
||||||
minimum_clients_for_reconstruct(0),
|
minimum_clients_for_reconstruct(0),
|
||||||
minimum_secret_shares_for_reconstruct(0),
|
minimum_secret_shares_for_reconstruct(0),
|
||||||
cipher_time_window_(0) {}
|
cipher_time_window_(0),
|
||||||
|
is_ready_(false) {}
|
||||||
~Server() = default;
|
~Server() = default;
|
||||||
Server(const Server &) = delete;
|
Server(const Server &) = delete;
|
||||||
Server &operator=(const Server &) = delete;
|
Server &operator=(const Server &) = delete;
|
||||||
|
@ -249,6 +252,9 @@ class Server {
|
||||||
size_t minimum_clients_for_reconstruct;
|
size_t minimum_clients_for_reconstruct;
|
||||||
size_t minimum_secret_shares_for_reconstruct;
|
size_t minimum_secret_shares_for_reconstruct;
|
||||||
uint64_t cipher_time_window_;
|
uint64_t cipher_time_window_;
|
||||||
|
|
||||||
|
// The flag that represents whether server is starting successful.
|
||||||
|
std::atomic_bool is_ready_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
|
|
|
@ -258,6 +258,10 @@ using BarrierBeforeScaleIn = std::function<void(void)>;
|
||||||
using HandlerAfterScaleOut = std::function<void(void)>;
|
using HandlerAfterScaleOut = std::function<void(void)>;
|
||||||
using HandlerAfterScaleIn = std::function<void(void)>;
|
using HandlerAfterScaleIn = std::function<void(void)>;
|
||||||
|
|
||||||
|
constexpr char kClusterNotReady[] =
|
||||||
|
"The Scheduler's connections are not equal with total node num, Maybe this is because some server nodes are drop "
|
||||||
|
"out or scale in nodes has not been recycled.";
|
||||||
|
constexpr char kJobNotReady[] = "The server's training job is not ready.";
|
||||||
constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
|
constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
|
||||||
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";
|
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
||||||
register_message.set_role(node_info_.node_role_);
|
register_message.set_role(node_info_.node_role_);
|
||||||
register_message.set_ip(node_info_.ip_);
|
register_message.set_ip(node_info_.ip_);
|
||||||
register_message.set_port(node_info_.port_);
|
register_message.set_port(node_info_.port_);
|
||||||
|
register_message.set_fl_iteration_num(PSContext::instance()->fl_iteration_num());
|
||||||
|
|
||||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||||
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
|
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
|
||||||
|
@ -735,8 +736,6 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
|
||||||
if (heartbeat_resp_message.cluster_state() != current_cluster_state_ &&
|
if (heartbeat_resp_message.cluster_state() != current_cluster_state_ &&
|
||||||
current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN &&
|
current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN &&
|
||||||
current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) {
|
current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) {
|
||||||
MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to "
|
|
||||||
<< CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state());
|
|
||||||
UpdateClusterState(heartbeat_resp_message.cluster_state());
|
UpdateClusterState(heartbeat_resp_message.cluster_state());
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
|
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
|
||||||
|
|
|
@ -86,7 +86,7 @@ constexpr char kLibeventLogPrefix[] = "[libevent log]:";
|
||||||
|
|
||||||
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
|
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
|
||||||
const std::vector<std::string> kClusterState = {
|
const std::vector<std::string> kClusterState = {
|
||||||
"ClUSTER_STARTING", // Initialization state when the cluster is just started.
|
"CLUSTER_STARTING", // Initialization state when the cluster is just started.
|
||||||
"CLUSTER_READY", // The state after all nodes are successfully registered.
|
"CLUSTER_READY", // The state after all nodes are successfully registered.
|
||||||
"CLUSTER_EXIT", // The state after the cluster exits successfully.
|
"CLUSTER_EXIT", // The state after the cluster exits successfully.
|
||||||
"NODE_TIMEOUT", // When a node has a heartbeat timeout
|
"NODE_TIMEOUT", // When a node has a heartbeat timeout
|
||||||
|
|
|
@ -54,7 +54,7 @@ class Node {
|
||||||
is_already_finished_(false),
|
is_already_finished_(false),
|
||||||
next_request_id_(0),
|
next_request_id_(0),
|
||||||
current_node_state_(NodeState::NODE_STARTING),
|
current_node_state_(NodeState::NODE_STARTING),
|
||||||
current_cluster_state_(ClusterState::ClUSTER_STARTING) {}
|
current_cluster_state_(ClusterState::CLUSTER_STARTING) {}
|
||||||
virtual ~Node() = default;
|
virtual ~Node() = default;
|
||||||
|
|
||||||
using MessageCallback = std::function<void()>;
|
using MessageCallback = std::function<void()>;
|
||||||
|
|
|
@ -49,9 +49,10 @@ struct NodeInfo {
|
||||||
NodeRole node_role_;
|
NodeRole node_role_;
|
||||||
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
||||||
uint32_t rank_id_;
|
uint32_t rank_id_;
|
||||||
|
|
||||||
// After the node registration is successful, it is alive.If the node's heartbeat times out, then it is not alive
|
// After the node registration is successful, it is alive.If the node's heartbeat times out, then it is not alive
|
||||||
bool is_alive;
|
bool is_alive;
|
||||||
|
// the number of the fl job iteration
|
||||||
|
size_t fl_iteration_num_;
|
||||||
};
|
};
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
|
@ -34,13 +34,17 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message
|
||||||
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
|
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
|
||||||
const std::string &new_ip = register_message.ip();
|
const std::string &new_ip = register_message.ip();
|
||||||
uint32_t new_port = register_message.port();
|
uint32_t new_port = register_message.port();
|
||||||
|
uint32_t new_fl_iteration_num = register_message.fl_iteration_num();
|
||||||
|
|
||||||
rank_id = registered_nodes_info_[node_id].rank_id_;
|
rank_id = registered_nodes_info_[node_id].rank_id_;
|
||||||
registered_nodes_info_[node_id].is_alive = true;
|
registered_nodes_info_[node_id].is_alive = true;
|
||||||
registered_nodes_info_[node_id].ip_ = new_ip;
|
registered_nodes_info_[node_id].ip_ = new_ip;
|
||||||
registered_nodes_info_[node_id].port_ = static_cast<uint16_t>(new_port);
|
registered_nodes_info_[node_id].port_ = static_cast<uint16_t>(new_port);
|
||||||
|
registered_nodes_info_[node_id].fl_iteration_num_ = new_fl_iteration_num;
|
||||||
MS_LOG(WARNING) << "The node id: " << node_id << " is already assigned!"
|
MS_LOG(WARNING) << "The node id: " << node_id << " is already assigned!"
|
||||||
<< ", ip: " << register_message.ip() << ", port: " << register_message.port()
|
<< ", ip: " << register_message.ip() << ", port: " << register_message.port()
|
||||||
<< ", rank id: " << rank_id << ", alive: " << registered_nodes_info_[node_id].is_alive
|
<< ", rank id: " << rank_id << ", alive: " << registered_nodes_info_[node_id].is_alive
|
||||||
|
<< ", fl iteration num: " << new_fl_iteration_num
|
||||||
<< ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_);
|
<< ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_);
|
||||||
return rank_id;
|
return rank_id;
|
||||||
}
|
}
|
||||||
|
@ -51,14 +55,18 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message
|
||||||
if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
|
if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
|
||||||
const std::string &new_ip = register_message.ip();
|
const std::string &new_ip = register_message.ip();
|
||||||
uint32_t new_port = register_message.port();
|
uint32_t new_port = register_message.port();
|
||||||
|
uint32_t new_fl_iteration_num = register_message.fl_iteration_num();
|
||||||
|
|
||||||
rank_id = recovery_node_infos[node_id].rank_id_;
|
rank_id = recovery_node_infos[node_id].rank_id_;
|
||||||
recovery_node_infos[node_id].is_alive = true;
|
recovery_node_infos[node_id].is_alive = true;
|
||||||
recovery_node_infos[node_id].ip_ = new_ip;
|
recovery_node_infos[node_id].ip_ = new_ip;
|
||||||
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(new_port);
|
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(new_port);
|
||||||
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
|
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
|
||||||
|
registered_nodes_info_[node_id].fl_iteration_num_ = new_fl_iteration_num;
|
||||||
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
|
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
|
||||||
<< ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_
|
<< ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_
|
||||||
<< ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive
|
<< ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive
|
||||||
|
<< ", fl iteration num: " << new_fl_iteration_num
|
||||||
<< ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_);
|
<< ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_);
|
||||||
return rank_id;
|
return rank_id;
|
||||||
}
|
}
|
||||||
|
@ -79,6 +87,7 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string &node_id = register_message.node_id();
|
const std::string &node_id = register_message.node_id();
|
||||||
|
const size_t fl_iteration_num = register_message.fl_iteration_num();
|
||||||
// create new rank id
|
// create new rank id
|
||||||
if (register_message.role() == NodeRole::SERVER) {
|
if (register_message.role() == NodeRole::SERVER) {
|
||||||
const std::string &ip = register_message.ip();
|
const std::string &ip = register_message.ip();
|
||||||
|
@ -105,10 +114,11 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const
|
||||||
node_info.ip_ = ip;
|
node_info.ip_ = ip;
|
||||||
node_info.port_ = port;
|
node_info.port_ = port;
|
||||||
node_info.is_alive = true;
|
node_info.is_alive = true;
|
||||||
|
node_info.fl_iteration_num_ = fl_iteration_num;
|
||||||
registered_nodes_info_[node_id] = node_info;
|
registered_nodes_info_[node_id] = node_info;
|
||||||
MS_LOG(INFO) << "The server node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
|
MS_LOG(INFO) << "The server node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
|
||||||
<< " assign rank id:" << rank_id << ", " << (meta_data_->server_num - next_server_rank_id_)
|
<< ", fl iteration num:" << fl_iteration_num << " assign rank id:" << rank_id << ", "
|
||||||
<< " servers still need to be registered.";
|
<< (meta_data_->server_num - next_server_rank_id_) << " servers still need to be registered.";
|
||||||
} else if (register_message.role() == NodeRole::WORKER) {
|
} else if (register_message.role() == NodeRole::WORKER) {
|
||||||
const std::string &ip = register_message.ip();
|
const std::string &ip = register_message.ip();
|
||||||
uint32_t port = register_message.port();
|
uint32_t port = register_message.port();
|
||||||
|
@ -134,10 +144,11 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const
|
||||||
node_info.ip_ = ip;
|
node_info.ip_ = ip;
|
||||||
node_info.port_ = port;
|
node_info.port_ = port;
|
||||||
node_info.is_alive = true;
|
node_info.is_alive = true;
|
||||||
|
node_info.fl_iteration_num_ = fl_iteration_num;
|
||||||
registered_nodes_info_[node_id] = node_info;
|
registered_nodes_info_[node_id] = node_info;
|
||||||
MS_LOG(INFO) << "The worker node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
|
MS_LOG(INFO) << "The worker node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
|
||||||
<< " assign rank id:" << rank_id << ", " << (meta_data_->worker_num - next_worker_rank_id_)
|
<< ", fl iteration num:" << fl_iteration_num << " assign rank id:" << rank_id << ", "
|
||||||
<< " workers still need to be registered.";
|
<< (meta_data_->worker_num - next_worker_rank_id_) << " workers still need to be registered.";
|
||||||
}
|
}
|
||||||
return rank_id;
|
return rank_id;
|
||||||
}
|
}
|
||||||
|
@ -178,7 +189,7 @@ std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
|
||||||
return servers_meta_list;
|
return servers_meta_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeManager::UpdateCluster() {
|
void NodeManager::UpdateCluster(bool is_cluster_ready) {
|
||||||
// 1. update cluster timeout state
|
// 1. update cluster timeout state
|
||||||
struct timeval current_time {};
|
struct timeval current_time {};
|
||||||
(void)gettimeofday(¤t_time, nullptr);
|
(void)gettimeofday(¤t_time, nullptr);
|
||||||
|
@ -205,20 +216,26 @@ void NodeManager::UpdateCluster() {
|
||||||
} else if (SizeToUint(heartbeats_.size()) == total_node_num_) {
|
} else if (SizeToUint(heartbeats_.size()) == total_node_num_) {
|
||||||
if (cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
if (cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||||
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
|
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
|
||||||
if (registered_nodes_info_.count(it->first)) {
|
if (registered_nodes_info_.count(it->first) && !it->second.is_alive) {
|
||||||
registered_nodes_info_[it->first].is_alive = true;
|
MS_LOG(WARNING) << it->second.node_id_ << " is alive.";
|
||||||
|
it->second.is_alive = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (onPersist_) {
|
if (onPersist_) {
|
||||||
onPersist_();
|
onPersist_();
|
||||||
}
|
}
|
||||||
|
if (is_cluster_ready) {
|
||||||
UpdateClusterState(ClusterState::CLUSTER_READY);
|
UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||||
|
} else {
|
||||||
|
UpdateClusterState(ClusterState::CLUSTER_STARTING);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. update cluster finish state
|
// 2. update cluster finish state
|
||||||
if (SizeToUint(finish_nodes_id_.size()) == total_node_num_ &&
|
if (SizeToUint(finish_nodes_id_.size()) == total_node_num_ &&
|
||||||
PSContext::instance()->server_mode() != kServerModeHybrid) {
|
PSContext::instance()->server_mode() != kServerModeHybrid &&
|
||||||
|
PSContext::instance()->server_mode() != kServerModeFL) {
|
||||||
UpdateClusterState(ClusterState::CLUSTER_EXIT);
|
UpdateClusterState(ClusterState::CLUSTER_EXIT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -330,6 +347,7 @@ bool NodeManager::IsWorker() const {
|
||||||
|
|
||||||
bool NodeManager::IsNodeRegistered(const std::string &node_id) {
|
bool NodeManager::IsNodeRegistered(const std::string &node_id) {
|
||||||
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
|
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
|
||||||
|
MS_LOG(WARNING) << "The node id " << node_id << " has been registered.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -381,6 +399,18 @@ void NodeManager::set_next_server_rank_id(const uint32_t &next_server_rank_id) {
|
||||||
this->next_server_rank_id_ = next_server_rank_id;
|
this->next_server_rank_id_ = next_server_rank_id;
|
||||||
}
|
}
|
||||||
void NodeManager::setPersistCallback(const OnPersist &onPersist) { this->onPersist_ = onPersist; }
|
void NodeManager::setPersistCallback(const OnPersist &onPersist) { this->onPersist_ = onPersist; }
|
||||||
|
|
||||||
|
bool NodeManager::VerifyClusterNodesParam() {
|
||||||
|
std::unordered_set<size_t> fl_iteration_num_set;
|
||||||
|
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
|
||||||
|
fl_iteration_num_set.insert(it->second.fl_iteration_num_);
|
||||||
|
}
|
||||||
|
if (fl_iteration_num_set.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "The server node fl iteration num is not inconsistent.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -50,7 +50,7 @@ class NodeManager {
|
||||||
next_server_rank_id_(0),
|
next_server_rank_id_(0),
|
||||||
meta_data_(nullptr),
|
meta_data_(nullptr),
|
||||||
node_state_(NodeState::NODE_STARTING),
|
node_state_(NodeState::NODE_STARTING),
|
||||||
cluster_state_(ClusterState::ClUSTER_STARTING) {}
|
cluster_state_(ClusterState::CLUSTER_STARTING) {}
|
||||||
virtual ~NodeManager() = default;
|
virtual ~NodeManager() = default;
|
||||||
using OnPersist = std::function<void()>;
|
using OnPersist = std::function<void()>;
|
||||||
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
|
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
|
||||||
|
@ -63,7 +63,7 @@ class NodeManager {
|
||||||
// Fetch metadata information of all nodes.
|
// Fetch metadata information of all nodes.
|
||||||
std::vector<ServersMeta> FetchAllNodesMeta();
|
std::vector<ServersMeta> FetchAllNodesMeta();
|
||||||
|
|
||||||
void UpdateCluster();
|
void UpdateCluster(bool is_cluster_ready);
|
||||||
void AddFinishNode(const std::string &finish_message);
|
void AddFinishNode(const std::string &finish_message);
|
||||||
|
|
||||||
// After the scheduler receives the scale_out_done node, it will save this node.
|
// After the scheduler receives the scale_out_done node, it will save this node.
|
||||||
|
@ -135,6 +135,8 @@ class NodeManager {
|
||||||
|
|
||||||
bool IsAllNodesAlive() const;
|
bool IsAllNodesAlive() const;
|
||||||
|
|
||||||
|
bool VerifyClusterNodesParam();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::mutex node_mutex_;
|
std::mutex node_mutex_;
|
||||||
std::mutex cluster_mutex_;
|
std::mutex cluster_mutex_;
|
||||||
|
|
|
@ -100,6 +100,8 @@ message RegisterMessage {
|
||||||
string node_id = 3;
|
string node_id = 3;
|
||||||
// the role of the node: worker,server,scheduler
|
// the role of the node: worker,server,scheduler
|
||||||
NodeRole role = 4;
|
NodeRole role = 4;
|
||||||
|
// the number of the fl job iteration
|
||||||
|
uint64 fl_iteration_num = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message RegisterRespMessage {
|
message RegisterRespMessage {
|
||||||
|
@ -120,7 +122,7 @@ enum NodeState {
|
||||||
}
|
}
|
||||||
|
|
||||||
enum ClusterState {
|
enum ClusterState {
|
||||||
ClUSTER_STARTING = 0;
|
CLUSTER_STARTING = 0;
|
||||||
CLUSTER_READY = 1;
|
CLUSTER_READY = 1;
|
||||||
CLUSTER_EXIT = 2;
|
CLUSTER_EXIT = 2;
|
||||||
NODE_TIMEOUT = 3;
|
NODE_TIMEOUT = 3;
|
||||||
|
|
|
@ -267,7 +267,6 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "The node id is registered.";
|
|
||||||
if (connected_nodes_.count(node_id)) {
|
if (connected_nodes_.count(node_id)) {
|
||||||
(void)connected_nodes_.erase(node_id);
|
(void)connected_nodes_.erase(node_id);
|
||||||
}
|
}
|
||||||
|
@ -300,12 +299,16 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
|
||||||
|
|
||||||
if (node_manager_.IsAllNodesRegistered()) {
|
if (node_manager_.IsAllNodesRegistered()) {
|
||||||
if (!node_manager_.IsAllNodesAlive()) {
|
if (!node_manager_.IsAllNodesAlive()) {
|
||||||
MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes are not alive.";
|
MS_LOG(ERROR)
|
||||||
|
<< "Do not broadcast nodes info because some server nodes are not alive, and cluster will exit later.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
is_ready_ = true;
|
if (!node_manager_.VerifyClusterNodesParam()) {
|
||||||
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
|
MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes info are not inconsistent, and cluster "
|
||||||
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
|
"will exit later.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
|
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
|
||||||
auto nodes = node_manager_.nodes_info();
|
auto nodes = node_manager_.nodes_info();
|
||||||
for (const auto &id : scale_in_node_ids_) {
|
for (const auto &id : scale_in_node_ids_) {
|
||||||
|
@ -325,10 +328,14 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
|
||||||
auto node_infos = node_manager_.nodes_info();
|
auto node_infos = node_manager_.nodes_info();
|
||||||
bool res = SendPrepareBuildingNetwork(node_infos);
|
bool res = SendPrepareBuildingNetwork(node_infos);
|
||||||
if (!res) {
|
if (!res) {
|
||||||
MS_LOG(ERROR) << "Prepare for building network failed!";
|
MS_LOG(ERROR) << "Prepare for building network failed! Cluster will exit later.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Prepare for building network success.";
|
is_ready_ = true;
|
||||||
|
MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and "
|
||||||
|
<< node_manager_.server_num()
|
||||||
|
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
|
||||||
|
|
||||||
for (const auto &kvs : node_infos) {
|
for (const auto &kvs : node_infos) {
|
||||||
auto client = GetOrCreateClient(kvs.second);
|
auto client = GetOrCreateClient(kvs.second);
|
||||||
MS_EXCEPTION_IF_NULL(client);
|
MS_EXCEPTION_IF_NULL(client);
|
||||||
|
@ -609,7 +616,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
|
||||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
|
node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
|
||||||
}
|
}
|
||||||
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
||||||
node_manager_.UpdateCluster();
|
node_manager_.UpdateCluster(is_ready_);
|
||||||
|
|
||||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
|
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
|
||||||
std::this_thread::sleep_for(
|
std::this_thread::sleep_for(
|
||||||
|
@ -1020,15 +1027,27 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler>
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
|
||||||
nlohmann::json js;
|
|
||||||
js["message"] = "Start new instance successful.";
|
|
||||||
js["code"] = kSuccessCode;
|
|
||||||
for (const auto &output : outputs) {
|
for (const auto &output : outputs) {
|
||||||
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
|
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
|
||||||
js["result"][output.first] = data;
|
nlohmann::json dataJson = nlohmann::json::parse(data);
|
||||||
|
if (!dataJson["result"]) {
|
||||||
|
res = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nlohmann::json js;
|
||||||
|
if (res) {
|
||||||
|
js["message"] = "Start new instance successful.";
|
||||||
|
js["code"] = kSuccessCode;
|
||||||
|
js["result"] = true;
|
||||||
|
} else {
|
||||||
|
js["message"] = "Start new instance failed.";
|
||||||
|
js["code"] = kErrorCode;
|
||||||
|
js["result"] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||||
resp->AddRespString(js.dump());
|
resp->AddRespString(js.dump());
|
||||||
resp->AddRespHeadParam("Content-Type", "application/json");
|
resp->AddRespHeadParam("Content-Type", "application/json");
|
||||||
|
|
||||||
|
@ -1073,7 +1092,6 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandle
|
||||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
nlohmann::json js;
|
nlohmann::json js;
|
||||||
js["message"] = "Query Instance successful.";
|
js["message"] = "Query Instance successful.";
|
||||||
js["code"] = kSuccessCode;
|
js["code"] = kSuccessCode;
|
||||||
|
@ -1095,9 +1113,15 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
|
||||||
MS_EXCEPTION_IF_NULL(resp);
|
MS_EXCEPTION_IF_NULL(resp);
|
||||||
|
|
||||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||||
|
if (CheckIfNodeDisconnected()) {
|
||||||
|
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, kClusterNotReady);
|
||||||
|
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
status = CheckIfClusterReady();
|
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_DISABLE_FLS) {
|
||||||
if (status != RequestProcessResultCode::kSuccess) {
|
std::string message = "The cluster state is not CLUSTER_DISABLE_FLS, does not need to enable fls.";
|
||||||
|
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message);
|
||||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1132,15 +1156,26 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
|
||||||
nlohmann::json js;
|
|
||||||
js["message"] = "start enabling FL-Server successful.";
|
|
||||||
js["code"] = kSuccessCode;
|
|
||||||
for (const auto &output : outputs) {
|
for (const auto &output : outputs) {
|
||||||
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
|
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
|
||||||
js["result"][output.first] = data;
|
nlohmann::json dataJson = nlohmann::json::parse(data);
|
||||||
|
if (!dataJson["result"]) {
|
||||||
|
res = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nlohmann::json js;
|
||||||
|
if (res) {
|
||||||
|
js["message"] = "start enabling FL-Server successful.";
|
||||||
|
js["code"] = kSuccessCode;
|
||||||
|
js["result"] = true;
|
||||||
|
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||||
|
} else {
|
||||||
|
js["message"] = "start enabling FL-Server failed.";
|
||||||
|
js["code"] = kErrorCode;
|
||||||
|
js["result"] = false;
|
||||||
|
}
|
||||||
resp->AddRespString(js.dump());
|
resp->AddRespString(js.dump());
|
||||||
resp->AddRespHeadParam("Content-Type", "application/json");
|
resp->AddRespHeadParam("Content-Type", "application/json");
|
||||||
|
|
||||||
|
@ -1152,6 +1187,12 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
|
||||||
MS_EXCEPTION_IF_NULL(resp);
|
MS_EXCEPTION_IF_NULL(resp);
|
||||||
|
|
||||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||||
|
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_DISABLE_FLS) {
|
||||||
|
std::string message = "The cluster state is already in CLUSTER_DISABLE_FLS.";
|
||||||
|
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message);
|
||||||
|
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
status = CheckIfClusterReady();
|
status = CheckIfClusterReady();
|
||||||
if (status != RequestProcessResultCode::kSuccess) {
|
if (status != RequestProcessResultCode::kSuccess) {
|
||||||
|
@ -1159,10 +1200,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
|
|
||||||
|
|
||||||
uint64_t request_id = AddMessageTrack(node_manager_.server_num());
|
uint64_t request_id = AddMessageTrack(node_manager_.server_num());
|
||||||
|
|
||||||
std::unordered_map<uint32_t, VectorPtr> outputs;
|
std::unordered_map<uint32_t, VectorPtr> outputs;
|
||||||
|
|
||||||
set_message_callback(request_id, [&]() {
|
set_message_callback(request_id, [&]() {
|
||||||
|
@ -1185,19 +1223,29 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
|
||||||
if (!res) {
|
if (!res) {
|
||||||
ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The disable FLS is timeout.");
|
ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The disable FLS is timeout.");
|
||||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
|
||||||
nlohmann::json js;
|
|
||||||
js["message"] = "start disabling FL-Server successful.";
|
|
||||||
js["code"] = kSuccessCode;
|
|
||||||
for (const auto &output : outputs) {
|
for (const auto &output : outputs) {
|
||||||
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
|
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
|
||||||
js["result"][output.first] = data;
|
nlohmann::json dataJson = nlohmann::json::parse(data);
|
||||||
|
if (!dataJson["result"]) {
|
||||||
|
res = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nlohmann::json js;
|
||||||
|
if (res) {
|
||||||
|
js["message"] = "start disabling FL-Server successful.";
|
||||||
|
js["code"] = kSuccessCode;
|
||||||
|
js["result"] = true;
|
||||||
|
node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
|
||||||
|
} else {
|
||||||
|
js["message"] = "start disabling FL-Server failed.";
|
||||||
|
js["code"] = kErrorCode;
|
||||||
|
js["result"] = false;
|
||||||
|
}
|
||||||
resp->AddRespString(js.dump());
|
resp->AddRespString(js.dump());
|
||||||
resp->AddRespHeadParam("Content-Type", "application/json");
|
resp->AddRespHeadParam("Content-Type", "application/json");
|
||||||
|
|
||||||
|
@ -1207,11 +1255,16 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
|
||||||
|
|
||||||
RequestProcessResult SchedulerNode::CheckIfClusterReady() {
|
RequestProcessResult SchedulerNode::CheckIfClusterReady() {
|
||||||
RequestProcessResult result(RequestProcessResultCode::kSuccess);
|
RequestProcessResult result(RequestProcessResultCode::kSuccess);
|
||||||
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY || CheckIfNodeDisconnected()) {
|
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
|
||||||
std::string message = "The cluster is not ready.";
|
std::string message = "The cluster is not ready.";
|
||||||
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
|
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (CheckIfNodeDisconnected()) {
|
||||||
|
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, kClusterNotReady);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,6 +76,11 @@ public class Common {
|
||||||
*/
|
*/
|
||||||
public static final String SAFE_MOD = "The cluster is in safemode.";
|
public static final String SAFE_MOD = "The cluster is in safemode.";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The tag when server is not ready.
|
||||||
|
*/
|
||||||
|
public static final String NOT_READY = "The server's training job is not ready.";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The tag when server is not ready.
|
* The tag when server is not ready.
|
||||||
*/
|
*/
|
||||||
|
@ -328,6 +333,9 @@ public class Common {
|
||||||
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
|
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
} else if (messageStr.contains(NOT_READY)) {
|
||||||
|
LOGGER.info(Common.addTag("[isSeverReady] " + NOT_READY + ", need wait some time and request again"));
|
||||||
|
return false;
|
||||||
} else {
|
} else {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue