!17534 fixed scale out/in

From: @anancds
Reviewed-by: @cristoval,@limingqi107,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2021-06-02 10:46:31 +08:00 committed by Gitee
commit 6a66766a0e
10 changed files with 303 additions and 30 deletions

View File

@ -96,15 +96,57 @@ void AbstractNode::set_ready_for_scale_out() {
}
void AbstractNode::set_ready_for_scale_in() {
Register(client_to_scheduler_);
connected_nodes_.clear();
if (!is_current_node_scale_in_) {
Register(client_to_scheduler_);
connected_nodes_.clear();
} else {
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
node_info_.rank_id_ = UINT_MAX;
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_IN_DONE);
}
}
void AbstractNode::set_scale_out_done() {
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
ScaleOutDoneMessage scale_out_done_message;
scale_out_done_message.set_node_id(node_info_.node_id_);
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
scale_out_done_message.SerializeAsString().data(), scale_out_done_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is send scale_out_done to scheduler!";
}
void AbstractNode::set_scale_in_done() {
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
ScaleInDoneMessage scale_in_done_message;
scale_in_done_message.set_node_id(node_info_.node_id_);
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
scale_in_done_message.SerializeAsString().data(), scale_in_done_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is send scale_in_done to scheduler!";
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_;
}
auto message_meta = std::make_shared<MessageMeta>();
@ -127,7 +169,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
}
for (size_t it = 0; it < rank_ids.size(); ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_;
}
auto message_meta = std::make_shared<MessageMeta>();
@ -152,7 +195,8 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
MS_EXCEPTION_IF_NULL(message);
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_;
}
uint64_t request_id = AddMessageTrack(1);
@ -202,7 +246,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
for (size_t it = 0; it < size; ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_;
}
auto message_meta = std::make_shared<MessageMeta>();
@ -227,7 +272,8 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
size_t size) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_;
}
std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
@ -243,7 +289,8 @@ std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum No
const uint32_t &rank_id, VectorPtr *output) {
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_;
}
receive_callbacks_mutex_.lock();
@ -347,10 +394,13 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
heartbeat_resp_message.ParseFromArray(data, size);
current_cluster_state_ = heartbeat_resp_message.cluster_state();
MS_LOG(DEBUG) << "The current cluster state from heartbeat:" << current_cluster_state_;
if (current_cluster_state_ == ClusterState::CLUSTER_READY) {
is_ready_ = true;
wait_start_cond_.notify_all();
}
if (current_cluster_state_ == ClusterState::CLUSTER_TIMEOUT && on_node_event_message_) {
is_ready_ = true;
wait_start_cond_.notify_all();
@ -390,6 +440,10 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
MS_EXCEPTION_IF_NULL(data);
SendMetadataMessage send_meta_message;
send_meta_message.ParseFromArray(data, size);
worker_num_ = send_meta_message.worker_num();
server_num_ = send_meta_message.server_num();
MS_LOG(WARNING) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_;
nodes_address_.clear();
for (const auto &it : send_meta_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
@ -403,11 +457,13 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
MS_LOG(WARNING) << "Trigger cluster scale out done event.";
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
}
if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_IN_DONE);
}
current_cluster_state_ = ClusterState::CLUSTER_READY;
MS_LOG(INFO) << "The current cluster state:" << current_cluster_state_;
}
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
@ -420,6 +476,26 @@ void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::share
wait_finish_cond_.notify_all();
}
void AbstractNode::ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
server_->SendMessage(conn, meta, Protos::RAW, data, size);
is_ready_ = true;
current_cluster_state_ = ClusterState::CLUSTER_READY;
}
void AbstractNode::ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
server_->SendMessage(conn, meta, Protos::RAW, data, size);
is_ready_ = true;
current_cluster_state_ = ClusterState::CLUSTER_READY;
}
void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
@ -428,8 +504,9 @@ void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::sha
ScaleOutMessage scale_out_message;
scale_out_message.ParseFromArray(data, size);
worker_num_ = scale_out_message.worker_num();
server_num_ = scale_out_message.server_num();
int32_t worker_num = scale_out_message.worker_num();
int32_t server_num = scale_out_message.server_num();
MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
server_->SendMessage(conn, meta, Protos::RAW, data, size);
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_OUT);
@ -445,8 +522,19 @@ void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shar
ScaleInMessage scale_in_message;
scale_in_message.ParseFromArray(data, size);
worker_num_ = scale_in_message.worker_num();
server_num_ = scale_in_message.server_num();
int32_t worker_num = scale_in_message.worker_num();
int32_t server_num = scale_in_message.server_num();
MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
is_current_node_scale_in_ = scale_in_message.is_node_scale_in();
if (is_current_node_scale_in_) {
MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " is a scale in node!";
} else {
MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " is not a scale in node!";
}
server_->SendMessage(conn, meta, Protos::RAW, data, size);
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_IN);
@ -665,6 +753,8 @@ void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
handlers_[NodeCommand::FINISH] = nullptr;
handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr;
handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
}
void AbstractNode::InitServerHandler() {
@ -673,6 +763,9 @@ void AbstractNode::InitServerHandler() {
server_handler_[NodeCommand::SEND_DATA] = nullptr;
server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
server_handler_[NodeCommand::SCALE_IN] = &AbstractNode::ProcessScaleIn;
server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
}
void AbstractNode::InitNodeInfo(const NodeRole &role) {

View File

@ -40,7 +40,8 @@ class AbstractNode : public Node {
server_(nullptr),
server_thread_(nullptr),
worker_num_(-1),
server_num_(-1) {}
server_num_(-1),
is_current_node_scale_in_(false) {}
~AbstractNode() override = default;
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
@ -59,6 +60,12 @@ class AbstractNode : public Node {
// When the business layer finish scale in, it should call this function
void set_ready_for_scale_in();
// Send scale_out_done instructions to the scheduler.
void set_scale_out_done();
// Send scale_in_done instructions to the scheduler.
void set_scale_in_done();
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
@ -99,6 +106,13 @@ class AbstractNode : public Node {
void ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
// The worker/server processes the scale_out_done message from scheduelr
void ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
// The worker/server processes the scale_in_done message from scheduelr
void ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const;
@ -163,6 +177,9 @@ class AbstractNode : public Node {
int32_t worker_num_;
int32_t server_num_;
// Identify whether the current node is a scale in node.
std::atomic<bool> is_current_node_scale_in_;
};
} // namespace core
} // namespace ps

View File

@ -30,7 +30,7 @@ bool CommunicatorBase::SendResponse(const void *rsp_data, size_t rsp_len, std::s
}
void CommunicatorBase::Join() {
if (!running_thread_.joinable()) {
MS_LOG(EXCEPTION) << "The running thread of communicator is not joinable.";
MS_LOG(WARNING) << "The running thread of communicator is not joinable.";
return;
}
running_thread_.join();

View File

@ -36,7 +36,8 @@ void LeaderScaler::ScaleOutAsync(const std::shared_ptr<TcpClient> &client, const
MS_LOG(INFO) << "The scheduler is sending scale out to workers and servers!";
}
void LeaderScaler::ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager) {
void LeaderScaler::ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager,
bool is_node_scale_in) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SCALE_IN);
@ -44,6 +45,7 @@ void LeaderScaler::ScaleInAsync(const std::shared_ptr<TcpClient> &client, const
ScaleInMessage scale_in_message;
scale_in_message.set_worker_num(manager.worker_num());
scale_in_message.set_server_num(manager.server_num());
scale_in_message.set_is_node_scale_in(is_node_scale_in);
if (!node_->SendMessageSync(client, message_meta, Protos::PROTOBUF, scale_in_message.SerializeAsString().data(),
scale_in_message.ByteSizeLong())) {

View File

@ -42,7 +42,7 @@ class LeaderScaler {
// When the scheduler receives the scale out message, it will send this message to the workers and servers.
void ScaleOutAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager);
// When the scheduler receives the scale in message, it will send this message to the workers and servers.
void ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager);
void ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager, bool is_node_scale_in);
private:
// The node_ will only be instantiated with scheduler node.

View File

@ -112,8 +112,8 @@ void NodeManager::UpdateCluster() {
timeout_nodes_info_.clear();
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
if (nodes_info_.count(it->first)) {
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
timeout_nodes_info_[it->first] = nodes_info_[it->first];
}
}
@ -145,10 +145,18 @@ void NodeManager::CheckClusterTimeout() {
void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_done_nodes_id_.insert(node_id); }
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
bool NodeManager::IsAllNodesRegistered() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleOutDone() { return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleInDone() { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }
void NodeManager::UpdateNodeState(const NodeState &state) {
@ -174,6 +182,7 @@ ClusterState NodeManager::GetClusterState() {
void NodeManager::ResetMetadata() {
MS_LOG(WARNING) << "Reset metadata.";
nodes_info_.clear();
heartbeats_.clear();
next_worker_rank_id_ = -1;
next_server_rank_id_ = -1;
}

View File

@ -66,6 +66,11 @@ class NodeManager {
void CheckClusterTimeout();
void AddFinishNode(const std::string &finish_message);
// After the scheduler receives the scale_out_done node, it will save this node.
void AddScaleOutDoneNode(const std::string &node_id);
// After the scheduler receives the scale_in_done node, it will save this node.
void AddScaleInDoneNode(const std::string &node_id);
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
bool IsAllNodesRegistered();
@ -73,6 +78,13 @@ class NodeManager {
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
bool IsAllNodesFinished();
// When workers and servers send a scale_out_done message to the scheduler, the scheduler will collect the number of
// nodes and Determine whether the nodes are equal to total_node_num_.
bool IsAllNodesScaleOutDone();
// When workers and servers send a scale_in_done message to the scheduler, the scheduler will collect the number of
// nodes and Determine whether the nodes are equal to total_node_num_.
bool IsAllNodesScaleInDone();
std::unordered_map<std::string, NodeInfo> &nodes_info();
void set_total_node_num(const int32_t &node_num);
@ -115,6 +127,11 @@ class NodeManager {
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
std::unordered_set<std::string> finish_nodes_id_;
// The scheduler aggregates scale_out_done messages from workers/servers
std::unordered_set<std::string> scale_out_done_nodes_id_;
// The scheduler aggregates scale_in_done messages from workers/servers
std::unordered_set<std::string> scale_in_done_nodes_id_;
// Cluster metadata information can be dynamically changed
std::unique_ptr<ClusterMetadata> meta_data_;

View File

@ -29,8 +29,14 @@ enum NodeCommand {
COLLECTIVE_SEND_DATA = 6;
// The scheduler actively sends metadata to the worker and server
SEND_METADATA = 7;
// This command is used to start scale out
SCALE_OUT = 8;
// This command is used to start scale in
SCALE_IN = 9;
// This command is used to synchronize the scale out status of the cluster
SCALE_OUT_DONE = 10;
// This command is used to synchronize the scale in status of the cluster
SCALE_IN_DONE = 11;
}
enum NodeRole {
@ -111,6 +117,10 @@ message ServersMeta {
message SendMetadataMessage {
repeated ServersMeta servers_meta = 1;
// the current worker number.
int32 worker_num = 2;
// the current server number.
int32 server_num = 3;
}
message FinishMessage {
@ -133,8 +143,20 @@ message ScaleOutMessage {
// The scheduler will broadcast the worker/server numbers after scale in to all nodes.
message ScaleInMessage {
// the worker number after scale in
// the worker number after scale in.
int32 worker_num = 1;
// the server number after scale in
// the server number after scale in.
int32 server_num = 2;
// Determine whether the current node is a scale in node.
bool is_node_scale_in = 3;
}
// This message is sent to the scheduler to notify the completion of scale out
message ScaleOutDoneMessage {
string node_id = 1;
}
// This message is sent to the scheduler to notify the completion of scale out
message ScaleInDoneMessage {
string node_id = 1;
}

View File

@ -37,6 +37,7 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
MS_LOG(ERROR) << "Start Scheduler node timeout!";
return false;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
MS_LOG(INFO) << "Start the scheduler node is successful!";
return true;
@ -55,6 +56,7 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
HeartbeatRespMessage heartbeat_resp_message;
MS_LOG(DEBUG) << "The cluster state:" << node_manager_.GetClusterState();
heartbeat_resp_message.set_cluster_state(node_manager_.GetClusterState());
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
@ -77,6 +79,8 @@ void SchedulerNode::InitCommandHandler() {
handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
handlers_[NodeCommand::FETCH_METADATA] = &SchedulerNode::ProcessFetchMetadata;
handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone;
handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone;
}
void SchedulerNode::CreateTcpServer() {
@ -135,7 +139,6 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
SendMetadata(client);
MS_LOG(INFO) << "Send meta data to" << kvs.first;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
wait_start_cond_.notify_all();
}
}
@ -177,6 +180,56 @@ void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std:
fetch_servers_message.ByteSizeLong());
}
void SchedulerNode::ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
ScaleOutDoneMessage scale_out_done_message;
scale_out_done_message.ParseFromArray(data, size);
std::string node_id = scale_out_done_message.node_id();
MS_LOG(INFO) << "The scheduler process a scale_out_done message from node id:" << node_id;
node_manager_.AddScaleOutDoneNode(node_id);
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
if (node_manager_.IsAllNodesScaleOutDone()) {
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
SendScaleOutDone(client);
}
is_ready_ = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
}
}
void SchedulerNode::ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
ScaleInDoneMessage scale_in_done_message;
scale_in_done_message.ParseFromArray(data, size);
std::string node_id = scale_in_done_message.node_id();
MS_LOG(INFO) << "The scheduler process a scale_in_done message from node id:" << node_id;
node_manager_.AddScaleInDoneNode(node_id);
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
if (node_manager_.IsAllNodesScaleInDone()) {
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
SendScaleInDone(client);
}
is_ready_ = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
}
}
void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
@ -184,6 +237,8 @@ void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
SendMetadataMessage send_metadata_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
send_metadata_message.set_worker_num(node_manager_.worker_num());
send_metadata_message.set_server_num(node_manager_.server_num());
*send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
@ -214,6 +269,40 @@ void SchedulerNode::SendFinish(const std::shared_ptr<TcpClient> &client) {
<< " the node id:" << node_info_.node_id_ << "is sending finish to workers and servers!";
}
void SchedulerNode::SendScaleOutDone(const std::shared_ptr<TcpClient> &client) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
// The scheduler does not need to bring any data when sending the scale_out_done command
std::string resp_data;
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send scale_out_done timeout!";
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is sending scale_out_done to workers and servers!";
}
void SchedulerNode::SendScaleInDone(const std::shared_ptr<TcpClient> &client) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
// The scheduler does not need to bring any data when sending the scale_in_done command
std::string resp_data;
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send scale_in_done timeout!";
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is sending scale_in_done to workers and servers!";
}
void SchedulerNode::StartUpdateClusterStateTimer() {
MS_LOG(WARNING) << "The scheduler start a heartbeat timer!";
update_state_thread_ = std::make_unique<std::thread>([&]() {
@ -330,11 +419,11 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
node_manager_.set_total_node_num(total_worker_num + total_server_num);
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
auto node_infos = node_manager_.nodes_info();
node_manager_.ResetMetadata();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
leader_scaler_->ScaleOutAsync(client, node_manager_);
}
node_manager_.ResetMetadata();
MS_LOG(INFO) << "Scheduler send scale out successful.";
nlohmann::json js;
@ -375,18 +464,24 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
return;
}
std::vector<std::string> node_ids;
status = resp->ParseNodeIdsFromKey(kNodesIds, &node_ids);
std::vector<std::string> scale_in_node_ids;
status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids);
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids;
std::unordered_map<std::string, bool> scale_in_nodes;
int32_t scale_worker_num = 0;
int32_t scale_server_num = 0;
auto node_infos = node_manager_.nodes_info();
for (auto const &val : node_ids) {
node_manager_.ResetMetadata();
for (auto const &val : scale_in_node_ids) {
if (node_infos.count(val)) {
scale_in_nodes[val] = true;
NodeInfo info = node_infos[val];
if (info.node_role_ == NodeRole::WORKER) {
scale_worker_num++;
@ -407,10 +502,13 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
leader_scaler_->ScaleInAsync(client, node_manager_);
bool is_node_scale_in = false;
if (scale_in_nodes.count(kvs.first)) {
is_node_scale_in = true;
}
leader_scaler_->ScaleInAsync(client, node_manager_, is_node_scale_in);
}
node_manager_.ResetMetadata();
nlohmann::json js;
js["message"] = "Cluster begin to scale in.";
resp->AddRespString(js.dump());
@ -446,7 +544,7 @@ void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp
for (const auto &kvs : node_infos) {
std::unordered_map<std::string, std::string> res;
res["node_id"] = kvs.second.node_id_;
res["rank_id"] = kvs.second.rank_id_;
res["rank_id"] = std::to_string(kvs.second.rank_id_);
res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
js["node_ids"].push_back(res);
}

View File

@ -79,11 +79,26 @@ class SchedulerNode : public Node {
void ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
// After scheduler collects all registered message, it actively sends metadata to workers and servers.
// Process scale_out_done messages from workers/servers
void ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
// Process scale_in_done messages from workers/servers
void ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
// After scheduler collects all registered message, it actively sends finish to the node connected by the client.
void SendMetadata(const std::shared_ptr<TcpClient> &client);
// // After scheduler collects all finish message, it actively sends finish message to workers and servers.
// After scheduler collects all finish message, it actively sends finish to the node connected by the client.
void SendFinish(const std::shared_ptr<TcpClient> &client);
// After scheduler collects all scale_out_done message, it actively sends scale_out_done to the node connected by the
// client.
void SendScaleOutDone(const std::shared_ptr<TcpClient> &client);
// After scheduler collects all scale_in_done message, it actively sends scale_out_done to the node connected by the
// client.
void SendScaleInDone(const std::shared_ptr<TcpClient> &client);
// Handle the scale out http request, then delegate to the leader scaler to process scale out asynchronously.
void ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp);