From c7fe82b43d6b6150c957271ec7861e8ad67d3c84 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Thu, 21 Jan 2021 14:51:17 +0800 Subject: [PATCH] Custom data transmission format --- mindspore/ccsrc/ps/core/abstract_node.cc | 368 ++++++++++-------- mindspore/ccsrc/ps/core/abstract_node.h | 58 +-- mindspore/ccsrc/ps/core/message.h | 59 +++ mindspore/ccsrc/ps/core/protos/comm.proto | 9 +- mindspore/ccsrc/ps/core/scheduler_node.cc | 57 +-- mindspore/ccsrc/ps/core/scheduler_node.h | 11 +- mindspore/ccsrc/ps/core/server_node.cc | 51 +-- mindspore/ccsrc/ps/core/server_node.h | 12 +- mindspore/ccsrc/ps/core/tcp_client.cc | 64 ++- mindspore/ccsrc/ps/core/tcp_client.h | 9 +- .../ccsrc/ps/core/tcp_message_handler.cc | 16 +- mindspore/ccsrc/ps/core/tcp_message_handler.h | 17 +- mindspore/ccsrc/ps/core/tcp_server.cc | 51 ++- mindspore/ccsrc/ps/core/tcp_server.h | 7 +- tests/ut/cpp/ps/core/tcp_client_tests.cc | 14 +- .../cpp/ps/core/tcp_message_handler_test.cc | 202 +++++----- tests/ut/cpp/ps/core/tcp_pb_server_test.cc | 22 +- 17 files changed, 634 insertions(+), 393 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/message.h diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index daac34850e0..fe608f88e88 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -20,8 +20,9 @@ namespace mindspore { namespace ps { namespace core { void AbstractNode::Register(const std::shared_ptr &client) { - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::REGISTER); + MS_EXCEPTION_IF_NULL(client); + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::REGISTER); RegisterMessage register_message; register_message.set_node_id(node_info_.node_id_); @@ -29,11 +30,8 @@ void AbstractNode::Register(const std::shared_ptr &client) { register_message.set_ip(node_info_.ip_); register_message.set_port(node_info_.port_); - CommMessage comm_message; - *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(register_message.SerializeAsString()); - comm_message.set_user_cmd(""); - if (!SendMessageSync(client, comm_message)) { + if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(), + register_message.ByteSizeLong())) { MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " register timeout!"; } @@ -42,9 +40,11 @@ void AbstractNode::Register(const std::shared_ptr &client) { << " the node id:" << node_info_.node_id_ << "is registering to scheduler!"; } -void AbstractNode::ProcessRegisterResp(const CommMessage &message) { +void AbstractNode::ProcessRegisterResp(std::shared_ptr meta, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); RegisterRespMessage register_resp_message; - register_resp_message.ParseFromString(message.data()); + register_resp_message.ParseFromArray(data, size); if (register_resp_message.node_id() != node_info_.node_id_) { MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() << " is not match the current node id:" << node_info_.node_id_; @@ -52,28 +52,29 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { node_info_.rank_id_ = register_resp_message.rank_id(); - MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; + MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_ + << " registered scheduler success!"; } -bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) { +bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, + const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(message); if (node_role != NodeRole::SERVER) { MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; } - CommMessage &comm_message = const_cast(message); - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); + uint64_t request_id = AddMessageTrack(nodes_address_.size()); for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - message_meta.set_rank_id(node_info_.rank_id_); - message_meta.set_role(node_info_.node_role_); + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_DATA); + message_meta->set_request_id(request_id); + message_meta->set_rank_id(node_info_.rank_id_); + message_meta->set_role(node_info_.node_role_); + message_meta->set_user_cmd(command); - *comm_message.mutable_pb_meta() = {message_meta}; auto client = GetOrCreateTcpClient((*it).first.second); - client->SendMessage(comm_message); + client->SendMessage(message_meta, Protos::RAW, message.get(), size); } MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; @@ -84,28 +85,27 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me on_node_event_message_ = on_node_event_message; } -bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, - const uint32_t &timeout) { +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, + int command, const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(data); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } - CommMessage &comm_message = const_cast(message); + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_DATA); + message_meta->set_rank_id(node_info_.rank_id_); + message_meta->set_role(node_info_.node_role_); + message_meta->set_user_cmd(command); - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_rank_id(node_info_.rank_id_); - message_meta.set_role(node_info_.node_role_); - - *comm_message.mutable_pb_meta() = {message_meta}; auto client = GetOrCreateTcpClient(rank_id); - return SendMessageSync(client, comm_message, timeout); + return SendMessageSync(client, message_meta, Protos::RAW, data.get(), len, timeout); } bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, const uint32_t &timeout) { - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(data.size(), 0); + const std::vector &data, const std::vector &lens, int command, + const uint32_t &timeout) { + uint64_t request_id = AddMessageTrack(data.size()); if (rank_ids.size() != data.size()) { MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; @@ -115,34 +115,32 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - message_meta.set_rank_id(node_info_.rank_id_); - message_meta.set_role(node_info_.node_role_); - - CommMessage &comm_message = const_cast(data.at(it)); - *comm_message.mutable_pb_meta() = {message_meta}; + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_DATA); + message_meta->set_request_id(request_id); + message_meta->set_rank_id(node_info_.rank_id_); + message_meta->set_role(node_info_.node_role_); + message_meta->set_user_cmd(command); + auto send = data.at(it); + auto len = lens.at(it); auto client = GetOrCreateTcpClient(rank_ids.at(it)); - client->SendMessage(comm_message); + client->SendMessage(message_meta, Protos::RAW, send.get(), len); } MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } -bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, - CommMessage *output, const uint32_t &timeout) { +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, + int command, VectorPtr *output, const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(output); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } - CommMessage &comm_message = const_cast(message); - - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(1, 0); + uint64_t request_id = AddMessageTrack(1); set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; @@ -151,59 +149,59 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, receive_messages_mutex_.unlock(); }); - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - message_meta.set_rank_id(node_info_.rank_id_); - message_meta.set_role(node_info_.node_role_); + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_DATA); + message_meta->set_request_id(request_id); + message_meta->set_rank_id(node_info_.rank_id_); + message_meta->set_role(node_info_.node_role_); + message_meta->set_user_cmd(command); - *comm_message.mutable_pb_meta() = {message_meta}; auto client = GetOrCreateTcpClient(rank_id); - client->SendMessage(comm_message); + client->SendMessage(message_meta, Protos::RAW, message.get(), len); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, - const std::vector &data, std::vector *output, - const uint32_t &timeout) { + const std::vector &data, const std::vector &data_lens, int command, + std::vector *output, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(output); - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(data.size(), 0); + uint64_t request_id = AddMessageTrack(data.size()); if (rank_ids.size() != data.size()) { MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; } - size_t len = rank_ids.size(); + size_t size = rank_ids.size(); set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; - for (size_t it = 0; it < len; ++it) { + for (size_t it = 0; it < size; ++it) { (*output).push_back(res[rank_ids.at(it)]); } receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); - for (size_t it = 0; it < len; ++it) { + for (size_t it = 0; it < size; ++it) { if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::SEND_DATA); - message_meta.set_request_id(request_id); - message_meta.set_rank_id(node_info_.rank_id_); - message_meta.set_role(node_info_.node_role_); + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_DATA); + message_meta->set_request_id(request_id); + message_meta->set_rank_id(node_info_.rank_id_); + message_meta->set_role(node_info_.node_role_); + message_meta->set_user_cmd(command); - CommMessage &comm_message = const_cast(data.at(it)); - *comm_message.mutable_pb_meta() = {message_meta}; + auto send = data.at(it); + auto len = data_lens.at(it); auto client = GetOrCreateTcpClient(rank_ids.at(it)); - client->SendMessage(comm_message); + client->SendMessage(message_meta, Protos::RAW, send.get(), len); } MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; @@ -220,55 +218,61 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { return res; } -uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - const CommMessage &message) { +uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, + size_t size) { + MS_EXCEPTION_IF_NULL(data); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } - CommMessage &comm_message = const_cast(message); + std::shared_ptr message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); + message_meta->set_rank_id(node_info_.rank_id_); + message_meta->set_role(node_info_.node_role_); - MessageMeta message_meta; - message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); - message_meta.set_rank_id(node_info_.rank_id_); - message_meta.set_role(node_info_.node_role_); - - *comm_message.mutable_pb_meta() = {message_meta}; auto client = GetOrCreateTcpClient(rank_id); - return SendMessageAsync(client, comm_message); + return SendMessageAsync(client, message_meta, Protos::RAW, data, size); } std::pair AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, - const uint32_t &rank_id, CommMessage *output) { + const uint32_t &rank_id, void **output, + size_t *size) { + MS_EXCEPTION_IF_NULL(output); + MS_EXCEPTION_IF_NULL(size); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } + receive_callbacks_mutex_.lock(); uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); + receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = false; if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { - *output = received_data_[std::make_pair(rank_id, rank_request_id)]; + auto res = received_data_[std::make_pair(rank_id, rank_request_id)]; + *output = res->data(); + *size = res->size(); received_data_.erase(std::make_pair(rank_id, rank_request_id)); + receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true; + MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id; } else { - set_receive_callback(rank_id, rank_request_id, [=]() { + receive_callbacks_[std::make_pair(rank_id, rank_request_id)] = [=]() mutable { receive_callbacks_mutex_.lock(); - *output = received_data_[std::make_pair(rank_id, rank_request_id)]; + auto res = received_data_[std::make_pair(rank_id, rank_request_id)]; + *output = res->data(); + *size = res->size(); received_data_.erase(std::make_pair(rank_id, rank_request_id)); + receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true; + MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id; receive_callbacks_mutex_.unlock(); - }); + }; } + receive_callbacks_mutex_.unlock(); return std::make_pair(rank_id, rank_request_id); } bool AbstractNode::CollectiveWait(std::pair request_id, const uint32_t &timeout) { std::unique_lock lock(receive_callbacks_mutex_); - bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { - if (actual_rank_request_ids_.count(request_id.first) && - (actual_rank_request_ids_[request_id.first] >= request_id.second)) { - return true; - } else { - return false; - } - }); + bool res = + receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; }); return res; } @@ -297,17 +301,15 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) } bool AbstractNode::Heartbeat(const std::shared_ptr &client, bool is_node_finish) { - MessageMeta meta; - meta.set_cmd(NodeCommand::HEARTBEAT); + auto meta = std::make_shared(); + meta->set_cmd(NodeCommand::HEARTBEAT); HeartbeatMessage heartbeat_message; heartbeat_message.set_node_id(node_info_.node_id_); heartbeat_message.set_is_node_finish(is_node_finish); - CommMessage message; - *message.mutable_pb_meta() = {meta}; - message.set_data(heartbeat_message.SerializeAsString()); - if (!SendMessageSync(client, message)) { + if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), + heartbeat_message.ByteSizeLong())) { MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; } return true; @@ -331,9 +333,11 @@ bool AbstractNode::CheckSchedulerTimeout() const { return false; } -void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { +void AbstractNode::ProcessHeartbeatResp(std::shared_ptr meta, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); HeartbeatRespMessage heartbeat_resp_message; - heartbeat_resp_message.ParseFromString(message.data()); + heartbeat_resp_message.ParseFromArray(data, size); is_ready_ = heartbeat_resp_message.is_cluster_ready(); if (is_ready_.load()) { @@ -359,19 +363,22 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { } void AbstractNode::FetchServers(const std::shared_ptr &client) { - MessageMeta meta; - meta.set_cmd(NodeCommand::FETCH_SERVER); + auto meta = std::make_shared(); + meta->set_cmd(NodeCommand::FETCH_SERVER); - CommMessage message; - *message.mutable_pb_meta() = {meta}; - if (!SendMessageSync(client, message)) { + FetchServersMessage fetch_servers; + fetch_servers.set_node_id(node_info_.node_id_); + if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(), + fetch_servers.ByteSizeLong())) { MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; } } -void AbstractNode::ProcessFetchServersResp(const CommMessage &message) { +void AbstractNode::ProcessFetchServersResp(std::shared_ptr meta, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); FetchServersRespMessage fetch_servers_resp_message; - fetch_servers_resp_message.ParseFromString(message.data()); + fetch_servers_resp_message.ParseFromArray(data, size); for (const auto &it : fetch_servers_resp_message.servers_meta()) { nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); @@ -381,16 +388,14 @@ void AbstractNode::ProcessFetchServersResp(const CommMessage &message) { } bool AbstractNode::Disconnect(const std::shared_ptr &client, const uint32_t &timeout) { - MessageMeta meta; - meta.set_cmd(NodeCommand::FINISH); + auto meta = std::make_shared(); + meta->set_cmd(NodeCommand::FINISH); FinishMessage finish_message; finish_message.set_node_id(node_info_.node_id_); - CommMessage message; - *message.mutable_pb_meta() = {meta}; - message.set_data(finish_message.SerializeAsString()); - if (!SendMessageSync(client, message)) { + if (!SendMessageSync(client, meta, Protos::PROTOBUF, finish_message.SerializeAsString().data(), + finish_message.ByteSizeLong())) { MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!"; } @@ -412,16 +417,17 @@ bool AbstractNode::InitClientToScheduler() { std::string scheduler_host = ClusterConfig::scheduler_host(); uint16_t scheduler_port = ClusterConfig::scheduler_port(); client_to_scheduler_ = std::make_shared(scheduler_host, scheduler_port); - client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { - if (handlers_.count(message.pb_meta().cmd()) == 0) { - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; - } - if (handlers_[message.pb_meta().cmd()] != nullptr) { - const auto &handler_ptr = handlers_[message.pb_meta().cmd()]; - (this->*handler_ptr)(message); - } - NotifyMessageArrival(message); - }); + client_to_scheduler_->SetMessageCallback( + [&](std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { + if (handlers_.count(meta->cmd()) == 0) { + MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; + } + if (handlers_[meta->cmd()] != nullptr) { + const auto &handler_ptr = handlers_[meta->cmd()]; + (this->*handler_ptr)(meta, data, size); + } + NotifyMessageArrival(meta); + }); client_to_scheduler_->Init(); client_to_scheduler_thread_ = std::make_unique([&]() { @@ -447,19 +453,20 @@ const std::shared_ptr &AbstractNode::GetOrCreateTcpClient(const int & std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; auto client = std::make_shared(ip, port); - client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { - switch (message.pb_meta().cmd()) { + client->SetMessageCallback([&](std::shared_ptr meta, const Protos &protos, const void *data, + size_t size) { + switch (meta->cmd()) { case NodeCommand::SEND_DATA: - ProcessSendDataResp(message); - RunMessageCallback(message.pb_meta().request_id()); + ProcessSendDataResp(meta, protos, data, size); + RunMessageCallback(meta->request_id()); break; case NodeCommand::COLLECTIVE_SEND_DATA: - MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; + MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; break; default: - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; } - NotifyMessageArrival(message); + NotifyMessageArrival(meta); }); client->Init(); connected_nodes_[rank_id] = client; @@ -469,8 +476,7 @@ const std::shared_ptr &AbstractNode::GetOrCreateTcpClient(const int & bool AbstractNode::SendMessageSync(const std::shared_ptr &client, const CommMessage &message, const uint32_t &timeout) { - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(1, 0); + uint64_t request_id = AddMessageTrack(1); const_cast(message).mutable_pb_meta()->set_request_id(request_id); client->SendMessage(message); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) @@ -478,29 +484,55 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr &client, con return Wait(request_id, timeout); } -uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr &client, const CommMessage &message) { - uint64_t request_id = ++next_request_id_; - message_tracker_[request_id] = std::make_pair(1, 0); - const_cast(message).mutable_pb_meta()->set_request_id(request_id); - client->SendMessage(message); +uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr &client, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(client); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + uint64_t request_id = AddMessageTrack(1); + meta->set_request_id(request_id); + client->SendMessage(meta, protos, data, size); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return request_id; } -void AbstractNode::ProcessSendDataResp(const CommMessage &message) { +bool AbstractNode::SendMessageSync(const std::shared_ptr &client, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size, const uint32_t &timeout) { + MS_EXCEPTION_IF_NULL(client); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + uint64_t request_id = AddMessageTrack(1); + meta->set_request_id(request_id); + client->SendMessage(meta, protos, data, size); + MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; + bool res = Wait(request_id, timeout); + return res; +} + +void AbstractNode::ProcessSendDataResp(std::shared_ptr meta, const Protos &protos, const void *data, + size_t size) { + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); std::lock_guard lock(receive_messages_mutex_); - const MessageMeta &message_meta = message.pb_meta(); - const uint32_t &rank_id = message_meta.rank_id(); - const uint64_t request_id = message_meta.request_id(); + const uint32_t &rank_id = meta->rank_id(); + const uint64_t request_id = meta->request_id(); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; auto it = receive_messages_.find(request_id); + VectorPtr received_data = std::make_shared>(size, 0); + if (size > 0) { + int ret = memcpy_s(received_data.get()->data(), size, data, size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + } if (it != receive_messages_.end()) { - it->second[rank_id] = message; + it->second[rank_id] = received_data; } else { - std::unordered_map res; - res.insert(std::make_pair(rank_id, message)); + std::unordered_map res; + res.insert(std::make_pair(rank_id, received_data)); receive_messages_[request_id] = res; } } @@ -509,7 +541,7 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) { message_callbacks_mutex_.lock(); // When receiving a message's response, Then compare with the desired number of responses, // If they are equal, then call the callback function - if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) { + if (CheckMessageTrack(request_id)) { auto it = message_callbacks_.find(request_id); if (it != message_callbacks_.end()) { message_callbacks_mutex_.unlock(); @@ -533,31 +565,31 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag message_callbacks_[request_id] = callback; } -void AbstractNode::NotifyMessageArrival(const CommMessage &message) { +void AbstractNode::NotifyMessageArrival(std::shared_ptr meta) { std::lock_guard lock(message_tracker_mutex_); - const MessageMeta &message_meta = message.pb_meta(); - uint64_t request_id = message_meta.request_id(); + uint64_t request_id = meta->request_id(); message_tracker_[request_id].second++; message_tracker_cond_.notify_all(); } -void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, - const MessageCallback &callback) { - if (!callback) { - return; - } - std::lock_guard lock(receive_callbacks_mutex_); - receive_callbacks_[std::make_pair(rank_id, request_id)] = callback; -} - -void AbstractNode::RunReceiveCallback(const CommMessage &message) { +void AbstractNode::RunReceiveCallback(std::shared_ptr meta, const Protos &protos, const void *data, + size_t size) { + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); receive_callbacks_mutex_.lock(); - uint32_t rank_id = message.pb_meta().rank_id(); + uint32_t rank_id = meta->rank_id(); // When receiving a collective message, Then generate rank request id,compare with the desired rank request id, // If they are equal, then call the callback function uint64_t rank_request_id = NextActualRankRequestId(rank_id); - received_data_[std::make_pair(rank_id, rank_request_id)] = message; + std::shared_ptr> received_data = std::make_shared>(size, 0); + int ret = memcpy_s(received_data->data(), size, data, size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + received_data_[std::make_pair(rank_id, rank_request_id)] = received_data; + MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id + << ", the send request id is:" << meta->request_id(); auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id)); if (it != receive_callbacks_.end()) { receive_callbacks_mutex_.unlock(); @@ -603,6 +635,18 @@ void AbstractNode::InitCommandHandler() { handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp; handlers_[NodeCommand::FINISH] = nullptr; } + +uint64_t AbstractNode::AddMessageTrack(const uint32_t &expected_response) { + std::lock_guard lock(message_tracker_mutex_); + uint64_t request_id = ++next_request_id_; + message_tracker_[request_id] = std::make_pair(expected_response, 0); + return request_id; +} + +bool AbstractNode::CheckMessageTrack(const uint64_t &request_id) { + std::lock_guard lock(message_tracker_mutex_); + return message_tracker_[request_id].first == message_tracker_[request_id].second + 1; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index 448e36489ac..3fd0271de6b 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -25,6 +25,7 @@ #include #include "ps/core/node.h" +#include "ps/core/message.h" namespace mindspore { namespace ps { @@ -34,53 +35,63 @@ class AbstractNode : public Node { AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} ~AbstractNode() override = default; - typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message); + typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr meta, const void *data, size_t size); - bool Broadcast(const enum NodeRole &node_role, const CommMessage &message, + using DataPtr = std::shared_ptr; + using VectorPtr = std::shared_ptr>; + + bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, const uint32_t &timeout = kCommTimeoutInSeconds); void set_event_callback(const OnNodeEventMessage &on_node_event_message); - bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + const std::vector &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds); + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command, + VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds); + bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, + const std::vector &data_lens, int command, std::vector *output, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output, - const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, - std::vector *output, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); - uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message); + uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size); std::pair CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - CommMessage *output); + void **output, size_t *size); bool CollectiveWait(std::pair request_id, const uint32_t &timeout = kCommTimeoutInSeconds); protected: void Register(const std::shared_ptr &client); - void ProcessRegisterResp(const CommMessage &message); - void StartHeartbeatTimer(const std::shared_ptr &client); bool Heartbeat(const std::shared_ptr &client, bool is_node_finish = false); + void FetchServers(const std::shared_ptr &client); + + void ProcessRegisterResp(std::shared_ptr meta, const void *data, size_t size); + void ProcessHeartbeatResp(std::shared_ptr meta, const void *data, size_t size); + void ProcessFetchServersResp(std::shared_ptr meta, const void *data, size_t size); + + void StartHeartbeatTimer(const std::shared_ptr &client); void UpdateSchedulerTime(); bool CheckSchedulerTimeout() const; - void ProcessHeartbeatResp(const CommMessage &message); - void FetchServers(const std::shared_ptr &client); - void ProcessFetchServersResp(const CommMessage &message); bool Disconnect(const std::shared_ptr &client, const uint32_t &timeout); bool WaitForDisconnect(const uint32_t &timeout); bool InitClientToScheduler(); const std::shared_ptr &GetOrCreateTcpClient(const int &rank_id); bool SendMessageSync(const std::shared_ptr &client, const CommMessage &message, const uint32_t &timeout = kCommTimeoutInSeconds); - uint64_t SendMessageAsync(const std::shared_ptr &client, const CommMessage &message); - void ProcessSendDataResp(const CommMessage &message); + bool SendMessageSync(const std::shared_ptr &client, std::shared_ptr, const Protos &, + const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds); + uint64_t SendMessageAsync(const std::shared_ptr &client, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size); + void ProcessSendDataResp(std::shared_ptr meta, const Protos &protos, const void *data, size_t size); void RunMessageCallback(const uint64_t &request_id); void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); - void NotifyMessageArrival(const CommMessage &message); - void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback); - void RunReceiveCallback(const CommMessage &message); + void NotifyMessageArrival(std::shared_ptr meta); + void RunReceiveCallback(std::shared_ptr meta, const Protos &protos, const void *data, size_t size); uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); uint64_t NextActualRankRequestId(const uint32_t &rank_id); void InitCommandHandler(); + uint64_t AddMessageTrack(const uint32_t &expected_response); + bool CheckMessageTrack(const uint64_t &request_id); std::unique_ptr heart_beat_thread_; std::unique_ptr client_to_scheduler_thread_; @@ -98,15 +109,16 @@ class AbstractNode : public Node { std::mutex message_tracker_mutex_; std::condition_variable message_tracker_cond_; - // the key is: request_id, the value is: - std::unordered_map> receive_messages_; + // the key is: request_id, the value is: + std::unordered_map> receive_messages_; + std::map, bool> receive_messages_done_; std::mutex receive_messages_mutex_; // the key is: request_id std::unordered_map message_callbacks_; std::mutex message_callbacks_mutex_; // the key is - std::map, CommMessage> received_data_; + std::map, std::shared_ptr>> received_data_; std::mutex receive_callbacks_mutex_; // the key is std::map, MessageCallback> receive_callbacks_; diff --git a/mindspore/ccsrc/ps/core/message.h b/mindspore/ccsrc/ps/core/message.h new file mode 100644 index 00000000000..a3a28ba3cff --- /dev/null +++ b/mindspore/ccsrc/ps/core/message.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_ +#define MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_ + +#include +#include + +namespace mindspore { +namespace ps { +namespace core { +enum class Protos : uint32_t { RAW = 0, PROTOBUF = 1, FLATBUFFERS = 2 }; + +enum class Command { + TERMINATE = 0, + REGISTER = 1, + HEARTBEAT = 2, + SEND_DATA = 3, + FETCH_SERVER = 4, + FINISH = 5, + COLLECTIVE_SEND_DATA = 6 +}; + +enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 }; + +struct MessageHeader { + Protos message_proto_ = Protos::RAW; + uint32_t message_meta_length_ = 0; + uint64_t message_length_ = 0; +}; + +struct CommandMeta { + // the command of this message,for example: register,heartbeat,data + Command cmd; + // the request id of this message + uint64_t request_id; + // the role of the current node: worker,server,scheduler + Role role; + // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] + int32_t rank_id = 4; +}; +} // namespace core +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_ diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 81d10137120..4a9932dcbca 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -15,7 +15,6 @@ */ syntax = "proto3"; -import "google/protobuf/any.proto"; package mindspore.ps.core; option optimize_for = LITE_RUNTIME; @@ -44,6 +43,8 @@ message MessageMeta { NodeRole role = 3; // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] int32 rank_id = 4; + // User-defined commands + int32 user_cmd = 5; } message RegisterMessage { @@ -76,6 +77,10 @@ message HeartbeatRespMessage { bool is_node_timeout = 4; } +message FetchServersMessage { + string node_id = 1; +} + message FetchServersRespMessage { repeated ServersMeta servers_meta = 1; } @@ -95,6 +100,4 @@ message FinishMessage { message CommMessage { MessageMeta pb_meta = 1; bytes data = 2; - // User-defined commands - bytes user_cmd = 3; } diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index a3a38519fbd..3561487b63d 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -38,9 +38,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) { } void SchedulerNode::ProcessHeartbeat(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message) { + std::shared_ptr meta, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(server); + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); HeartbeatMessage heartbeat_message; - heartbeat_message.ParseFromString(message->data()); + heartbeat_message.ParseFromArray(data, size); node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); @@ -60,10 +64,8 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr server, std::sha heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); - std::shared_ptr comm_message = std::make_shared(); - *comm_message->mutable_pb_meta() = {message->pb_meta()}; - comm_message->set_data(heartbeat_resp_message.SerializeAsString()); - server->SendMessage(conn, comm_message); + server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(), + heartbeat_resp_message.ByteSizeLong()); } void SchedulerNode::Initialize() { @@ -89,12 +91,13 @@ void SchedulerNode::CreateTcpServer() { std::string scheduler_host = ClusterConfig::scheduler_host(); uint32_t scheduler_port = ClusterConfig::scheduler_port(); server_ = std::make_shared(scheduler_host, scheduler_port); - server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr message) { - if (handlers_.count(message->pb_meta().cmd()) == 0) { - MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; + server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size) { + if (handlers_.count(meta->cmd()) == 0) { + MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; } - const auto &handler_ptr = handlers_[message->pb_meta().cmd()]; - (this->*handler_ptr)(server_, conn, message); + const auto &handler_ptr = handlers_[meta->cmd()]; + (this->*handler_ptr)(server_, conn, meta, data, size); }); server_->Init(); @@ -106,10 +109,14 @@ void SchedulerNode::CreateTcpServer() { } void SchedulerNode::ProcessRegister(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message) { + std::shared_ptr meta, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(server); + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); MS_LOG(INFO) << "The scheduler process a register message!"; RegisterMessage register_message; - register_message.ParseFromString(message->data()); + register_message.ParseFromArray(data, size); // assign worker node and server node rank id int rank_id = node_manager_.NextRankId(register_message); @@ -123,32 +130,32 @@ void SchedulerNode::ProcessRegister(std::shared_ptr server, std::shar register_resp_message.set_node_id(node_id); register_resp_message.set_rank_id(rank_id); - std::shared_ptr comm_message = std::make_shared(); - *comm_message->mutable_pb_meta() = {message->pb_meta()}; - comm_message->set_data(register_resp_message.SerializeAsString()); - server->SendMessage(conn, comm_message); + server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(), + register_resp_message.ByteSizeLong()); } void SchedulerNode::ProcessFinish(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message) { + std::shared_ptr meta, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(server); + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); FinishMessage finish_message; - finish_message.ParseFromString(message->data()); + finish_message.ParseFromArray(data, size); node_manager_.AddFinishNode(finish_message); MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); - server->SendMessage(conn, message); + server->SendMessage(conn, meta, Protos::PROTOBUF, data, size); } void SchedulerNode::ProcessFetchServers(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message) { + std::shared_ptr meta, const void *data, size_t size) { FetchServersRespMessage fetch_servers_message; std::vector servers_meta_list = node_manager_.FetchServersMeta(); *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; - std::shared_ptr comm_message = std::make_shared(); - *comm_message->mutable_pb_meta() = {message->pb_meta()}; - comm_message->set_data(fetch_servers_message.SerializeAsString()); - server->SendMessage(conn, comm_message); + server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(), + fetch_servers_message.ByteSizeLong()); } void SchedulerNode::StartUpdateClusterStateTimer() { diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index 1c89d2398dd..ebbdaf52e39 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -36,13 +36,14 @@ namespace mindspore { namespace ps { namespace core { + class SchedulerNode : public Node { public: SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} ~SchedulerNode() override; typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message); + std::shared_ptr meta, const void *data, size_t size); bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Stop() override; @@ -53,14 +54,14 @@ class SchedulerNode : public Node { void InitCommandHandler(); void CreateTcpServer(); void ProcessHeartbeat(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message); + std::shared_ptr meta, const void *data, size_t size); void ProcessRegister(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message); + std::shared_ptr meta, const void *data, size_t size); void StartUpdateClusterStateTimer(); void ProcessFinish(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message); + std::shared_ptr meta, const void *data, size_t size); void ProcessFetchServers(std::shared_ptr server, std::shared_ptr conn, - std::shared_ptr message); + std::shared_ptr meta, const void *data, size_t size); std::shared_ptr server_; std::unique_ptr scheduler_thread_; diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 28d09570678..2ba7ca794b2 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -46,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) { void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } -void ServerNode::Response(std::shared_ptr conn, std::shared_ptr message) { +void ServerNode::Response(std::shared_ptr conn, std::shared_ptr meta, DataPtr data, + size_t size) { MS_EXCEPTION_IF_NULL(conn); - MS_EXCEPTION_IF_NULL(message); - message->mutable_pb_meta()->set_role(node_info_.node_role_); - message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_); - const MessageMeta &message_meta = message->pb_meta(); - const uint64_t request_id = message_meta.request_id(); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + meta->set_role(node_info_.node_role_); + meta->set_rank_id(node_info_.rank_id_); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) - << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; - server_->SendMessage(conn, message); + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id(); + server_->SendMessage(conn, meta, Protos::RAW, data.get(), size); } void ServerNode::CreateTcpServer() { @@ -63,17 +63,18 @@ void ServerNode::CreateTcpServer() { std::string server_ip; CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); server_ = std::make_shared(server_ip, 0); - server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr message) { - switch (message->pb_meta().cmd()) { + server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size) { + switch (meta->cmd()) { case NodeCommand::SEND_DATA: - ProcessSendData(conn, message); + ProcessSendData(conn, meta, protos, data, size); break; case NodeCommand::COLLECTIVE_SEND_DATA: - ProcessCollectiveSendData(conn, message); - RunReceiveCallback(*message); + ProcessCollectiveSendData(conn, meta, data, size); + RunReceiveCallback(meta, protos, data, size); break; default: - MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; + MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; } }); server_->Init(); @@ -99,18 +100,24 @@ void ServerNode::Initialize() { MS_LOG(INFO) << "Server node init client successful!"; } -void ServerNode::ProcessSendData(std::shared_ptr conn, std::shared_ptr message) { +void ServerNode::ProcessSendData(std::shared_ptr conn, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size) { MS_EXCEPTION_IF_NULL(conn); - MS_EXCEPTION_IF_NULL(message); - request_handler_(conn, message); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + std::shared_ptr res(new unsigned char[size]); + int ret = memcpy_s(res.get(), size, data, size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + request_handler_(conn, meta, res, size); } -void ServerNode::ProcessCollectiveSendData(std::shared_ptr conn, std::shared_ptr message) { +void ServerNode::ProcessCollectiveSendData(std::shared_ptr conn, std::shared_ptr meta, + const void *data, size_t size) { MS_EXCEPTION_IF_NULL(conn); - MS_EXCEPTION_IF_NULL(message); - std::shared_ptr comm_message = std::make_shared(); - *comm_message->mutable_pb_meta() = {message->pb_meta()}; - server_->SendMessage(conn, comm_message); + MS_EXCEPTION_IF_NULL(meta); + server_->SendMessage(conn, meta, Protos::RAW, data, size); } bool ServerNode::Stop() { diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 086358f56e5..df109e2a5c5 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" @@ -41,16 +42,19 @@ class ServerNode : public AbstractNode { bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; - using RequestHandler = std::function conn, std::shared_ptr message)>; + using RequestHandler = std::function conn, std::shared_ptr meta, + DataPtr data, size_t size)>; void set_handler(const RequestHandler &handler); - void Response(std::shared_ptr conn, std::shared_ptr message); + void Response(std::shared_ptr conn, std::shared_ptr meta, DataPtr data, size_t size); private: void CreateTcpServer(); void Initialize(); - void ProcessSendData(std::shared_ptr conn, std::shared_ptr message); - void ProcessCollectiveSendData(std::shared_ptr conn, std::shared_ptr message); + void ProcessSendData(std::shared_ptr conn, std::shared_ptr meta, const Protos &protos, + const void *data, size_t size); + void ProcessCollectiveSendData(std::shared_ptr conn, std::shared_ptr meta, + const void *data, size_t size); std::shared_ptr server_; std::unique_ptr server_thread_; diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index 14d9a965f6b..ce0412bdc38 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -46,11 +46,12 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port) server_port_(port), is_stop_(true), is_connected_(false) { - message_handler_.SetCallback([this](std::shared_ptr message) { - if (message_callback_) { - message_callback_(*this, *message); - } - }); + message_handler_.SetCallback( + [this](std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { + if (message_callback_) { + message_callback_(meta, protos, data, size); + } + }); } TcpClient::~TcpClient() { @@ -189,7 +190,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { void TcpClient::OnReadHandler(const void *buf, size_t num) { MS_EXCEPTION_IF_NULL(buf); if (read_callback_) { - read_callback_(*this, buf, num); + read_callback_(buf, num); } message_handler_.ReceiveMessage(buf, num); } @@ -198,7 +199,7 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { MS_EXCEPTION_IF_NULL(arg); auto tcp_client = reinterpret_cast(arg); if (tcp_client->on_timer_callback_) { - tcp_client->on_timer_callback_(*tcp_client); + tcp_client->on_timer_callback_(); } } @@ -245,7 +246,7 @@ void TcpClient::Start() { MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base dispatch failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; - MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!"; } void TcpClient::StartWithNoBlock() { @@ -256,7 +257,7 @@ void TcpClient::StartWithNoBlock() { MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; - MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!"; } void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } @@ -265,14 +266,49 @@ bool TcpClient::SendMessage(const CommMessage &message) const { MS_EXCEPTION_IF_NULL(buffer_event_); bufferevent_lock(buffer_event_); bool res = true; - size_t buf_size = message.ByteSizeLong(); - std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); - if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { + size_t buf_size = IntToUint(message.ByteSizeLong()); + uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong()); + Messageheader header; + header.message_proto_ = Protos::PROTOBUF; + header.message_length_ = buf_size; + header.message_meta_length_ = meta_size; + if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) { MS_LOG(ERROR) << "Event buffer add header failed!"; res = false; } - if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { + if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) { + MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; + res = false; + } + if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) { + MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; + res = false; + } + bufferevent_unlock(buffer_event_); + return res; +} + +bool TcpClient::SendMessage(std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(buffer_event_); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + bufferevent_lock(buffer_event_); + bool res = true; + + Messageheader header; + header.message_proto_ = protos; + header.message_meta_length_ = SizeToUint(meta->ByteSizeLong()); + header.message_length_ = size + header.message_meta_length_; + + if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) { + MS_LOG(ERROR) << "Event buffer add header failed!"; + res = false; + } + if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) { + MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; + res = false; + } + if (bufferevent_write(buffer_event_, data, size) == -1) { MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; res = false; } diff --git a/mindspore/ccsrc/ps/core/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h index cdf3add7080..fa0ef8df156 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -42,10 +42,10 @@ class TcpClient { public: using OnConnected = std::function; using OnDisconnected = std::function; - using OnRead = std::function; - using OnTimeout = std::function; - using OnMessage = std::function; - using OnTimer = std::function; + using OnRead = std::function; + using OnTimeout = std::function; + using OnMessage = std::function, const Protos &, const void *, size_t size)>; + using OnTimer = std::function; explicit TcpClient(const std::string &address, std::uint16_t port); virtual ~TcpClient(); @@ -61,6 +61,7 @@ class TcpClient { void StartWithNoBlock(); void SetMessageCallback(const OnMessage &cb); bool SendMessage(const CommMessage &message) const; + bool SendMessage(std::shared_ptr meta, const Protos &protos, const void *data, size_t size); void StartTimer(const uint32_t &time); void set_timer_callback(const OnTimer &timer); const event_base &eventbase(); diff --git a/mindspore/ccsrc/ps/core/tcp_message_handler.cc b/mindspore/ccsrc/ps/core/tcp_message_handler.cc index c63fd1ab50b..e3c74ffc985 100644 --- a/mindspore/ccsrc/ps/core/tcp_message_handler.cc +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.cc @@ -35,8 +35,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { header_[++header_index_] = *(buffer_data + i); --num; if (header_index_ == kHeaderLen - 1) { - message_length_ = *reinterpret_cast(header_); - remaining_length_ = message_length_; + message_header_.message_proto_ = *reinterpret_cast(header_); + message_header_.message_meta_length_ = + *reinterpret_cast(header_ + sizeof(message_header_.message_proto_)); + message_header_.message_length_ = *reinterpret_cast( + header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_)); + remaining_length_ = message_header_.message_length_; message_buffer_.reset(new unsigned char[remaining_length_]); buffer_data += (i + 1); break; @@ -57,10 +61,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { } if (remaining_length_ == 0) { - std::shared_ptr pb_message = std::make_shared(); - pb_message->ParseFromArray(message_buffer_.get(), message_length_); if (message_callback_) { - message_callback_(pb_message); + std::shared_ptr pb_message = std::make_shared(); + pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_); + message_callback_(pb_message, message_header_.message_proto_, + message_buffer_.get() + message_header_.message_meta_length_, + message_header_.message_length_ - message_header_.message_meta_length_); } message_buffer_.reset(); message_buffer_ = nullptr; diff --git a/mindspore/ccsrc/ps/core/tcp_message_handler.h b/mindspore/ccsrc/ps/core/tcp_message_handler.h index 2caa5112bd6..2912ae1c723 100644 --- a/mindspore/ccsrc/ps/core/tcp_message_handler.h +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.h @@ -24,24 +24,20 @@ #include #include "utils/log_adapter.h" +#include "ps/core/message.h" #include "proto/comm.pb.h" #include "proto/ps.pb.h" namespace mindspore { namespace ps { namespace core { -using messageReceive = std::function)>; -constexpr int kHeaderLen = 8; +using messageReceive = std::function, const Protos &, const void *, size_t size)>; +constexpr int kHeaderLen = 16; class TcpMessageHandler { public: TcpMessageHandler() - : is_parsed_(false), - message_buffer_(nullptr), - message_length_(0), - remaining_length_(0), - header_index_(-1), - last_copy_len_(0) {} + : is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {} virtual ~TcpMessageHandler() = default; void SetCallback(const messageReceive &cb); @@ -51,11 +47,12 @@ class TcpMessageHandler { messageReceive message_callback_; bool is_parsed_; std::unique_ptr message_buffer_; - size_t message_length_; size_t remaining_length_; - char header_[8]; + char header_[16]; int header_index_; size_t last_copy_len_; + MessageHeader message_header_; + std::string mBuffer; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 4751a6a10c2..4ff22d0d8f9 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -54,13 +54,39 @@ bool TcpConnection::SendMessage(std::shared_ptr message) const { bufferevent_lock(buffer_event_); bool res = true; size_t buf_size = message->ByteSizeLong(); - std::vector serialized(buf_size); - message->SerializeToArray(serialized.data(), SizeToInt(buf_size)); if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { MS_LOG(ERROR) << "Event buffer add header failed!"; res = false; } - if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) { + if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) { + MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; + res = false; + } + bufferevent_unlock(buffer_event_); + return res; +} + +bool TcpConnection::SendMessage(std::shared_ptr meta, const Protos &protos, const void *data, + size_t size) const { + MS_EXCEPTION_IF_NULL(buffer_event_); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + bufferevent_lock(buffer_event_); + bool res = true; + Messageheader header; + header.message_proto_ = protos; + header.message_meta_length_ = SizeToUint(meta->ByteSizeLong()); + header.message_length_ = size + header.message_meta_length_; + + if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) { + MS_LOG(ERROR) << "Event buffer add header failed!"; + res = false; + } + if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) { + MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; + res = false; + } + if (bufferevent_write(buffer_event_, data, size) == -1) { MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; res = false; } @@ -158,7 +184,7 @@ void TcpServer::Start() { MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base dispatch failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; - MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!"; } void TcpServer::StartWithNoBlock() { @@ -169,7 +195,7 @@ void TcpServer::StartWithNoBlock() { MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; - MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; + MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!"; } void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { @@ -260,10 +286,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st MS_EXCEPTION_IF_NULL(conn); server->AddConnection(fd, conn); - conn->InitConnection([=](std::shared_ptr message) { + conn->InitConnection([=](std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { OnServerReceiveMessage on_server_receive = server->GetServerReceive(); if (on_server_receive) { - on_server_receive(conn, message); + on_server_receive(conn, meta, protos, data, size); } }); bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, @@ -274,6 +300,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st } std::shared_ptr TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { + MS_EXCEPTION_IF_NULL(bev); std::shared_ptr conn = nullptr; if (client_accept_) { conn = (client_accept_(*this)); @@ -367,9 +394,17 @@ bool TcpServer::SendMessage(std::shared_ptr conn, std::shared_ptr return conn->SendMessage(message); } +bool TcpServer::SendMessage(std::shared_ptr conn, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + return conn->SendMessage(meta, protos, data, size); +} + void TcpServer::SendMessage(std::shared_ptr message) { - std::lock_guard lock(connection_mutex_); MS_EXCEPTION_IF_NULL(message); + std::lock_guard lock(connection_mutex_); for (auto it = connections_.begin(); it != connections_.end(); ++it) { SendMessage(it->second, message); diff --git a/mindspore/ccsrc/ps/core/tcp_server.h b/mindspore/ccsrc/ps/core/tcp_server.h index 84dbffaec4c..9ec5d0f2bad 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.h +++ b/mindspore/ccsrc/ps/core/tcp_server.h @@ -36,7 +36,6 @@ #include "ps/core/tcp_message_handler.h" #include "ps/core/cluster_config.h" -#include "utils/log_adapter.h" #include "utils/convert_utils_base.h" namespace mindspore { @@ -55,6 +54,7 @@ class TcpConnection { virtual void InitConnection(const messageReceive &callback); virtual void SendMessage(const void *buffer, size_t num) const; bool SendMessage(std::shared_ptr message) const; + bool SendMessage(std::shared_ptr meta, const Protos &protos, const void *data, size_t size) const; virtual void OnReadHandler(const void *buffer, size_t numBytes); TcpServer *GetServer() const; const evutil_socket_t &GetFd() const; @@ -69,7 +69,8 @@ class TcpConnection { }; using OnServerReceiveMessage = - std::function conn, std::shared_ptr message)>; + std::function conn, std::shared_ptr meta, const Protos &protos, + const void *data, size_t size)>; class TcpServer { public: @@ -100,6 +101,8 @@ class TcpServer { OnServerReceiveMessage GetServerReceive() const; void SetMessageCallback(const OnServerReceiveMessage &cb); bool SendMessage(std::shared_ptr conn, std::shared_ptr message); + bool SendMessage(std::shared_ptr conn, std::shared_ptr meta, const Protos &protos, + const void *data, size_t sizee); void SendMessage(std::shared_ptr message); uint16_t BoundPort() const; std::string BoundIp() const; diff --git a/tests/ut/cpp/ps/core/tcp_client_tests.cc b/tests/ut/cpp/ps/core/tcp_client_tests.cc index dcbfddc9228..26e67037900 100644 --- a/tests/ut/cpp/ps/core/tcp_client_tests.cc +++ b/tests/ut/cpp/ps/core/tcp_client_tests.cc @@ -30,7 +30,12 @@ class TestTcpClient : public UT::Common { TEST_F(TestTcpClient, InitClientIPError) { auto client = std::make_unique("127.0.0.13543", 9000); - client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); + client->SetMessageCallback([&](std::shared_ptr, const Protos &, const void *data, size_t size) { + CommMessage message; + message.ParseFromArray(data, size); + + client->SendMessage(message); + }); ASSERT_THROW(client->Init(), std::exception); } @@ -38,10 +43,15 @@ TEST_F(TestTcpClient, InitClientIPError) { TEST_F(TestTcpClient, InitClientPortErrorNoException) { auto client = std::make_unique("127.0.0.1", -1); - client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); + client->SetMessageCallback([&](std::shared_ptr, const Protos &, const void *data, size_t size) { + CommMessage message; + message.ParseFromArray(data, size); + client->SendMessage(message); + }); EXPECT_NO_THROW(client->Init()); } + } // namespace core } // namespace ps } // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/core/tcp_message_handler_test.cc b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc index ffe6d9ab2b2..39c63d9bb0b 100644 --- a/tests/ut/cpp/ps/core/tcp_message_handler_test.cc +++ b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc @@ -33,131 +33,145 @@ class TestTcpMessageHandler : public UT::Common { void TearDown() override {} }; -TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { +TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data) { TcpMessageHandler handler; - handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 1000); }); + handler.SetCallback([this](std::shared_ptr meta, const Protos &, const void *data, size_t size) { + EXPECT_EQ(meta->ByteSizeLong(), 2); + EXPECT_EQ(size, 1000); + }); std::string data(1000, 'a'); - CommMessage message; - message.set_data(data); - size_t buf_size = message.ByteSizeLong(); - char result[1011]; - int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); + + char result[1018]; + + MessageMeta meta; + meta.set_request_id(1); + EXPECT_EQ(meta.ByteSizeLong(), 2); + + MessageHeader header; + header.message_proto_ = Protos::RAW; + header.message_meta_length_ = meta.ByteSizeLong(); + header.message_length_ = data.length() + meta.ByteSizeLong(); + int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), static_cast(buf_size)); - memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); - handler.ReceiveMessage(result, buf_size + kHeaderLen); + memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); + memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); + + handler.ReceiveMessage(result, 1018); } -TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { +TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data_16Header_2meta_1000Data) { TcpMessageHandler handler; - handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 1000); }); + handler.SetCallback([this](std::shared_ptr meta, const Protos &, const void *data, size_t size) { + EXPECT_EQ(meta->ByteSizeLong(), 2); + EXPECT_EQ(size, 1000); + }); std::string data(1000, 'a'); - CommMessage message; - message.set_data(data); - size_t buf_size = message.ByteSizeLong(); - char result[2022] = {0}; - int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), static_cast(buf_size)); - ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size); + + char result[2036]; + + MessageMeta meta; + meta.set_request_id(1); + EXPECT_EQ(meta.ByteSizeLong(), 2); + + MessageHeader header; + header.message_proto_ = Protos::RAW; + header.message_meta_length_ = meta.ByteSizeLong(); + header.message_length_ = data.length() + meta.ByteSizeLong(); + int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2); + memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); + memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); + + memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen); + memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() + data.length(), meta.ByteSizeLong(), + meta.SerializeAsString().data(), meta.ByteSizeLong()); + memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() * 2 + data.length(), data.length(), data.data(), + data.length()); + + handler.ReceiveMessage(result, 2036); } -TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { +TEST_F(TestTcpMessageHandler, 16header_2meta_4070data_8header_8header_2meta_4070data) { TcpMessageHandler handler; - handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 4081); }); + handler.SetCallback([this](std::shared_ptr meta, const Protos &, const void *data, size_t size) { + EXPECT_EQ(meta->ByteSizeLong(), 2); + EXPECT_EQ(size, 4070); + }); + + std::string data(4070, 'a'); - std::string data(4081, 'a'); - CommMessage message; - message.set_data(data); - size_t buf_size = message.ByteSizeLong(); char result[4096] = {0}; - int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), static_cast(buf_size)); - ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); + + MessageMeta meta; + meta.set_request_id(1); + EXPECT_EQ(meta.ByteSizeLong(), 2); + + MessageHeader header; + header.message_proto_ = Protos::RAW; + header.message_meta_length_ = meta.ByteSizeLong(); + header.message_length_ = data.length() + meta.ByteSizeLong(); + int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); if (ret != 0) { MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } - ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } + memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); + memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); + memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), 8, &header, 8); handler.ReceiveMessage(result, 4096); - auto temp = reinterpret_cast(&buf_size); - ret = memcpy_s(result, 4, temp + 4, 4); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - - handler.ReceiveMessage(result, 4088); -} - -TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { - TcpMessageHandler handler; - handler.SetCallback([this](std::shared_ptr message) { EXPECT_EQ(message->data().size(), 4077); }); - - std::string data(4077, 'a'); - CommMessage message; - message.set_data(data); - size_t buf_size = message.ByteSizeLong(); - char result[4096] = {0}; - int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), static_cast(buf_size)); - ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - - ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } - - handler.ReceiveMessage(result, 4096); - - ret = memcpy_s(result, buf_size, serialized.data(), buf_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; - } + auto temp = reinterpret_cast(&header); + memcpy_s(result, 8, temp + 8, 8); + memcpy_s(result + 8, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); + memcpy_s(result + 8 + 2, data.length(), data.data(), data.length()); handler.ReceiveMessage(result, 4080); } + +TEST_F(TestTcpMessageHandler, 16Header_2meta_4062Data_16Header_2meta_4062_data) { + TcpMessageHandler handler; + handler.SetCallback([this](std::shared_ptr meta, const Protos &, const void *data, size_t size) { + EXPECT_EQ(meta->ByteSizeLong(), 2); + EXPECT_EQ(size, 4062); + }); + + std::string data(4062, 'a'); + + char result[4096] = {0}; + + MessageMeta meta; + meta.set_request_id(1); + EXPECT_EQ(meta.ByteSizeLong(), 2); + + MessageHeader header; + header.message_proto_ = Protos::RAW; + header.message_meta_length_ = meta.ByteSizeLong(); + header.message_length_ = data.length() + meta.ByteSizeLong(); + int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); + memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length()); + memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen); + + handler.ReceiveMessage(result, 4096); + + memcpy_s(result, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong()); + memcpy_s(result + meta.ByteSizeLong(), data.length(), data.data(), data.length()); + + handler.ReceiveMessage(result, 4064); +} } // namespace core } // namespace ps } // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc index df5f70ee956..3afa7f90133 100644 --- a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc +++ b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc @@ -33,11 +33,12 @@ class TestTcpServer : public UT::Common { server_ = std::make_unique("127.0.0.1", 0); std::unique_ptr http_server_thread_(nullptr); http_server_thread_ = std::make_unique([=]() { - server_->SetMessageCallback([=](std::shared_ptr conn, std::shared_ptr message) { + server_->SetMessageCallback([=](std::shared_ptr conn, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size) { KVMessage kv_message; - kv_message.ParseFromString(message->data()); + kv_message.ParseFromArray(data, size); EXPECT_EQ(2, kv_message.keys_size()); - server_->SendMessage(conn, message); + server_->SendMessage(conn, meta, protos, data, size); }); server_->Init(); server_->Start(); @@ -61,23 +62,24 @@ TEST_F(TestTcpServer, ServerSendMessage) { std::cout << server_->BoundPort() << std::endl; std::unique_ptr http_client_thread(nullptr); http_client_thread = std::make_unique([&]() { - client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { - KVMessage kv_message; - kv_message.ParseFromString(message.data()); - EXPECT_EQ(2, kv_message.keys_size()); + client_->SetMessageCallback([&](std::shared_ptr meta, const Protos &, const void *data, size_t size) { + KVMessage message; + message.ParseFromArray(data, size); + EXPECT_EQ(2, message.keys_size()); }); client_->Init(); - CommMessage comm_message; KVMessage kv_message; std::vector keys{1, 2}; std::vector values{3, 4}; *kv_message.mutable_keys() = {keys.begin(), keys.end()}; *kv_message.mutable_values() = {values.begin(), values.end()}; - comm_message.set_data(kv_message.SerializeAsString()); - client_->SendMessage(comm_message); + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_DATA); + + client_->SendMessage(message_meta, Protos::RAW, kv_message.SerializeAsString().data(), kv_message.ByteSizeLong()); client_->Start(); });