added get cluster state

This commit is contained in:
chendongsheng 2021-06-02 16:38:30 +08:00
parent fb5eea169b
commit ad22ca8de1
12 changed files with 97 additions and 48 deletions

View File

@ -434,10 +434,10 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
wait_start_cond_.notify_all();
}
if (current_cluster_state_ == ClusterState::CLUSTER_TIMEOUT) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
is_ready_ = true;
wait_start_cond_.notify_all();
OnEventCallback(ClusterEvent::CLUSTER_TIMEOUT);
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
}
}

View File

@ -166,6 +166,11 @@ bool CommUtil::IsFileExists(const std::string &file) {
return true;
}
}
std::string CommUtil::ClusterStateToString(const ClusterState &state) {
MS_LOG(INFO) << "The cluster state:" << state;
return kClusterState.at(state);
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -48,6 +48,7 @@
#include <thread>
#include <fstream>
#include <iostream>
#include <vector>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@ -72,6 +73,16 @@ constexpr int kMessageChunkLength = 4096;
constexpr int kConnectionTimeout = 120;
constexpr char kLibeventLogPrefix[] = "[libevent log]:";
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
const std::vector<std::string> kClusterState = {
"ClUSTER_STARTING", // Initialization state when the cluster is just started.
"CLUSTER_READY", // The state after all nodes are successfully registered.
"CLUSTER_EXIT", // The state after the cluster exits successfully.
"NODE_TIMEOUT", // When a node has a heartbeat timeout
"CLUSTER_SCALE_OUT", // When the cluster is scale out.
"CLUSTER_SCALE_IN" // When the cluster is scale in.
};
class CommUtil {
public:
static bool CheckIpWithRegex(const std::string &ip);
@ -86,6 +97,8 @@ class CommUtil {
// Check if the file exists.
static bool IsFileExists(const std::string &file);
// Convert cluster state to string when response the http request.
static std::string ClusterStateToString(const ClusterState &state);
private:
static std::random_device rd;

View File

@ -27,7 +27,6 @@ namespace ps {
namespace core {
// Events reported to the business layer, include cluster event and node event.
enum class ClusterEvent {
CLUSTER_TIMEOUT = 0,
NODE_TIMEOUT = 1,
SCHEDULER_TIMEOUT = 2,
READY_FOR_SCALE_OUT = 3,

View File

@ -119,7 +119,7 @@ void NodeManager::UpdateCluster() {
}
}
if (!timeout_nodes_info_.empty()) {
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
UpdateClusterState(ClusterState::NODE_TIMEOUT);
for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) {
finish_nodes_id_.insert(it->first);
}
@ -128,7 +128,7 @@ void NodeManager::UpdateCluster() {
// 2. update cluster finish state
if (SizeToInt(finish_nodes_id_.size()) == total_node_num_ ||
SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
UpdateClusterState(ClusterState::CLUSTER_FINISH);
UpdateClusterState(ClusterState::CLUSTER_EXIT);
}
}
@ -139,7 +139,7 @@ void NodeManager::CheckClusterTimeout() {
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
<< nodes_info_.size();
current_node_num_ = nodes_info_.size();
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
UpdateClusterState(ClusterState::NODE_TIMEOUT);
}
}

View File

@ -83,17 +83,15 @@ enum NodeState {
NODE_STARTING = 0;
NODE_FINISH = 1;
NODE_READY = 2;
NODE_TIMEOUT = 3;
}
enum ClusterState {
ClUSTER_STARTING = 0;
CLUSTER_READY = 1;
CLUSTER_FINISH = 2;
CLUSTER_TIMEOUT = 3;
CLUSTER_EXIT = 2;
NODE_TIMEOUT = 3;
CLUSTER_SCALE_OUT = 4;
CLUSTER_SCALE_IN = 5;
CLUSTER_FAILURE = 6;
}
message HeartbeatRespMessage {

View File

@ -116,7 +116,7 @@ message OneClientNoises {
message ClientShareStr {
string fl_id = 1;
bytes share = 2; // todo: verify the correctness
bytes share = 2;
int32 index = 3;
}

View File

@ -160,7 +160,7 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
SendFinish(client);
}
is_finish_ = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_FINISH);
node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
wait_finish_cond_.notify_all();
}
}
@ -316,7 +316,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
node_manager_.UpdateCluster();
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_FINISH) {
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
std::this_thread::sleep_for(
std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * 2));
is_finish_ = true;
@ -435,20 +435,9 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
}
/*
* The body format is:
* The response body format.
* {
* "node_ids": [
* {
* "node_id": "423ljjfslkj5",
* "rank_id": "0",
* "role": "SERVER"
* },
* {
* "node_id": "jklj3424kljj",
* "rank_id": "1",
* "role": "WORKER"
* }
* ]
* "node_ids": ["node_id1", "node_id2"]
* }
*/
void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
@ -471,6 +460,12 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
return;
}
status = CheckIfNodeIdLegal(scale_in_node_ids);
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids;
std::unordered_map<std::string, bool> scale_in_nodes;
@ -518,17 +513,17 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
}
/*
* The return body format is:
* The response body format.
* {
* "message": "Get nodes info successful.",
* "node_ids": [
* {
* "node_id": "423ljjfslkj5",
* "node_id": "node_id1",
* "rank_id": "0",
* "role": "SERVER"
* },
* {
* "node_id": "jklj3424kljj",
* "node_id": "node_id2",
* "rank_id": "1",
* "role": "WORKER"
* }
@ -536,8 +531,6 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
* }
*/
void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp) {
RequestProcessResult status(RequestProcessResultCode::kSuccess);
nlohmann::json js;
js["message"] = "Get nodes info successful.";
auto node_infos = node_manager_.nodes_info();
@ -555,6 +548,25 @@ void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp
resp->SendResponse();
}
/*
* The response body format.
* {
* "message": "Get cluster state successful.",
* "cluster_state": "CLUSTER_READY"
* }
*/
void SchedulerNode::ProcessGetClusterState(std::shared_ptr<HttpMessageHandler> resp) {
nlohmann::json js;
js["message"] = "Get cluster state successful.";
auto cluster_state = node_manager_.GetClusterState();
js["cluster_state"] = CommUtil::ClusterStateToString(cluster_state);
resp->AddRespString(js.dump());
resp->SetRespCode(HTTP_OK);
resp->SendResponse();
}
RequestProcessResult SchedulerNode::CheckIfClusterReady() {
RequestProcessResult result(RequestProcessResultCode::kSuccess);
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
@ -565,6 +577,35 @@ RequestProcessResult SchedulerNode::CheckIfClusterReady() {
return result;
}
RequestProcessResult SchedulerNode::CheckIfNodeIdLegal(const std::vector<std::string> &node_ids) {
RequestProcessResult result(RequestProcessResultCode::kSuccess);
if (node_ids.size() == 0) {
std::string message = "The node ids should not be empty.";
ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
return result;
}
auto node_infos = node_manager_.nodes_info();
for (auto val : node_ids) {
if (!node_infos.count(val)) {
std::string message = "The node id:" + val + " is illegal.";
MS_LOG(ERROR) << message;
ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
return result;
}
if (node_infos[val].node_role_ == NodeRole::SERVER && node_infos[val].rank_id_ == 0) {
std::string error_message = "The node id:" + val + " is rank 0 of server, should not be scale in.";
MS_LOG(ERROR) << error_message;
ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, error_message);
return result;
}
}
return result;
}
void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num) {
MS_LOG(INFO) << "Scheduler start https server.";
http_server_ = std::make_shared<HttpServer>(address, port, thread_num);
@ -581,6 +622,10 @@ void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t
callbacks_["/nodes"] = nodes;
http_server_->RegisterRoute("/nodes", &callbacks_["/nodes"]);
OnRequestReceive cluster_state = std::bind(&SchedulerNode::ProcessGetClusterState, this, std::placeholders::_1);
callbacks_["/state"] = cluster_state;
http_server_->RegisterRoute("/state", &callbacks_["/state"]);
http_server_->InitServer();
http_server_->Start();

View File

@ -108,9 +108,15 @@ class SchedulerNode : public Node {
// Handle the get nodes info http request Synchronously.
void ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp);
// Handle the get cluster state http request Synchronously.
void ProcessGetClusterState(std::shared_ptr<HttpMessageHandler> resp);
// check whether the cluster is in the ready state.
RequestProcessResult CheckIfClusterReady();
// check whether the node id is legal.
RequestProcessResult CheckIfNodeIdLegal(const std::vector<std::string> &node_ids);
void StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num = 10);
void StopRestfulServer();

View File

@ -48,10 +48,6 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
InitOptimInfoBuilders();
server_node_->set_handler(*handler_);
server_node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [this]() {
MS_LOG(ERROR) << "Trigger timeout event: CLUSTER_TIMEOUT begin to exit the system!";
this->Finalize();
});
server_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
this->Finalize();

View File

@ -202,14 +202,6 @@ void Server::RegisterCommCallbacks() {
void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator->RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
"started during network building phase.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
});
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),

View File

@ -33,11 +33,6 @@ void Worker::Run() {
Initialize();
worker_node_.RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [this]() {
MS_LOG(ERROR) << "Trigger timeout event: CLUSTER_TIMEOUT begin to exit the system!";
this->Finalize();
exit(0);
});
worker_node_.RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
this->Finalize();