forked from mindspore-Ecosystem/mindspore
!17534 fixed scale out/in
From: @anancds Reviewed-by: @cristoval,@limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
6a66766a0e
|
@ -96,15 +96,57 @@ void AbstractNode::set_ready_for_scale_out() {
|
|||
}
|
||||
|
||||
void AbstractNode::set_ready_for_scale_in() {
|
||||
Register(client_to_scheduler_);
|
||||
connected_nodes_.clear();
|
||||
if (!is_current_node_scale_in_) {
|
||||
Register(client_to_scheduler_);
|
||||
connected_nodes_.clear();
|
||||
} else {
|
||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
|
||||
node_info_.rank_id_ = UINT_MAX;
|
||||
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
|
||||
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_IN_DONE);
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::set_scale_out_done() {
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
|
||||
|
||||
ScaleOutDoneMessage scale_out_done_message;
|
||||
scale_out_done_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
|
||||
scale_out_done_message.SerializeAsString().data(), scale_out_done_message.ByteSizeLong())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << "is send scale_out_done to scheduler!";
|
||||
}
|
||||
|
||||
void AbstractNode::set_scale_in_done() {
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
|
||||
|
||||
ScaleInDoneMessage scale_in_done_message;
|
||||
scale_in_done_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
|
||||
scale_in_done_message.SerializeAsString().data(), scale_in_done_message.ByteSizeLong())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << "is send scale_in_done to scheduler!";
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||
int command, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||
<< ", the server num:" << server_num_;
|
||||
}
|
||||
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
|
@ -127,7 +169,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
}
|
||||
for (size_t it = 0; it < rank_ids.size(); ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||
<< ", the server num:" << server_num_;
|
||||
}
|
||||
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
|
@ -152,7 +195,8 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|||
MS_EXCEPTION_IF_NULL(message);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||
<< ", the server num:" << server_num_;
|
||||
}
|
||||
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
|
@ -202,7 +246,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
|
||||
for (size_t it = 0; it < size; ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||
<< ", the server num:" << server_num_;
|
||||
}
|
||||
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
|
@ -227,7 +272,8 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
|
|||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||
<< ", the server num:" << server_num_;
|
||||
}
|
||||
|
||||
std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
|
||||
|
@ -243,7 +289,8 @@ std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum No
|
|||
const uint32_t &rank_id, VectorPtr *output) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||
<< ", the server num:" << server_num_;
|
||||
}
|
||||
|
||||
receive_callbacks_mutex_.lock();
|
||||
|
@ -347,10 +394,13 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
|
|||
heartbeat_resp_message.ParseFromArray(data, size);
|
||||
|
||||
current_cluster_state_ = heartbeat_resp_message.cluster_state();
|
||||
MS_LOG(DEBUG) << "The current cluster state from heartbeat:" << current_cluster_state_;
|
||||
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_READY) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_TIMEOUT && on_node_event_message_) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
|
@ -390,6 +440,10 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
|
|||
MS_EXCEPTION_IF_NULL(data);
|
||||
SendMetadataMessage send_meta_message;
|
||||
send_meta_message.ParseFromArray(data, size);
|
||||
worker_num_ = send_meta_message.worker_num();
|
||||
server_num_ = send_meta_message.server_num();
|
||||
MS_LOG(WARNING) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_;
|
||||
|
||||
nodes_address_.clear();
|
||||
for (const auto &it : send_meta_message.servers_meta()) {
|
||||
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||
|
@ -403,11 +457,13 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
|
|||
MS_LOG(WARNING) << "Trigger cluster scale out done event.";
|
||||
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
|
||||
}
|
||||
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
|
||||
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
|
||||
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_IN_DONE);
|
||||
}
|
||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
|
||||
MS_LOG(INFO) << "The current cluster state:" << current_cluster_state_;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
|
@ -420,6 +476,26 @@ void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::share
|
|||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
is_ready_ = true;
|
||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
is_ready_ = true;
|
||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
|
@ -428,8 +504,9 @@ void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::sha
|
|||
|
||||
ScaleOutMessage scale_out_message;
|
||||
scale_out_message.ParseFromArray(data, size);
|
||||
worker_num_ = scale_out_message.worker_num();
|
||||
server_num_ = scale_out_message.server_num();
|
||||
int32_t worker_num = scale_out_message.worker_num();
|
||||
int32_t server_num = scale_out_message.server_num();
|
||||
MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
|
||||
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_OUT);
|
||||
|
@ -445,8 +522,19 @@ void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shar
|
|||
|
||||
ScaleInMessage scale_in_message;
|
||||
scale_in_message.ParseFromArray(data, size);
|
||||
worker_num_ = scale_in_message.worker_num();
|
||||
server_num_ = scale_in_message.server_num();
|
||||
int32_t worker_num = scale_in_message.worker_num();
|
||||
int32_t server_num = scale_in_message.server_num();
|
||||
MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
|
||||
|
||||
is_current_node_scale_in_ = scale_in_message.is_node_scale_in();
|
||||
|
||||
if (is_current_node_scale_in_) {
|
||||
MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " is a scale in node!";
|
||||
} else {
|
||||
MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " is not a scale in node!";
|
||||
}
|
||||
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_IN);
|
||||
|
@ -665,6 +753,8 @@ void AbstractNode::InitCommandHandler() {
|
|||
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
|
||||
handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
|
||||
handlers_[NodeCommand::FINISH] = nullptr;
|
||||
handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr;
|
||||
handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
|
||||
}
|
||||
|
||||
void AbstractNode::InitServerHandler() {
|
||||
|
@ -673,6 +763,9 @@ void AbstractNode::InitServerHandler() {
|
|||
server_handler_[NodeCommand::SEND_DATA] = nullptr;
|
||||
server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
|
||||
server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
|
||||
server_handler_[NodeCommand::SCALE_IN] = &AbstractNode::ProcessScaleIn;
|
||||
server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
|
||||
server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
|
||||
}
|
||||
|
||||
void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
||||
|
|
|
@ -40,7 +40,8 @@ class AbstractNode : public Node {
|
|||
server_(nullptr),
|
||||
server_thread_(nullptr),
|
||||
worker_num_(-1),
|
||||
server_num_(-1) {}
|
||||
server_num_(-1),
|
||||
is_current_node_scale_in_(false) {}
|
||||
~AbstractNode() override = default;
|
||||
|
||||
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
@ -59,6 +60,12 @@ class AbstractNode : public Node {
|
|||
// When the business layer finish scale in, it should call this function
|
||||
void set_ready_for_scale_in();
|
||||
|
||||
// Send scale_out_done instructions to the scheduler.
|
||||
void set_scale_out_done();
|
||||
|
||||
// Send scale_in_done instructions to the scheduler.
|
||||
void set_scale_in_done();
|
||||
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
||||
|
@ -99,6 +106,13 @@ class AbstractNode : public Node {
|
|||
void ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
|
||||
// The worker/server processes the scale_out_done message from scheduelr
|
||||
void ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
// The worker/server processes the scale_in_done message from scheduelr
|
||||
void ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
void UpdateSchedulerTime();
|
||||
bool CheckSchedulerTimeout() const;
|
||||
|
@ -163,6 +177,9 @@ class AbstractNode : public Node {
|
|||
|
||||
int32_t worker_num_;
|
||||
int32_t server_num_;
|
||||
|
||||
// Identify whether the current node is a scale in node.
|
||||
std::atomic<bool> is_current_node_scale_in_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -30,7 +30,7 @@ bool CommunicatorBase::SendResponse(const void *rsp_data, size_t rsp_len, std::s
|
|||
}
|
||||
void CommunicatorBase::Join() {
|
||||
if (!running_thread_.joinable()) {
|
||||
MS_LOG(EXCEPTION) << "The running thread of communicator is not joinable.";
|
||||
MS_LOG(WARNING) << "The running thread of communicator is not joinable.";
|
||||
return;
|
||||
}
|
||||
running_thread_.join();
|
||||
|
|
|
@ -36,7 +36,8 @@ void LeaderScaler::ScaleOutAsync(const std::shared_ptr<TcpClient> &client, const
|
|||
MS_LOG(INFO) << "The scheduler is sending scale out to workers and servers!";
|
||||
}
|
||||
|
||||
void LeaderScaler::ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager) {
|
||||
void LeaderScaler::ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager,
|
||||
bool is_node_scale_in) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SCALE_IN);
|
||||
|
@ -44,6 +45,7 @@ void LeaderScaler::ScaleInAsync(const std::shared_ptr<TcpClient> &client, const
|
|||
ScaleInMessage scale_in_message;
|
||||
scale_in_message.set_worker_num(manager.worker_num());
|
||||
scale_in_message.set_server_num(manager.server_num());
|
||||
scale_in_message.set_is_node_scale_in(is_node_scale_in);
|
||||
|
||||
if (!node_->SendMessageSync(client, message_meta, Protos::PROTOBUF, scale_in_message.SerializeAsString().data(),
|
||||
scale_in_message.ByteSizeLong())) {
|
||||
|
|
|
@ -42,7 +42,7 @@ class LeaderScaler {
|
|||
// When the scheduler receives the scale out message, it will send this message to the workers and servers.
|
||||
void ScaleOutAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager);
|
||||
// When the scheduler receives the scale in message, it will send this message to the workers and servers.
|
||||
void ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager);
|
||||
void ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager, bool is_node_scale_in);
|
||||
|
||||
private:
|
||||
// The node_ will only be instantiated with scheduler node.
|
||||
|
|
|
@ -112,8 +112,8 @@ void NodeManager::UpdateCluster() {
|
|||
timeout_nodes_info_.clear();
|
||||
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
|
||||
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
|
||||
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
|
||||
if (nodes_info_.count(it->first)) {
|
||||
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
|
||||
timeout_nodes_info_[it->first] = nodes_info_[it->first];
|
||||
}
|
||||
}
|
||||
|
@ -145,10 +145,18 @@ void NodeManager::CheckClusterTimeout() {
|
|||
|
||||
void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
|
||||
|
||||
void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_done_nodes_id_.insert(node_id); }
|
||||
|
||||
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
|
||||
|
||||
bool NodeManager::IsAllNodesRegistered() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::IsAllNodesScaleOutDone() { return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::IsAllNodesScaleInDone() { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }
|
||||
|
||||
void NodeManager::UpdateNodeState(const NodeState &state) {
|
||||
|
@ -174,6 +182,7 @@ ClusterState NodeManager::GetClusterState() {
|
|||
void NodeManager::ResetMetadata() {
|
||||
MS_LOG(WARNING) << "Reset metadata.";
|
||||
nodes_info_.clear();
|
||||
heartbeats_.clear();
|
||||
next_worker_rank_id_ = -1;
|
||||
next_server_rank_id_ = -1;
|
||||
}
|
||||
|
|
|
@ -66,6 +66,11 @@ class NodeManager {
|
|||
void CheckClusterTimeout();
|
||||
void AddFinishNode(const std::string &finish_message);
|
||||
|
||||
// After the scheduler receives the scale_out_done node, it will save this node.
|
||||
void AddScaleOutDoneNode(const std::string &node_id);
|
||||
// After the scheduler receives the scale_in_done node, it will save this node.
|
||||
void AddScaleInDoneNode(const std::string &node_id);
|
||||
|
||||
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
|
||||
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
|
||||
bool IsAllNodesRegistered();
|
||||
|
@ -73,6 +78,13 @@ class NodeManager {
|
|||
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
|
||||
bool IsAllNodesFinished();
|
||||
|
||||
// When workers and servers send a scale_out_done message to the scheduler, the scheduler will collect the number of
|
||||
// nodes and Determine whether the nodes are equal to total_node_num_.
|
||||
bool IsAllNodesScaleOutDone();
|
||||
// When workers and servers send a scale_in_done message to the scheduler, the scheduler will collect the number of
|
||||
// nodes and Determine whether the nodes are equal to total_node_num_.
|
||||
bool IsAllNodesScaleInDone();
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> &nodes_info();
|
||||
|
||||
void set_total_node_num(const int32_t &node_num);
|
||||
|
@ -115,6 +127,11 @@ class NodeManager {
|
|||
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
|
||||
std::unordered_set<std::string> finish_nodes_id_;
|
||||
|
||||
// The scheduler aggregates scale_out_done messages from workers/servers
|
||||
std::unordered_set<std::string> scale_out_done_nodes_id_;
|
||||
// The scheduler aggregates scale_in_done messages from workers/servers
|
||||
std::unordered_set<std::string> scale_in_done_nodes_id_;
|
||||
|
||||
// Cluster metadata information can be dynamically changed
|
||||
std::unique_ptr<ClusterMetadata> meta_data_;
|
||||
|
||||
|
|
|
@ -29,8 +29,14 @@ enum NodeCommand {
|
|||
COLLECTIVE_SEND_DATA = 6;
|
||||
// The scheduler actively sends metadata to the worker and server
|
||||
SEND_METADATA = 7;
|
||||
// This command is used to start scale out
|
||||
SCALE_OUT = 8;
|
||||
// This command is used to start scale in
|
||||
SCALE_IN = 9;
|
||||
// This command is used to synchronize the scale out status of the cluster
|
||||
SCALE_OUT_DONE = 10;
|
||||
// This command is used to synchronize the scale in status of the cluster
|
||||
SCALE_IN_DONE = 11;
|
||||
}
|
||||
|
||||
enum NodeRole {
|
||||
|
@ -111,6 +117,10 @@ message ServersMeta {
|
|||
|
||||
message SendMetadataMessage {
|
||||
repeated ServersMeta servers_meta = 1;
|
||||
// the current worker number.
|
||||
int32 worker_num = 2;
|
||||
// the current server number.
|
||||
int32 server_num = 3;
|
||||
}
|
||||
|
||||
message FinishMessage {
|
||||
|
@ -133,8 +143,20 @@ message ScaleOutMessage {
|
|||
|
||||
// The scheduler will broadcast the worker/server numbers after scale in to all nodes.
|
||||
message ScaleInMessage {
|
||||
// the worker number after scale in
|
||||
// the worker number after scale in.
|
||||
int32 worker_num = 1;
|
||||
// the server number after scale in
|
||||
// the server number after scale in.
|
||||
int32 server_num = 2;
|
||||
// Determine whether the current node is a scale in node.
|
||||
bool is_node_scale_in = 3;
|
||||
}
|
||||
|
||||
// This message is sent to the scheduler to notify the completion of scale out
|
||||
message ScaleOutDoneMessage {
|
||||
string node_id = 1;
|
||||
}
|
||||
|
||||
// This message is sent to the scheduler to notify the completion of scale out
|
||||
message ScaleInDoneMessage {
|
||||
string node_id = 1;
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
MS_LOG(ERROR) << "Start Scheduler node timeout!";
|
||||
return false;
|
||||
}
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||
MS_LOG(INFO) << "Start the scheduler node is successful!";
|
||||
|
||||
return true;
|
||||
|
@ -55,6 +56,7 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
|
|||
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
|
||||
MS_LOG(DEBUG) << "The cluster state:" << node_manager_.GetClusterState();
|
||||
heartbeat_resp_message.set_cluster_state(node_manager_.GetClusterState());
|
||||
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
|
||||
|
@ -77,6 +79,8 @@ void SchedulerNode::InitCommandHandler() {
|
|||
handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
|
||||
handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
|
||||
handlers_[NodeCommand::FETCH_METADATA] = &SchedulerNode::ProcessFetchMetadata;
|
||||
handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone;
|
||||
handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone;
|
||||
}
|
||||
|
||||
void SchedulerNode::CreateTcpServer() {
|
||||
|
@ -135,7 +139,6 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
SendMetadata(client);
|
||||
MS_LOG(INFO) << "Send meta data to" << kvs.first;
|
||||
}
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
@ -177,6 +180,56 @@ void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std:
|
|||
fetch_servers_message.ByteSizeLong());
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
ScaleOutDoneMessage scale_out_done_message;
|
||||
scale_out_done_message.ParseFromArray(data, size);
|
||||
std::string node_id = scale_out_done_message.node_id();
|
||||
MS_LOG(INFO) << "The scheduler process a scale_out_done message from node id:" << node_id;
|
||||
node_manager_.AddScaleOutDoneNode(node_id);
|
||||
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
|
||||
|
||||
if (node_manager_.IsAllNodesScaleOutDone()) {
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
SendScaleOutDone(client);
|
||||
}
|
||||
is_ready_ = true;
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
ScaleInDoneMessage scale_in_done_message;
|
||||
scale_in_done_message.ParseFromArray(data, size);
|
||||
std::string node_id = scale_in_done_message.node_id();
|
||||
MS_LOG(INFO) << "The scheduler process a scale_in_done message from node id:" << node_id;
|
||||
node_manager_.AddScaleInDoneNode(node_id);
|
||||
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
|
||||
|
||||
if (node_manager_.IsAllNodesScaleInDone()) {
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
SendScaleInDone(client);
|
||||
}
|
||||
is_ready_ = true;
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
|
@ -184,6 +237,8 @@ void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
|
|||
|
||||
SendMetadataMessage send_metadata_message;
|
||||
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
|
||||
send_metadata_message.set_worker_num(node_manager_.worker_num());
|
||||
send_metadata_message.set_server_num(node_manager_.server_num());
|
||||
|
||||
*send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
|
||||
|
||||
|
@ -214,6 +269,40 @@ void SchedulerNode::SendFinish(const std::shared_ptr<TcpClient> &client) {
|
|||
<< " the node id:" << node_info_.node_id_ << "is sending finish to workers and servers!";
|
||||
}
|
||||
|
||||
void SchedulerNode::SendScaleOutDone(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
|
||||
|
||||
// The scheduler does not need to bring any data when sending the scale_out_done command
|
||||
std::string resp_data;
|
||||
|
||||
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send scale_out_done timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << "is sending scale_out_done to workers and servers!";
|
||||
}
|
||||
|
||||
void SchedulerNode::SendScaleInDone(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
|
||||
|
||||
// The scheduler does not need to bring any data when sending the scale_in_done command
|
||||
std::string resp_data;
|
||||
|
||||
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send scale_in_done timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << "is sending scale_in_done to workers and servers!";
|
||||
}
|
||||
|
||||
void SchedulerNode::StartUpdateClusterStateTimer() {
|
||||
MS_LOG(WARNING) << "The scheduler start a heartbeat timer!";
|
||||
update_state_thread_ = std::make_unique<std::thread>([&]() {
|
||||
|
@ -330,11 +419,11 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
node_manager_.set_total_node_num(total_worker_num + total_server_num);
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
node_manager_.ResetMetadata();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
leader_scaler_->ScaleOutAsync(client, node_manager_);
|
||||
}
|
||||
node_manager_.ResetMetadata();
|
||||
MS_LOG(INFO) << "Scheduler send scale out successful.";
|
||||
|
||||
nlohmann::json js;
|
||||
|
@ -375,18 +464,24 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> node_ids;
|
||||
status = resp->ParseNodeIdsFromKey(kNodesIds, &node_ids);
|
||||
std::vector<std::string> scale_in_node_ids;
|
||||
status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids);
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids;
|
||||
|
||||
std::unordered_map<std::string, bool> scale_in_nodes;
|
||||
|
||||
int32_t scale_worker_num = 0;
|
||||
int32_t scale_server_num = 0;
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (auto const &val : node_ids) {
|
||||
node_manager_.ResetMetadata();
|
||||
for (auto const &val : scale_in_node_ids) {
|
||||
if (node_infos.count(val)) {
|
||||
scale_in_nodes[val] = true;
|
||||
NodeInfo info = node_infos[val];
|
||||
if (info.node_role_ == NodeRole::WORKER) {
|
||||
scale_worker_num++;
|
||||
|
@ -407,10 +502,13 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
leader_scaler_->ScaleInAsync(client, node_manager_);
|
||||
bool is_node_scale_in = false;
|
||||
if (scale_in_nodes.count(kvs.first)) {
|
||||
is_node_scale_in = true;
|
||||
}
|
||||
leader_scaler_->ScaleInAsync(client, node_manager_, is_node_scale_in);
|
||||
}
|
||||
|
||||
node_manager_.ResetMetadata();
|
||||
nlohmann::json js;
|
||||
js["message"] = "Cluster begin to scale in.";
|
||||
resp->AddRespString(js.dump());
|
||||
|
@ -446,7 +544,7 @@ void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp
|
|||
for (const auto &kvs : node_infos) {
|
||||
std::unordered_map<std::string, std::string> res;
|
||||
res["node_id"] = kvs.second.node_id_;
|
||||
res["rank_id"] = kvs.second.rank_id_;
|
||||
res["rank_id"] = std::to_string(kvs.second.rank_id_);
|
||||
res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
|
||||
js["node_ids"].push_back(res);
|
||||
}
|
||||
|
|
|
@ -79,11 +79,26 @@ class SchedulerNode : public Node {
|
|||
void ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
// After scheduler collects all registered message, it actively sends metadata to workers and servers.
|
||||
// Process scale_out_done messages from workers/servers
|
||||
void ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
// Process scale_in_done messages from workers/servers
|
||||
void ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
// After scheduler collects all registered message, it actively sends finish to the node connected by the client.
|
||||
void SendMetadata(const std::shared_ptr<TcpClient> &client);
|
||||
// // After scheduler collects all finish message, it actively sends finish message to workers and servers.
|
||||
// After scheduler collects all finish message, it actively sends finish to the node connected by the client.
|
||||
void SendFinish(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
// After scheduler collects all scale_out_done message, it actively sends scale_out_done to the node connected by the
|
||||
// client.
|
||||
void SendScaleOutDone(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
// After scheduler collects all scale_in_done message, it actively sends scale_out_done to the node connected by the
|
||||
// client.
|
||||
void SendScaleInDone(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
// Handle the scale out http request, then delegate to the leader scaler to process scale out asynchronously.
|
||||
void ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp);
|
||||
|
||||
|
|
Loading…
Reference in New Issue