!32365 fix issue I4Z7WC、I502P8、I502TN、I5031D、I503MS、I503SO、I502L2

Merge pull request !32365 from tan-wei-cheng-3260/r1.6-develop3
This commit is contained in:
i-robot 2022-03-31 14:59:37 +00:00 committed by Gitee
commit 613939605c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 199 additions and 72 deletions

View File

@ -32,7 +32,7 @@ void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &
MS_EXCEPTION_IF_NULL(server_node); MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node; server_node_ = server_node;
rank_id_ = server_node_->rank_id(); rank_id_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num(); server_num_ = server_node->server_num();
return; return;
} }

View File

@ -29,7 +29,7 @@ void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerN
MS_EXCEPTION_IF_NULL(server_node); MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node; server_node_ = server_node;
local_rank_ = server_node_->rank_id(); local_rank_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num(); server_num_ = server_node->server_num();
counting_server_rank_ = counting_server_rank; counting_server_rank_ = counting_server_rank;
return; return;
} }
@ -109,6 +109,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
} }
if (!TriggerCounterEvent(name, reason)) { if (!TriggerCounterEvent(name, reason)) {
MS_LOG(WARNING) << "Leader server trigger count event failed."; MS_LOG(WARNING) << "Leader server trigger count event failed.";
Iteration::GetInstance().NotifyNext(false, *reason);
return false; return false;
} }
} else { } else {

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "fl/server/iteration.h"
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -26,7 +27,7 @@ void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::Server
MS_EXCEPTION_IF_NULL(server_node); MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node; server_node_ = server_node;
local_rank_ = server_node_->rank_id(); local_rank_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num(); server_num_ = server_node->server_num();
InitHashRing(); InitHashRing();
return; return;
} }
@ -109,6 +110,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
if (reason != nullptr) { if (reason != nullptr) {
*reason = kNetworkError; *reason = kNetworkError;
} }
Iteration::GetInstance().NotifyNext(false, *reason);
return false; return false;
} }

View File

@ -323,8 +323,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
} }
if (iteration_num_ == 1) { if (iteration_num_ == 1) {
MS_LOG(INFO) << "This is just the first iteration."; *result = "This is just the first iteration, do not need to new instance.";
return true; MS_LOG(WARNING) << *result;
return false;
} }
// Start new server instance. // Start new server instance.

View File

@ -25,7 +25,10 @@ namespace fl {
namespace server { namespace server {
class Server; class Server;
class Iteration; class Iteration;
std::atomic<uint32_t> kPrintTimes = 0; std::atomic<uint32_t> kJobNotReadyPrintTimes = 0;
std::atomic<uint32_t> kJobNotAvailablePrintTimes = 0;
std::atomic<uint32_t> kClusterSafeModePrintTimes = 0;
const uint32_t kPrintTimesThreshold = 3000; const uint32_t kPrintTimesThreshold = 3000;
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count, Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
bool server_num_as_threshold) bool server_num_as_threshold)
@ -133,8 +136,6 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) { void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
std::string reason = ""; std::string reason = "";
if (!IsServerAvailable(&reason)) { if (!IsServerAvailable(&reason)) {
if (!message->SendResponse(reason.c_str(), reason.size())) { if (!message->SendResponse(reason.c_str(), reason.size())) {
@ -143,6 +144,8 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
} }
return; return;
} }
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
(void)(Iteration::GetInstance().running_round_num_++); (void)(Iteration::GetInstance().running_round_num_++);
bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message); bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message);
// Must send response back no matter what value Launch method returns. // Must send response back no matter what value Launch method returns.
@ -201,25 +204,35 @@ bool Round::IsServerAvailable(std::string *reason) {
return true; return true;
} }
if (!Server::GetInstance().IsReady()) {
if (kJobNotReadyPrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The server's training job is not ready, please retry " + name_ + " later.";
kJobNotReadyPrintTimes = 0;
}
kJobNotReadyPrintTimes += 1;
*reason = ps::kJobNotReady;
return false;
}
// If the server state is Disable or Finish, refuse the request. // If the server state is Disable or Finish, refuse the request.
if (Iteration::GetInstance().instance_state() == InstanceState::kDisable || if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
Iteration::GetInstance().instance_state() == InstanceState::kFinish) { Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
if (kPrintTimes % kPrintTimesThreshold == 0) { if (kJobNotAvailablePrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later."; MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
kPrintTimes = 0; kJobNotAvailablePrintTimes = 0;
} }
kPrintTimes += 1; kJobNotAvailablePrintTimes += 1;
*reason = ps::kJobNotAvailable; *reason = ps::kJobNotAvailable;
return false; return false;
} }
// If the server is still in safemode, reject the request. // If the server is still in safemode, reject the request.
if (Server::GetInstance().IsSafeMode()) { if (Server::GetInstance().IsSafeMode()) {
if (kPrintTimes % kPrintTimesThreshold == 0) { if (kClusterSafeModePrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later."; MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
kPrintTimes = 0; kClusterSafeModePrintTimes = 0;
} }
kPrintTimes += 1; kClusterSafeModePrintTimes += 1;
*reason = ps::kClusterSafeMode; *reason = ps::kClusterSafeMode;
return false; return false;
} }

View File

@ -89,6 +89,7 @@ void Server::Run() {
Recover(); Recover();
MS_LOG(INFO) << "Server started successfully."; MS_LOG(INFO) << "Server started successfully.";
safemode_ = false; safemode_ = false;
is_ready_ = true;
lock.unlock(); lock.unlock();
// Wait communicators to stop so the main thread is blocked. // Wait communicators to stop so the main thread is blocked.
@ -461,6 +462,17 @@ void Server::StartCommunicator() {
return; return;
} }
MS_LOG(INFO) << "Start communicator with worker.";
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
if (typeid(*communicator.get()) != typeid(ps::core::TcpCommunicator)) {
if (!communicator->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
}
}
});
MS_EXCEPTION_IF_NULL(server_node_); MS_EXCEPTION_IF_NULL(server_node_);
MS_EXCEPTION_IF_NULL(communicator_with_server_); MS_EXCEPTION_IF_NULL(communicator_with_server_);
MS_LOG(INFO) << "Start communicator with server."; MS_LOG(INFO) << "Start communicator with server.";
@ -472,15 +484,6 @@ void Server::StartCommunicator() {
CollectiveOpsImpl::GetInstance().Initialize(server_node_); CollectiveOpsImpl::GetInstance().Initialize(server_node_);
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank); DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
MS_LOG(INFO) << "This server rank is " << server_node_->rank_id(); MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
MS_LOG(INFO) << "Start communicator with worker.";
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
if (!communicator->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
}
});
} }
void Server::Recover() { void Server::Recover() {
@ -695,6 +698,8 @@ void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::Mess
} }
} }
} }
bool Server::IsReady() const { return is_ready_.load(); }
} // namespace server } // namespace server
} // namespace fl } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -78,6 +78,8 @@ class Server {
bool SubmitTask(std::function<void()> &&task); bool SubmitTask(std::function<void()> &&task);
bool IsReady() const;
private: private:
Server() Server()
: server_node_(nullptr), : server_node_(nullptr),
@ -111,7 +113,8 @@ class Server {
cipher_get_list_sign_cnt_(0), cipher_get_list_sign_cnt_(0),
minimum_clients_for_reconstruct(0), minimum_clients_for_reconstruct(0),
minimum_secret_shares_for_reconstruct(0), minimum_secret_shares_for_reconstruct(0),
cipher_time_window_(0) {} cipher_time_window_(0),
is_ready_(false) {}
~Server() = default; ~Server() = default;
Server(const Server &) = delete; Server(const Server &) = delete;
Server &operator=(const Server &) = delete; Server &operator=(const Server &) = delete;
@ -249,6 +252,9 @@ class Server {
size_t minimum_clients_for_reconstruct; size_t minimum_clients_for_reconstruct;
size_t minimum_secret_shares_for_reconstruct; size_t minimum_secret_shares_for_reconstruct;
uint64_t cipher_time_window_; uint64_t cipher_time_window_;
// The flag that represents whether server is starting successful.
std::atomic_bool is_ready_;
}; };
} // namespace server } // namespace server
} // namespace fl } // namespace fl

View File

@ -258,6 +258,10 @@ using BarrierBeforeScaleIn = std::function<void(void)>;
using HandlerAfterScaleOut = std::function<void(void)>; using HandlerAfterScaleOut = std::function<void(void)>;
using HandlerAfterScaleIn = std::function<void(void)>; using HandlerAfterScaleIn = std::function<void(void)>;
constexpr char kClusterNotReady[] =
"The Scheduler's connections are not equal with total node num, Maybe this is because some server nodes are drop "
"out or scale in nodes has not been recycled.";
constexpr char kJobNotReady[] = "The server's training job is not ready.";
constexpr char kClusterSafeMode[] = "The cluster is in safemode."; constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished."; constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";

View File

@ -34,6 +34,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
register_message.set_role(node_info_.node_role_); register_message.set_role(node_info_.node_role_);
register_message.set_ip(node_info_.ip_); register_message.set_ip(node_info_.ip_);
register_message.set_port(node_info_.port_); register_message.set_port(node_info_.port_);
register_message.set_fl_iteration_num(PSContext::instance()->fl_iteration_num());
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!"; << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
@ -735,8 +736,6 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
if (heartbeat_resp_message.cluster_state() != current_cluster_state_ && if (heartbeat_resp_message.cluster_state() != current_cluster_state_ &&
current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN && current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN &&
current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) { current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) {
MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to "
<< CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state());
UpdateClusterState(heartbeat_resp_message.cluster_state()); UpdateClusterState(heartbeat_resp_message.cluster_state());
} }
MS_LOG(DEBUG) << "The current cluster state from heartbeat:" MS_LOG(DEBUG) << "The current cluster state from heartbeat:"

View File

@ -86,7 +86,7 @@ constexpr char kLibeventLogPrefix[] = "[libevent log]:";
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState // Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
const std::vector<std::string> kClusterState = { const std::vector<std::string> kClusterState = {
"ClUSTER_STARTING", // Initialization state when the cluster is just started. "CLUSTER_STARTING", // Initialization state when the cluster is just started.
"CLUSTER_READY", // The state after all nodes are successfully registered. "CLUSTER_READY", // The state after all nodes are successfully registered.
"CLUSTER_EXIT", // The state after the cluster exits successfully. "CLUSTER_EXIT", // The state after the cluster exits successfully.
"NODE_TIMEOUT", // When a node has a heartbeat timeout "NODE_TIMEOUT", // When a node has a heartbeat timeout

View File

@ -54,7 +54,7 @@ class Node {
is_already_finished_(false), is_already_finished_(false),
next_request_id_(0), next_request_id_(0),
current_node_state_(NodeState::NODE_STARTING), current_node_state_(NodeState::NODE_STARTING),
current_cluster_state_(ClusterState::ClUSTER_STARTING) {} current_cluster_state_(ClusterState::CLUSTER_STARTING) {}
virtual ~Node() = default; virtual ~Node() = default;
using MessageCallback = std::function<void()>; using MessageCallback = std::function<void()>;

View File

@ -49,9 +49,10 @@ struct NodeInfo {
NodeRole node_role_; NodeRole node_role_;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
uint32_t rank_id_; uint32_t rank_id_;
// After the node registration is successful, it is alive.If the node's heartbeat times out, then it is not alive // After the node registration is successful, it is alive.If the node's heartbeat times out, then it is not alive
bool is_alive; bool is_alive;
// the number of the fl job iteration
size_t fl_iteration_num_;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps

View File

@ -34,13 +34,17 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) { if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
const std::string &new_ip = register_message.ip(); const std::string &new_ip = register_message.ip();
uint32_t new_port = register_message.port(); uint32_t new_port = register_message.port();
uint32_t new_fl_iteration_num = register_message.fl_iteration_num();
rank_id = registered_nodes_info_[node_id].rank_id_; rank_id = registered_nodes_info_[node_id].rank_id_;
registered_nodes_info_[node_id].is_alive = true; registered_nodes_info_[node_id].is_alive = true;
registered_nodes_info_[node_id].ip_ = new_ip; registered_nodes_info_[node_id].ip_ = new_ip;
registered_nodes_info_[node_id].port_ = static_cast<uint16_t>(new_port); registered_nodes_info_[node_id].port_ = static_cast<uint16_t>(new_port);
registered_nodes_info_[node_id].fl_iteration_num_ = new_fl_iteration_num;
MS_LOG(WARNING) << "The node id: " << node_id << " is already assigned!" MS_LOG(WARNING) << "The node id: " << node_id << " is already assigned!"
<< ", ip: " << register_message.ip() << ", port: " << register_message.port() << ", ip: " << register_message.ip() << ", port: " << register_message.port()
<< ", rank id: " << rank_id << ", alive: " << registered_nodes_info_[node_id].is_alive << ", rank id: " << rank_id << ", alive: " << registered_nodes_info_[node_id].is_alive
<< ", fl iteration num: " << new_fl_iteration_num
<< ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_); << ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_);
return rank_id; return rank_id;
} }
@ -51,14 +55,18 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) { if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
const std::string &new_ip = register_message.ip(); const std::string &new_ip = register_message.ip();
uint32_t new_port = register_message.port(); uint32_t new_port = register_message.port();
uint32_t new_fl_iteration_num = register_message.fl_iteration_num();
rank_id = recovery_node_infos[node_id].rank_id_; rank_id = recovery_node_infos[node_id].rank_id_;
recovery_node_infos[node_id].is_alive = true; recovery_node_infos[node_id].is_alive = true;
recovery_node_infos[node_id].ip_ = new_ip; recovery_node_infos[node_id].ip_ = new_ip;
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(new_port); recovery_node_infos[node_id].port_ = static_cast<uint16_t>(new_port);
registered_nodes_info_[node_id] = recovery_node_infos[node_id]; registered_nodes_info_[node_id] = recovery_node_infos[node_id];
registered_nodes_info_[node_id].fl_iteration_num_ = new_fl_iteration_num;
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!" MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
<< ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_ << ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_
<< ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive << ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive
<< ", fl iteration num: " << new_fl_iteration_num
<< ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_); << ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_);
return rank_id; return rank_id;
} }
@ -79,6 +87,7 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
} }
const std::string &node_id = register_message.node_id(); const std::string &node_id = register_message.node_id();
const size_t fl_iteration_num = register_message.fl_iteration_num();
// create new rank id // create new rank id
if (register_message.role() == NodeRole::SERVER) { if (register_message.role() == NodeRole::SERVER) {
const std::string &ip = register_message.ip(); const std::string &ip = register_message.ip();
@ -105,10 +114,11 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
node_info.ip_ = ip; node_info.ip_ = ip;
node_info.port_ = port; node_info.port_ = port;
node_info.is_alive = true; node_info.is_alive = true;
node_info.fl_iteration_num_ = fl_iteration_num;
registered_nodes_info_[node_id] = node_info; registered_nodes_info_[node_id] = node_info;
MS_LOG(INFO) << "The server node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port MS_LOG(INFO) << "The server node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
<< " assign rank id:" << rank_id << ", " << (meta_data_->server_num - next_server_rank_id_) << ", fl iteration num:" << fl_iteration_num << " assign rank id:" << rank_id << ", "
<< " servers still need to be registered."; << (meta_data_->server_num - next_server_rank_id_) << " servers still need to be registered.";
} else if (register_message.role() == NodeRole::WORKER) { } else if (register_message.role() == NodeRole::WORKER) {
const std::string &ip = register_message.ip(); const std::string &ip = register_message.ip();
uint32_t port = register_message.port(); uint32_t port = register_message.port();
@ -134,10 +144,11 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
node_info.ip_ = ip; node_info.ip_ = ip;
node_info.port_ = port; node_info.port_ = port;
node_info.is_alive = true; node_info.is_alive = true;
node_info.fl_iteration_num_ = fl_iteration_num;
registered_nodes_info_[node_id] = node_info; registered_nodes_info_[node_id] = node_info;
MS_LOG(INFO) << "The worker node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port MS_LOG(INFO) << "The worker node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port
<< " assign rank id:" << rank_id << ", " << (meta_data_->worker_num - next_worker_rank_id_) << ", fl iteration num:" << fl_iteration_num << " assign rank id:" << rank_id << ", "
<< " workers still need to be registered."; << (meta_data_->worker_num - next_worker_rank_id_) << " workers still need to be registered.";
} }
return rank_id; return rank_id;
} }
@ -178,7 +189,7 @@ std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
return servers_meta_list; return servers_meta_list;
} }
void NodeManager::UpdateCluster() { void NodeManager::UpdateCluster(bool is_cluster_ready) {
// 1. update cluster timeout state // 1. update cluster timeout state
struct timeval current_time {}; struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr); (void)gettimeofday(&current_time, nullptr);
@ -205,20 +216,26 @@ void NodeManager::UpdateCluster() {
} else if (SizeToUint(heartbeats_.size()) == total_node_num_) { } else if (SizeToUint(heartbeats_.size()) == total_node_num_) {
if (cluster_state_ == ClusterState::NODE_TIMEOUT) { if (cluster_state_ == ClusterState::NODE_TIMEOUT) {
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) { for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
if (registered_nodes_info_.count(it->first)) { if (registered_nodes_info_.count(it->first) && !it->second.is_alive) {
registered_nodes_info_[it->first].is_alive = true; MS_LOG(WARNING) << it->second.node_id_ << " is alive.";
it->second.is_alive = true;
} }
} }
if (onPersist_) { if (onPersist_) {
onPersist_(); onPersist_();
} }
UpdateClusterState(ClusterState::CLUSTER_READY); if (is_cluster_ready) {
UpdateClusterState(ClusterState::CLUSTER_READY);
} else {
UpdateClusterState(ClusterState::CLUSTER_STARTING);
}
} }
} }
// 2. update cluster finish state // 2. update cluster finish state
if (SizeToUint(finish_nodes_id_.size()) == total_node_num_ && if (SizeToUint(finish_nodes_id_.size()) == total_node_num_ &&
PSContext::instance()->server_mode() != kServerModeHybrid) { PSContext::instance()->server_mode() != kServerModeHybrid &&
PSContext::instance()->server_mode() != kServerModeFL) {
UpdateClusterState(ClusterState::CLUSTER_EXIT); UpdateClusterState(ClusterState::CLUSTER_EXIT);
} }
} }
@ -330,6 +347,7 @@ bool NodeManager::IsWorker() const {
bool NodeManager::IsNodeRegistered(const std::string &node_id) { bool NodeManager::IsNodeRegistered(const std::string &node_id) {
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) { if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
MS_LOG(WARNING) << "The node id " << node_id << " has been registered.";
return true; return true;
} }
return false; return false;
@ -381,6 +399,18 @@ void NodeManager::set_next_server_rank_id(const uint32_t &next_server_rank_id) {
this->next_server_rank_id_ = next_server_rank_id; this->next_server_rank_id_ = next_server_rank_id;
} }
void NodeManager::setPersistCallback(const OnPersist &onPersist) { this->onPersist_ = onPersist; } void NodeManager::setPersistCallback(const OnPersist &onPersist) { this->onPersist_ = onPersist; }
bool NodeManager::VerifyClusterNodesParam() {
std::unordered_set<size_t> fl_iteration_num_set;
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
fl_iteration_num_set.insert(it->second.fl_iteration_num_);
}
if (fl_iteration_num_set.size() != 1) {
MS_LOG(ERROR) << "The server node fl iteration num is not inconsistent.";
return false;
}
return true;
}
} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -50,7 +50,7 @@ class NodeManager {
next_server_rank_id_(0), next_server_rank_id_(0),
meta_data_(nullptr), meta_data_(nullptr),
node_state_(NodeState::NODE_STARTING), node_state_(NodeState::NODE_STARTING),
cluster_state_(ClusterState::ClUSTER_STARTING) {} cluster_state_(ClusterState::CLUSTER_STARTING) {}
virtual ~NodeManager() = default; virtual ~NodeManager() = default;
using OnPersist = std::function<void()>; using OnPersist = std::function<void()>;
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes. // When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
@ -63,7 +63,7 @@ class NodeManager {
// Fetch metadata information of all nodes. // Fetch metadata information of all nodes.
std::vector<ServersMeta> FetchAllNodesMeta(); std::vector<ServersMeta> FetchAllNodesMeta();
void UpdateCluster(); void UpdateCluster(bool is_cluster_ready);
void AddFinishNode(const std::string &finish_message); void AddFinishNode(const std::string &finish_message);
// After the scheduler receives the scale_out_done node, it will save this node. // After the scheduler receives the scale_out_done node, it will save this node.
@ -135,6 +135,8 @@ class NodeManager {
bool IsAllNodesAlive() const; bool IsAllNodesAlive() const;
bool VerifyClusterNodesParam();
private: private:
std::mutex node_mutex_; std::mutex node_mutex_;
std::mutex cluster_mutex_; std::mutex cluster_mutex_;

View File

@ -100,6 +100,8 @@ message RegisterMessage {
string node_id = 3; string node_id = 3;
// the role of the node: worker,server,scheduler // the role of the node: worker,server,scheduler
NodeRole role = 4; NodeRole role = 4;
// the number of the fl job iteration
uint64 fl_iteration_num = 5;
} }
message RegisterRespMessage { message RegisterRespMessage {
@ -120,7 +122,7 @@ enum NodeState {
} }
enum ClusterState { enum ClusterState {
ClUSTER_STARTING = 0; CLUSTER_STARTING = 0;
CLUSTER_READY = 1; CLUSTER_READY = 1;
CLUSTER_EXIT = 2; CLUSTER_EXIT = 2;
NODE_TIMEOUT = 3; NODE_TIMEOUT = 3;

View File

@ -267,7 +267,6 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
return; return;
} }
MS_LOG(INFO) << "The node id is registered.";
if (connected_nodes_.count(node_id)) { if (connected_nodes_.count(node_id)) {
(void)connected_nodes_.erase(node_id); (void)connected_nodes_.erase(node_id);
} }
@ -300,12 +299,16 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
if (node_manager_.IsAllNodesRegistered()) { if (node_manager_.IsAllNodesRegistered()) {
if (!node_manager_.IsAllNodesAlive()) { if (!node_manager_.IsAllNodesAlive()) {
MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes are not alive."; MS_LOG(ERROR)
<< "Do not broadcast nodes info because some server nodes are not alive, and cluster will exit later.";
return; return;
} }
is_ready_ = true; if (!node_manager_.VerifyClusterNodesParam()) {
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num() MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes info are not inconsistent, and cluster "
<< " servers registered to scheduer, so the scheduler send meta data to worker/server."; "will exit later.";
return;
}
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
auto nodes = node_manager_.nodes_info(); auto nodes = node_manager_.nodes_info();
for (const auto &id : scale_in_node_ids_) { for (const auto &id : scale_in_node_ids_) {
@ -325,10 +328,14 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
auto node_infos = node_manager_.nodes_info(); auto node_infos = node_manager_.nodes_info();
bool res = SendPrepareBuildingNetwork(node_infos); bool res = SendPrepareBuildingNetwork(node_infos);
if (!res) { if (!res) {
MS_LOG(ERROR) << "Prepare for building network failed!"; MS_LOG(ERROR) << "Prepare for building network failed! Cluster will exit later.";
return; return;
} }
MS_LOG(INFO) << "Prepare for building network success."; is_ready_ = true;
MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and "
<< node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
for (const auto &kvs : node_infos) { for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second); auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client); MS_EXCEPTION_IF_NULL(client);
@ -609,7 +616,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT); node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
} }
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval)); std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
node_manager_.UpdateCluster(); node_manager_.UpdateCluster(is_ready_);
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) { if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
std::this_thread::sleep_for( std::this_thread::sleep_for(
@ -1020,15 +1027,27 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler>
return; return;
} }
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
nlohmann::json js;
js["message"] = "Start new instance successful.";
js["code"] = kSuccessCode;
for (const auto &output : outputs) { for (const auto &output : outputs) {
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size()); std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
js["result"][output.first] = data; nlohmann::json dataJson = nlohmann::json::parse(data);
if (!dataJson["result"]) {
res = false;
break;
}
} }
nlohmann::json js;
if (res) {
js["message"] = "Start new instance successful.";
js["code"] = kSuccessCode;
js["result"] = true;
} else {
js["message"] = "Start new instance failed.";
js["code"] = kErrorCode;
js["result"] = false;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
resp->AddRespString(js.dump()); resp->AddRespString(js.dump());
resp->AddRespHeadParam("Content-Type", "application/json"); resp->AddRespHeadParam("Content-Type", "application/json");
@ -1073,7 +1092,6 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandle
resp->ErrorResponse(HTTP_BADREQUEST, status); resp->ErrorResponse(HTTP_BADREQUEST, status);
return; return;
} }
nlohmann::json js; nlohmann::json js;
js["message"] = "Query Instance successful."; js["message"] = "Query Instance successful.";
js["code"] = kSuccessCode; js["code"] = kSuccessCode;
@ -1095,9 +1113,15 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
MS_EXCEPTION_IF_NULL(resp); MS_EXCEPTION_IF_NULL(resp);
RequestProcessResult status(RequestProcessResultCode::kSuccess); RequestProcessResult status(RequestProcessResultCode::kSuccess);
if (CheckIfNodeDisconnected()) {
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, kClusterNotReady);
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
status = CheckIfClusterReady(); if (node_manager_.GetClusterState() != ClusterState::CLUSTER_DISABLE_FLS) {
if (status != RequestProcessResultCode::kSuccess) { std::string message = "The cluster state is not CLUSTER_DISABLE_FLS, does not need to enable fls.";
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message);
resp->ErrorResponse(HTTP_BADREQUEST, status); resp->ErrorResponse(HTTP_BADREQUEST, status);
return; return;
} }
@ -1132,15 +1156,26 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
return; return;
} }
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
nlohmann::json js;
js["message"] = "start enabling FL-Server successful.";
js["code"] = kSuccessCode;
for (const auto &output : outputs) { for (const auto &output : outputs) {
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size()); std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
js["result"][output.first] = data; nlohmann::json dataJson = nlohmann::json::parse(data);
if (!dataJson["result"]) {
res = false;
break;
}
} }
nlohmann::json js;
if (res) {
js["message"] = "start enabling FL-Server successful.";
js["code"] = kSuccessCode;
js["result"] = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
} else {
js["message"] = "start enabling FL-Server failed.";
js["code"] = kErrorCode;
js["result"] = false;
}
resp->AddRespString(js.dump()); resp->AddRespString(js.dump());
resp->AddRespHeadParam("Content-Type", "application/json"); resp->AddRespHeadParam("Content-Type", "application/json");
@ -1152,6 +1187,12 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
MS_EXCEPTION_IF_NULL(resp); MS_EXCEPTION_IF_NULL(resp);
RequestProcessResult status(RequestProcessResultCode::kSuccess); RequestProcessResult status(RequestProcessResultCode::kSuccess);
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_DISABLE_FLS) {
std::string message = "The cluster state is already in CLUSTER_DISABLE_FLS.";
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message);
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
status = CheckIfClusterReady(); status = CheckIfClusterReady();
if (status != RequestProcessResultCode::kSuccess) { if (status != RequestProcessResultCode::kSuccess) {
@ -1159,10 +1200,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
return; return;
} }
node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
uint64_t request_id = AddMessageTrack(node_manager_.server_num()); uint64_t request_id = AddMessageTrack(node_manager_.server_num());
std::unordered_map<uint32_t, VectorPtr> outputs; std::unordered_map<uint32_t, VectorPtr> outputs;
set_message_callback(request_id, [&]() { set_message_callback(request_id, [&]() {
@ -1185,19 +1223,29 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
if (!res) { if (!res) {
ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The disable FLS is timeout."); ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The disable FLS is timeout.");
resp->ErrorResponse(HTTP_BADREQUEST, status); resp->ErrorResponse(HTTP_BADREQUEST, status);
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
return; return;
} }
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
nlohmann::json js;
js["message"] = "start disabling FL-Server successful.";
js["code"] = kSuccessCode;
for (const auto &output : outputs) { for (const auto &output : outputs) {
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size()); std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
js["result"][output.first] = data; nlohmann::json dataJson = nlohmann::json::parse(data);
if (!dataJson["result"]) {
res = false;
break;
}
} }
nlohmann::json js;
if (res) {
js["message"] = "start disabling FL-Server successful.";
js["code"] = kSuccessCode;
js["result"] = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
} else {
js["message"] = "start disabling FL-Server failed.";
js["code"] = kErrorCode;
js["result"] = false;
}
resp->AddRespString(js.dump()); resp->AddRespString(js.dump());
resp->AddRespHeadParam("Content-Type", "application/json"); resp->AddRespHeadParam("Content-Type", "application/json");
@ -1207,11 +1255,16 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
RequestProcessResult SchedulerNode::CheckIfClusterReady() { RequestProcessResult SchedulerNode::CheckIfClusterReady() {
RequestProcessResult result(RequestProcessResultCode::kSuccess); RequestProcessResult result(RequestProcessResultCode::kSuccess);
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY || CheckIfNodeDisconnected()) { if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
std::string message = "The cluster is not ready."; std::string message = "The cluster is not ready.";
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message); ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
return result; return result;
} }
if (CheckIfNodeDisconnected()) {
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, kClusterNotReady);
return result;
}
return result; return result;
} }

View File

@ -76,6 +76,11 @@ public class Common {
*/ */
public static final String SAFE_MOD = "The cluster is in safemode."; public static final String SAFE_MOD = "The cluster is in safemode.";
/**
* The tag when server is not ready.
*/
public static final String NOT_READY = "The server's training job is not ready.";
/** /**
* The tag when server is not ready. * The tag when server is not ready.
*/ */
@ -328,6 +333,9 @@ public class Common {
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration.")); LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
} }
return false; return false;
} else if (messageStr.contains(NOT_READY)) {
LOGGER.info(Common.addTag("[isSeverReady] " + NOT_READY + ", need wait some time and request again"));
return false;
} else { } else {
return true; return true;
} }