!32413 Add return value for tcp message handler

Merge pull request !32413 from chengang/add_retrieve_api
This commit is contained in:
i-robot 2022-04-01 09:33:53 +00:00 committed by Gitee
commit 18d3e7c467
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 88 additions and 64 deletions

View File

@ -18,6 +18,7 @@
#include <algorithm>
#include <string>
#include "proto/topology.pb.h"
#include "distributed/rpc/tcp/constants.h"
#include "distributed/cluster/topology/utils.h"
#include "distributed/cluster/topology/meta_server_node.h"
@ -69,7 +70,7 @@ bool MetaServerNode::InitTCPServer() {
return true;
}
void MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message) {
std::shared_ptr<MessageBase> MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message) {
MS_EXCEPTION_IF_NULL(message);
const auto &name = message->Name();
@ -79,7 +80,7 @@ void MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message)
const auto &handler = system_msg_handlers_.find(message_name);
if (handler == system_msg_handlers_.end()) {
MS_LOG(ERROR) << "Unknown system message name: " << message->Name();
return;
return rpc::NULL_MSG;
}
system_msg_handlers_[message_name](message);
@ -88,13 +89,14 @@ void MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message)
const auto &handler = message_handlers_.find(name);
if (handler == message_handlers_.end()) {
MS_LOG(ERROR) << "Unknown message name: " << name;
return;
return rpc::NULL_MSG;
}
(*message_handlers_[name])(message->Body());
}
return rpc::NULL_MSG;
}
void MetaServerNode::ProcessRegister(const std::shared_ptr<MessageBase> &message) {
std::shared_ptr<MessageBase> MetaServerNode::ProcessRegister(const std::shared_ptr<MessageBase> &message) {
MS_EXCEPTION_IF_NULL(message);
RegistrationMessage registration;
@ -111,9 +113,10 @@ void MetaServerNode::ProcessRegister(const std::shared_ptr<MessageBase> &message
} else {
MS_LOG(ERROR) << "The node: " << node_id << " have been registered before.";
}
return rpc::NULL_MSG;
}
void MetaServerNode::ProcessUnregister(const std::shared_ptr<MessageBase> &message) {
std::shared_ptr<MessageBase> MetaServerNode::ProcessUnregister(const std::shared_ptr<MessageBase> &message) {
MS_EXCEPTION_IF_NULL(message);
UnregistrationMessage unregistration;
@ -124,12 +127,13 @@ void MetaServerNode::ProcessUnregister(const std::shared_ptr<MessageBase> &messa
std::unique_lock<std::shared_mutex> lock(nodes_mutex_);
if (nodes_.find(node_id) == nodes_.end()) {
MS_LOG(ERROR) << "Received unregistration message from invalid compute graph node: " << node_id;
return;
return rpc::NULL_MSG;
}
nodes_.erase(node_id);
return rpc::NULL_MSG;
}
void MetaServerNode::ProcessHeartbeat(const std::shared_ptr<MessageBase> &message) {
std::shared_ptr<MessageBase> MetaServerNode::ProcessHeartbeat(const std::shared_ptr<MessageBase> &message) {
MS_EXCEPTION_IF_NULL(message);
HeartbeatMessage heartbeat;
@ -145,6 +149,7 @@ void MetaServerNode::ProcessHeartbeat(const std::shared_ptr<MessageBase> &messag
} else {
MS_LOG(ERROR) << "Invalid node: " << node_id << ".";
}
return rpc::NULL_MSG;
}
void MetaServerNode::UpdateTopoState() {

View File

@ -80,16 +80,16 @@ class MetaServerNode : public NodeBase {
bool InitTCPServer();
// Handle the message received by the tcp server.
void HandleMessage(const std::shared_ptr<MessageBase> &message);
std::shared_ptr<MessageBase> HandleMessage(const std::shared_ptr<MessageBase> &message);
// Process the received register message sent from compute graph nodes.
void ProcessRegister(const std::shared_ptr<MessageBase> &message);
std::shared_ptr<MessageBase> ProcessRegister(const std::shared_ptr<MessageBase> &message);
// Process the received unregister message sent from compute graph nodes.
void ProcessUnregister(const std::shared_ptr<MessageBase> &message);
std::shared_ptr<MessageBase> ProcessUnregister(const std::shared_ptr<MessageBase> &message);
// Process the received heartbeat message sent from compute graph nodes.
void ProcessHeartbeat(const std::shared_ptr<MessageBase> &message);
std::shared_ptr<MessageBase> ProcessHeartbeat(const std::shared_ptr<MessageBase> &message);
// Maintain the state which is type of `TopoState` of this cluster topology.
void UpdateTopoState();

View File

@ -422,6 +422,42 @@ void Connection::FillRecvMessage() {
recv_message = msg;
}
int Connection::Flush() {
int total_send_bytes = 0;
while (!send_message_queue.empty() || total_send_len != 0) {
if (total_send_len == 0) {
FillSendMessage(send_message_queue.front(), source, false);
send_message_queue.pop();
}
size_t sendLen = 0;
int retval = socket_operation->SendMessage(this, &send_kernel_msg, total_send_len, &sendLen);
if (retval == IO_RW_OK && sendLen > 0) {
total_send_len -= sendLen;
if (total_send_len == 0) {
// update metrics
send_metrics->UpdateError(false);
output_buffer_size -= send_message->body.size();
total_send_bytes += send_message->body.size();
delete send_message;
send_message = nullptr;
break;
}
} else if (retval == IO_RW_OK && sendLen == 0) {
// EAGAIN
MS_LOG(ERROR) << "Failed to send message and update the epoll event";
(void)recv_event_loop->UpdateEpollEvent(socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
continue;
} else {
// update metrics
send_metrics->UpdateError(true, error_code);
state = ConnectionState::kDisconnecting;
break;
}
}
return total_send_bytes;
}
int Connection::AddConnnectEventHandler() {
return recv_event_loop->SetEventHandler(socket_fd, EPOLLIN | EPOLLHUP | EPOLLERR, NewConnectEventHandler,
reinterpret_cast<void *>(this));

View File

@ -109,6 +109,9 @@ struct Connection {
return !(that != nullptr && that->destination == destination && that->is_remote == is_remote);
}
// Send all the messages in the message queue.
int Flush();
// The socket used by this connection.
int socket_fd;

View File

@ -30,7 +30,7 @@
namespace mindspore {
namespace distributed {
namespace rpc {
using MessageHandler = std::function<void(const std::shared_ptr<MessageBase> &)>;
using MessageHandler = std::function<std::shared_ptr<MessageBase>(const std::shared_ptr<MessageBase> &)>;
using DeleteCallBack = void (*)(const std::string &from, const std::string &to);
using ConnectionCallBack = void (*)(void *conn);
@ -57,6 +57,8 @@ static const int g_httpKmsgEnable = -1;
using IntTypeMetrics = std::queue<int>;
using StringTypeMetrics = std::queue<std::string>;
static const std::shared_ptr<MessageBase> NULL_MSG = nullptr;
// Server socket listen backlog.
static const int SOCKET_LISTEN_BACKLOG = 2048;

View File

@ -131,42 +131,6 @@ void OnAccept(int server, uint32_t events, void *arg) {
tcpmgr->conn_pool_->AddConnection(conn);
}
int DoSend(Connection *conn) {
int total_send_bytes = 0;
while (!conn->send_message_queue.empty() || conn->total_send_len != 0) {
if (conn->total_send_len == 0) {
conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false);
conn->send_message_queue.pop();
}
size_t sendLen = 0;
int retval = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, conn->total_send_len, &sendLen);
if (retval == IO_RW_OK && sendLen > 0) {
conn->total_send_len -= sendLen;
if (conn->total_send_len == 0) {
// update metrics
conn->send_metrics->UpdateError(false);
conn->output_buffer_size -= conn->send_message->body.size();
total_send_bytes += conn->send_message->body.size();
delete conn->send_message;
conn->send_message = nullptr;
break;
}
} else if (retval == IO_RW_OK && sendLen == 0) {
// EAGAIN
MS_LOG(ERROR) << "Failed to send message and update the epoll event";
(void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
continue;
} else {
// update metrics
conn->send_metrics->UpdateError(true, conn->error_code);
conn->state = ConnectionState::kDisconnecting;
break;
}
}
return total_send_bytes;
}
void TCPComm::SetMessageHandler(const MessageHandler &handler) { message_handler_ = handler; }
bool TCPComm::Initialize() {
@ -260,7 +224,7 @@ void TCPComm::EventCallBack(void *connection) {
if (conn->state == ConnectionState::kConnected) {
conn->conn_mutex->lock();
(void)DoSend(conn);
(void)conn->Flush();
conn->conn_mutex->unlock();
} else if (conn->state == ConnectionState::kDisconnecting) {
conn->conn_mutex->lock();
@ -272,7 +236,7 @@ void TCPComm::WriteCallBack(void *connection) {
Connection *conn = reinterpret_cast<Connection *>(connection);
if (conn->state == ConnectionState::kConnected) {
conn->conn_mutex->lock();
(void)DoSend(conn);
(void)conn->Flush();
conn->conn_mutex->unlock();
}
}
@ -370,7 +334,7 @@ ssize_t TCPComm::Send(MessageBase *msg, bool sync) {
} else {
(void)conn->send_message_queue.emplace(msg);
}
return DoSend(conn);
return conn->Flush();
};
if (sync) {
return task();

View File

@ -33,9 +33,6 @@ namespace rpc {
// Event handler for new connecting request arrived.
void OnAccept(int server, uint32_t events, void *arg);
// Send messages buffered in the connection.
int DoSend(Connection *conn);
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
void ConnectedEventHandler(int fd, uint32_t events, void *context);
@ -116,7 +113,6 @@ class TCPComm {
std::shared_ptr<std::mutex> conn_mutex_;
friend void OnAccept(int server, uint32_t events, void *arg);
friend int DoSend(Connection *conn);
friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
ConnectionCallBack write_callback, ConnectionCallBack read_callback);
};

View File

@ -20,6 +20,7 @@
#include <utility>
#include <functional>
#include <condition_variable>
#include "distributed/rpc/tcp/constants.h"
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
namespace mindspore {
@ -130,7 +131,7 @@ void RecvActor::EraseInput(const OpContext<DeviceTensor> *context) {
}
}
void RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
std::shared_ptr<MessageBase> RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
// Block the message handler if the context is invalid.
std::unique_lock<std::mutex> lock(context_mtx_);
context_cv_.wait(lock, [this] { return is_context_valid_; });
@ -138,9 +139,11 @@ void RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
MS_LOG(INFO) << "Rpc actor recv message for inter-process edge: " << inter_process_edge_name_;
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
if (msg == nullptr || op_context_ == nullptr) {
return distributed::rpc::NULL_MSG;
}
ActorDispatcher::Send(GetAID(), &RecvActor::RunOpInterProcessData, msg, op_context_);
return distributed::rpc::NULL_MSG;
}
} // namespace runtime
} // namespace mindspore

View File

@ -66,7 +66,7 @@ class RecvActor : public RpcActor {
private:
// The message callback of the tcp server.
void HandleMessage(const std::shared_ptr<MessageBase> &msg);
std::shared_ptr<MessageBase> HandleMessage(const std::shared_ptr<MessageBase> &msg);
// The network address of this recv actor. It's generated automatically by rpc module.
std::string ip_;

View File

@ -154,7 +154,10 @@ TEST_F(TCPTest, SendOneMessage) {
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
IncrDataMsgNum(1);
return NULL_MSG;
});
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
@ -193,7 +196,10 @@ TEST_F(TCPTest, SendTwoMessages) {
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
IncrDataMsgNum(1);
return NULL_MSG;
});
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
@ -245,7 +251,10 @@ TEST_F(TCPTest, SendSyncMessage) {
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
IncrDataMsgNum(1);
return NULL_MSG;
});
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
@ -283,7 +292,10 @@ TEST_F(TCPTest, SendLargeMessages) {
bool ret = server->Initialize();
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
IncrDataMsgNum(1);
return NULL_MSG;
});
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
@ -336,7 +348,10 @@ TEST_F(TCPTest, CreateManyConnectionPairs) {
auto port = server->GetPort();
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
IncrDataMsgNum(1);
return NULL_MSG;
});
// Start the tcp client.
auto client_url = "127.0.0.1:1234";