!32365 fix issue I4Z7WC、I502P8、I502TN、I5031D、I503MS、I503SO、I502L2

Merge pull request !32365 from tan-wei-cheng-3260/r1.6-develop3
This commit is contained in:
i-robot 2022-03-31 14:59:37 +00:00 committed by Gitee
commit 613939605c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 199 additions and 72 deletions

View File

@ -32,7 +32,7 @@ void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
rank_id_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num();
server_num_ = server_node->server_num();
return;
}

View File

@ -29,7 +29,7 @@ void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerN
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
local_rank_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num();
server_num_ = server_node->server_num();
counting_server_rank_ = counting_server_rank;
return;
}
@ -109,6 +109,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
}
if (!TriggerCounterEvent(name, reason)) {
MS_LOG(WARNING) << "Leader server trigger count event failed.";
Iteration::GetInstance().NotifyNext(false, *reason);
return false;
}
} else {

View File

@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <vector>
#include "fl/server/iteration.h"
namespace mindspore {
namespace fl {
@ -26,7 +27,7 @@ void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::Server
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
local_rank_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num();
server_num_ = server_node->server_num();
InitHashRing();
return;
}
@ -109,6 +110,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
if (reason != nullptr) {
*reason = kNetworkError;
}
Iteration::GetInstance().NotifyNext(false, *reason);
return false;
}

View File

@ -323,8 +323,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
}
if (iteration_num_ == 1) {
MS_LOG(INFO) << "This is just the first iteration.";
return true;
*result = "This is just the first iteration, do not need to new instance.";
MS_LOG(WARNING) << *result;
return false;
}
// Start new server instance.

View File

@ -25,7 +25,10 @@ namespace fl {
namespace server {
class Server;
class Iteration;
std::atomic<uint32_t> kPrintTimes = 0;
std::atomic<uint32_t> kJobNotReadyPrintTimes = 0;
std::atomic<uint32_t> kJobNotAvailablePrintTimes = 0;
std::atomic<uint32_t> kClusterSafeModePrintTimes = 0;
const uint32_t kPrintTimesThreshold = 3000;
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
bool server_num_as_threshold)
@ -133,8 +136,6 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
std::string reason = "";
if (!IsServerAvailable(&reason)) {
if (!message->SendResponse(reason.c_str(), reason.size())) {
@ -143,6 +144,8 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
}
return;
}
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
(void)(Iteration::GetInstance().running_round_num_++);
bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message);
// Must send response back no matter what value Launch method returns.
@ -201,25 +204,35 @@ bool Round::IsServerAvailable(std::string *reason) {
return true;
}
if (!Server::GetInstance().IsReady()) {
if (kJobNotReadyPrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The server's training job is not ready, please retry " + name_ + " later.";
kJobNotReadyPrintTimes = 0;
}
kJobNotReadyPrintTimes += 1;
*reason = ps::kJobNotReady;
return false;
}
// If the server state is Disable or Finish, refuse the request.
if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
if (kPrintTimes % kPrintTimesThreshold == 0) {
if (kJobNotAvailablePrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
kPrintTimes = 0;
kJobNotAvailablePrintTimes = 0;
}
kPrintTimes += 1;
kJobNotAvailablePrintTimes += 1;
*reason = ps::kJobNotAvailable;
return false;
}
// If the server is still in safemode, reject the request.
if (Server::GetInstance().IsSafeMode()) {
if (kPrintTimes % kPrintTimesThreshold == 0) {
if (kClusterSafeModePrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
kPrintTimes = 0;
kClusterSafeModePrintTimes = 0;
}
kPrintTimes += 1;
kClusterSafeModePrintTimes += 1;
*reason = ps::kClusterSafeMode;
return false;
}

View File

@ -89,6 +89,7 @@ void Server::Run() {
Recover();
MS_LOG(INFO) << "Server started successfully.";
safemode_ = false;
is_ready_ = true;
lock.unlock();
// Wait communicators to stop so the main thread is blocked.
@ -461,6 +462,17 @@ void Server::StartCommunicator() {
return;
}
MS_LOG(INFO) << "Start communicator with worker.";
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
if (typeid(*communicator.get()) != typeid(ps::core::TcpCommunicator)) {
if (!communicator->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
}
}
});
MS_EXCEPTION_IF_NULL(server_node_);
MS_EXCEPTION_IF_NULL(communicator_with_server_);
MS_LOG(INFO) << "Start communicator with server.";
@ -472,15 +484,6 @@ void Server::StartCommunicator() {
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
MS_LOG(INFO) << "Start communicator with worker.";
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
if (!communicator->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
}
});
}
void Server::Recover() {
@ -695,6 +698,8 @@ void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::Mess
}
}
}
bool Server::IsReady() const { return is_ready_.load(); }
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -78,6 +78,8 @@ class Server {
bool SubmitTask(std::function<void()> &&task);
bool IsReady() const;
private:
Server()
: server_node_(nullptr),
@ -111,7 +113,8 @@ class Server {
cipher_get_list_sign_cnt_(0),
minimum_clients_for_reconstruct(0),
minimum_secret_shares_for_reconstruct(0),
cipher_time_window_(0) {}
cipher_time_window_(0),
is_ready_(false) {}
~Server() = default;
Server(const Server &) = delete;
Server &operator=(const Server &) = delete;
@ -249,6 +252,9 @@ class Server {
size_t minimum_clients_for_reconstruct;
size_t minimum_secret_shares_for_reconstruct;
uint64_t cipher_time_window_;
// The flag that represents whether server is starting successful.
std::atomic_bool is_ready_;
};
} // namespace server
} // namespace fl

View File

@ -258,6 +258,10 @@ using BarrierBeforeScaleIn = std::function<void(void)>;
using HandlerAfterScaleOut = std::function<void(void)>;
using HandlerAfterScaleIn = std::function<void(void)>;
constexpr char kClusterNotReady[] =
"The Scheduler's connections are not equal with total node num, Maybe this is because some server nodes are drop "
"out or scale in nodes has not been recycled.";
constexpr char kJobNotReady[] = "The server's training job is not ready.";
constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";

View File

@ -34,6 +34,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
register_message.set_role(node_info_.node_role_);
register_message.set_ip(node_info_.ip_);
register_message.set_port(node_info_.port_);
register_message.set_fl_iteration_num(PSContext::instance()->fl_iteration_num());
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
@ -735,8 +736,6 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
if (heartbeat_resp_message.cluster_state() != current_cluster_state_ &&
current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN &&
current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) {
MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to "
<< CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state());
UpdateClusterState(heartbeat_resp_message.cluster_state());
}
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"

View File

@ -86,7 +86,7 @@ constexpr char kLibeventLogPrefix[] = "[libevent log]:";
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
const std::vector<std::string> kClusterState = {
"ClUSTER_STARTING", // Initialization state when the cluster is just started.
"CLUSTER_STARTING", // Initialization state when the cluster is just started.
"CLUSTER_READY", // The state after all nodes are successfully registered.
"CLUSTER_EXIT", // The state after the cluster exits successfully.
"NODE_TIMEOUT", // When a node has a heartbeat timeout

View File

@ -54,7 +54,7 @@ class Node {
is_already_finished_(false),
next_request_id_(0),
current_node_state_(NodeState::NODE_STARTING),
current_cluster_state_(ClusterState::ClUSTER_STARTING) {}
current_cluster_state_(ClusterState::CLUSTER_STARTING) {}
virtual ~Node() = default;
using MessageCallback = std::function<void()>;

View File

@ -49,9 +49,10 @@ struct NodeInfo {
NodeRole node_role_;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
uint32_t rank_id_;
// After the node registration is successful, it is alive.If the node's heartbeat times out, then it is not alive
bool is_alive;
// the number of the fl job iteration
size_t fl_iteration_num_;
};
} // namespace core
} // namespace ps

View File

@ -34,13 +34,17 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
const std::string &new_ip = register_message.ip();
uint32_t new_port = register_message.port();
uint32_t new_fl_iteration_num = register_message.fl_iteration_num();
rank_id = registered_nodes_info_[node_id].rank_id_;
registered_nodes_info_[node_id].is_alive = true;
registered_nodes_info_[node_id].ip_ = new_ip;
registered_nodes_info_[node_id].port_ = static_cast<uint16_t>(new_port);
registered_nodes_info_[node_id].fl_iteration_num_ = new_fl_iteration_num;
MS_LOG(WARNING) << "The node id: " << node_id << " is already assigned!"
<< ", ip: " << register_message.ip() << ", port: " << register_message.port()
<< ", rank id: " << rank_id << ", alive: " << registered_nodes_info_[node_id].is_alive
<< ", fl iteration num: " << new_fl_iteration_num
<< ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_);
return rank_id;
}
@ -51,14 +55,18 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
const std::string &new_ip = register_message.ip();
uint32_t new_port = register_message.port();
uint32_t new_fl_iteration_num = register_message.fl_iteration_num();
rank_id = recovery_node_infos[node_id].rank_id_;
recovery_node_infos[node_id].is_alive = true;
recovery_node_infos[node_id].ip_ = new_ip;
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(new_port);
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
registered_nodes_info_[node_id].fl_iteration_num_ = new_fl_iteration_num;
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
<< ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_
<< ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive
<< ", fl iteration num: " << new_fl_iteration_num
<< ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_);
return rank_id;
}
@ -79,6 +87,7 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
}
const std::string &node_id = register_message.node_id();
const size_t fl_iteration_num = register_message.fl_iteration_num();
// create new rank id
if (register_message.role() == NodeRole::SERVER) {
const std::string &ip = register_message.ip();
@ -105,10 +114,11 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
node_info.ip_ = ip;
node_info.port_ = port;
node_info.is_alive = true;
node_info.fl_iteration_num_ = fl_iteration_num;
registered_nodes_info_[node_id] = node_info;
MS_LOG(INFO) << "The server node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
<< " assign rank id:" << rank_id << ", " << (meta_data_->server_num - next_server_rank_id_)
<< " servers still need to be registered.";
<< ", fl iteration num:" << fl_iteration_num << " assign rank id:" << rank_id << ", "
<< (meta_data_->server_num - next_server_rank_id_) << " servers still need to be registered.";
} else if (register_message.role() == NodeRole::WORKER) {
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
@ -134,10 +144,11 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
node_info.ip_ = ip;
node_info.port_ = port;
node_info.is_alive = true;
node_info.fl_iteration_num_ = fl_iteration_num;
registered_nodes_info_[node_id] = node_info;
MS_LOG(INFO) << "The worker node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
<< " assign rank id:" << rank_id << ", " << (meta_data_->worker_num - next_worker_rank_id_)
<< " workers still need to be registered.";
<< ", fl iteration num:" << fl_iteration_num << " assign rank id:" << rank_id << ", "
<< (meta_data_->worker_num - next_worker_rank_id_) << " workers still need to be registered.";
}
return rank_id;
}
@ -178,7 +189,7 @@ std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
return servers_meta_list;
}
void NodeManager::UpdateCluster() {
void NodeManager::UpdateCluster(bool is_cluster_ready) {
// 1. update cluster timeout state
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
@ -205,20 +216,26 @@ void NodeManager::UpdateCluster() {
} else if (SizeToUint(heartbeats_.size()) == total_node_num_) {
if (cluster_state_ == ClusterState::NODE_TIMEOUT) {
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
if (registered_nodes_info_.count(it->first)) {
registered_nodes_info_[it->first].is_alive = true;
if (registered_nodes_info_.count(it->first) && !it->second.is_alive) {
MS_LOG(WARNING) << it->second.node_id_ << " is alive.";
it->second.is_alive = true;
}
}
if (onPersist_) {
onPersist_();
}
UpdateClusterState(ClusterState::CLUSTER_READY);
if (is_cluster_ready) {
UpdateClusterState(ClusterState::CLUSTER_READY);
} else {
UpdateClusterState(ClusterState::CLUSTER_STARTING);
}
}
}
// 2. update cluster finish state
if (SizeToUint(finish_nodes_id_.size()) == total_node_num_ &&
PSContext::instance()->server_mode() != kServerModeHybrid) {
PSContext::instance()->server_mode() != kServerModeHybrid &&
PSContext::instance()->server_mode() != kServerModeFL) {
UpdateClusterState(ClusterState::CLUSTER_EXIT);
}
}
@ -330,6 +347,7 @@ bool NodeManager::IsWorker() const {
bool NodeManager::IsNodeRegistered(const std::string &node_id) {
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
MS_LOG(WARNING) << "The node id " << node_id << " has been registered.";
return true;
}
return false;
@ -381,6 +399,18 @@ void NodeManager::set_next_server_rank_id(const uint32_t &next_server_rank_id) {
this->next_server_rank_id_ = next_server_rank_id;
}
void NodeManager::setPersistCallback(const OnPersist &onPersist) { this->onPersist_ = onPersist; }
bool NodeManager::VerifyClusterNodesParam() {
std::unordered_set<size_t> fl_iteration_num_set;
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
fl_iteration_num_set.insert(it->second.fl_iteration_num_);
}
if (fl_iteration_num_set.size() != 1) {
MS_LOG(ERROR) << "The server node fl iteration num is not inconsistent.";
return false;
}
return true;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -50,7 +50,7 @@ class NodeManager {
next_server_rank_id_(0),
meta_data_(nullptr),
node_state_(NodeState::NODE_STARTING),
cluster_state_(ClusterState::ClUSTER_STARTING) {}
cluster_state_(ClusterState::CLUSTER_STARTING) {}
virtual ~NodeManager() = default;
using OnPersist = std::function<void()>;
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
@ -63,7 +63,7 @@ class NodeManager {
// Fetch metadata information of all nodes.
std::vector<ServersMeta> FetchAllNodesMeta();
void UpdateCluster();
void UpdateCluster(bool is_cluster_ready);
void AddFinishNode(const std::string &finish_message);
// After the scheduler receives the scale_out_done node, it will save this node.
@ -135,6 +135,8 @@ class NodeManager {
bool IsAllNodesAlive() const;
bool VerifyClusterNodesParam();
private:
std::mutex node_mutex_;
std::mutex cluster_mutex_;

View File

@ -100,6 +100,8 @@ message RegisterMessage {
string node_id = 3;
// the role of the node: worker,server,scheduler
NodeRole role = 4;
// the number of the fl job iteration
uint64 fl_iteration_num = 5;
}
message RegisterRespMessage {
@ -120,7 +122,7 @@ enum NodeState {
}
enum ClusterState {
ClUSTER_STARTING = 0;
CLUSTER_STARTING = 0;
CLUSTER_READY = 1;
CLUSTER_EXIT = 2;
NODE_TIMEOUT = 3;

View File

@ -267,7 +267,6 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
return;
}
MS_LOG(INFO) << "The node id is registered.";
if (connected_nodes_.count(node_id)) {
(void)connected_nodes_.erase(node_id);
}
@ -300,12 +299,16 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
if (node_manager_.IsAllNodesRegistered()) {
if (!node_manager_.IsAllNodesAlive()) {
MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes are not alive.";
MS_LOG(ERROR)
<< "Do not broadcast nodes info because some server nodes are not alive, and cluster will exit later.";
return;
}
is_ready_ = true;
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
if (!node_manager_.VerifyClusterNodesParam()) {
MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes info are not inconsistent, and cluster "
"will exit later.";
return;
}
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
auto nodes = node_manager_.nodes_info();
for (const auto &id : scale_in_node_ids_) {
@ -325,10 +328,14 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
auto node_infos = node_manager_.nodes_info();
bool res = SendPrepareBuildingNetwork(node_infos);
if (!res) {
MS_LOG(ERROR) << "Prepare for building network failed!";
MS_LOG(ERROR) << "Prepare for building network failed! Cluster will exit later.";
return;
}
MS_LOG(INFO) << "Prepare for building network success.";
is_ready_ = true;
MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and "
<< node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
@ -609,7 +616,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
}
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
node_manager_.UpdateCluster();
node_manager_.UpdateCluster(is_ready_);
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
std::this_thread::sleep_for(
@ -1020,15 +1027,27 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler>
return;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
nlohmann::json js;
js["message"] = "Start new instance successful.";
js["code"] = kSuccessCode;
for (const auto &output : outputs) {
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
js["result"][output.first] = data;
nlohmann::json dataJson = nlohmann::json::parse(data);
if (!dataJson["result"]) {
res = false;
break;
}
}
nlohmann::json js;
if (res) {
js["message"] = "Start new instance successful.";
js["code"] = kSuccessCode;
js["result"] = true;
} else {
js["message"] = "Start new instance failed.";
js["code"] = kErrorCode;
js["result"] = false;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
resp->AddRespString(js.dump());
resp->AddRespHeadParam("Content-Type", "application/json");
@ -1073,7 +1092,6 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandle
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
nlohmann::json js;
js["message"] = "Query Instance successful.";
js["code"] = kSuccessCode;
@ -1095,9 +1113,15 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
MS_EXCEPTION_IF_NULL(resp);
RequestProcessResult status(RequestProcessResultCode::kSuccess);
if (CheckIfNodeDisconnected()) {
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, kClusterNotReady);
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
status = CheckIfClusterReady();
if (status != RequestProcessResultCode::kSuccess) {
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_DISABLE_FLS) {
std::string message = "The cluster state is not CLUSTER_DISABLE_FLS, does not need to enable fls.";
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message);
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
@ -1132,15 +1156,26 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
return;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
nlohmann::json js;
js["message"] = "start enabling FL-Server successful.";
js["code"] = kSuccessCode;
for (const auto &output : outputs) {
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
js["result"][output.first] = data;
nlohmann::json dataJson = nlohmann::json::parse(data);
if (!dataJson["result"]) {
res = false;
break;
}
}
nlohmann::json js;
if (res) {
js["message"] = "start enabling FL-Server successful.";
js["code"] = kSuccessCode;
js["result"] = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
} else {
js["message"] = "start enabling FL-Server failed.";
js["code"] = kErrorCode;
js["result"] = false;
}
resp->AddRespString(js.dump());
resp->AddRespHeadParam("Content-Type", "application/json");
@ -1152,6 +1187,12 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
MS_EXCEPTION_IF_NULL(resp);
RequestProcessResult status(RequestProcessResultCode::kSuccess);
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_DISABLE_FLS) {
std::string message = "The cluster state is already in CLUSTER_DISABLE_FLS.";
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message);
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
status = CheckIfClusterReady();
if (status != RequestProcessResultCode::kSuccess) {
@ -1159,10 +1200,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
return;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
uint64_t request_id = AddMessageTrack(node_manager_.server_num());
std::unordered_map<uint32_t, VectorPtr> outputs;
set_message_callback(request_id, [&]() {
@ -1185,19 +1223,29 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
if (!res) {
ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The disable FLS is timeout.");
resp->ErrorResponse(HTTP_BADREQUEST, status);
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
return;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
nlohmann::json js;
js["message"] = "start disabling FL-Server successful.";
js["code"] = kSuccessCode;
for (const auto &output : outputs) {
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
js["result"][output.first] = data;
nlohmann::json dataJson = nlohmann::json::parse(data);
if (!dataJson["result"]) {
res = false;
break;
}
}
nlohmann::json js;
if (res) {
js["message"] = "start disabling FL-Server successful.";
js["code"] = kSuccessCode;
js["result"] = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
} else {
js["message"] = "start disabling FL-Server failed.";
js["code"] = kErrorCode;
js["result"] = false;
}
resp->AddRespString(js.dump());
resp->AddRespHeadParam("Content-Type", "application/json");
@ -1207,11 +1255,16 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
RequestProcessResult SchedulerNode::CheckIfClusterReady() {
RequestProcessResult result(RequestProcessResultCode::kSuccess);
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY || CheckIfNodeDisconnected()) {
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
std::string message = "The cluster is not ready.";
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
return result;
}
if (CheckIfNodeDisconnected()) {
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, kClusterNotReady);
return result;
}
return result;
}

View File

@ -76,6 +76,11 @@ public class Common {
*/
public static final String SAFE_MOD = "The cluster is in safemode.";
/**
* The tag when server is not ready.
*/
public static final String NOT_READY = "The server's training job is not ready.";
/**
* The tag when server is not ready.
*/
@ -328,6 +333,9 @@ public class Common {
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
}
return false;
} else if (messageStr.contains(NOT_READY)) {
LOGGER.info(Common.addTag("[isSeverReady] " + NOT_READY + ", need wait some time and request again"));
return false;
} else {
return true;
}