diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index ec26f1578bb..0fe9577df8d 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -122,6 +122,7 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string & MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); if (server_node_->rank_id() == kLeaderServerRank) { + std::unique_lock lock(iter_move_mtx_); if (!BroadcastPrepareForNextIterRequest(iteration_num_, is_last_iter_valid, reason)) { MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed."; return; @@ -432,7 +433,6 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr lock(iter_move_mtx_); NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req; (void)notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); const auto &rank = notify_leader_to_next_iter_req.rank(); @@ -446,6 +446,7 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr lock(iter_move_mtx_); if (!BroadcastPrepareForNextIterRequest(iter_num, is_last_iter_valid, reason)) { MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed."; return; @@ -717,10 +718,6 @@ void Iteration::EndLastIter() { instance_state_ = InstanceState::kFinish; } - std::unique_lock lock(pinned_mtx_); - pinned_iter_num_ = 0; - lock.unlock(); - SetIterationEnd(); if (!SummarizeIteration()) { MS_LOG(WARNING) << "Summarizing iteration data failed."; diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index 9a37acff9dd..cb38e11cc3e 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -153,6 +153,7 @@ void Server::InitServerContext() { void Server::InitCluster() { server_node_ = std::make_shared(); MS_EXCEPTION_IF_NULL(server_node_); + server_node_->SetCancelSafeModeCallBack([this]() -> void { CancelSafeMode(); }); task_executor_ = std::make_shared(kExecutorThreadPoolSize); MS_EXCEPTION_IF_NULL(task_executor_); if (!InitCommunicatorWithServer()) { diff --git a/mindspore/ccsrc/fl/worker/fl_worker.cc b/mindspore/ccsrc/fl/worker/fl_worker.cc index d61914c4d15..74da23497cd 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.cc +++ b/mindspore/ccsrc/fl/worker/fl_worker.cc @@ -52,6 +52,7 @@ void FLWorker::Run() { worker_node_ = std::make_shared(); MS_EXCEPTION_IF_NULL(worker_node_); + worker_node_->SetCancelSafeModeCallBack([this]() -> void { safemode_ = false; }); worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() { Finalize(); running_ = false; diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 2f7d063dba4..3ffbe30647a 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -629,10 +629,17 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ << " begin send heartbeat to the scheduler!"; heart_beat_thread_ = std::make_unique([&]() { + uint32_t connect_interval = PSContext::instance()->cluster_config().connect_interval; + uint32_t heartbeat_interval = PSContext::instance()->cluster_config().heartbeat_interval * 1000; + uint32_t reconnect_interval = 0; + if (heartbeat_interval > connect_interval) { + MS_LOG(WARNING) << "heartbeat_interval [" << heartbeat_interval << "] is larger than connect_interval [" + << connect_interval << "], reset connect_interval to " << heartbeat_interval; + } while (!is_finish_.load()) { if (!Heartbeat(client)) { MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) - << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; + << ", the node id is:" << node_info_.node_id_ << " Send heartbeat failed!"; if (CheckSchedulerTimeout()) { MS_LOG(WARNING) << "Scheduler is Timeout, please recovery."; } @@ -640,7 +647,17 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) UpdateSchedulerTime(); } - std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval)); + if (!is_already_finished_ && (client->connection_status() == -1)) { + if (reconnect_interval > connect_interval) { + MS_LOG(WARNING) << "Connection to Scheduler is disconnected, try to reconnect."; + reconnect_interval = 0; + ConnectToScheduler(); + } else { + reconnect_interval += heartbeat_interval; + } + } + + std::this_thread::sleep_for(std::chrono::milliseconds(heartbeat_interval)); } }); MS_EXCEPTION_IF_NULL(heart_beat_thread_); @@ -649,6 +666,9 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) bool AbstractNode::Heartbeat(const std::shared_ptr &client) { MS_EXCEPTION_IF_NULL(client); + if (client->connection_status() != 1) { + return false; + } auto meta = std::make_shared(); MS_EXCEPTION_IF_NULL(meta); meta->set_cmd(NodeCommand::HEARTBEAT); @@ -693,7 +713,13 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr &meta HeartbeatRespMessage heartbeat_resp_message; CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size))); - UpdateClusterState(heartbeat_resp_message.cluster_state()); + if (heartbeat_resp_message.cluster_state() != current_cluster_state_ && + current_cluster_state_ != ClusterState::CLUSTER_SCALE_IN && + current_cluster_state_ != ClusterState::CLUSTER_SCALE_OUT) { + MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to " + << CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state()); + UpdateClusterState(heartbeat_resp_message.cluster_state()); + } MS_LOG(DEBUG) << "The current cluster state from heartbeat:" << CommUtil::ClusterStateToString(current_cluster_state_); @@ -811,7 +837,6 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr &con send_meta_message.ParseFromArray(data, SizeToInt(size)); worker_num_ = send_meta_message.worker_num(); server_num_ = send_meta_message.server_num(); - if (send_meta_message.rank_id() < 0) { MS_LOG(EXCEPTION) << "The rank id is wrong."; } @@ -964,10 +989,6 @@ void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(data); - if (is_connected_to_scheduler_.load()) { - MS_LOG(WARNING) << "This node has been connected to scheduler."; - return; - } SendMetadataMessage scheduler_recovery_message; (void)scheduler_recovery_message.ParseFromArray(data, SizeToInt(size)); worker_num_ = scheduler_recovery_message.worker_num(); @@ -982,7 +1003,9 @@ void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr } MS_LOG(INFO) << "[Scheduler Recovery]: Server response message success!."; - if (!InitClientToScheduler()) { + ConnectToScheduler(); + bool connected = client_to_scheduler_->WaitConnected(); + if (!connected) { MS_LOG(WARNING) << "[Scheduler Recovery]: Server node connect to scheduler timedout!"; } @@ -991,6 +1014,12 @@ void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr connected_nodes_.clear(); MS_LOG(INFO) << "[Scheduler Recovery]: This node connect to scheduler successful!"; + if (cancelSafeModeFn_ && (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN || + current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT)) { + MS_LOG(INFO) << "[Scheduler Recovery]: Cancel Safe mode for " << kClusterState.at(current_cluster_state_); + cancelSafeModeFn_(); + } + UpdateClusterState(ClusterState::CLUSTER_SCHEDULER_RECOVERY); is_ready_ = false; } @@ -1067,27 +1096,23 @@ bool AbstractNode::InitClientToScheduler() { MsException::Instance().SetException(); } }); + ConnectToScheduler(); + StartHeartbeatTimer(client_to_scheduler_); + MS_LOG(INFO) << "Start heartbeat timer!"; + + bool wait_res = client_to_scheduler_->WaitConnected(); + if (!wait_res) { + is_ready_ = true; + } + return wait_res; +} +void AbstractNode::ConnectToScheduler() { client_to_scheduler_->Init(); client_to_scheduler_thread_ = std::make_unique([this]() { MS_LOG(INFO) << "The node start a tcp client!"; client_to_scheduler_->Start(); }); client_to_scheduler_thread_->detach(); - - client_to_scheduler_->set_connected_callback([&]() { is_connected_to_scheduler_ = true; }); - - client_to_scheduler_->set_disconnected_callback([&]() { - is_connected_to_scheduler_ = false; - std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval)); - if (is_ready_.load() == false) { - client_to_scheduler_->Init(); - } - }); - bool wait_res = client_to_scheduler_->WaitConnected(); - if (!wait_res) { - is_ready_ = true; - } - return wait_res; } const std::shared_ptr &AbstractNode::GetOrCreateTcpClient(const uint32_t &rank_id, const NodeRole &role) { diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index 95bf95fdf14..849fb321935 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -69,6 +69,7 @@ class BACKEND_EXPORT AbstractNode : public Node { using VectorPtr = std::shared_ptr>; using RequestHandler = std::function &conn, const std::shared_ptr &meta, const void *data, size_t size)>; + using CancelSafeModeFn = std::function; bool Broadcast(const NodeRole &node_role, const std::string &message, int command, const uint32_t &timeout = kCommTimeoutInSeconds); @@ -155,6 +156,8 @@ class BACKEND_EXPORT AbstractNode : public Node { void SetIterationResult(size_t last_iteration, bool is_iteration_valid); bool HasIterationFailed(uint32_t iteration_num) const; + // register cancel SafeMode function to node + void SetCancelSafeModeCallBack(const CancelSafeModeFn &fn) { cancelSafeModeFn_ = fn; } protected: virtual void Register(const std::shared_ptr &client); @@ -248,6 +251,7 @@ class BACKEND_EXPORT AbstractNode : public Node { bool FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output, const uint32_t &timeout); void OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data); + void ConnectToScheduler(); std::unique_ptr heart_beat_thread_; std::unique_ptr client_to_scheduler_thread_; @@ -325,6 +329,7 @@ class BACKEND_EXPORT AbstractNode : public Node { size_t failed_iteration_num_ = 0; bool iteration_failed_ = false; + CancelSafeModeFn cancelSafeModeFn_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 0e46a19e45d..afeccd9346b 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -260,7 +260,7 @@ std::string CommUtil::ClusterStateToString(const ClusterState &state) { if (state < SizeToInt(kClusterState.size())) { return kClusterState.at(state); } else { - return ""; + return std::to_string(state); } } diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index 8d37a924e77..ebd155ccc09 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -86,12 +86,16 @@ constexpr char kLibeventLogPrefix[] = "[libevent log]:"; // Find the corresponding string style of cluster state through the subscript of the enum:ClusterState const std::vector kClusterState = { - "ClUSTER_STARTING", // Initialization state when the cluster is just started. - "CLUSTER_READY", // The state after all nodes are successfully registered. - "CLUSTER_EXIT", // The state after the cluster exits successfully. - "NODE_TIMEOUT", // When a node has a heartbeat timeout - "CLUSTER_SCALE_OUT", // When the cluster is scale out. - "CLUSTER_SCALE_IN" // When the cluster is scale in. + "ClUSTER_STARTING", // Initialization state when the cluster is just started. + "CLUSTER_READY", // The state after all nodes are successfully registered. + "CLUSTER_EXIT", // The state after the cluster exits successfully. + "NODE_TIMEOUT", // When a node has a heartbeat timeout + "CLUSTER_SCALE_OUT", // When the cluster is scale out. + "CLUSTER_SCALE_IN", // When the cluster is scale in. + "CLUSTER_NEW_INSTANCE", // When the cluster is doing NEW_INSTANCE. + "CLUSTER_ENABLE_FLS", // When the cluster is doing ENABLE_FLS. + "CLUSTER_DISABLE_FLS", // When the cluster is doing DISABLE_FLS. + "CLUSTER_SCHEDULER_RECOVERY" // When the cluster is doing SCHEDULER_RECOVERY. }; class CommUtil { diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc index 4e402e08b22..9f85f15ab9b 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc @@ -43,8 +43,7 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port, NodeRole pe server_address_(std::move(address)), server_port_(port), peer_role_(peer_role), - disconnected_(false), - connected_(false) { + connection_status_(-1) { message_handler_.SetCallback( [this](const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size) { if (message_callback_) { @@ -86,16 +85,20 @@ std::string TcpClient::PeerRoleName() const { bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { std::unique_lock lock(connection_mutex_); bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), - [this] { return this->connected_.load(); }); + [this] { return this->connection_status_ == 1; }); return res; } void TcpClient::Init() { - if (disconnected_) { + std::lock_guard lock(connection_mutex_); + if (connection_status_ != -1) { return; } - - std::lock_guard lock(connection_mutex_); + connection_status_ = 0; + if (buffer_event_) { + bufferevent_free(buffer_event_); + buffer_event_ = nullptr; + } if (!CommUtil::CheckIp(server_address_)) { MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; } @@ -117,10 +120,10 @@ void TcpClient::Init() { sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); sin.sin_port = htons(server_port_); - if (!PSContext::instance()->enable_ssl() && buffer_event_ == nullptr) { + if (!PSContext::instance()->enable_ssl()) { MS_LOG(INFO) << "SSL is disable."; buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); - } else if (buffer_event_ == nullptr) { + } else { if (!EstablishSSL()) { MS_LOG(WARNING) << "Establish SSL failed."; return; @@ -228,7 +231,7 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { void TcpClient::NotifyConnected() { MS_LOG(INFO) << "Client connected to the server! Peer " << PeerRoleName() << " ip: " << server_address_ << ", port: " << server_port_; - connected_ = true; + connection_status_ = 1; connection_cond_.notify_all(); } @@ -269,12 +272,14 @@ void TcpClient::EventCallbackInner(struct bufferevent *bev, std::int16_t events) } else if (events & BEV_EVENT_ERROR) { MS_LOG(WARNING) << "The client will retry to connect to the server! Peer " << PeerRoleName() << " ip: " << server_address_ << ", port: " << server_port_; + connection_status_ = -1; if (disconnected_callback_) { disconnected_callback_(); } } else if (events & BEV_EVENT_EOF) { MS_LOG(WARNING) << "Client connected end of file! Peer " << PeerRoleName() << " ip: " << server_address_ << ", port: " << server_port_; + connection_status_ = -1; if (disconnected_callback_) { disconnected_callback_(); } diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_client.h b/mindspore/ccsrc/ps/core/communicator/tcp_client.h index 8484bb60cf3..48e325f9fd8 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_client.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_client.h @@ -72,7 +72,7 @@ class TcpClient { bool SendMessage(const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size); void set_timer_callback(const OnTimer &timer); const event_base &eventbase() const; - void set_disconnected() { disconnected_ = true; } + int connection_status() { return connection_status_; } protected: static void SetTcpNoDelay(const evutil_socket_t &fd); @@ -110,8 +110,8 @@ class TcpClient { std::string server_address_; std::uint16_t server_port_; NodeRole peer_role_; - std::atomic disconnected_; - std::atomic connected_; + // -1:disconnected, 0:connecting, 1:connected + std::atomic connection_status_; // The Configuration file Configuration *config_; }; diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index c9b7a2f3f98..3b85da2f956 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -27,15 +27,12 @@ bool ServerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "[Server start]: 4. The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " successfully registered to the scheduler!"; - StartHeartbeatTimer(client_to_scheduler_); - MS_LOG(INFO) << "[Server start]: 5. Server start heartbeat timer!"; - if (!WaitForStart(timeout)) { MS_LOG(ERROR) << "Start server node timeout!"; return false; } MsException::Instance().CheckException(); - MS_LOG(INFO) << "[Server start]: 6. Successfully start server node!"; + MS_LOG(INFO) << "[Server start]: 5. Successfully start server node!"; return true; } @@ -94,11 +91,10 @@ bool ServerNode::Finish(const uint32_t &timeout) { return true; } - if (!is_connected_to_scheduler_) { + if (client_to_scheduler_->connection_status() != 1) { MS_LOG(INFO) << "[Server finish]: Not connect to scheduler, no need to disconnect!"; return true; } - client_to_scheduler_->set_disconnected(); MS_LOG(INFO) << "[Server finish]: 1. Begin to finish server node!"; bool res = Disconnect(client_to_scheduler_, timeout); diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index cbe82c12876..9bbf1faffdf 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -27,15 +27,12 @@ bool WorkerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "[Worker start]: 4. The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " successfully registered to the scheduler!"; - StartHeartbeatTimer(client_to_scheduler_); - MS_LOG(INFO) << "[Worker start]: 5. Worker start heartbeat timer!"; - if (!WaitForStart(timeout)) { MS_LOG(ERROR) << "Start Worker node timeout!"; return false; } MsException::Instance().CheckException(); - MS_LOG(INFO) << "[Worker start]: 6. Successfully start worker node!"; + MS_LOG(INFO) << "[Worker start]: 5. Successfully start worker node!"; return true; } @@ -96,11 +93,10 @@ bool WorkerNode::Finish(const uint32_t &timeout) { return true; } - if (!is_connected_to_scheduler_) { + if (client_to_scheduler_->connection_status() != 1) { MS_LOG(INFO) << "[Worker finish]: Not connect to scheduler, no need to disconnect!"; return true; } - client_to_scheduler_->set_disconnected(); bool res = Disconnect(client_to_scheduler_, timeout); if (res) {