From 72980a8a48a346dae45e51df693a5769a4620666 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Mon, 30 Nov 2020 21:36:35 +0800 Subject: [PATCH] added node --- mindspore/ccsrc/ps/CMakeLists.txt | 4 +- mindspore/ccsrc/ps/core/comm_util.cc | 13 ++ mindspore/ccsrc/ps/core/comm_util.h | 8 +- mindspore/ccsrc/ps/core/node.cc | 158 ++++++++++++++++++++++ mindspore/ccsrc/ps/core/node.h | 102 ++++++++++++++ mindspore/ccsrc/ps/core/node_info.h | 47 +++++++ mindspore/ccsrc/ps/core/node_manager.cc | 137 +++++++++++++++++++ mindspore/ccsrc/ps/core/node_manager.h | 86 ++++++++++++ mindspore/ccsrc/ps/core/protos/comm.proto | 7 + mindspore/ccsrc/ps/core/tcp_client.h | 1 + mindspore/ccsrc/ps/core/tcp_server.cc | 30 +++- mindspore/ccsrc/ps/core/tcp_server.h | 9 +- 12 files changed, 596 insertions(+), 6 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/node.cc create mode 100644 mindspore/ccsrc/ps/core/node.h create mode 100644 mindspore/ccsrc/ps/core/node_info.h create mode 100644 mindspore/ccsrc/ps/core/node_manager.cc create mode 100644 mindspore/ccsrc/ps/core/node_manager.h diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 658546465a8..6c7e313c0a9 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -12,7 +12,9 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_message_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_server.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") -endif() + list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") +endif () set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 28fb5ed658f..5fc35df0747 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -109,6 +109,19 @@ std::string CommUtil::GenerateUUID() { return ss.str(); } +std::string CommUtil::NodeRoleToString(const NodeRole &role) { + switch (role) { + case NodeRole::SCHEDULER: + return "SCHEDULER"; + case NodeRole::SERVER: + return "SERVER"; + case NodeRole::WORKER: + return "WORKER"; + default: + MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; + } +} + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index 62b8b76b254..d48ef49891c 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -41,11 +41,13 @@ #include #include #include -#include -#include #include #include +#include +#include +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" #include "utils/log_adapter.h" namespace mindspore { @@ -63,7 +65,9 @@ class CommUtil { static bool CheckIp(const std::string &ip); static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); static std::string GenerateUUID(); + static std::string NodeRoleToString(const NodeRole &role); + private: static std::random_device rd; static std::mt19937_64 gen; static std::uniform_int_distribution<> dis; diff --git a/mindspore/ccsrc/ps/core/node.cc b/mindspore/ccsrc/ps/core/node.cc new file mode 100644 index 00000000000..bbca86d302f --- /dev/null +++ b/mindspore/ccsrc/ps/core/node.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/node.h" + +namespace mindspore { +namespace ps { +namespace core { +void Node::Heartbeat(const std::shared_ptr &client) { + MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ + << " begin send heartbeat to the scheduler!"; + heart_beat_thread_ = std::make_unique([&]() { + while (!is_finish_.load()) { + std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); + MessageMeta meta; + meta.set_cmd(NodeCommand::HEARTBEAT); + + HeartbeatMessage heartbeat_message; + heartbeat_message.set_node_id(node_info_.node_id_); + + CommMessage message; + *message.mutable_pb_meta() = {meta}; + message.set_data(heartbeat_message.SerializeAsString()); + SendMessageAsync(client, message); + } + }); + heart_beat_thread_->detach(); +} + +void Node::ProcessHeartbeatResp(const CommMessage &message) { + HeartbeatRespMessage heartbeat_resp_message; + heartbeat_resp_message.ParseFromString(message.data()); + is_ready_ = heartbeat_resp_message.is_cluster_ready(); + if (is_ready_.load()) { + wait_start_cond_.notify_all(); + } + is_finish_ = heartbeat_resp_message.is_cluster_finish(); + if (is_finish_.load()) { + wait_finish_cond_.notify_all(); + } + is_timeout_ = heartbeat_resp_message.is_cluster_timeout(); + if (is_timeout_ && on_node_event_message_) { + on_node_event_message_(NodeEvent::NODE_TIMEOUT); + } +} + +void Node::FetchServers(const std::shared_ptr &client) { + MessageMeta meta; + meta.set_cmd(NodeCommand::FETCH_SERVER); + + CommMessage message; + *message.mutable_pb_meta() = {meta}; + SendMessageSync(client, message); +} + +void Node::ProcessFetchServersResp(const CommMessage &message) { + FetchServersRespMessage fetch_servers_resp_message; + fetch_servers_resp_message.ParseFromString(message.data()); + + for (const auto &it : fetch_servers_resp_message.servers_meta()) { + server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port()); + } + + MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size(); +} + +std::string Node::node_id() const { return node_info_.node_id_; } + +uint32_t Node::rank_id() const { return node_info_.rank_id_; } + +void Node::set_callback(const OnNodeEventMessage &on_node_event_message) { + on_node_event_message_ = on_node_event_message; +} + +void Node::Wait(uint64_t request_id) { + std::unique_lock lock(message_mutex_); + message_tracker_cond_.wait(lock, [&] { + bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second; + if (ret) { + MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id; + message_tracker_.erase(request_id); + } + return ret; + }); +} + +void Node::Disconnect(const std::shared_ptr &client) { + MessageMeta meta; + meta.set_cmd(NodeCommand::FINISH); + + FinishMessage finish_message; + finish_message.set_node_id(node_info_.node_id_); + + CommMessage message; + *message.mutable_pb_meta() = {meta}; + message.set_data(finish_message.SerializeAsString()); + SendMessageSync(client, message); + WaitForDisconnect(); +} + +void Node::WaitForStart() { + std::unique_lock lock(wait_start_mutex_); + wait_start_cond_.wait(lock, [&] { + if (is_ready_.load()) { + MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!"; + } + return is_ready_.load(); + }); +} + +void Node::WaitForDisconnect() { + std::unique_lock lock(wait_finish_mutex_); + wait_finish_cond_.wait(lock, [&] { + if (is_finish_.load()) { + MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!"; + } + return is_finish_.load(); + }); +} + +void Node::SendMessageSync(const std::shared_ptr &client, const CommMessage &message) { + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(1, 0); + const_cast(message).mutable_pb_meta()->set_request_id(request_id); + client->SendMessage(message); + Wait(request_id); +} + +void Node::SendMessageAsync(const std::shared_ptr &client, const CommMessage &message) { + uint64_t request_id = ++next_request_id_; + const_cast(message).mutable_pb_meta()->set_request_id(request_id); + client->SendMessage(message); +} + +void Node::NotifyMessageArrival(const CommMessage &message) { + const MessageMeta &message_meta = message.pb_meta(); + uint64_t request_id = message_meta.request_id(); + + message_tracker_[request_id].second++; + message_tracker_cond_.notify_all(); +} +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h new file mode 100644 index 00000000000..5ff490f0084 --- /dev/null +++ b/mindspore/ccsrc/ps/core/node.h @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_H_ +#define MINDSPORE_CCSRC_PS_CORE_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/core/cluster_config.h" +#include "ps/core/node_info.h" +#include "ps/core/tcp_client.h" +#include "ps/core/tcp_server.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace core { +class Node { + public: + Node() + : is_ready_(false), + is_finish_(false), + is_timeout_(false), + is_already_stopped_(true), + next_request_id_(0), + heart_beat_thread_(nullptr) {} + virtual ~Node() = default; + + using OnNodeEventMessage = std::function; + void set_callback(const OnNodeEventMessage &on_node_event_message); + + std::string node_id() const; + uint32_t rank_id() const; + + void Wait(uint64_t request_id); + + protected: + void Heartbeat(const std::shared_ptr &client); + void ProcessHeartbeatResp(const CommMessage &message); + void FetchServers(const std::shared_ptr &client); + void ProcessFetchServersResp(const CommMessage &message); + void Disconnect(const std::shared_ptr &client); + void WaitForStart(); + void WaitForDisconnect(); + void SendMessageSync(const std::shared_ptr &client, const CommMessage &message); + void SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); + void NotifyMessageArrival(const CommMessage &message); + + NodeInfo node_info_; + std::atomic is_ready_; + std::atomic is_finish_; + std::atomic is_timeout_; + std::atomic is_already_stopped_; + std::atomic_uint64_t next_request_id_; + std::unique_ptr heart_beat_thread_; + + OnNodeEventMessage on_node_event_message_; + + // rank_id-> + std::unordered_map> server_rank_ids_; + + // timestamp-> + std::unordered_map> message_tracker_; + std::mutex message_mutex_; + std::condition_variable message_tracker_cond_; + std::mutex wait_finish_mutex_; + std::condition_variable wait_finish_cond_; + std::mutex wait_start_mutex_; + std::condition_variable wait_start_cond_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ diff --git a/mindspore/ccsrc/ps/core/node_info.h b/mindspore/ccsrc/ps/core/node_info.h new file mode 100644 index 00000000000..0ab39ff24fa --- /dev/null +++ b/mindspore/ccsrc/ps/core/node_info.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_ +#define MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_ + +#include + +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" + +namespace mindspore { +namespace ps { +namespace core { + +enum NodeEvent { NODE_TIMEOUT = 0 }; + +struct NodeInfo { + NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} + // ip + std::string ip_; + // the port of this node + uint16_t port_; + // the current Node unique id:0,1,2... + std::string node_id_; + // the role of the node: worker,server,scheduler + NodeRole node_role_; + // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] + uint32_t rank_id_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_ diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc new file mode 100644 index 00000000000..10796859efb --- /dev/null +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/node_manager.h" + +namespace mindspore { +namespace ps { +namespace core { +void NodeManager::InitNodeNum() { total_node_num_ = ClusterConfig::server_num() + ClusterConfig::worker_num(); } + +int NodeManager::NextRankId(const RegisterMessage ®ister_message) { + std::lock_guard lock(assign_rank_id_mutex_); + int rank_id = -1; + + const std::string &node_id = register_message.node_id(); + if (nodes_info_.find(node_id) != nodes_info_.end()) { + rank_id = nodes_info_[node_id].rank_id_; + MS_LOG(INFO) << "The node id: " << node_id << " is already assigned!"; + return rank_id; + } + + if (register_message.role() == NodeRole::SERVER) { + const std::string &ip = register_message.ip(); + uint32_t port = register_message.port(); + + rank_id = ++next_server_rank_id_; + NodeInfo node_info; + node_info.node_role_ = NodeRole::SERVER; + node_info.node_id_ = node_id; + node_info.rank_id_ = rank_id; + node_info.ip_ = ip; + node_info.port_ = port; + nodes_info_[node_id] = node_info; + MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port + << " assign rank id:" << rank_id; + + } else if (register_message.role() == NodeRole::WORKER) { + rank_id = ++next_worker_rank_id_; + NodeInfo node_info; + node_info.node_role_ = NodeRole::WORKER; + node_info.node_id_ = node_id; + node_info.rank_id_ = rank_id; + nodes_info_[node_id] = node_info; + MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id; + } + return rank_id; +} + +void NodeManager::UpdateHeartbeat(const std::string &node_id) { + std::lock_guard lock(heartbeat_mutex_); + NodeInfo node_info = nodes_info_[node_id]; + struct timeval current_time {}; + (void)gettimeofday(¤t_time, nullptr); + heartbeats_[node_id] = current_time; + MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id + << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; +} + +std::vector NodeManager::FetchServersMeta() { + std::vector servers_meta_list; + for (auto it = nodes_info_.begin(); it != nodes_info_.end(); ++it) { + if (it->second.node_role_ == NodeRole::SERVER) { + ServersMeta servers_meta; + servers_meta.set_rank_id(it->second.rank_id_); + servers_meta.set_ip(it->second.ip_); + servers_meta.set_port(it->second.port_); + servers_meta_list.push_back(servers_meta); + } + } + return servers_meta_list; +} + +void NodeManager::UpdateClusterState() { + // 1. update cluster timeout state + struct timeval current_time {}; + (void)gettimeofday(¤t_time, nullptr); + timeout_nodes_info_.clear(); + for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) { + if (it->second.tv_sec + ClusterConfig::heartbeat_timeout() < current_time.tv_sec) { + MS_LOG(ERROR) << "The node id:" << it->first << " is timeout!"; + timeout_nodes_info_[it->first] = nodes_info_[it->first]; + } + } + if (!timeout_nodes_info_.empty()) { + is_cluster_timeout_ = true; + for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) { + finish_nodes_id_.insert(it->first); + } + } + + // 2. update cluster finish state + if (finish_nodes_id_.size() == total_node_num_) { + is_cluster_finish_ = true; + is_cluster_ready_ = true; + } + + // 3. update cluster ready state + if (nodes_info_.size() == total_node_num_) { + is_cluster_ready_ = true; + } +} + +void NodeManager::CheckClusterTimeout() { + if (total_node_num_ != nodes_info_.size()) { + MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout() + << " seconds,so finish the cluster"; + is_cluster_timeout_ = true; + } +} + +void NodeManager::AddFinishNode(const FinishMessage &finish_message) { + finish_nodes_id_.insert(finish_message.node_id()); +} + +std::unordered_map NodeManager::nodes_info() { return nodes_info_; } + +bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); } + +bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); } + +bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_; } +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/node_manager.h b/mindspore/ccsrc/ps/core/node_manager.h new file mode 100644 index 00000000000..ec070d33327 --- /dev/null +++ b/mindspore/ccsrc/ps/core/node_manager.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RPC_CLUSTER_MANAGER_H +#define RPC_CLUSTER_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/core/node.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace core { +class NodeManager { + public: + NodeManager() + : is_cluster_ready_(false), + is_cluster_finish_(false), + is_cluster_timeout_(false), + total_node_num_(0), + next_worker_rank_id_(-1), + next_server_rank_id_(-1) {} + virtual ~NodeManager() = default; + + enum ClusterState { STARTING, STARTED, FAILED, STOPPING, STOPPED }; + + void InitNodeNum(); + int NextRankId(const RegisterMessage ®ister_message); + void UpdateHeartbeat(const std::string &node_id); + std::vector FetchServersMeta(); + void UpdateClusterState(); + void CheckClusterTimeout(); + void AddFinishNode(const FinishMessage &finish_message); + std::unordered_map nodes_info(); + bool is_cluster_ready(); + bool is_cluster_finish(); + bool is_cluster_timeout(); + + private: + std::atomic is_cluster_ready_; + std::atomic is_cluster_finish_; + std::atomic is_cluster_timeout_; + uint32_t total_node_num_; + std::atomic next_worker_rank_id_; + std::atomic next_server_rank_id_; + // worker nodes and server nodes + std::unordered_map nodes_info_; + std::mutex assign_rank_id_mutex_; + std::mutex heartbeat_mutex_; + std::unordered_map heartbeats_; + // timeout nodes + std::unordered_map timeout_nodes_info_; + std::unordered_set finish_nodes_id_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore +#endif // RPC_CLUSTER_MANAGER_H diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 2b76a8814d7..f45fe583f56 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -25,6 +25,7 @@ enum NodeCommand { HEARTBEAT = 2; SEND_DATA = 3; FETCH_SERVER = 4; + FINISH = 5; } enum NodeRole { @@ -65,6 +66,7 @@ message HeartbeatRespMessage { // Is the entire system ready to use. bool is_cluster_ready = 1; bool is_cluster_finish = 2; + bool is_cluster_timeout = 3; } message FetchServersRespMessage { @@ -78,6 +80,11 @@ message ServersMeta { } +message FinishMessage { + // the current Node unique id:0,1,2... + string node_id = 1; +} + message CommMessage { MessageMeta pb_meta = 1; bytes data = 2; diff --git a/mindspore/ccsrc/ps/core/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h index 10c84460a9c..d98738b532d 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -32,6 +32,7 @@ #include #include "proto/comm.pb.h" +#include "proto/ps.pb.h" #include "ps/core/cluster_config.h" namespace mindspore { diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 1dcd0048faa..1276ea63b02 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -85,6 +85,8 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon this->client_accept_ = client_accept; } +void TcpServer::set_timer_once_callback(const OnTimerOnce &timer) { on_timer_once_callback_ = timer; } + void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } void TcpServer::Init() { @@ -165,7 +167,21 @@ void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { struct timeval timeout {}; timeout.tv_sec = time; timeout.tv_usec = 0; - ev = evtimer_new(base_, TimerCallback, this); + ev = evtimer_new(base_, TimerOnceCallback, this); + MS_EXCEPTION_IF_NULL(ev); + evtimer_add(ev, &timeout); +} + +void TcpServer::StartTimer(const uint32_t &time) { + MS_EXCEPTION_IF_NULL(base_); + struct event *ev = nullptr; + if (time == 0) { + MS_LOG(EXCEPTION) << "The time should not be 0!"; + } + struct timeval timeout {}; + timeout.tv_sec = time; + timeout.tv_usec = 0; + ev = event_new(base_, -1, EV_PERSIST, TimerCallback, this); MS_EXCEPTION_IF_NULL(ev); evtimer_add(ev, &timeout); } @@ -321,7 +337,15 @@ void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) { MS_EXCEPTION_IF_NULL(arg); auto tcp_server = reinterpret_cast(arg); if (tcp_server->on_timer_callback_) { - tcp_server->on_timer_callback_(*tcp_server); + tcp_server->on_timer_callback_(); + } +} + +void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { + MS_EXCEPTION_IF_NULL(arg); + auto tcp_server = reinterpret_cast(arg); + if (tcp_server->on_timer_once_callback_) { + tcp_server->on_timer_once_callback_(*tcp_server); } } @@ -337,6 +361,8 @@ void TcpServer::SendMessage(const CommMessage &message) { uint16_t TcpServer::BoundPort() const { return server_port_; } +std::string TcpServer::BoundIp() const { return server_address_; } + int TcpServer::ConnectionNum() const { return connections_.size(); } const std::map &TcpServer::Connections() const { return connections_; } diff --git a/mindspore/ccsrc/ps/core/tcp_server.h b/mindspore/ccsrc/ps/core/tcp_server.h index ed986ac2d4f..43294aa5e92 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.h +++ b/mindspore/ccsrc/ps/core/tcp_server.h @@ -35,6 +35,7 @@ #include #include "proto/comm.pb.h" +#include "proto/ps.pb.h" #include "ps/core/tcp_message_handler.h" #include "ps/core/cluster_config.h" #include "utils/log_adapter.h" @@ -71,18 +72,21 @@ class TcpServer { using OnConnected = std::function; using OnDisconnected = std::function; using OnAccepted = std::function; - using OnTimer = std::function; + using OnTimerOnce = std::function; + using OnTimer = std::function; explicit TcpServer(const std::string &address, std::uint16_t port); virtual ~TcpServer(); void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, const OnAccepted &client_accept); + void set_timer_once_callback(const OnTimerOnce &timer); void set_timer_callback(const OnTimer &timer); void Init(); void Start(); void StartWithNoBlock(); void StartTimerOnlyOnce(const uint32_t &time); + void StartTimer(const uint32_t &time); void Stop(); void SendToAllClients(const char *data, size_t len); void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); @@ -92,6 +96,7 @@ class TcpServer { void SendMessage(const TcpConnection &conn, const CommMessage &message); void SendMessage(const CommMessage &message); uint16_t BoundPort() const; + std::string BoundIp() const; int ConnectionNum() const; const std::map &Connections() const; @@ -102,6 +107,7 @@ class TcpServer { static void ReadCallback(struct bufferevent *, void *connection); static void EventCallback(struct bufferevent *, std::int16_t events, void *server); static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); + static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg); virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); struct event_base *base_; @@ -117,6 +123,7 @@ class TcpServer { OnAccepted client_accept_; std::recursive_mutex connection_mutex_; OnServerReceiveMessage message_callback_; + OnTimerOnce on_timer_once_callback_; OnTimer on_timer_callback_; }; } // namespace core