forked from mindspore-Ecosystem/mindspore
added get cluster state
This commit is contained in:
parent
fb5eea169b
commit
ad22ca8de1
|
@ -434,10 +434,10 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
|
|||
wait_start_cond_.notify_all();
|
||||
}
|
||||
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_TIMEOUT) {
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
OnEventCallback(ClusterEvent::CLUSTER_TIMEOUT);
|
||||
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -166,6 +166,11 @@ bool CommUtil::IsFileExists(const std::string &file) {
|
|||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::string CommUtil::ClusterStateToString(const ClusterState &state) {
|
||||
MS_LOG(INFO) << "The cluster state:" << state;
|
||||
return kClusterState.at(state);
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#include <thread>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
|
@ -72,6 +73,16 @@ constexpr int kMessageChunkLength = 4096;
|
|||
constexpr int kConnectionTimeout = 120;
|
||||
constexpr char kLibeventLogPrefix[] = "[libevent log]:";
|
||||
|
||||
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
|
||||
const std::vector<std::string> kClusterState = {
|
||||
"ClUSTER_STARTING", // Initialization state when the cluster is just started.
|
||||
"CLUSTER_READY", // The state after all nodes are successfully registered.
|
||||
"CLUSTER_EXIT", // The state after the cluster exits successfully.
|
||||
"NODE_TIMEOUT", // When a node has a heartbeat timeout
|
||||
"CLUSTER_SCALE_OUT", // When the cluster is scale out.
|
||||
"CLUSTER_SCALE_IN" // When the cluster is scale in.
|
||||
};
|
||||
|
||||
class CommUtil {
|
||||
public:
|
||||
static bool CheckIpWithRegex(const std::string &ip);
|
||||
|
@ -86,6 +97,8 @@ class CommUtil {
|
|||
|
||||
// Check if the file exists.
|
||||
static bool IsFileExists(const std::string &file);
|
||||
// Convert cluster state to string when response the http request.
|
||||
static std::string ClusterStateToString(const ClusterState &state);
|
||||
|
||||
private:
|
||||
static std::random_device rd;
|
||||
|
|
|
@ -27,7 +27,6 @@ namespace ps {
|
|||
namespace core {
|
||||
// Events reported to the business layer, include cluster event and node event.
|
||||
enum class ClusterEvent {
|
||||
CLUSTER_TIMEOUT = 0,
|
||||
NODE_TIMEOUT = 1,
|
||||
SCHEDULER_TIMEOUT = 2,
|
||||
READY_FOR_SCALE_OUT = 3,
|
||||
|
|
|
@ -119,7 +119,7 @@ void NodeManager::UpdateCluster() {
|
|||
}
|
||||
}
|
||||
if (!timeout_nodes_info_.empty()) {
|
||||
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
|
||||
UpdateClusterState(ClusterState::NODE_TIMEOUT);
|
||||
for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) {
|
||||
finish_nodes_id_.insert(it->first);
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ void NodeManager::UpdateCluster() {
|
|||
// 2. update cluster finish state
|
||||
if (SizeToInt(finish_nodes_id_.size()) == total_node_num_ ||
|
||||
SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
|
||||
UpdateClusterState(ClusterState::CLUSTER_FINISH);
|
||||
UpdateClusterState(ClusterState::CLUSTER_EXIT);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,7 +139,7 @@ void NodeManager::CheckClusterTimeout() {
|
|||
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
|
||||
<< nodes_info_.size();
|
||||
current_node_num_ = nodes_info_.size();
|
||||
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
|
||||
UpdateClusterState(ClusterState::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -83,17 +83,15 @@ enum NodeState {
|
|||
NODE_STARTING = 0;
|
||||
NODE_FINISH = 1;
|
||||
NODE_READY = 2;
|
||||
NODE_TIMEOUT = 3;
|
||||
}
|
||||
|
||||
enum ClusterState {
|
||||
ClUSTER_STARTING = 0;
|
||||
CLUSTER_READY = 1;
|
||||
CLUSTER_FINISH = 2;
|
||||
CLUSTER_TIMEOUT = 3;
|
||||
CLUSTER_EXIT = 2;
|
||||
NODE_TIMEOUT = 3;
|
||||
CLUSTER_SCALE_OUT = 4;
|
||||
CLUSTER_SCALE_IN = 5;
|
||||
CLUSTER_FAILURE = 6;
|
||||
}
|
||||
|
||||
message HeartbeatRespMessage {
|
||||
|
|
|
@ -116,7 +116,7 @@ message OneClientNoises {
|
|||
|
||||
message ClientShareStr {
|
||||
string fl_id = 1;
|
||||
bytes share = 2; // todo: verify the correctness
|
||||
bytes share = 2;
|
||||
int32 index = 3;
|
||||
}
|
||||
|
||||
|
|
|
@ -160,7 +160,7 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
|
|||
SendFinish(client);
|
||||
}
|
||||
is_finish_ = true;
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_FINISH);
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
@ -316,7 +316,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
|
|||
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
||||
node_manager_.UpdateCluster();
|
||||
|
||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_FINISH) {
|
||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * 2));
|
||||
is_finish_ = true;
|
||||
|
@ -435,20 +435,9 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
}
|
||||
|
||||
/*
|
||||
* The body format is:
|
||||
* The response body format.
|
||||
* {
|
||||
* "node_ids": [
|
||||
* {
|
||||
* "node_id": "423ljjfslkj5",
|
||||
* "rank_id": "0",
|
||||
* "role": "SERVER"
|
||||
* },
|
||||
* {
|
||||
* "node_id": "jklj3424kljj",
|
||||
* "rank_id": "1",
|
||||
* "role": "WORKER"
|
||||
* }
|
||||
* ]
|
||||
* "node_ids": ["node_id1", "node_id2"]
|
||||
* }
|
||||
*/
|
||||
void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
||||
|
@ -471,6 +460,12 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
return;
|
||||
}
|
||||
|
||||
status = CheckIfNodeIdLegal(scale_in_node_ids);
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids;
|
||||
|
||||
std::unordered_map<std::string, bool> scale_in_nodes;
|
||||
|
@ -518,17 +513,17 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
}
|
||||
|
||||
/*
|
||||
* The return body format is:
|
||||
* The response body format.
|
||||
* {
|
||||
* "message": "Get nodes info successful.",
|
||||
* "node_ids": [
|
||||
* {
|
||||
* "node_id": "423ljjfslkj5",
|
||||
* "node_id": "node_id1",
|
||||
* "rank_id": "0",
|
||||
* "role": "SERVER"
|
||||
* },
|
||||
* {
|
||||
* "node_id": "jklj3424kljj",
|
||||
* "node_id": "node_id2",
|
||||
* "rank_id": "1",
|
||||
* "role": "WORKER"
|
||||
* }
|
||||
|
@ -536,8 +531,6 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
* }
|
||||
*/
|
||||
void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp) {
|
||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||
|
||||
nlohmann::json js;
|
||||
js["message"] = "Get nodes info successful.";
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
|
@ -555,6 +548,25 @@ void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp
|
|||
resp->SendResponse();
|
||||
}
|
||||
|
||||
/*
|
||||
* The response body format.
|
||||
* {
|
||||
* "message": "Get cluster state successful.",
|
||||
* "cluster_state": "CLUSTER_READY"
|
||||
* }
|
||||
*/
|
||||
void SchedulerNode::ProcessGetClusterState(std::shared_ptr<HttpMessageHandler> resp) {
|
||||
nlohmann::json js;
|
||||
js["message"] = "Get cluster state successful.";
|
||||
auto cluster_state = node_manager_.GetClusterState();
|
||||
js["cluster_state"] = CommUtil::ClusterStateToString(cluster_state);
|
||||
|
||||
resp->AddRespString(js.dump());
|
||||
|
||||
resp->SetRespCode(HTTP_OK);
|
||||
resp->SendResponse();
|
||||
}
|
||||
|
||||
RequestProcessResult SchedulerNode::CheckIfClusterReady() {
|
||||
RequestProcessResult result(RequestProcessResultCode::kSuccess);
|
||||
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
|
||||
|
@ -565,6 +577,35 @@ RequestProcessResult SchedulerNode::CheckIfClusterReady() {
|
|||
return result;
|
||||
}
|
||||
|
||||
RequestProcessResult SchedulerNode::CheckIfNodeIdLegal(const std::vector<std::string> &node_ids) {
|
||||
RequestProcessResult result(RequestProcessResultCode::kSuccess);
|
||||
if (node_ids.size() == 0) {
|
||||
std::string message = "The node ids should not be empty.";
|
||||
ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
|
||||
return result;
|
||||
}
|
||||
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
|
||||
for (auto val : node_ids) {
|
||||
if (!node_infos.count(val)) {
|
||||
std::string message = "The node id:" + val + " is illegal.";
|
||||
MS_LOG(ERROR) << message;
|
||||
ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
|
||||
return result;
|
||||
}
|
||||
|
||||
if (node_infos[val].node_role_ == NodeRole::SERVER && node_infos[val].rank_id_ == 0) {
|
||||
std::string error_message = "The node id:" + val + " is rank 0 of server, should not be scale in.";
|
||||
MS_LOG(ERROR) << error_message;
|
||||
ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, error_message);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num) {
|
||||
MS_LOG(INFO) << "Scheduler start https server.";
|
||||
http_server_ = std::make_shared<HttpServer>(address, port, thread_num);
|
||||
|
@ -581,6 +622,10 @@ void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t
|
|||
callbacks_["/nodes"] = nodes;
|
||||
http_server_->RegisterRoute("/nodes", &callbacks_["/nodes"]);
|
||||
|
||||
OnRequestReceive cluster_state = std::bind(&SchedulerNode::ProcessGetClusterState, this, std::placeholders::_1);
|
||||
callbacks_["/state"] = cluster_state;
|
||||
http_server_->RegisterRoute("/state", &callbacks_["/state"]);
|
||||
|
||||
http_server_->InitServer();
|
||||
|
||||
http_server_->Start();
|
||||
|
|
|
@ -108,9 +108,15 @@ class SchedulerNode : public Node {
|
|||
// Handle the get nodes info http request Synchronously.
|
||||
void ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp);
|
||||
|
||||
// Handle the get cluster state http request Synchronously.
|
||||
void ProcessGetClusterState(std::shared_ptr<HttpMessageHandler> resp);
|
||||
|
||||
// check whether the cluster is in the ready state.
|
||||
RequestProcessResult CheckIfClusterReady();
|
||||
|
||||
// check whether the node id is legal.
|
||||
RequestProcessResult CheckIfNodeIdLegal(const std::vector<std::string> &node_ids);
|
||||
|
||||
void StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num = 10);
|
||||
void StopRestfulServer();
|
||||
|
||||
|
|
|
@ -48,10 +48,6 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
|
|||
|
||||
InitOptimInfoBuilders();
|
||||
server_node_->set_handler(*handler_);
|
||||
server_node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [this]() {
|
||||
MS_LOG(ERROR) << "Trigger timeout event: CLUSTER_TIMEOUT begin to exit the system!";
|
||||
this->Finalize();
|
||||
});
|
||||
server_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
||||
MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
|
||||
this->Finalize();
|
||||
|
|
|
@ -202,14 +202,6 @@ void Server::RegisterCommCallbacks() {
|
|||
|
||||
void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator->RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
|
||||
"started during network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
|
|
|
@ -33,11 +33,6 @@ void Worker::Run() {
|
|||
|
||||
Initialize();
|
||||
|
||||
worker_node_.RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [this]() {
|
||||
MS_LOG(ERROR) << "Trigger timeout event: CLUSTER_TIMEOUT begin to exit the system!";
|
||||
this->Finalize();
|
||||
exit(0);
|
||||
});
|
||||
worker_node_.RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
||||
MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
|
||||
this->Finalize();
|
||||
|
|
Loading…
Reference in New Issue