diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 6d05481207e..dfddf14bc63 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -19,6 +19,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") endif () if (NOT ENABLE_D) diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index b1eaf9b4d09..ea26b10edf7 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -74,30 +74,161 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me on_node_event_message_ = on_node_event_message; } -void AbstractNode::Heartbeat(const std::shared_ptr &client) { +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + const uint32_t &timeout) { + if (!CommUtil::ValidateRankId(node_role, rank_id)) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(message); + auto client = GetOrCreateTcpClient(rank_id); + return SendMessageSync(client, comm_message); +} + +bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, + const std::vector &data, const uint32_t &timeout) { + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(data.size(), 0); + + if (rank_ids.size() != data.size()) { + MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; + } + for (size_t it = 0; it < rank_ids.size(); ++it) { + if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(data.at(it)); + + auto client = GetOrCreateTcpClient(rank_ids.at(it)); + client->SendMessage(comm_message); + } + return Wait(request_id, timeout); +} + +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + CommMessage *comm_message_resp, const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(comm_message_resp); + if (!CommUtil::ValidateRankId(node_role, rank_id)) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(1, 0); + set_message_callback(request_id, [&]() { + receive_messages_mutex_.lock(); + auto res = receive_messages_[request_id]; + *comm_message_resp = res[rank_id]; + receive_messages_.erase(request_id); + receive_messages_mutex_.unlock(); + }); + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + message_meta.set_rank_id(node_info_.rank_id_); + message_meta.set_role(node_info_.node_role_); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(message); + auto client = GetOrCreateTcpClient(rank_id); + client->SendMessage(comm_message); + return Wait(request_id, timeout); +} + +bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, + const std::vector &data, std::vector *comm_message_resp, + const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(comm_message_resp); + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(data.size(), 0); + + if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { + MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; + } + + size_t len = rank_ids.size(); + + set_message_callback(request_id, [&]() { + receive_messages_mutex_.lock(); + auto res = receive_messages_[request_id]; + for (size_t it = 0; it < len; ++it) { + comm_message_resp->at(it) = &res[rank_ids.at(it)]; + } + receive_messages_.erase(request_id); + receive_messages_mutex_.unlock(); + }); + + for (size_t it = 0; it < len; ++it) { + if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { + MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; + } + + MessageMeta message_meta; + message_meta.set_cmd(NodeCommand::SEND_DATA); + message_meta.set_request_id(request_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message_meta}; + comm_message.set_data(data.at(it)); + + auto client = GetOrCreateTcpClient(rank_ids.at(it)); + client->SendMessage(comm_message); + } + return Wait(request_id, timeout); +} + +bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { + std::unique_lock lock(message_tracker_mutex_); + bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { + bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; + return ret; + }); + message_tracker_.erase(request_id); + return res; +} + +void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) { MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ << " begin send heartbeat to the scheduler!"; heart_beat_thread_ = std::make_unique([&]() { while (!is_finish_.load()) { + Heartbeat(client); std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); - MessageMeta meta; - meta.set_cmd(NodeCommand::HEARTBEAT); - - HeartbeatMessage heartbeat_message; - heartbeat_message.set_node_id(node_info_.node_id_); - - CommMessage message; - *message.mutable_pb_meta() = {meta}; - message.set_data(heartbeat_message.SerializeAsString()); - if (!SendMessageSync(client, message)) { - MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; - } } }); heart_beat_thread_->detach(); } +void AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_node_finish) { + MessageMeta meta; + meta.set_cmd(NodeCommand::HEARTBEAT); + + HeartbeatMessage heartbeat_message; + heartbeat_message.set_node_id(node_info_.node_id_); + heartbeat_message.set_is_node_finish(is_node_finish); + + CommMessage message; + *message.mutable_pb_meta() = {meta}; + message.set_data(heartbeat_message.SerializeAsString()); + if (!SendMessageSync(client, message)) { + MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; + } +} + void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { HeartbeatRespMessage heartbeat_resp_message; heartbeat_resp_message.ParseFromString(message.data()); @@ -106,8 +237,9 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { wait_start_cond_.notify_all(); MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!"; } - is_finish_ = heartbeat_resp_message.is_cluster_finish(); - if (is_finish_.load()) { + if (heartbeat_resp_message.is_cluster_finish()) { + Heartbeat(client_to_scheduler_, true); + is_finish_ = true; wait_finish_cond_.notify_all(); MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!"; } @@ -115,6 +247,10 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { if (is_timeout_ && on_node_event_message_) { is_ready_ = true; wait_start_cond_.notify_all(); + on_node_event_message_(NodeEvent::CLUSTER_TIMEOUT); + } + + if (heartbeat_resp_message.is_node_timeout() && on_node_event_message_) { on_node_event_message_(NodeEvent::NODE_TIMEOUT); } } @@ -207,6 +343,101 @@ bool AbstractNode::InitClientToScheduler() { }); return client_to_scheduler_->WaitConnected(); } + +const std::shared_ptr &AbstractNode::GetOrCreateTcpClient(const int &rank_id) { + std::lock_guard lock(client_mutex_); + if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { + return connected_nodes_[rank_id]; + } else { + if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { + MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!"; + } + std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; + uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; + auto client = std::make_shared(ip, port); + client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { + switch (message.pb_meta().cmd()) { + case NodeCommand::SEND_DATA: + ProcessSendDataResp(message); + RunMessageCallback(message.pb_meta().request_id()); + break; + default: + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + NotifyMessageArrival(message); + }); + client->Init(); + connected_nodes_[rank_id] = client; + return connected_nodes_[rank_id]; + } +} + +bool AbstractNode::SendMessageSync(const std::shared_ptr &client, const CommMessage &message, + const uint32_t &timeout) { + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(1, 0); + const_cast(message).mutable_pb_meta()->set_request_id(request_id); + client->SendMessage(message); + return Wait(request_id, timeout); +} + +void AbstractNode::SendMessageAsync(const std::shared_ptr &client, const CommMessage &message) { + uint64_t request_id = ++next_request_id_; + const_cast(message).mutable_pb_meta()->set_request_id(request_id); + client->SendMessage(message); +} + +void AbstractNode::ProcessSendDataResp(const CommMessage &message) { + std::lock_guard lock(receive_messages_mutex_); + const MessageMeta &message_meta = message.pb_meta(); + const uint32_t &rank_id = message_meta.rank_id(); + const uint64_t request_id = message_meta.request_id(); + auto it = receive_messages_.find(request_id); + if (it != receive_messages_.end()) { + it->second.insert(std::make_pair(rank_id, message)); + } else { + std::unordered_map res; + res.insert(std::make_pair(rank_id, message)); + receive_messages_[request_id] = res; + } +} + +void AbstractNode::RunMessageCallback(const uint64_t &request_id) { + message_callbacks_mutex_.lock(); + // When receiving a message's response, Then compare with the desired number of responses, + // If they are equal, then call the callback function + if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { + auto it = message_callbacks_.find(request_id); + if (it != message_callbacks_.end()) { + message_callbacks_mutex_.unlock(); + + if (it->second) { + it->second(); + } + + message_callbacks_mutex_.lock(); + message_callbacks_.erase(it); + } + } + message_callbacks_mutex_.unlock(); +} + +void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { + if (!message_callback) { + return; + } + std::lock_guard lock(message_callbacks_mutex_); + message_callbacks_[request_id] = message_callback; +} + +void AbstractNode::NotifyMessageArrival(const CommMessage &message) { + std::lock_guard lock(message_tracker_mutex_); + const MessageMeta &message_meta = message.pb_meta(); + uint64_t request_id = message_meta.request_id(); + + message_tracker_[request_id].second++; + message_tracker_cond_.notify_all(); +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index e1fe6a3d7fb..725508dff38 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -20,6 +20,9 @@ #include #include #include +#include +#include +#include #include "ps/core/node.h" @@ -34,21 +37,60 @@ class AbstractNode : public Node { bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds); void set_event_callback(const OnNodeEventMessage &on_node_event_message); + virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + const uint32_t &timeout = kCommTimeoutInSeconds); + virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, + const std::vector &data, const uint32_t &timeout = kCommTimeoutInSeconds); + virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); + virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, + const std::vector &data, std::vector *comm_message_resp, + const uint32_t &timeout = kCommTimeoutInSeconds); + + bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); + protected: void Register(const std::shared_ptr &client); void ProcessRegisterResp(const CommMessage &message); - void Heartbeat(const std::shared_ptr &client); + void StartHeartbeatTimer(const std::shared_ptr &client); + void Heartbeat(const std::shared_ptr &client, bool is_node_finish = false); void ProcessHeartbeatResp(const CommMessage &message); void FetchServers(const std::shared_ptr &client); void ProcessFetchServersResp(const CommMessage &message); bool Disconnect(const std::shared_ptr &client, const uint32_t &timeout); bool WaitForDisconnect(const uint32_t &timeout); bool InitClientToScheduler(); + const std::shared_ptr &GetOrCreateTcpClient(const int &rank_id); + bool SendMessageSync(const std::shared_ptr &client, const CommMessage &message, + const uint32_t &timeout = kCommTimeoutInSeconds); + void SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); + void ProcessSendDataResp(const CommMessage &message); + void RunMessageCallback(const uint64_t &request_id); + void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); + void NotifyMessageArrival(const CommMessage &message); std::unique_ptr heart_beat_thread_; std::unique_ptr client_to_scheduler_thread_; std::shared_ptr client_to_scheduler_; + OnNodeEventMessage on_node_event_message_; + // the map's key is: , the map's value is: + std::map, std::pair> nodes_address_; + std::mutex client_mutex_; + // the map's key is: rank_id + std::unordered_map> connected_nodes_; + + // the map's key is: request_id, the map's value is: + std::unordered_map> message_tracker_; + std::mutex message_tracker_mutex_; + std::condition_variable message_tracker_cond_; + + // the map's key is: request_id, the map's value is: + std::unordered_map> receive_messages_; + std::mutex receive_messages_mutex_; + // the map's key is: request_id + std::unordered_map message_callbacks_; + std::mutex message_callbacks_mutex_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/node.cc b/mindspore/ccsrc/ps/core/node.cc index 2ae02f6a393..31666eec9df 100644 --- a/mindspore/ccsrc/ps/core/node.cc +++ b/mindspore/ccsrc/ps/core/node.cc @@ -25,131 +25,6 @@ uint32_t Node::rank_id() const { return node_info_.rank_id_; } NodeRole Node::role() const { return node_info_.node_role_; } -bool Node::Wait(uint64_t request_id, const uint32_t &timeout) { - std::unique_lock lock(message_tracker_mutex_); - bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { - bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; - return ret; - }); - message_tracker_.erase(request_id); - return res; -} - -bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - const uint32_t &timeout) { - if (!CommUtil::ValidateRankId(node_role, rank_id)) { - MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; - } - - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); - auto client = GetOrCreateTcpClient(rank_id); - return SendMessageSync(client, comm_message); -} - -bool Node::Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, - const uint32_t &timeout) { - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(data.size(), 0); - - if (rank_ids.size() != data.size()) { - MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; - } - for (size_t it = 0; it < rank_ids.size(); ++it) { - if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { - MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; - } - - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(data.at(it)); - - auto client = GetOrCreateTcpClient(rank_ids.at(it)); - client->SendMessage(comm_message); - } - return Wait(request_id, timeout); -} - -bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - CommMessage *comm_message_resp, const uint32_t &timeout) { - MS_EXCEPTION_IF_NULL(comm_message_resp); - if (!CommUtil::ValidateRankId(node_role, rank_id)) { - MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; - } - - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(1, 0); - set_message_callback(request_id, [&]() { - receive_messages_mutex_.lock(); - auto res = receive_messages_[request_id]; - *comm_message_resp = res[rank_id]; - receive_messages_.erase(request_id); - receive_messages_mutex_.unlock(); - }); - - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - message_meta.set_rank_id(node_info_.rank_id_); - message_meta.set_role(node_info_.node_role_); - - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); - auto client = GetOrCreateTcpClient(rank_id); - client->SendMessage(comm_message); - return Wait(request_id, timeout); -} - -bool Node::Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, - std::vector *comm_message_resp, const uint32_t &timeout) { - MS_EXCEPTION_IF_NULL(comm_message_resp); - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(data.size(), 0); - - if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { - MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; - } - - size_t len = rank_ids.size(); - - set_message_callback(request_id, [&]() { - receive_messages_mutex_.lock(); - auto res = receive_messages_[request_id]; - for (size_t it = 0; it < len; ++it) { - comm_message_resp->at(it) = &res[rank_ids.at(it)]; - } - receive_messages_.erase(request_id); - receive_messages_mutex_.unlock(); - }); - - for (size_t it = 0; it < len; ++it) { - if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { - MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; - } - - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(data.at(it)); - - auto client = GetOrCreateTcpClient(rank_ids.at(it)); - client->SendMessage(comm_message); - } - return Wait(request_id, timeout); -} - bool Node::WaitForStart(const uint32_t &timeout) { std::unique_lock lock(wait_start_mutex_); bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { @@ -161,101 +36,6 @@ bool Node::WaitForStart(const uint32_t &timeout) { }); return res; } - -bool Node::SendMessageSync(const std::shared_ptr &client, const CommMessage &message, - const uint32_t &timeout) { - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(1, 0); - const_cast(message).mutable_pb_meta()->set_request_id(request_id); - client->SendMessage(message); - return Wait(request_id, timeout); -} - -void Node::SendMessageAsync(const std::shared_ptr &client, const CommMessage &message) { - uint64_t request_id = ++next_request_id_; - const_cast(message).mutable_pb_meta()->set_request_id(request_id); - client->SendMessage(message); -} - -const std::shared_ptr &Node::GetOrCreateTcpClient(const int &rank_id) { - std::lock_guard lock(client_mutex_); - if (connected_nodes_.find(rank_id) != connected_nodes_.end()) { - return connected_nodes_[rank_id]; - } else { - if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { - MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!"; - } - std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; - uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; - auto client = std::make_shared(ip, port); - client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { - switch (message.pb_meta().cmd()) { - case NodeCommand::SEND_DATA: - ProcessSendDataResp(message); - RunMessageCallback(message.pb_meta().request_id()); - break; - default: - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; - } - NotifyMessageArrival(message); - }); - client->Init(); - connected_nodes_[rank_id] = client; - return connected_nodes_[rank_id]; - } -} - -void Node::ProcessSendDataResp(const CommMessage &message) { - std::lock_guard lock(receive_messages_mutex_); - const MessageMeta &message_meta = message.pb_meta(); - const uint32_t &rank_id = message_meta.rank_id(); - const uint64_t request_id = message_meta.request_id(); - auto it = receive_messages_.find(request_id); - if (it != receive_messages_.end()) { - it->second.insert(std::make_pair(rank_id, message)); - } else { - std::unordered_map res; - res.insert(std::make_pair(rank_id, message)); - receive_messages_[request_id] = res; - } -} - -void Node::RunMessageCallback(const uint64_t &request_id) { - message_callbacks_mutex_.lock(); - // When receiving a message's response, Then compare with the desired number of responses, - // If they are equal, then call the callback function - if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { - auto it = message_callbacks_.find(request_id); - if (it != message_callbacks_.end()) { - message_callbacks_mutex_.unlock(); - - if (it->second) { - it->second(); - } - - message_callbacks_mutex_.lock(); - message_callbacks_.erase(it); - } - } - message_callbacks_mutex_.unlock(); -} - -void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) { - if (!message_callback) { - return; - } - std::lock_guard lock(message_callbacks_mutex_); - message_callbacks_[request_id] = message_callback; -} - -void Node::NotifyMessageArrival(const CommMessage &message) { - std::lock_guard lock(message_tracker_mutex_); - const MessageMeta &message_meta = message.pb_meta(); - uint64_t request_id = message_meta.request_id(); - - message_tracker_[request_id].second++; - message_tracker_cond_.notify_all(); -} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index 2f8190f2d3a..89f006a42a8 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -29,7 +29,6 @@ #include #include #include -#include #include "proto/comm.pb.h" #include "proto/ps.pb.h" @@ -66,28 +65,8 @@ class Node { uint32_t rank_id() const; NodeRole role() const; - bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); - - virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - const uint32_t &timeout = kCommTimeoutInSeconds); - virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, const uint32_t &timeout = kCommTimeoutInSeconds); - virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); - virtual bool Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, std::vector *comm_message_resp, - const uint32_t &timeout = kCommTimeoutInSeconds); - protected: bool WaitForStart(const uint32_t &timeout); - bool SendMessageSync(const std::shared_ptr &client, const CommMessage &message, - const uint32_t &timeout = kCommTimeoutInSeconds); - void SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); - const std::shared_ptr &GetOrCreateTcpClient(const int &rank_id); - void ProcessSendDataResp(const CommMessage &message); - void RunMessageCallback(const uint64_t &request_id); - void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback); - void NotifyMessageArrival(const CommMessage &message); NodeInfo node_info_; std::atomic is_ready_; @@ -97,28 +76,11 @@ class Node { std::atomic is_already_finished_; std::atomic_uint64_t next_request_id_; - // -> - std::map, std::pair> nodes_address_; - // rank_id->tcpclient - std::unordered_map> connected_nodes_; - - // request_id-> - std::unordered_map> message_tracker_; - std::mutex message_tracker_mutex_; - std::condition_variable message_tracker_cond_; - std::mutex wait_finish_mutex_; - std::condition_variable wait_finish_cond_; std::mutex wait_start_mutex_; std::condition_variable wait_start_cond_; + std::mutex wait_finish_mutex_; + std::condition_variable wait_finish_cond_; std::mutex finish_mutex_; - std::mutex client_mutex_; - - // request_id -> - std::unordered_map> receive_messages_; - std::mutex receive_messages_mutex_; - // request_id -> MessageCallback - std::unordered_map message_callbacks_; - std::mutex message_callbacks_mutex_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/node_info.h b/mindspore/ccsrc/ps/core/node_info.h index 0ab39ff24fa..8d9609ce53a 100644 --- a/mindspore/ccsrc/ps/core/node_info.h +++ b/mindspore/ccsrc/ps/core/node_info.h @@ -26,7 +26,7 @@ namespace mindspore { namespace ps { namespace core { -enum NodeEvent { NODE_TIMEOUT = 0 }; +enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; struct NodeInfo { NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index 10796859efb..89362246bd7 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -69,6 +69,10 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) { << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; } +void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); } + +bool NodeManager::CheckNodesFinishState() { return heartbeats_finish_nodes_.size() == nodes_info_.size(); } + std::vector NodeManager::FetchServersMeta() { std::vector servers_meta_list; for (auto it = nodes_info_.begin(); it != nodes_info_.end(); ++it) { @@ -131,7 +135,11 @@ bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); } bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); } -bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_; } +bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_.load(); } + +bool NodeManager::is_node_timeout() { return is_node_timeout_.load(); } + +void NodeManager::set_cluster_timeout(bool is_cluster_timeout) { is_cluster_timeout_ = is_cluster_timeout; } } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/node_manager.h b/mindspore/ccsrc/ps/core/node_manager.h index ec070d33327..3cc8eeb60fa 100644 --- a/mindspore/ccsrc/ps/core/node_manager.h +++ b/mindspore/ccsrc/ps/core/node_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef RPC_CLUSTER_MANAGER_H -#define RPC_CLUSTER_MANAGER_H +#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_ +#define MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_ #include #include @@ -45,6 +45,7 @@ class NodeManager { : is_cluster_ready_(false), is_cluster_finish_(false), is_cluster_timeout_(false), + is_node_timeout_(false), total_node_num_(0), next_worker_rank_id_(-1), next_server_rank_id_(-1) {} @@ -55,6 +56,8 @@ class NodeManager { void InitNodeNum(); int NextRankId(const RegisterMessage ®ister_message); void UpdateHeartbeat(const std::string &node_id); + void UpdateNodeFinishState(const std::string &node_id); + bool CheckNodesFinishState(); std::vector FetchServersMeta(); void UpdateClusterState(); void CheckClusterTimeout(); @@ -63,11 +66,14 @@ class NodeManager { bool is_cluster_ready(); bool is_cluster_finish(); bool is_cluster_timeout(); + bool is_node_timeout(); + void set_cluster_timeout(bool is_cluster_timeout); private: std::atomic is_cluster_ready_; std::atomic is_cluster_finish_; std::atomic is_cluster_timeout_; + std::atomic is_node_timeout_; uint32_t total_node_num_; std::atomic next_worker_rank_id_; std::atomic next_server_rank_id_; @@ -76,6 +82,7 @@ class NodeManager { std::mutex assign_rank_id_mutex_; std::mutex heartbeat_mutex_; std::unordered_map heartbeats_; + std::unordered_set heartbeats_finish_nodes_; // timeout nodes std::unordered_map timeout_nodes_info_; std::unordered_set finish_nodes_id_; @@ -83,4 +90,4 @@ class NodeManager { } // namespace core } // namespace ps } // namespace mindspore -#endif // RPC_CLUSTER_MANAGER_H +#endif // MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_ diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 4e47afeed0f..0d4aa67c59d 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -64,6 +64,7 @@ message RegisterRespMessage { message HeartbeatMessage { // the current Node unique id:0,1,2... string node_id = 1; + bool is_node_finish = 2; } message HeartbeatRespMessage { @@ -71,6 +72,7 @@ message HeartbeatRespMessage { bool is_cluster_ready = 1; bool is_cluster_finish = 2; bool is_cluster_timeout = 3; + bool is_node_timeout = 4; } message FetchServersRespMessage { diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc new file mode 100644 index 00000000000..8aa0be94dad --- /dev/null +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -0,0 +1,222 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/scheduler_node.h" + +namespace mindspore { +namespace ps { +namespace core { + +SchedulerNode::~SchedulerNode() { + MS_LOG(INFO) << "Stop scheduler node!"; + if (!is_already_stopped_) { + is_already_stopped_ = true; + server_->Stop(); + if (scheduler_thread_->joinable()) { + scheduler_thread_->join(); + } + if (update_state_thread_->joinable()) { + update_state_thread_->join(); + } + is_ready_ = true; + } +} + +bool SchedulerNode::Start(const uint32_t &timeout) { + MS_LOG(INFO) << "Start scheduler node!"; + Initialize(); + StartUpdateClusterStateTimer(); + if (!WaitForStart(timeout)) { + MS_LOG(ERROR) << "Start Scheduler node timeout!"; + return false; + } + MS_LOG(INFO) << "Start the scheduler node is successful!"; + return true; +} + +void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + HeartbeatMessage heartbeat_message; + heartbeat_message.ParseFromString(message.data()); + + node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); + + if (heartbeat_message.is_node_finish()) { + node_manager_.UpdateNodeFinishState(heartbeat_message.node_id()); + } + + if (heartbeat_message.is_node_finish() && node_manager_.CheckNodesFinishState()) { + MS_LOG(INFO) << "The scheduler node receive all the finish cmd!"; + is_finish_ = true; + wait_finish_cond_.notify_all(); + } + + HeartbeatRespMessage heartbeat_resp_message; + heartbeat_resp_message.set_is_cluster_ready(node_manager_.is_cluster_ready()); + heartbeat_resp_message.set_is_cluster_finish(node_manager_.is_cluster_finish()); + heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); + heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message.pb_meta()}; + comm_message.set_data(heartbeat_resp_message.SerializeAsString()); + const_cast(server).SendMessage(conn, comm_message); +} + +void SchedulerNode::Initialize() { + CreateTcpServer(); + is_already_stopped_ = false; + node_info_.node_id_ = CommUtil::GenerateUUID(); + node_info_.node_role_ = NodeRole::SCHEDULER; + MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_; +} + +void SchedulerNode::CreateTcpServer() { + node_manager_.InitNodeNum(); + + std::string scheduler_host = ClusterConfig::scheduler_host(); + uint32_t scheduler_port = ClusterConfig::scheduler_port(); + server_ = std::make_unique(scheduler_host, scheduler_port); + server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + switch (message.pb_meta().cmd()) { + case NodeCommand::HEARTBEAT: + ProcessHeartbeat(server, conn, message); + break; + case NodeCommand::REGISTER: + ProcessRegister(server, conn, message); + break; + case NodeCommand::FINISH: + ProcessFinish(server, conn, message); + break; + case NodeCommand::FETCH_SERVER: + ProcessFetchServers(server, conn, message); + break; + default: + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + }); + + server_->Init(); + + scheduler_thread_ = std::make_unique([&]() { + MS_LOG(INFO) << "The scheduler node start a tcp server!"; + server_->Start(); + }); + scheduler_thread_->detach(); +} + +void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + MS_LOG(INFO) << "The scheduler process a register message!"; + RegisterMessage register_message; + register_message.ParseFromString(message.data()); + + // assign worker node and server node rank id + int rank_id = node_manager_.NextRankId(register_message); + if (rank_id < 0) { + MS_LOG(EXCEPTION) << "The rank id is wrong!"; + } + const std::string &node_id = register_message.node_id(); + node_manager_.UpdateHeartbeat(node_id); + + RegisterRespMessage register_resp_message; + register_resp_message.set_node_id(node_id); + register_resp_message.set_rank_id(rank_id); + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message.pb_meta()}; + comm_message.set_data(register_resp_message.SerializeAsString()); + const_cast(server).SendMessage(conn, comm_message); +} + +void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + FinishMessage finish_message; + finish_message.ParseFromString(message.data()); + node_manager_.AddFinishNode(finish_message); + MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); + const_cast(server).SendMessage(conn, message); +} + +void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, + const CommMessage &message) { + FetchServersRespMessage fetch_servers_message; + std::vector servers_meta_list = node_manager_.FetchServersMeta(); + + *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; + + CommMessage comm_message; + *comm_message.mutable_pb_meta() = {message.pb_meta()}; + comm_message.set_data(fetch_servers_message.SerializeAsString()); + const_cast(server).SendMessage(conn, comm_message); +} + +void SchedulerNode::StartUpdateClusterStateTimer() { + MS_LOG(WARNING) << "The scheduler start a heartbeat timer!"; + update_state_thread_ = std::make_unique([&]() { + auto start_time = std::chrono::steady_clock::now(); + while (!is_finish_.load()) { + // 1. update cluster timeout + if (!node_manager_.is_cluster_ready() && (std::chrono::steady_clock::now() - start_time > + std::chrono::seconds(ClusterConfig::cluster_available_timeout()))) { + node_manager_.CheckClusterTimeout(); + } + + // 2. update cluster state + std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); + node_manager_.UpdateClusterState(); + if (node_manager_.is_cluster_ready()) { + is_ready_ = true; + wait_start_cond_.notify_all(); + } + if (node_manager_.is_cluster_finish()) { + std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval() * 2)); + is_finish_ = true; + wait_finish_cond_.notify_all(); + } + } + }); + update_state_thread_->detach(); +} + +bool SchedulerNode::Stop() { + MS_LOG(INFO) << "Stop scheduler node!"; + if (!is_already_stopped_) { + is_already_stopped_ = true; + server_->Stop(); + if (scheduler_thread_->joinable()) { + scheduler_thread_->join(); + } + if (update_state_thread_->joinable()) { + update_state_thread_->join(); + } + is_ready_ = true; + } + return true; +} + +bool SchedulerNode::Finish(const uint32_t &timeout) { + MS_LOG(INFO) << "Finish scheduler node!"; + std::unique_lock lock(wait_finish_mutex_); + wait_finish_cond_.wait(lock, [&] { + if (is_finish_.load()) { + MS_LOG(INFO) << "The scheduler finish success!"; + } + return is_finish_.load(); + }); + return true; +} +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h new file mode 100644 index 00000000000..86488ea9ac1 --- /dev/null +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_ +#define MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/core/cluster_config.h" +#include "ps/core/tcp_client.h" +#include "ps/core/tcp_server.h" +#include "ps/core/node_manager.h" +#include "ps/core/node.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace core { + +class SchedulerNode : public Node { + public: + SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} + ~SchedulerNode() override; + + bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; + bool Stop() override; + bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; + + private: + void Initialize(); + void CreateTcpServer(); + void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + void StartUpdateClusterStateTimer(); + void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + + std::unique_ptr server_; + std::unique_ptr scheduler_thread_; + std::unique_ptr update_state_thread_; + + NodeManager node_manager_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_ diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 987451666e6..2ac8861b24d 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -38,7 +38,7 @@ bool ServerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "Start server node!"; Initialize(); Register(client_to_scheduler_); - Heartbeat(client_to_scheduler_); + StartHeartbeatTimer(client_to_scheduler_); if (!WaitForStart(timeout)) { MS_LOG(EXCEPTION) << "Start Worker node timeout!"; diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index a59431c257c..d6a6e560b0e 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -146,11 +146,7 @@ void TcpClient::StopEventBase() { MS_LOG(INFO) << "Stop tcp client event base!"; int ret = event_base_loopbreak(event_base_); if (ret != 0) { - MS_LOG(EXCEPTION) << "Event base loop break failed!"; - } - if (event_base_) { - event_base_free(event_base_); - event_base_ = nullptr; + MS_LOG(ERROR) << "Event base loop break failed!"; } } diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index 9e14bb7e656..3a2f40f92e9 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -44,7 +44,7 @@ bool WorkerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "Starting worker node!"; Initialize(); Register(client_to_scheduler_); - Heartbeat(client_to_scheduler_); + StartHeartbeatTimer(client_to_scheduler_); if (!WaitForStart(timeout)) { MS_LOG(ERROR) << "Start Worker node timeout!";