diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 95cc15d9290..293a1fe1490 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -75,6 +75,8 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string & auto client = GetOrCreateTcpClient((*it).first.second); client->SendMessage(comm_message); } + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } @@ -126,11 +128,13 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & auto client = GetOrCreateTcpClient(rank_ids.at(it)); client->SendMessage(comm_message); } + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - CommMessage *output, const uint32_t &timeout) { + std::string *output, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(output); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; @@ -141,7 +145,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; - *output = res[rank_id]; + *output = res[rank_id].data(); receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -157,11 +161,13 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); client->SendMessage(comm_message); + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, std::vector *output, + const std::vector &data, std::vector *output, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(output); uint64_t request_id = ++next_request_id_; @@ -177,7 +183,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; for (size_t it = 0; it < len; ++it) { - (*output).push_back(res[rank_ids.at(it)]); + (*output).push_back(res[rank_ids.at(it)].data()); } receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); @@ -201,6 +207,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & auto client = GetOrCreateTcpClient(rank_ids.at(it)); client->SendMessage(comm_message); } + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } @@ -215,7 +223,7 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { } uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - const std::string &message, const uint32_t &timeout) { + const std::string &message) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } @@ -233,19 +241,19 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const } std::pair AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, - const uint32_t &rank_id, CommMessage *output) { + const uint32_t &rank_id, std::string *output) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { - *output = received_data_[std::make_pair(rank_id, rank_request_id)]; + *output = received_data_[std::make_pair(rank_id, rank_request_id)].data(); received_data_.erase(std::make_pair(rank_id, rank_request_id)); } else { set_receive_callback(rank_id, rank_request_id, [=]() { receive_callbacks_mutex_.lock(); - *output = received_data_[std::make_pair(rank_id, 1)]; + *output = received_data_[std::make_pair(rank_id, 1)].data(); received_data_.erase(std::make_pair(rank_id, rank_request_id)); receive_callbacks_mutex_.unlock(); }); @@ -272,13 +280,25 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) << " begin send heartbeat to the scheduler!"; heart_beat_thread_ = std::make_unique([&]() { while (!is_finish_.load()) { - Heartbeat(client); + if (!Heartbeat(client)) { + MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; + if (!CheckSchedulerTimeout() && on_node_event_message_) { + MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; + is_finish_ = true; + wait_finish_cond_.notify_all(); + on_node_event_message_(NodeEvent::SCHEDULER_TIMEOUT); + } + } else { + UpdateSchedulerTime(); + } std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); } }); } -void AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_node_finish) { +bool AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_node_finish) { MessageMeta meta; meta.set_cmd(NodeCommand::HEARTBEAT); @@ -292,11 +312,31 @@ void AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_n if (!SendMessageSync(client, message)) { MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; } + return true; +} + +void AbstractNode::UpdateSchedulerTime() { + struct timeval current_time {}; + (void)gettimeofday(¤t_time, nullptr); + scheduler_time_ = current_time; + MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ + << " update scheduler time, the current time is: " << current_time.tv_sec; +} + +bool AbstractNode::CheckSchedulerTimeout() const { + struct timeval current_time {}; + (void)gettimeofday(¤t_time, nullptr); + if (scheduler_time_.tv_sec + ClusterConfig::scheduler_timeout() < current_time.tv_sec) { + return true; + } + return false; } void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { HeartbeatRespMessage heartbeat_resp_message; heartbeat_resp_message.ParseFromString(message.data()); + is_ready_ = heartbeat_resp_message.is_cluster_ready(); if (is_ready_.load()) { wait_start_cond_.notify_all(); @@ -353,9 +393,9 @@ bool AbstractNode::Disconnect(const std::shared_ptr &client, const ui *message.mutable_pb_meta() = {meta}; message.set_data(finish_message.SerializeAsString()); if (!SendMessageSync(client, message)) { - MS_LOG(EXCEPTION) << "Disconnect timeout!"; + MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!"; } - MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!"; return WaitForDisconnect(timeout); } @@ -444,6 +484,8 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr &client, con message_tracker_[request_id] = std::make_pair(1, 0); const_cast(message).mutable_pb_meta()->set_request_id(request_id); client->SendMessage(message); + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } @@ -452,6 +494,8 @@ uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr &client message_tracker_[request_id] = std::make_pair(1, 0); const_cast(message).mutable_pb_meta()->set_request_id(request_id); client->SendMessage(message); + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return request_id; } @@ -460,6 +504,8 @@ void AbstractNode::ProcessSendDataResp(const CommMessage &message) { const MessageMeta &message_meta = message.pb_meta(); const uint32_t &rank_id = message_meta.rank_id(); const uint64_t request_id = message_meta.request_id(); + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; auto it = receive_messages_.find(request_id); if (it != receive_messages_.end()) { it->second[rank_id] = message; diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index dff77346e1e..eea8eb773da 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -42,23 +42,24 @@ class AbstractNode : public Node { const uint32_t &timeout = kCommTimeoutInSeconds); bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output, + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output, const uint32_t &timeout = kCommTimeoutInSeconds); bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, - std::vector *output, const uint32_t &timeout = kCommTimeoutInSeconds); + std::vector *output, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); - uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - const uint32_t &timeout = kCommTimeoutInSeconds); + uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message); std::pair CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - CommMessage *output); + std::string *output); bool CollectiveWait(std::pair request_id, const uint32_t &timeout = kCommTimeoutInSeconds); protected: void Register(const std::shared_ptr &client); void ProcessRegisterResp(const CommMessage &message); void StartHeartbeatTimer(const std::shared_ptr &client); - void Heartbeat(const std::shared_ptr &client, bool is_node_finish = false); + bool Heartbeat(const std::shared_ptr &client, bool is_node_finish = false); + void UpdateSchedulerTime(); + bool CheckSchedulerTimeout() const; void ProcessHeartbeatResp(const CommMessage &message); void FetchServers(const std::shared_ptr &client); void ProcessFetchServersResp(const CommMessage &message); @@ -113,6 +114,7 @@ class AbstractNode : public Node { // the key is rank_id, the value is rank_id's actual request_id std::unordered_map actual_rank_request_ids_; std::mutex rank_request_ids_mutex; + timeval scheduler_time_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/cluster_config.cc b/mindspore/ccsrc/ps/core/cluster_config.cc index 23f1635da71..33bd658c5e6 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.cc +++ b/mindspore/ccsrc/ps/core/cluster_config.cc @@ -33,15 +33,17 @@ uint32_t ClusterConfig::heartbeat_timeout_ = 30; uint32_t ClusterConfig::cluster_available_timeout_ = 300; // The timeout period for the client to connect to the server is 100ms. uint32_t ClusterConfig::connect_interval_ = 100; +// When the scheduler exits, the worker and server can continue to work for 5 hours +uint32_t ClusterConfig::scheduler_timeout_ = 3600 * 5; -void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, - std::unique_ptr scheduler_host, const uint16_t &scheduler_port) { +void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, + const uint16_t &scheduler_port) { worker_num_ = worker_num; server_num_ = server_num; - if (!CommUtil::CheckIp(*scheduler_host.get())) { - MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!"; + if (!CommUtil::CheckIp(scheduler_host)) { + MS_LOG(EXCEPTION) << "The scheduler_host:" << scheduler_host << " is illegal!"; } - scheduler_host_ = std::move(scheduler_host); + scheduler_host_ = std::make_unique(scheduler_host); scheduler_port_ = scheduler_port; } @@ -55,7 +57,7 @@ void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) { heartbeat_interval_ = heartbeat_interval; } -std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } +std::string ClusterConfig::scheduler_host() { return *scheduler_host_; } uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } @@ -74,6 +76,10 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa uint32_t ClusterConfig::connect_interval() { return connect_interval_; } void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } + +uint32_t ClusterConfig::scheduler_timeout() { return scheduler_timeout_; } + +void ClusterConfig::set_scheduler_timeout(const uint32_t &scheduler_timeout) { scheduler_timeout_ = scheduler_timeout; } } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/cluster_config.h b/mindspore/ccsrc/ps/core/cluster_config.h index 20104949a70..c13c6d0192a 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.h +++ b/mindspore/ccsrc/ps/core/cluster_config.h @@ -30,7 +30,7 @@ namespace ps { namespace core { class ClusterConfig { public: - static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr scheduler_host, + static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, const uint16_t &scheduler_port); static uint32_t worker_num(); static uint32_t server_num(); @@ -44,6 +44,8 @@ class ClusterConfig { static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); static uint32_t connect_interval(); static void set_connect_interval(const uint32_t &connect_interval); + static uint32_t scheduler_timeout(); + static void set_scheduler_timeout(const uint32_t &scheduler_timeout); private: static uint32_t worker_num_; @@ -54,6 +56,7 @@ class ClusterConfig { static uint32_t heartbeat_timeout_; static uint32_t cluster_available_timeout_; static uint32_t connect_interval_; + static uint32_t scheduler_timeout_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/node.cc b/mindspore/ccsrc/ps/core/node.cc index 31666eec9df..ee42d798244 100644 --- a/mindspore/ccsrc/ps/core/node.cc +++ b/mindspore/ccsrc/ps/core/node.cc @@ -21,7 +21,12 @@ namespace ps { namespace core { std::string Node::node_id() const { return node_info_.node_id_; } -uint32_t Node::rank_id() const { return node_info_.rank_id_; } +uint32_t Node::rank_id() const { + if (!is_ready_.load()) { + MS_LOG(EXCEPTION) << "The cluster is not ready yet to get rank id!"; + } + return node_info_.rank_id_; +} NodeRole Node::role() const { return node_info_.node_role_; } diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index bf7e3cbee1c..8a9216c658e 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -30,8 +30,6 @@ #include #include -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "ps/core/cluster_config.h" #include "ps/core/node_info.h" #include "ps/core/tcp_client.h" diff --git a/mindspore/ccsrc/ps/core/node_info.h b/mindspore/ccsrc/ps/core/node_info.h index 6c4076b75fb..b421cf2ad63 100644 --- a/mindspore/ccsrc/ps/core/node_info.h +++ b/mindspore/ccsrc/ps/core/node_info.h @@ -25,7 +25,7 @@ namespace mindspore { namespace ps { namespace core { -enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; +enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT }; struct NodeInfo { NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index eb1d0d609bb..d6eac42cc37 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -64,8 +64,8 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) { struct timeval current_time {}; (void)gettimeofday(¤t_time, nullptr); heartbeats_[node_id] = current_time; - MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id - << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; + MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id + << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; } void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); } diff --git a/mindspore/ccsrc/ps/core/node_manager.h b/mindspore/ccsrc/ps/core/node_manager.h index 27ee8c41f56..615f9ef3572 100644 --- a/mindspore/ccsrc/ps/core/node_manager.h +++ b/mindspore/ccsrc/ps/core/node_manager.h @@ -31,8 +31,6 @@ #include #include -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "ps/core/node.h" #include "utils/log_adapter.h" #include "utils/convert_utils_base.h" diff --git a/mindspore/ccsrc/ps/core/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto index 1516af4b087..9ae31a94c13 100644 --- a/mindspore/ccsrc/ps/core/protos/ps.proto +++ b/mindspore/ccsrc/ps/core/protos/ps.proto @@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME; enum PSCommand { PUSH = 0; PULL = 1; + INIT_EMBEDDING_TABLE = 2; } message KVMessage { diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index fb593b0da1c..d84fc77dc47 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -37,9 +37,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) { return true; } -void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { +void SchedulerNode::ProcessHeartbeat(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message) { HeartbeatMessage heartbeat_message; - heartbeat_message.ParseFromString(message.data()); + heartbeat_message.ParseFromString(message->data()); node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); @@ -59,10 +60,10 @@ void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnectio heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message.pb_meta()}; - comm_message.set_data(heartbeat_resp_message.SerializeAsString()); - const_cast(server).SendMessage(conn, comm_message); + std::shared_ptr comm_message = std::make_shared(); + *comm_message->mutable_pb_meta() = {message->pb_meta()}; + comm_message->set_data(heartbeat_resp_message.SerializeAsString()); + server->SendMessage(conn, comm_message); } void SchedulerNode::Initialize() { @@ -79,23 +80,23 @@ void SchedulerNode::CreateTcpServer() { std::string scheduler_host = ClusterConfig::scheduler_host(); uint32_t scheduler_port = ClusterConfig::scheduler_port(); - server_ = std::make_unique(scheduler_host, scheduler_port); - server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { - switch (message.pb_meta().cmd()) { + server_ = std::make_shared(scheduler_host, scheduler_port); + server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr message) { + switch (message->pb_meta().cmd()) { case NodeCommand::HEARTBEAT: - ProcessHeartbeat(server, conn, message); + ProcessHeartbeat(server_, conn, message); break; case NodeCommand::REGISTER: - ProcessRegister(server, conn, message); + ProcessRegister(server_, conn, message); break; case NodeCommand::FINISH: - ProcessFinish(server, conn, message); + ProcessFinish(server_, conn, message); break; case NodeCommand::FETCH_SERVER: - ProcessFetchServers(server, conn, message); + ProcessFetchServers(server_, conn, message); break; default: - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; } }); @@ -107,10 +108,11 @@ void SchedulerNode::CreateTcpServer() { }); } -void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { +void SchedulerNode::ProcessRegister(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message) { MS_LOG(INFO) << "The scheduler process a register message!"; RegisterMessage register_message; - register_message.ParseFromString(message.data()); + register_message.ParseFromString(message->data()); // assign worker node and server node rank id int rank_id = node_manager_.NextRankId(register_message); @@ -124,31 +126,32 @@ void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection register_resp_message.set_node_id(node_id); register_resp_message.set_rank_id(rank_id); - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message.pb_meta()}; - comm_message.set_data(register_resp_message.SerializeAsString()); - const_cast(server).SendMessage(conn, comm_message); + std::shared_ptr comm_message = std::make_shared(); + *comm_message->mutable_pb_meta() = {message->pb_meta()}; + comm_message->set_data(register_resp_message.SerializeAsString()); + server->SendMessage(conn, comm_message); } -void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { +void SchedulerNode::ProcessFinish(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message) { FinishMessage finish_message; - finish_message.ParseFromString(message.data()); + finish_message.ParseFromString(message->data()); node_manager_.AddFinishNode(finish_message); MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); - const_cast(server).SendMessage(conn, message); + server->SendMessage(conn, message); } -void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, - const CommMessage &message) { +void SchedulerNode::ProcessFetchServers(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message) { FetchServersRespMessage fetch_servers_message; std::vector servers_meta_list = node_manager_.FetchServersMeta(); *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message.pb_meta()}; - comm_message.set_data(fetch_servers_message.SerializeAsString()); - const_cast(server).SendMessage(conn, comm_message); + std::shared_ptr comm_message = std::make_shared(); + *comm_message->mutable_pb_meta() = {message->pb_meta()}; + comm_message->set_data(fetch_servers_message.SerializeAsString()); + server->SendMessage(conn, comm_message); } void SchedulerNode::StartUpdateClusterStateTimer() { diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index 6f132367977..a476caae53e 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -26,8 +26,6 @@ #include #include -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" @@ -51,13 +49,17 @@ class SchedulerNode : public Node { private: void Initialize(); void CreateTcpServer(); - void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); - void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + void ProcessHeartbeat(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message); + void ProcessRegister(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message); void StartUpdateClusterStateTimer(); - void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); - void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + void ProcessFinish(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message); + void ProcessFetchServers(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr message); - std::unique_ptr server_; + std::shared_ptr server_; std::unique_ptr scheduler_thread_; std::unique_ptr update_state_thread_; diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index e6658749ab4..08d0b280b80 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -30,7 +30,8 @@ bool ServerNode::Start(const uint32_t &timeout) { StartHeartbeatTimer(client_to_scheduler_); if (!WaitForStart(timeout)) { - MS_LOG(ERROR) << "Start Server node timeout!"; + MS_LOG(ERROR) << "Start server node timeout!"; + return false; } MS_LOG(INFO) << "The cluster is ready to use!"; @@ -45,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) { void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } -void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, - const std::string &message) { - auto &meta = const_cast(message_meta); - meta.set_role(node_info_.node_role_); - meta.set_rank_id(node_info_.rank_id_); - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {meta}; - comm_message.set_data(message); - - const_cast(server).SendMessage(conn, comm_message); +void ServerNode::Response(std::shared_ptr conn, std::shared_ptr message) { + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(message); + message->mutable_pb_meta()->set_role(node_info_.node_role_); + message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_); + const MessageMeta &message_meta = message->pb_meta(); + const uint64_t request_id = message_meta.request_id(); + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; + server_->SendMessage(conn, message); } void ServerNode::CreateTcpServer() { @@ -62,17 +63,17 @@ void ServerNode::CreateTcpServer() { std::string server_ip; CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); server_ = std::make_shared(server_ip, 0); - server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { - switch (message.pb_meta().cmd()) { + server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr message) { + switch (message->pb_meta().cmd()) { case NodeCommand::SEND_DATA: - ProcessSendData(server, conn, message); + ProcessSendData(conn, message); break; case NodeCommand::COLLECTIVE_SEND_DATA: - ProcessCollectiveSendData(server, conn, message); - RunReceiveCallback(message); + ProcessCollectiveSendData(conn, message); + RunReceiveCallback(*message); break; default: - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; } }); server_->Init(); @@ -97,15 +98,18 @@ void ServerNode::Initialize() { MS_LOG(INFO) << "Server node init client successful!"; } -void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { - request_handler_(server, conn, message.pb_meta(), message.data()); +void ServerNode::ProcessSendData(std::shared_ptr conn, std::shared_ptr message) { + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(message); + request_handler_(conn, message); } -void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, - const CommMessage &message) { - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message.pb_meta()}; - const_cast(server).SendMessage(conn, comm_message); +void ServerNode::ProcessCollectiveSendData(std::shared_ptr conn, std::shared_ptr message) { + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(message); + std::shared_ptr comm_message = std::make_shared(); + *comm_message->mutable_pb_meta() = {message->pb_meta()}; + server_->SendMessage(conn, comm_message); } bool ServerNode::Stop() { diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 73d103840ee..2a0d70e82b6 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -44,18 +44,16 @@ class ServerNode : public AbstractNode { bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; - using RequestHandler = std::function; + using RequestHandler = std::function conn, std::shared_ptr message)>; void set_handler(const RequestHandler &handler); - void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, - const std::string &message); + void Response(std::shared_ptr conn, std::shared_ptr message); private: void CreateTcpServer(); void Initialize(); - void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); - void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); + void ProcessSendData(std::shared_ptr conn, std::shared_ptr message); + void ProcessCollectiveSendData(std::shared_ptr conn, std::shared_ptr message); std::shared_ptr server_; std::unique_ptr server_thread_; diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index b6528567ba2..14d9a965f6b 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -46,9 +46,9 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) server_port_(port), is_stop_(true), is_connected_(false) { - message_handler_.SetCallback([this](const CommMessage &message) { + message_handler_.SetCallback([this](std::shared_ptr message) { if (message_callback_) { - message_callback_(*this, message); + message_callback_(*this, *message); } }); } @@ -105,7 +105,7 @@ void TcpClient::Init() { sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); sin.sin_port = htons(server_port_); - buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); + buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); MS_EXCEPTION_IF_NULL(buffer_event_); bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); @@ -261,17 +261,23 @@ void TcpClient::StartWithNoBlock() { void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } -void TcpClient::SendMessage(const CommMessage &message) const { +bool TcpClient::SendMessage(const CommMessage &message) const { MS_EXCEPTION_IF_NULL(buffer_event_); + bufferevent_lock(buffer_event_); + bool res = true; size_t buf_size = message.ByteSizeLong(); std::vector serialized(buf_size); message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); - if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { - MS_LOG(EXCEPTION) << "Event buffer add header failed!"; + if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { + MS_LOG(ERROR) << "Event buffer add header failed!"; + res = false; } - if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) { - MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; + if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { + MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; + res = false; } + bufferevent_unlock(buffer_event_); + return res; } void TcpClient::StartTimer(const uint32_t &time) { diff --git a/mindspore/ccsrc/ps/core/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h index f34982a2bb2..cdf3add7080 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -33,8 +33,6 @@ #include #include "ps/core/cluster_config.h" -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "utils/convert_utils_base.h" namespace mindspore { @@ -62,7 +60,7 @@ class TcpClient { void Start(); void StartWithNoBlock(); void SetMessageCallback(const OnMessage &cb); - void SendMessage(const CommMessage &message) const; + bool SendMessage(const CommMessage &message) const; void StartTimer(const uint32_t &time); void set_timer_callback(const OnTimer &timer); const event_base &eventbase(); diff --git a/mindspore/ccsrc/ps/core/tcp_message_handler.cc b/mindspore/ccsrc/ps/core/tcp_message_handler.cc index c64b36a306a..c63fd1ab50b 100644 --- a/mindspore/ccsrc/ps/core/tcp_message_handler.cc +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.cc @@ -57,8 +57,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { } if (remaining_length_ == 0) { - CommMessage pb_message; - pb_message.ParseFromArray(message_buffer_.get(), message_length_); + std::shared_ptr pb_message = std::make_shared(); + pb_message->ParseFromArray(message_buffer_.get(), message_length_); if (message_callback_) { message_callback_(pb_message); } diff --git a/mindspore/ccsrc/ps/core/tcp_message_handler.h b/mindspore/ccsrc/ps/core/tcp_message_handler.h index b728d8a3fc8..2caa5112bd6 100644 --- a/mindspore/ccsrc/ps/core/tcp_message_handler.h +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.h @@ -30,7 +30,7 @@ namespace mindspore { namespace ps { namespace core { -using messageReceive = std::function; +using messageReceive = std::function)>; constexpr int kHeaderLen = 8; class TcpMessageHandler { diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 4d4466fd2c3..4751a6a10c2 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -32,14 +32,7 @@ namespace mindspore { namespace ps { namespace core { -void TcpConnection::InitConnection() { - tcp_message_handler_.SetCallback([&](const CommMessage &message) { - OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); - if (on_server_receive) { - on_server_receive(*server_, *this, message); - } - }); -} +void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); } void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } @@ -49,23 +42,30 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const { } } -TcpServer *TcpConnection::GetServer() const { return const_cast(server_); } +TcpServer *TcpConnection::GetServer() const { return server_; } const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } -void TcpConnection::SendMessage(const CommMessage &message) const { +void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; } + +bool TcpConnection::SendMessage(std::shared_ptr message) const { MS_EXCEPTION_IF_NULL(buffer_event_); - size_t buf_size = message.ByteSizeLong(); + MS_EXCEPTION_IF_NULL(message); + bufferevent_lock(buffer_event_); + bool res = true; + size_t buf_size = message->ByteSizeLong(); std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); - if (evbuffer_add(bufferevent_get_output(const_cast(buffer_event_)), &buf_size, - sizeof(buf_size)) == -1) { - MS_LOG(EXCEPTION) << "Event buffer add header failed!"; + message->SerializeToArray(serialized.data(), SizeToInt(buf_size)); + if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { + MS_LOG(ERROR) << "Event buffer add header failed!"; + res = false; } - if (evbuffer_add(bufferevent_get_output(const_cast(buffer_event_)), serialized.data(), - buf_size) == -1) { - MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; + if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { + MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; + res = false; } + bufferevent_unlock(buffer_event_); + return res; } TcpServer::TcpServer(const std::string &address, std::uint16_t port) @@ -225,7 +225,7 @@ void TcpServer::SendToAllClients(const char *data, size_t len) { } } -void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { +void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr connection) { MS_EXCEPTION_IF_NULL(connection); std::lock_guard lock(connection_mutex_); connections_.insert(std::make_pair(fd, connection)); @@ -233,11 +233,11 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co void TcpServer::RemoveConnection(const evutil_socket_t &fd) { std::lock_guard lock(connection_mutex_); - TcpConnection *connection = const_cast(connections_.find(fd)->second); - delete connection; connections_.erase(fd); } +std::shared_ptr TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; } + void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, void *data) { auto server = reinterpret_cast(data); @@ -246,7 +246,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st MS_EXCEPTION_IF_NULL(base); MS_EXCEPTION_IF_NULL(sockaddr); - struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); + struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); if (!bev) { MS_LOG(ERROR) << "Error constructing buffer event!"; int ret = event_base_loopbreak(base); @@ -256,23 +256,29 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st return; } - TcpConnection *conn = server->onCreateConnection(bev, fd); + std::shared_ptr conn = server->onCreateConnection(bev, fd); MS_EXCEPTION_IF_NULL(conn); - conn->InitConnection(); server->AddConnection(fd, conn); - bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast(conn)); + conn->InitConnection([=](std::shared_ptr message) { + OnServerReceiveMessage on_server_receive = server->GetServerReceive(); + if (on_server_receive) { + on_server_receive(conn, message); + } + }); + bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, + reinterpret_cast(conn.get())); if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; } } -TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { - TcpConnection *conn = nullptr; +std::shared_ptr TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { + std::shared_ptr conn = nullptr; if (client_accept_) { - conn = const_cast(client_accept_(*this)); + conn = (client_accept_(*this)); } else { - conn = new TcpConnection(bev, fd, this); + conn = std::make_shared(bev, fd, this); } return conn; @@ -312,8 +318,8 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void MS_EXCEPTION_IF_NULL(data); struct evbuffer *output = bufferevent_get_output(bev); size_t remain = evbuffer_get_length(output); - auto conn = reinterpret_cast(data); - TcpServer *srv = conn->GetServer(); + auto conn = static_cast(data); + auto srv = conn->GetServer(); if (events & BEV_EVENT_EOF) { MS_LOG(INFO) << "Event buffer end of file!"; @@ -355,13 +361,18 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { } } -void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } +bool TcpServer::SendMessage(std::shared_ptr conn, std::shared_ptr message) { + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(message); + return conn->SendMessage(message); +} -void TcpServer::SendMessage(const CommMessage &message) { +void TcpServer::SendMessage(std::shared_ptr message) { std::lock_guard lock(connection_mutex_); + MS_EXCEPTION_IF_NULL(message); for (auto it = connections_.begin(); it != connections_.end(); ++it) { - SendMessage(*it->second, message); + SendMessage(it->second, message); } } @@ -371,7 +382,7 @@ std::string TcpServer::BoundIp() const { return server_address_; } int TcpServer::ConnectionNum() const { return connections_.size(); } -const std::map &TcpServer::Connections() const { return connections_; } +const std::map> &TcpServer::Connections() const { return connections_; } void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } diff --git a/mindspore/ccsrc/ps/core/tcp_server.h b/mindspore/ccsrc/ps/core/tcp_server.h index fcb51ff8de3..84dbffaec4c 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.h +++ b/mindspore/ccsrc/ps/core/tcp_server.h @@ -34,8 +34,6 @@ #include #include -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "ps/core/tcp_message_handler.h" #include "ps/core/cluster_config.h" #include "utils/log_adapter.h" @@ -47,36 +45,42 @@ namespace core { class TcpServer; class TcpConnection { public: - explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) + explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, TcpServer *server) : buffer_event_(bev), fd_(fd), server_(server) {} + TcpConnection(const TcpConnection &); virtual ~TcpConnection() = default; - virtual void InitConnection(); + using Callback = std::function)>; + + virtual void InitConnection(const messageReceive &callback); virtual void SendMessage(const void *buffer, size_t num) const; - void SendMessage(const CommMessage &message) const; + bool SendMessage(std::shared_ptr message) const; virtual void OnReadHandler(const void *buffer, size_t numBytes); TcpServer *GetServer() const; const evutil_socket_t &GetFd() const; + void set_callback(const Callback &callback); protected: struct bufferevent *buffer_event_; evutil_socket_t fd_; - const TcpServer *server_; + TcpServer *server_; TcpMessageHandler tcp_message_handler_; + Callback callback_; }; using OnServerReceiveMessage = - std::function; + std::function conn, std::shared_ptr message)>; class TcpServer { public: using OnConnected = std::function; using OnDisconnected = std::function; - using OnAccepted = std::function; + using OnAccepted = std::function(const TcpServer &)>; using OnTimerOnce = std::function; using OnTimer = std::function; - explicit TcpServer(const std::string &address, std::uint16_t port); + TcpServer(const std::string &address, std::uint16_t port); + TcpServer(const TcpServer &server); virtual ~TcpServer(); void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, @@ -90,16 +94,17 @@ class TcpServer { void StartTimer(const uint32_t &time); void Stop(); void SendToAllClients(const char *data, size_t len); - void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); + void AddConnection(const evutil_socket_t &fd, std::shared_ptr connection); void RemoveConnection(const evutil_socket_t &fd); + std::shared_ptr GetConnectionByFd(const evutil_socket_t &fd); OnServerReceiveMessage GetServerReceive() const; void SetMessageCallback(const OnServerReceiveMessage &cb); - void SendMessage(const TcpConnection &conn, const CommMessage &message); - void SendMessage(const CommMessage &message); + bool SendMessage(std::shared_ptr conn, std::shared_ptr message); + void SendMessage(std::shared_ptr message); uint16_t BoundPort() const; std::string BoundIp() const; int ConnectionNum() const; - const std::map &Connections() const; + const std::map> &Connections() const; protected: static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, @@ -109,7 +114,7 @@ class TcpServer { static void EventCallback(struct bufferevent *, std::int16_t events, void *server); static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg); - virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); + std::shared_ptr onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); struct event_base *base_; struct event *signal_event_; @@ -118,7 +123,7 @@ class TcpServer { std::uint16_t server_port_; std::atomic is_stop_; - std::map connections_; + std::map> connections_; OnConnected client_connection_; OnDisconnected client_disconnection_; OnAccepted client_accept_; diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h index 9d2713d81e0..a1343aa3623 100644 --- a/mindspore/ccsrc/ps/core/worker_node.h +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -24,8 +24,6 @@ #include #include -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" diff --git a/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc b/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc index 9f724460067..1efa1b15e7f 100644 --- a/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc +++ b/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc @@ -31,7 +31,7 @@ class TestClusterAvailableTimeout : public UT::Common { }; TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { - ClusterConfig::Init(1, 1, std::make_unique("127.0.0.1"), 9999); + ClusterConfig::Init(1, 1, "127.0.0.1", 9999); ClusterConfig::set_cluster_available_timeout(3); SchedulerNode node; node.Start(); diff --git a/tests/ut/cpp/ps/core/cluster_config_test.cc b/tests/ut/cpp/ps/core/cluster_config_test.cc index 0136d359f8e..904034b5c70 100644 --- a/tests/ut/cpp/ps/core/cluster_config_test.cc +++ b/tests/ut/cpp/ps/core/cluster_config_test.cc @@ -33,7 +33,7 @@ class TestClusterConfig : public UT::Common { }; TEST_F(TestClusterConfig, HeartbeatInterval) { - ClusterConfig::Init(2, 2, std::make_unique("127.0.0.1"), 8080); + ClusterConfig::Init(2, 2, "127.0.0.1", 8080); EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); ClusterConfig::set_heartbeat_interval(100); EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); diff --git a/tests/ut/cpp/ps/core/common_util_test.cc b/tests/ut/cpp/ps/core/common_util_test.cc index f2b3bf2e605..3af77ded519 100644 --- a/tests/ut/cpp/ps/core/common_util_test.cc +++ b/tests/ut/cpp/ps/core/common_util_test.cc @@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { } TEST_F(TestCommUtil, ValidateRankId) { - ClusterConfig::Init(3, 2, std::make_unique("127.0.0.1"), 9999); + ClusterConfig::Init(3, 2, "127.0.0.1", 9999); EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); diff --git a/tests/ut/cpp/ps/core/tcp_message_handler_test.cc b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc index f1382ad5a32..ffe6d9ab2b2 100644 --- a/tests/ut/cpp/ps/core/tcp_message_handler_test.cc +++ b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc @@ -35,7 +35,7 @@ class TestTcpMessageHandler : public UT::Common { TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { TcpMessageHandler handler; - handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); + handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 1000); }); std::string data(1000, 'a'); CommMessage message; @@ -55,7 +55,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { TcpMessageHandler handler; - handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); + handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 1000); }); std::string data(1000, 'a'); CommMessage message; @@ -86,7 +86,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { TcpMessageHandler handler; - handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); }); + handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 4081); }); std::string data(4081, 'a'); CommMessage message; @@ -126,7 +126,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { TcpMessageHandler handler; - handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); }); + handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 4077); }); std::string data(4077, 'a'); CommMessage message; diff --git a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc index 8752cfe0d39..df5f70ee956 100644 --- a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc +++ b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc @@ -32,12 +32,12 @@ class TestTcpServer : public UT::Common { void SetUp() override { server_ = std::make_unique("127.0.0.1", 0); std::unique_ptr http_server_thread_(nullptr); - http_server_thread_ = std::make_unique([&]() { - server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { + http_server_thread_ = std::make_unique([=]() { + server_->SetMessageCallback([=](std::shared_ptr conn, std::shared_ptr message) { KVMessage kv_message; - kv_message.ParseFromString(message.data()); + kv_message.ParseFromString(message->data()); EXPECT_EQ(2, kv_message.keys_size()); - const_cast(server).SendMessage(conn, message); + server_->SendMessage(conn, message); }); server_->Init(); server_->Start(); @@ -58,6 +58,7 @@ class TestTcpServer : public UT::Common { TEST_F(TestTcpServer, ServerSendMessage) { client_ = std::make_unique("127.0.0.1", server_->BoundPort()); + std::cout << server_->BoundPort() << std::endl; std::unique_ptr http_client_thread(nullptr); http_client_thread = std::make_unique([&]() { client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {