forked from mindspore-Ecosystem/mindspore
realize scale out rollback
This commit is contained in:
parent
b554d27a2d
commit
2f068a4218
|
@ -400,6 +400,8 @@ void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunic
|
|||
std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
|
||||
communicator->RegisterMsgCallBack("syncAfterRecover",
|
||||
std::bind(&Server::HandleSyncAfterRecoveryRequest, this, std::placeholders::_1));
|
||||
communicator->RegisterMsgCallBack("queryNodeScaleState",
|
||||
std::bind(&Server::HandleQueryNodeScaleStateRequest, this, std::placeholders::_1));
|
||||
}
|
||||
|
||||
void Server::InitExecutor() {
|
||||
|
@ -700,6 +702,22 @@ void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::Mess
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Server::HandleQueryNodeScaleStateRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
|
||||
nlohmann::basic_json<std::map, std::vector, std::string> response;
|
||||
response["node_scale_state"] = server_node_->node_scale_state_str();
|
||||
|
||||
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
|
||||
if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
|
||||
MS_LOG(ERROR) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Response query node scale state success, response data is " << response.dump().c_str();
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -176,6 +176,8 @@ class BACKEND_EXPORT Server {
|
|||
// Synchronize after recovery is completed to ensure consistency.
|
||||
void HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
void HandleQueryNodeScaleStateRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// The server node is initialized in Server.
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
|
||||
|
|
|
@ -248,6 +248,7 @@ using BarrierBeforeScaleIn = std::function<void(void)>;
|
|||
// is done.
|
||||
using HandlerAfterScaleOut = std::function<void(void)>;
|
||||
using HandlerAfterScaleIn = std::function<void(void)>;
|
||||
using HandlerAfterScaleOutRollback = std::function<void(void)>;
|
||||
|
||||
constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
|
||||
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";
|
||||
|
|
|
@ -870,6 +870,12 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
|
|||
OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
|
||||
}
|
||||
|
||||
if (cancelSafeModeFn_ && current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT_ROLLBACK) {
|
||||
MS_LOG(WARNING) << "Trigger cluster scale out rollback done event.";
|
||||
OnEventCallback(ClusterEvent::CLUSTER_SCALE_OUT_ROLLBACK_DONE);
|
||||
cancelSafeModeFn_();
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||
connected_nodes_.clear();
|
||||
|
||||
|
@ -934,6 +940,27 @@ void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, cons
|
|||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
||||
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
||||
MS_LOG(WARNING) << "Server response message failed.";
|
||||
}
|
||||
|
||||
UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT_ROLLBACK);
|
||||
|
||||
MS_LOG(INFO) << "[Scale out rollback]: begin to set scale out rollback.";
|
||||
Register(client_to_scheduler_);
|
||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||
connected_nodes_.clear();
|
||||
|
||||
MS_LOG(INFO) << "The node begin to start scale out rollback.";
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
|
@ -1322,6 +1349,7 @@ void AbstractNode::InitServerHandler() {
|
|||
server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
|
||||
server_handler_[NodeCommand::SCHEDULER_RECOVERY] = &AbstractNode::ProcessSchedulerRecovery;
|
||||
server_handler_[NodeCommand::PREPARE_BUILDING_NETWORK] = &AbstractNode::ProcessPrepareBuildingNetwork;
|
||||
server_handler_[NodeCommand::SCALE_OUT_ROLLBACK] = &AbstractNode::ProcessScaleOutRollback;
|
||||
}
|
||||
|
||||
void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
||||
|
@ -1481,6 +1509,11 @@ void AbstractNode::ProcessPrepareBuildingNetwork(const std::shared_ptr<TcpConnec
|
|||
MS_LOG(INFO) << "prepare for building network success.";
|
||||
}
|
||||
}
|
||||
|
||||
std::string AbstractNode::node_scale_state_str() {
|
||||
MS_EXCEPTION_IF_NULL(follower_scaler_);
|
||||
return follower_scaler_->GetNodeScaleStateStr();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -127,6 +127,8 @@ class BACKEND_EXPORT AbstractNode : public Node {
|
|||
void RegisterFollowerScalerHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler);
|
||||
void RegisterFollowerScalerHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler);
|
||||
|
||||
std::string node_scale_state_str();
|
||||
|
||||
PersistentState persistent_state() const;
|
||||
void set_persistent_state(PersistentState persistent_state);
|
||||
|
||||
|
@ -259,6 +261,9 @@ class BACKEND_EXPORT AbstractNode : public Node {
|
|||
void OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data);
|
||||
void ConnectToScheduler();
|
||||
|
||||
void ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &, const void *data, size_t size);
|
||||
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||
std::shared_ptr<TcpClient> client_to_scheduler_;
|
||||
|
|
|
@ -86,16 +86,17 @@ constexpr char kLibeventLogPrefix[] = "[libevent log]:";
|
|||
|
||||
// Find the corresponding string style of cluster state through the subscript of the enum:ClusterState
|
||||
const std::vector<std::string> kClusterState = {
|
||||
"ClUSTER_STARTING", // Initialization state when the cluster is just started.
|
||||
"CLUSTER_READY", // The state after all nodes are successfully registered.
|
||||
"CLUSTER_EXIT", // The state after the cluster exits successfully.
|
||||
"NODE_TIMEOUT", // When a node has a heartbeat timeout
|
||||
"CLUSTER_SCALE_OUT", // When the cluster is scale out.
|
||||
"CLUSTER_SCALE_IN", // When the cluster is scale in.
|
||||
"CLUSTER_NEW_INSTANCE", // When the cluster is doing NEW_INSTANCE.
|
||||
"CLUSTER_ENABLE_FLS", // When the cluster is doing ENABLE_FLS.
|
||||
"CLUSTER_DISABLE_FLS", // When the cluster is doing DISABLE_FLS.
|
||||
"CLUSTER_SCHEDULER_RECOVERY" // When the cluster is doing SCHEDULER_RECOVERY.
|
||||
"ClUSTER_STARTING", // Initialization state when the cluster is just started.
|
||||
"CLUSTER_READY", // The state after all nodes are successfully registered.
|
||||
"CLUSTER_EXIT", // The state after the cluster exits successfully.
|
||||
"NODE_TIMEOUT", // When a node has a heartbeat timeout
|
||||
"CLUSTER_SCALE_OUT", // When the cluster is scale out.
|
||||
"CLUSTER_SCALE_IN", // When the cluster is scale in.
|
||||
"CLUSTER_NEW_INSTANCE", // When the cluster is doing NEW_INSTANCE.
|
||||
"CLUSTER_ENABLE_FLS", // When the cluster is doing ENABLE_FLS.
|
||||
"CLUSTER_DISABLE_FLS", // When the cluster is doing DISABLE_FLS.
|
||||
"CLUSTER_SCHEDULER_RECOVERY", // When the cluster is doing SCHEDULER_RECOVERY.
|
||||
"CLUSTER_SCALE_OUT_ROLLBACK", // When the cluster is scale out rollback.
|
||||
};
|
||||
|
||||
class CommUtil {
|
||||
|
|
|
@ -61,6 +61,7 @@ enum class TcpUserCommand {
|
|||
kExchangeKeys,
|
||||
kGetKeys,
|
||||
kGetOneDeviceMeta,
|
||||
kQueryNodeScaleState
|
||||
};
|
||||
|
||||
// CommunicatorBase is used to receive request and send response for server.
|
||||
|
|
|
@ -63,7 +63,8 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
|||
{TcpUserCommand::kQueryInstance, "queryInstance"},
|
||||
{TcpUserCommand::kEnableFLS, "enableFLS"},
|
||||
{TcpUserCommand::kDisableFLS, "disableFLS"},
|
||||
{TcpUserCommand::kSyncAfterRecover, "syncAfterRecover"}};
|
||||
{TcpUserCommand::kSyncAfterRecover, "syncAfterRecover"},
|
||||
{TcpUserCommand::kQueryNodeScaleState, "queryNodeScaleState"}};
|
||||
|
||||
class TcpCommunicator : public CommunicatorBase {
|
||||
public:
|
||||
|
|
|
@ -72,6 +72,18 @@ FollowerScaler::FollowerScaler(AbstractNode *node)
|
|||
ProcessAfterScaleIn();
|
||||
}
|
||||
});
|
||||
|
||||
process_after_scale_out_rollback_thread_ = std::thread([&]() {
|
||||
while (running_.load()) {
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scale_out_cv_.wait(
|
||||
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kRollback; });
|
||||
if (!running_.load()) {
|
||||
break;
|
||||
}
|
||||
ProcessAfterScaleOutRollback();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
FollowerScaler::~FollowerScaler() {
|
||||
|
@ -90,6 +102,9 @@ FollowerScaler::~FollowerScaler() {
|
|||
if (process_after_scale_in_thread_.joinable()) {
|
||||
process_after_scale_in_thread_.join();
|
||||
}
|
||||
if (process_after_scale_out_rollback_thread_.joinable()) {
|
||||
process_after_scale_out_rollback_thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterScaleEventCallbacks() {
|
||||
|
@ -118,11 +133,19 @@ void FollowerScaler::RegisterScaleEventCallbacks() {
|
|||
scale_in_cv_.notify_all();
|
||||
};
|
||||
|
||||
scale_out_rollback_done_event_callback_ = [&]() -> void {
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scaling_state_ = NodeScaleState::kRollback;
|
||||
scale_out_cv_.notify_all();
|
||||
};
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_OUT, ready_for_scale_out_event_callback_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_IN, ready_for_scale_in_event_callback_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_OUT_DONE, scale_out_done_event_callback_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_IN_DONE, scale_in_done_event_callback_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_OUT_ROLLBACK_DONE,
|
||||
scale_out_rollback_done_event_callback_);
|
||||
}
|
||||
|
||||
void FollowerScaler::ProcessBeforeScaleOut() {
|
||||
|
@ -171,6 +194,11 @@ void FollowerScaler::ProcessAfterScaleIn() {
|
|||
node_->set_scale_in_done();
|
||||
}
|
||||
|
||||
void FollowerScaler::ProcessAfterScaleOutRollback() {
|
||||
MS_LOG(INFO) << "Scaling out rollback operation is done. Do scaling out rollback for this node.";
|
||||
scaling_state_ = NodeScaleState::kNormal;
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier) {
|
||||
(void)barriers_before_scale_out_.try_emplace(module, barrier);
|
||||
}
|
||||
|
@ -186,6 +214,21 @@ void FollowerScaler::RegisterHandlerAfterScaleOut(const std::string &module, con
|
|||
void FollowerScaler::RegisterHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler) {
|
||||
(void)handlers_after_scale_in_.try_emplace(module, handler);
|
||||
}
|
||||
|
||||
std::string FollowerScaler::GetNodeScaleStateStr() {
|
||||
switch (scaling_state_) {
|
||||
case NodeScaleState::kNormal:
|
||||
return "kNormal";
|
||||
case NodeScaleState::kPreparing:
|
||||
return "kPreparing";
|
||||
case NodeScaleState::kWaiting:
|
||||
return "kWaiting";
|
||||
case NodeScaleState::kScaling:
|
||||
return "kScaling";
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "scale_state is not supported.";
|
||||
}
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,7 +44,9 @@ enum class NodeScaleState {
|
|||
kWaiting,
|
||||
// Server/worker node will switch to this state after scheduler's scaling out/in operation is done.
|
||||
// When in this state, server/worker node can't send/receive messages.
|
||||
kScaling
|
||||
kScaling,
|
||||
// This state means the server/worker node is begin to start scaling operations.
|
||||
kRollback,
|
||||
};
|
||||
|
||||
// The class helps worker/server node to elastic scale while running a training job. In this class, the scaling events
|
||||
|
@ -66,6 +68,8 @@ class FollowerScaler {
|
|||
void ProcessAfterScaleOut();
|
||||
void ProcessAfterScaleIn();
|
||||
|
||||
void ProcessAfterScaleOutRollback();
|
||||
|
||||
void RegisterBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier);
|
||||
void RegisterBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier);
|
||||
void RegisterHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler);
|
||||
|
@ -74,6 +78,8 @@ class FollowerScaler {
|
|||
// Register the scaling event callbacks to the node.
|
||||
void RegisterScaleEventCallbacks();
|
||||
|
||||
std::string GetNodeScaleStateStr();
|
||||
|
||||
private:
|
||||
AbstractNode *node_;
|
||||
|
||||
|
@ -87,6 +93,7 @@ class FollowerScaler {
|
|||
std::thread process_before_scale_in_thread_;
|
||||
std::thread process_after_scale_out_thread_;
|
||||
std::thread process_after_scale_in_thread_;
|
||||
std::thread process_after_scale_out_rollback_thread_;
|
||||
|
||||
// Variables for signals of scaling out/in operations.
|
||||
std::mutex scale_out_mtx_;
|
||||
|
@ -104,6 +111,7 @@ class FollowerScaler {
|
|||
std::function<void(void)> ready_for_scale_in_event_callback_;
|
||||
std::function<void(void)> scale_out_done_event_callback_;
|
||||
std::function<void(void)> scale_in_done_event_callback_;
|
||||
std::function<void(void)> scale_out_rollback_done_event_callback_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -97,6 +97,26 @@ void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client,
|
|||
|
||||
MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!";
|
||||
}
|
||||
|
||||
void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &client, const NodeManager &,
|
||||
const uint64_t &request_id, const NodeInfo &node_info) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(node_);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
MS_EXCEPTION_IF_NULL(message_meta);
|
||||
message_meta->set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta->set_request_id(request_id);
|
||||
message_meta->set_rank_id(node_info.rank_id_);
|
||||
message_meta->set_role(node_info.node_role_);
|
||||
message_meta->set_user_cmd(static_cast<int32_t>(TcpUserCommand::kQueryNodeScaleState));
|
||||
|
||||
std::string res;
|
||||
if (!client->SendMessage(message_meta, Protos::RAW, res.data(), res.length())) {
|
||||
MS_LOG(WARNING) << "Send query node scale state timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The scheduler is sending query node scale state to workers and servers!";
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,6 +53,9 @@ class InstanceManager {
|
|||
void DisableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager, const uint64_t &request_id,
|
||||
const NodeInfo &node_info);
|
||||
|
||||
void QueryNodeScaleState(const std::shared_ptr<TcpClient> &client, const NodeManager &, const uint64_t &request_id,
|
||||
const NodeInfo &node_info);
|
||||
|
||||
private:
|
||||
// The node_ will only be instantiated with scheduler node.
|
||||
Node *const node_;
|
||||
|
|
|
@ -58,6 +58,21 @@ void LeaderScaler::ScaleInAsync(const std::shared_ptr<TcpClient> &client, const
|
|||
|
||||
MS_LOG(INFO) << "The scheduler is sending scale in to workers and servers!";
|
||||
}
|
||||
|
||||
void LeaderScaler::ScaleOutRollbackAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(node_);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
MS_EXCEPTION_IF_NULL(message_meta);
|
||||
message_meta->set_cmd(NodeCommand::SCALE_OUT_ROLLBACK);
|
||||
|
||||
std::string data = "";
|
||||
if (!node_->SendMessageSync(client, message_meta, Protos::PROTOBUF, data.data(), data.size())) {
|
||||
MS_LOG(WARNING) << "Send scale out rollback timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The scheduler is sending scale out rollback to workers and servers!";
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,8 @@ class LeaderScaler {
|
|||
// When the scheduler receives the scale in message, it will send this message to the workers and servers.
|
||||
void ScaleInAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager, bool is_node_scale_in);
|
||||
|
||||
void ScaleOutRollbackAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &manager);
|
||||
|
||||
private:
|
||||
// The node_ will only be instantiated with scheduler node.
|
||||
Node *const node_;
|
||||
|
|
|
@ -36,6 +36,7 @@ enum class ClusterEvent {
|
|||
ON_PREPARE_PERSIST = 7,
|
||||
ON_BEGIN_PERSIST = 8,
|
||||
ON_SEND_META_DATA = 9,
|
||||
CLUSTER_SCALE_OUT_ROLLBACK_DONE = 10
|
||||
};
|
||||
|
||||
struct NodeInfo {
|
||||
|
|
|
@ -313,8 +313,12 @@ void NodeManager::UpdateNodeState(const NodeState &state) {
|
|||
|
||||
void NodeManager::UpdateClusterState(const ClusterState &state) {
|
||||
std::lock_guard<std::mutex> lk(cluster_mutex_);
|
||||
MS_LOG(INFO) << "[state]: Scheduler change state from:" << CommUtil::ClusterStateToString(cluster_state_) << " to "
|
||||
<< CommUtil::ClusterStateToString(state);
|
||||
std::string state_str = CommUtil::ClusterStateToString(state);
|
||||
if (state_str.empty() || state == cluster_state_) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(cluster_state_) << " to "
|
||||
<< state_str;
|
||||
cluster_state_ = state;
|
||||
}
|
||||
|
||||
|
|
|
@ -63,6 +63,8 @@ enum NodeCommand {
|
|||
// Query the ready status to finish transform graph of computed node,
|
||||
// used in disaster recovery mode to prevent timeout of waiting for graph transformation.
|
||||
QUERY_FINISH_TRANSFORM = 23;
|
||||
// This command is used to start scale out rollback
|
||||
SCALE_OUT_ROLLBACK = 24;
|
||||
}
|
||||
|
||||
enum NodeRole {
|
||||
|
@ -153,6 +155,7 @@ enum ClusterState {
|
|||
CLUSTER_ENABLE_FLS = 7;
|
||||
CLUSTER_DISABLE_FLS = 8;
|
||||
CLUSTER_SCHEDULER_RECOVERY = 9;
|
||||
CLUSTER_SCALE_OUT_ROLLBACK = 10;
|
||||
}
|
||||
|
||||
message HeartbeatRespMessage {
|
||||
|
|
|
@ -819,7 +819,7 @@ bool SchedulerNode::Finish(const uint32_t &) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessScaleoutRollback(const std::shared_ptr<HttpMessageHandler> &resp) {
|
||||
void SchedulerNode::ProcessScaleOutRollback(const std::shared_ptr<HttpMessageHandler> &resp) {
|
||||
MS_EXCEPTION_IF_NULL(resp);
|
||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_SCALE_OUT) {
|
||||
|
@ -828,8 +828,21 @@ void SchedulerNode::ProcessScaleoutRollback(const std::shared_ptr<HttpMessageHan
|
|||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
// set the last worker num and last server num
|
||||
|
||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_OUT_ROLLBACK) {
|
||||
std::string message =
|
||||
"The cluster state is already in CLUSTER_SCALE_OUT_ROLLBACK, does not need to rollback again.";
|
||||
ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message);
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!QueryNodeScaleState(resp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
||||
// set the last worker num and last server num and start cluster scale out rollback
|
||||
node_manager_.set_worker_num(clusterConfig.initial_worker_num);
|
||||
node_manager_.set_server_num(clusterConfig.initial_server_num);
|
||||
node_manager_.set_total_node_num(clusterConfig.initial_total_node_num);
|
||||
|
@ -837,17 +850,17 @@ void SchedulerNode::ProcessScaleoutRollback(const std::shared_ptr<HttpMessageHan
|
|||
MS_LOG(INFO) << "After scale out rollback, the last worker num:" << clusterConfig.initial_worker_num
|
||||
<< ", the last server num:" << clusterConfig.initial_server_num;
|
||||
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT_ROLLBACK);
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
node_manager_.ResetMetadata();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(leader_scaler_);
|
||||
leader_scaler_->ScaleOutAsync(client, node_manager_);
|
||||
leader_scaler_->ScaleOutRollbackAsync(client, node_manager_);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Scheduler send scale out rollback successful.";
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
|
||||
nlohmann::json js;
|
||||
js["message"] = "Cluster scale out rollback success.";
|
||||
js["code"] = kSuccessCode;
|
||||
|
@ -858,6 +871,57 @@ void SchedulerNode::ProcessScaleoutRollback(const std::shared_ptr<HttpMessageHan
|
|||
resp->SendResponse();
|
||||
}
|
||||
|
||||
bool SchedulerNode::QueryNodeScaleState(const std::shared_ptr<HttpMessageHandler> &resp) {
|
||||
ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
||||
uint64_t request_id = AddMessageTrack(clusterConfig.initial_server_num);
|
||||
std::unordered_map<uint32_t, VectorPtr> outputs;
|
||||
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
outputs = receive_messages_[request_id];
|
||||
(void)receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
if (kvs.second.node_role_ == NodeRole::SERVER) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(instance_manager_);
|
||||
instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, node_info_);
|
||||
}
|
||||
}
|
||||
|
||||
bool res = Wait(request_id);
|
||||
if (!res) {
|
||||
std::string message = "The query node scale state is timeout.";
|
||||
RequestProcessResult result(RequestProcessResultCode::kSystemError, message);
|
||||
MS_LOG(WARNING) << message;
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, result);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto &output : outputs) {
|
||||
std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
|
||||
nlohmann::json dataJson = nlohmann::json::parse(data);
|
||||
if (dataJson["node_scale_state"] != "kWaiting") {
|
||||
res = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
std::string message =
|
||||
"Cluster servers are ready for scaling rollback. Please process server scaling rollback later.";
|
||||
RequestProcessResult result(RequestProcessResultCode::kSystemError, message);
|
||||
MS_LOG(WARNING) << message;
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, result);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessScaleOut(const std::shared_ptr<HttpMessageHandler> &resp) {
|
||||
MS_EXCEPTION_IF_NULL(resp);
|
||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||
|
@ -1386,7 +1450,7 @@ void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t
|
|||
callbacks_["/disableFLS"] = disable_fls;
|
||||
(void)http_server_->RegisterRoute("/disableFLS", &callbacks_["/disableFLS"]);
|
||||
|
||||
OnRequestReceive scale_out_rollback = std::bind(&SchedulerNode::ProcessScaleoutRollback, this, std::placeholders::_1);
|
||||
OnRequestReceive scale_out_rollback = std::bind(&SchedulerNode::ProcessScaleOutRollback, this, std::placeholders::_1);
|
||||
callbacks_["/scaleoutRollback"] = scale_out_rollback;
|
||||
(void)http_server_->RegisterRoute("/scaleoutRollback", &callbacks_["/scaleoutRollback"]);
|
||||
|
||||
|
|
|
@ -180,7 +180,9 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
|
||||
// Handle the scale out rollback http request, then delegate to the leader scaler to
|
||||
// process scale out rollback asynchronously.
|
||||
void ProcessScaleoutRollback(const std::shared_ptr<HttpMessageHandler> &resp);
|
||||
void ProcessScaleOutRollback(const std::shared_ptr<HttpMessageHandler> &resp);
|
||||
|
||||
bool QueryNodeScaleState(const std::shared_ptr<HttpMessageHandler> &resp);
|
||||
|
||||
// check whether the cluster is in the ready state.
|
||||
RequestProcessResult CheckIfClusterReady();
|
||||
|
|
|
@ -74,7 +74,7 @@ metrics_file_path = args.metrics_file_path
|
|||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
session = requests.Session()
|
||||
base_url = http_type + "://" + ip + ":" + str(port) + "/"
|
||||
BASE_URL = '{}://{}:{}/'.format(http_type, ip, str(port))
|
||||
|
||||
|
||||
def call_scale():
|
||||
|
@ -118,7 +118,7 @@ def call_scaleout(scale_out_server_num, scale_out_worker_num=0):
|
|||
"""
|
||||
call scaleout
|
||||
"""
|
||||
url = base_url + Restful.SCALE_OUT.value
|
||||
url = BASE_URL + Restful.SCALE_OUT.value
|
||||
data = {"server_num": scale_out_server_num, "worker_num": scale_out_worker_num}
|
||||
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
|
||||
res_json = json.loads(res.text)
|
||||
|
@ -133,7 +133,7 @@ def call_scaleout_rollback():
|
|||
"""
|
||||
call scaleout rollback
|
||||
"""
|
||||
url = base_url + Restful.SCALE_OUT_ROLLBACK.value
|
||||
url = BASE_URL + Restful.SCALE_OUT_ROLLBACK.value
|
||||
res = session.get(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
|
@ -148,7 +148,7 @@ def call_scalein(scale_in_node_ids):
|
|||
if not scale_in_node_ids:
|
||||
return process_self_define_json(Status.FAILED.value, "error. node ids is empty.")
|
||||
|
||||
url = base_url + Restful.SCALE_IN.value
|
||||
url = BASE_URL + Restful.SCALE_IN.value
|
||||
data = {"node_ids": scale_in_node_ids}
|
||||
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
|
||||
res_json = json.loads(res.text)
|
||||
|
@ -162,7 +162,7 @@ def call_nodes():
|
|||
"""
|
||||
get nodes info
|
||||
"""
|
||||
url = base_url + Restful.NODES.value
|
||||
url = BASE_URL + Restful.NODES.value
|
||||
res = session.get(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
|
@ -224,7 +224,7 @@ def call_new_instance():
|
|||
instance_param_list = instance_param.split(sep=",")
|
||||
instance_param_json_obj = {}
|
||||
|
||||
url = base_url + Restful.NEW_INSTANCE.value
|
||||
url = BASE_URL + Restful.NEW_INSTANCE.value
|
||||
for cur in instance_param_list:
|
||||
pair = cur.split(sep="=")
|
||||
instance_param_json_obj[pair[0]] = float(pair[1])
|
||||
|
@ -241,7 +241,7 @@ def call_query_instance():
|
|||
"""
|
||||
query cluster instance
|
||||
"""
|
||||
url = base_url + Restful.QUERY_INSTANCE.value
|
||||
url = BASE_URL + Restful.QUERY_INSTANCE.value
|
||||
res = session.post(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
|
@ -253,7 +253,7 @@ def call_enable_fls():
|
|||
"""
|
||||
enable cluster fls
|
||||
"""
|
||||
url = base_url + Restful.ENABLE_FLS.value
|
||||
url = BASE_URL + Restful.ENABLE_FLS.value
|
||||
res = session.post(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
|
@ -265,7 +265,7 @@ def call_disable_fls():
|
|||
"""
|
||||
disable cluster fls
|
||||
"""
|
||||
url = base_url + Restful.DISABLE_FLS.value
|
||||
url = BASE_URL + Restful.DISABLE_FLS.value
|
||||
res = session.post(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
|
@ -277,7 +277,7 @@ def call_state():
|
|||
"""
|
||||
get cluster state
|
||||
"""
|
||||
url = base_url + Restful.STATE.value
|
||||
url = BASE_URL + Restful.STATE.value
|
||||
res = session.get(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
|
|
Loading…
Reference in New Issue