forked from mindspore-Ecosystem/mindspore
!32413 Add return value for tcp message handler
Merge pull request !32413 from chengang/add_retrieve_api
This commit is contained in:
commit
18d3e7c467
|
@ -18,6 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "proto/topology.pb.h"
|
||||
#include "distributed/rpc/tcp/constants.h"
|
||||
#include "distributed/cluster/topology/utils.h"
|
||||
#include "distributed/cluster/topology/meta_server_node.h"
|
||||
|
||||
|
@ -69,7 +70,7 @@ bool MetaServerNode::InitTCPServer() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message) {
|
||||
std::shared_ptr<MessageBase> MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
const auto &name = message->Name();
|
||||
|
||||
|
@ -79,7 +80,7 @@ void MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message)
|
|||
const auto &handler = system_msg_handlers_.find(message_name);
|
||||
if (handler == system_msg_handlers_.end()) {
|
||||
MS_LOG(ERROR) << "Unknown system message name: " << message->Name();
|
||||
return;
|
||||
return rpc::NULL_MSG;
|
||||
}
|
||||
system_msg_handlers_[message_name](message);
|
||||
|
||||
|
@ -88,13 +89,14 @@ void MetaServerNode::HandleMessage(const std::shared_ptr<MessageBase> &message)
|
|||
const auto &handler = message_handlers_.find(name);
|
||||
if (handler == message_handlers_.end()) {
|
||||
MS_LOG(ERROR) << "Unknown message name: " << name;
|
||||
return;
|
||||
return rpc::NULL_MSG;
|
||||
}
|
||||
(*message_handlers_[name])(message->Body());
|
||||
}
|
||||
return rpc::NULL_MSG;
|
||||
}
|
||||
|
||||
void MetaServerNode::ProcessRegister(const std::shared_ptr<MessageBase> &message) {
|
||||
std::shared_ptr<MessageBase> MetaServerNode::ProcessRegister(const std::shared_ptr<MessageBase> &message) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
|
||||
RegistrationMessage registration;
|
||||
|
@ -111,9 +113,10 @@ void MetaServerNode::ProcessRegister(const std::shared_ptr<MessageBase> &message
|
|||
} else {
|
||||
MS_LOG(ERROR) << "The node: " << node_id << " have been registered before.";
|
||||
}
|
||||
return rpc::NULL_MSG;
|
||||
}
|
||||
|
||||
void MetaServerNode::ProcessUnregister(const std::shared_ptr<MessageBase> &message) {
|
||||
std::shared_ptr<MessageBase> MetaServerNode::ProcessUnregister(const std::shared_ptr<MessageBase> &message) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
|
||||
UnregistrationMessage unregistration;
|
||||
|
@ -124,12 +127,13 @@ void MetaServerNode::ProcessUnregister(const std::shared_ptr<MessageBase> &messa
|
|||
std::unique_lock<std::shared_mutex> lock(nodes_mutex_);
|
||||
if (nodes_.find(node_id) == nodes_.end()) {
|
||||
MS_LOG(ERROR) << "Received unregistration message from invalid compute graph node: " << node_id;
|
||||
return;
|
||||
return rpc::NULL_MSG;
|
||||
}
|
||||
nodes_.erase(node_id);
|
||||
return rpc::NULL_MSG;
|
||||
}
|
||||
|
||||
void MetaServerNode::ProcessHeartbeat(const std::shared_ptr<MessageBase> &message) {
|
||||
std::shared_ptr<MessageBase> MetaServerNode::ProcessHeartbeat(const std::shared_ptr<MessageBase> &message) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
|
||||
HeartbeatMessage heartbeat;
|
||||
|
@ -145,6 +149,7 @@ void MetaServerNode::ProcessHeartbeat(const std::shared_ptr<MessageBase> &messag
|
|||
} else {
|
||||
MS_LOG(ERROR) << "Invalid node: " << node_id << ".";
|
||||
}
|
||||
return rpc::NULL_MSG;
|
||||
}
|
||||
|
||||
void MetaServerNode::UpdateTopoState() {
|
||||
|
|
|
@ -80,16 +80,16 @@ class MetaServerNode : public NodeBase {
|
|||
bool InitTCPServer();
|
||||
|
||||
// Handle the message received by the tcp server.
|
||||
void HandleMessage(const std::shared_ptr<MessageBase> &message);
|
||||
std::shared_ptr<MessageBase> HandleMessage(const std::shared_ptr<MessageBase> &message);
|
||||
|
||||
// Process the received register message sent from compute graph nodes.
|
||||
void ProcessRegister(const std::shared_ptr<MessageBase> &message);
|
||||
std::shared_ptr<MessageBase> ProcessRegister(const std::shared_ptr<MessageBase> &message);
|
||||
|
||||
// Process the received unregister message sent from compute graph nodes.
|
||||
void ProcessUnregister(const std::shared_ptr<MessageBase> &message);
|
||||
std::shared_ptr<MessageBase> ProcessUnregister(const std::shared_ptr<MessageBase> &message);
|
||||
|
||||
// Process the received heartbeat message sent from compute graph nodes.
|
||||
void ProcessHeartbeat(const std::shared_ptr<MessageBase> &message);
|
||||
std::shared_ptr<MessageBase> ProcessHeartbeat(const std::shared_ptr<MessageBase> &message);
|
||||
|
||||
// Maintain the state which is type of `TopoState` of this cluster topology.
|
||||
void UpdateTopoState();
|
||||
|
|
|
@ -422,6 +422,42 @@ void Connection::FillRecvMessage() {
|
|||
recv_message = msg;
|
||||
}
|
||||
|
||||
int Connection::Flush() {
|
||||
int total_send_bytes = 0;
|
||||
while (!send_message_queue.empty() || total_send_len != 0) {
|
||||
if (total_send_len == 0) {
|
||||
FillSendMessage(send_message_queue.front(), source, false);
|
||||
send_message_queue.pop();
|
||||
}
|
||||
size_t sendLen = 0;
|
||||
int retval = socket_operation->SendMessage(this, &send_kernel_msg, total_send_len, &sendLen);
|
||||
if (retval == IO_RW_OK && sendLen > 0) {
|
||||
total_send_len -= sendLen;
|
||||
if (total_send_len == 0) {
|
||||
// update metrics
|
||||
send_metrics->UpdateError(false);
|
||||
|
||||
output_buffer_size -= send_message->body.size();
|
||||
total_send_bytes += send_message->body.size();
|
||||
delete send_message;
|
||||
send_message = nullptr;
|
||||
break;
|
||||
}
|
||||
} else if (retval == IO_RW_OK && sendLen == 0) {
|
||||
// EAGAIN
|
||||
MS_LOG(ERROR) << "Failed to send message and update the epoll event";
|
||||
(void)recv_event_loop->UpdateEpollEvent(socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
|
||||
continue;
|
||||
} else {
|
||||
// update metrics
|
||||
send_metrics->UpdateError(true, error_code);
|
||||
state = ConnectionState::kDisconnecting;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return total_send_bytes;
|
||||
}
|
||||
|
||||
int Connection::AddConnnectEventHandler() {
|
||||
return recv_event_loop->SetEventHandler(socket_fd, EPOLLIN | EPOLLHUP | EPOLLERR, NewConnectEventHandler,
|
||||
reinterpret_cast<void *>(this));
|
||||
|
|
|
@ -109,6 +109,9 @@ struct Connection {
|
|||
return !(that != nullptr && that->destination == destination && that->is_remote == is_remote);
|
||||
}
|
||||
|
||||
// Send all the messages in the message queue.
|
||||
int Flush();
|
||||
|
||||
// The socket used by this connection.
|
||||
int socket_fd;
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
using MessageHandler = std::function<void(const std::shared_ptr<MessageBase> &)>;
|
||||
using MessageHandler = std::function<std::shared_ptr<MessageBase>(const std::shared_ptr<MessageBase> &)>;
|
||||
using DeleteCallBack = void (*)(const std::string &from, const std::string &to);
|
||||
using ConnectionCallBack = void (*)(void *conn);
|
||||
|
||||
|
@ -57,6 +57,8 @@ static const int g_httpKmsgEnable = -1;
|
|||
using IntTypeMetrics = std::queue<int>;
|
||||
using StringTypeMetrics = std::queue<std::string>;
|
||||
|
||||
static const std::shared_ptr<MessageBase> NULL_MSG = nullptr;
|
||||
|
||||
// Server socket listen backlog.
|
||||
static const int SOCKET_LISTEN_BACKLOG = 2048;
|
||||
|
||||
|
|
|
@ -131,42 +131,6 @@ void OnAccept(int server, uint32_t events, void *arg) {
|
|||
tcpmgr->conn_pool_->AddConnection(conn);
|
||||
}
|
||||
|
||||
int DoSend(Connection *conn) {
|
||||
int total_send_bytes = 0;
|
||||
while (!conn->send_message_queue.empty() || conn->total_send_len != 0) {
|
||||
if (conn->total_send_len == 0) {
|
||||
conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false);
|
||||
conn->send_message_queue.pop();
|
||||
}
|
||||
size_t sendLen = 0;
|
||||
int retval = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, conn->total_send_len, &sendLen);
|
||||
if (retval == IO_RW_OK && sendLen > 0) {
|
||||
conn->total_send_len -= sendLen;
|
||||
if (conn->total_send_len == 0) {
|
||||
// update metrics
|
||||
conn->send_metrics->UpdateError(false);
|
||||
|
||||
conn->output_buffer_size -= conn->send_message->body.size();
|
||||
total_send_bytes += conn->send_message->body.size();
|
||||
delete conn->send_message;
|
||||
conn->send_message = nullptr;
|
||||
break;
|
||||
}
|
||||
} else if (retval == IO_RW_OK && sendLen == 0) {
|
||||
// EAGAIN
|
||||
MS_LOG(ERROR) << "Failed to send message and update the epoll event";
|
||||
(void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
|
||||
continue;
|
||||
} else {
|
||||
// update metrics
|
||||
conn->send_metrics->UpdateError(true, conn->error_code);
|
||||
conn->state = ConnectionState::kDisconnecting;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return total_send_bytes;
|
||||
}
|
||||
|
||||
void TCPComm::SetMessageHandler(const MessageHandler &handler) { message_handler_ = handler; }
|
||||
|
||||
bool TCPComm::Initialize() {
|
||||
|
@ -260,7 +224,7 @@ void TCPComm::EventCallBack(void *connection) {
|
|||
|
||||
if (conn->state == ConnectionState::kConnected) {
|
||||
conn->conn_mutex->lock();
|
||||
(void)DoSend(conn);
|
||||
(void)conn->Flush();
|
||||
conn->conn_mutex->unlock();
|
||||
} else if (conn->state == ConnectionState::kDisconnecting) {
|
||||
conn->conn_mutex->lock();
|
||||
|
@ -272,7 +236,7 @@ void TCPComm::WriteCallBack(void *connection) {
|
|||
Connection *conn = reinterpret_cast<Connection *>(connection);
|
||||
if (conn->state == ConnectionState::kConnected) {
|
||||
conn->conn_mutex->lock();
|
||||
(void)DoSend(conn);
|
||||
(void)conn->Flush();
|
||||
conn->conn_mutex->unlock();
|
||||
}
|
||||
}
|
||||
|
@ -370,7 +334,7 @@ ssize_t TCPComm::Send(MessageBase *msg, bool sync) {
|
|||
} else {
|
||||
(void)conn->send_message_queue.emplace(msg);
|
||||
}
|
||||
return DoSend(conn);
|
||||
return conn->Flush();
|
||||
};
|
||||
if (sync) {
|
||||
return task();
|
||||
|
|
|
@ -33,9 +33,6 @@ namespace rpc {
|
|||
// Event handler for new connecting request arrived.
|
||||
void OnAccept(int server, uint32_t events, void *arg);
|
||||
|
||||
// Send messages buffered in the connection.
|
||||
int DoSend(Connection *conn);
|
||||
|
||||
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
|
||||
|
||||
void ConnectedEventHandler(int fd, uint32_t events, void *context);
|
||||
|
@ -116,7 +113,6 @@ class TCPComm {
|
|||
std::shared_ptr<std::mutex> conn_mutex_;
|
||||
|
||||
friend void OnAccept(int server, uint32_t events, void *arg);
|
||||
friend int DoSend(Connection *conn);
|
||||
friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
|
||||
ConnectionCallBack write_callback, ConnectionCallBack read_callback);
|
||||
};
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <utility>
|
||||
#include <functional>
|
||||
#include <condition_variable>
|
||||
#include "distributed/rpc/tcp/constants.h"
|
||||
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -130,7 +131,7 @@ void RecvActor::EraseInput(const OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
|
||||
std::shared_ptr<MessageBase> RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
|
||||
// Block the message handler if the context is invalid.
|
||||
std::unique_lock<std::mutex> lock(context_mtx_);
|
||||
context_cv_.wait(lock, [this] { return is_context_valid_; });
|
||||
|
@ -138,9 +139,11 @@ void RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
|
|||
|
||||
MS_LOG(INFO) << "Rpc actor recv message for inter-process edge: " << inter_process_edge_name_;
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
|
||||
if (msg == nullptr || op_context_ == nullptr) {
|
||||
return distributed::rpc::NULL_MSG;
|
||||
}
|
||||
ActorDispatcher::Send(GetAID(), &RecvActor::RunOpInterProcessData, msg, op_context_);
|
||||
return distributed::rpc::NULL_MSG;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,7 +66,7 @@ class RecvActor : public RpcActor {
|
|||
|
||||
private:
|
||||
// The message callback of the tcp server.
|
||||
void HandleMessage(const std::shared_ptr<MessageBase> &msg);
|
||||
std::shared_ptr<MessageBase> HandleMessage(const std::shared_ptr<MessageBase> &msg);
|
||||
|
||||
// The network address of this recv actor. It's generated automatically by rpc module.
|
||||
std::string ip_;
|
||||
|
|
|
@ -154,7 +154,10 @@ TEST_F(TCPTest, SendOneMessage) {
|
|||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
|
||||
IncrDataMsgNum(1);
|
||||
return NULL_MSG;
|
||||
});
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
|
@ -193,7 +196,10 @@ TEST_F(TCPTest, SendTwoMessages) {
|
|||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
|
||||
IncrDataMsgNum(1);
|
||||
return NULL_MSG;
|
||||
});
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
|
@ -245,7 +251,10 @@ TEST_F(TCPTest, SendSyncMessage) {
|
|||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
|
||||
IncrDataMsgNum(1);
|
||||
return NULL_MSG;
|
||||
});
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
|
@ -283,7 +292,10 @@ TEST_F(TCPTest, SendLargeMessages) {
|
|||
bool ret = server->Initialize();
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
|
||||
IncrDataMsgNum(1);
|
||||
return NULL_MSG;
|
||||
});
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
|
@ -336,7 +348,10 @@ TEST_F(TCPTest, CreateManyConnectionPairs) {
|
|||
auto port = server->GetPort();
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> std::shared_ptr<MessageBase> {
|
||||
IncrDataMsgNum(1);
|
||||
return NULL_MSG;
|
||||
});
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
|
|
Loading…
Reference in New Issue