support sync and async message sending mode for tcp client

This commit is contained in:
Parallels 2022-03-12 00:40:27 +08:00
parent b3a30e2a68
commit 6f31f270e9
12 changed files with 130 additions and 56 deletions

View File

@ -62,7 +62,7 @@ bool ComputeGraphNode::Register() {
auto message = CreateMessage(server_url, content);
MS_EXCEPTION_IF_NULL(message);
tcp_client_->Send(std::move(message));
tcp_client_->SendSync(std::move(message));
return true;
}
@ -77,7 +77,7 @@ bool ComputeGraphNode::Heartbeat() {
auto message = CreateMessage(server_url, content);
MS_EXCEPTION_IF_NULL(message);
tcp_client_->Send(std::move(message));
tcp_client_->SendSync(std::move(message));
return true;
}
} // namespace topology

View File

@ -326,12 +326,9 @@ void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseU
if (msg->type == MessageBase::Type::KMSG) {
if (!isHttpKmsg) {
send_to = msg->to;
send_from = msg->from.Name() + "@" + advertiseUrl;
send_from = msg->from;
send_msg_header.name_len = htonl(static_cast<uint32_t>(msg->name.size()));
send_msg_header.to_len = htonl(static_cast<uint32_t>(send_to.size()));
send_msg_header.from_len = htonl(static_cast<uint32_t>(send_from.size()));
send_msg_header.body_len = htonl(static_cast<uint32_t>(msg->body.size()));
FillMessageHeader(*msg, &send_msg_header);
send_io_vec[index].iov_base = &send_msg_header;
send_io_vec[index].iov_len = sizeof(send_msg_header);
@ -431,7 +428,6 @@ int Connection::AddConnnectEventHandler() {
}
bool Connection::ParseMessage() {
std::string magic_id = "";
int retval = 0;
uint32_t recvLen = 0;
char *recvBuf = nullptr;
@ -454,7 +450,7 @@ bool Connection::ParseMessage() {
if (strncmp(recv_msg_header.magic, RPC_MAGICID, sizeof(RPC_MAGICID) - 1) != 0) {
MS_LOG(ERROR) << "Failed to check magicid, RPC_MAGICID: " << RPC_MAGICID
<< ", recv magic_id: " << magic_id.c_str();
<< ", recv magic_id: " << recv_msg_header.magic;
state = ConnectionState::kDisconnecting;
return false;
}

View File

@ -30,27 +30,6 @@
namespace mindspore {
namespace distributed {
namespace rpc {
/*
* The MessageHeader contains the stats info about the message body.
*/
struct MessageHeader {
MessageHeader() {
for (unsigned int i = 0; i < BUSMAGIC_LEN; ++i) {
if (i < sizeof(RPC_MAGICID) - 1) {
magic[i] = RPC_MAGICID[i];
} else {
magic[i] = '\0';
}
}
}
char magic[BUSMAGIC_LEN];
uint32_t name_len{0};
uint32_t to_len{0};
uint32_t from_len{0};
uint32_t body_len{0};
};
/*
* The SendMetrics is responsible for collecting metrics when sending data through a connection.
*/

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONSTANTS_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONSTANTS_H_
#include <arpa/inet.h>
#include <string>
#include <csignal>
#include <queue>
@ -85,6 +86,45 @@ constexpr int IP_LEN_MAX = 128;
// Kill the process for safe exiting.
inline void KillProcess(const std::string &ret) { raise(SIGKILL); }
/*
* The MessageHeader contains the stats info about the message body.
*/
struct MessageHeader {
MessageHeader() {
for (unsigned int i = 0; i < BUSMAGIC_LEN; ++i) {
if (i < sizeof(RPC_MAGICID) - 1) {
magic[i] = RPC_MAGICID[i];
} else {
magic[i] = '\0';
}
}
}
char magic[BUSMAGIC_LEN];
uint32_t name_len{0};
uint32_t to_len{0};
uint32_t from_len{0};
uint32_t body_len{0};
};
// Fill the message header using the given message.
__attribute__((unused)) static void FillMessageHeader(const MessageBase &message, MessageHeader *header) {
std::string send_to = message.to;
std::string send_from = message.from;
header->name_len = htonl(static_cast<uint32_t>(message.name.size()));
header->to_len = htonl(static_cast<uint32_t>(send_to.size()));
header->from_len = htonl(static_cast<uint32_t>(send_from.size()));
header->body_len = htonl(static_cast<uint32_t>(message.body.size()));
}
// Compute and return the byte size of the whole message.
__attribute__((unused)) static size_t GetMessageSize(const MessageBase &message) {
std::string send_to = message.to;
std::string send_from = message.from;
size_t size = message.name.size() + send_to.size() + send_from.size() + message.body.size() + sizeof(MessageHeader);
return size;
}
#define RPC_ASSERT(expression) \
do { \
if (!(expression)) { \

View File

@ -129,7 +129,7 @@ void EventLoop::ReleaseResource() {
}
}
int EventLoop::AddTask(std::function<void()> &&task) {
int EventLoop::AddTask(std::function<int()> &&task) {
// put func to the queue
task_queue_mutex_.lock();
(void)task_queue_.emplace(std::move(task));

View File

@ -68,7 +68,7 @@ class EventLoop {
// Add task (eg. send message, reconnect etc.) to task queue of the event loop.
// These tasks are executed asynchronously.
int AddTask(std::function<void()> &&task);
int AddTask(std::function<int()> &&task);
// Set event handler for events(read/write/..) occurred on the socket fd.
int SetEventHandler(int sock_fd, uint32_t events, EventHandler handler, void *data);

View File

@ -69,11 +69,13 @@ bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
return rt;
}
int TCPClient::Send(std::unique_ptr<MessageBase> &&msg) {
int TCPClient::SendSync(std::unique_ptr<MessageBase> &&msg) {
int rt = -1;
rt = tcp_comm_->Send(msg.release());
rt = tcp_comm_->Send(msg.release(), true);
return rt;
}
void TCPClient::SendAsync(std::unique_ptr<MessageBase> &&msg) { (void)tcp_comm_->Send(msg.release(), false); }
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -41,8 +41,11 @@ class TCPClient {
// Disconnect from the specified server.
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5);
// Send the message from the source to the destination.
int Send(std::unique_ptr<MessageBase> &&msg);
// Send the message from the source to the destination synchronously and return the byte size by this method call.
int SendSync(std::unique_ptr<MessageBase> &&msg);
// Send the message from the source to the destination asynchronously.
void SendAsync(std::unique_ptr<MessageBase> &&msg);
private:
// The basic TCP communication component used by the client.

View File

@ -130,7 +130,8 @@ void OnAccept(int server, uint32_t events, void *arg) {
}
}
void DoSend(Connection *conn) {
int DoSend(Connection *conn) {
int total_send_bytes = 0;
while (!conn->send_message_queue.empty() || conn->total_send_len != 0) {
if (conn->total_send_len == 0) {
conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false);
@ -139,6 +140,7 @@ void DoSend(Connection *conn) {
int sendLen = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, &conn->total_send_len);
if (sendLen > 0) {
total_send_bytes += sendLen;
if (conn->total_send_len == 0) {
// update metrics
conn->send_metrics->UpdateError(false);
@ -158,6 +160,7 @@ void DoSend(Connection *conn) {
break;
}
}
return total_send_bytes;
}
TCPComm::~TCPComm() {
@ -354,8 +357,8 @@ void TCPComm::DropMessage(MessageBase *msg) {
ptr = nullptr;
}
int TCPComm::Send(MessageBase *msg) {
return send_event_loop_->AddTask([msg, this] {
int TCPComm::Send(MessageBase *msg, bool sync) {
auto task = [msg, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
// Search connection by the target address
Connection *conn = conn_pool_->FindConnection(msg->to.Url());
@ -363,7 +366,8 @@ int TCPComm::Send(MessageBase *msg) {
MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str()
<< ", from: " << msg->from.Url().c_str() << ", to: " << msg->to.Url().c_str();
DropMessage(msg);
return;
int error_no = -1;
return error_no;
}
if (conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
@ -371,7 +375,8 @@ int TCPComm::Send(MessageBase *msg) {
<< ") and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
<< ", to: " << conn->destination.c_str();
DropMessage(msg);
return;
int error_no = -1;
return error_no;
}
if (conn->state != ConnectionState::kConnected) {
@ -379,7 +384,8 @@ int TCPComm::Send(MessageBase *msg) {
<< " and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
<< ", to: " << conn->destination.c_str();
DropMessage(msg);
return;
int error_no = -1;
return error_no;
}
if (conn->total_send_len == 0) {
@ -387,8 +393,13 @@ int TCPComm::Send(MessageBase *msg) {
} else {
(void)conn->send_message_queue.emplace(msg);
}
DoSend(conn);
});
return DoSend(conn);
};
if (sync) {
return task();
} else {
return send_event_loop_->AddTask(task);
}
}
void TCPComm::Connect(const std::string &dst_url) {
@ -403,7 +414,7 @@ void TCPComm::Connect(const std::string &dst_url) {
conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection and link fail destination: " << dst_url;
return;
return false;
}
conn->source = url_;
conn->destination = dst_url;
@ -418,12 +429,12 @@ void TCPComm::Connect(const std::string &dst_url) {
SocketAddress addr;
if (!SocketOperation::GetSockAddr(dst_url, &addr)) {
MS_LOG(ERROR) << "Failed to get socket address to dest url " << dst_url;
return;
return false;
}
int sock_fd = SocketOperation::CreateSocket(addr.sa.sa_family);
if (sock_fd < 0) {
MS_LOG(ERROR) << "Failed to create client tcp socket to dest url " << dst_url;
return;
return false;
}
conn->socket_fd = sock_fd;
@ -439,12 +450,13 @@ void TCPComm::Connect(const std::string &dst_url) {
conn->socket_operation = nullptr;
}
delete conn;
return;
return false;
}
conn_pool_->AddConnection(conn);
}
conn_pool_->AddConnInfo(conn->socket_fd, dst_url, nullptr);
MS_LOG(INFO) << "Connected to destination: " << dst_url;
return true;
});
}
@ -460,6 +472,7 @@ void TCPComm::Disconnect(const std::string &dst_url) {
(void)recv_event_loop_->AddTask([dst_url, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
conn_pool_->DeleteConnection(dst_url);
return true;
});
}

View File

@ -34,7 +34,7 @@ namespace rpc {
void OnAccept(int server, uint32_t events, void *arg);
// Send messages buffered in the connection.
void DoSend(Connection *conn);
int DoSend(Connection *conn);
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
@ -65,7 +65,8 @@ class TCPComm {
void Disconnect(const std::string &dst_url);
// Send the message from the source to the destination.
int Send(MessageBase *msg);
// The flag sync means if the message is sent directly or added to the task queue.
int Send(MessageBase *msg, bool sync = false);
// Set the message processing handler.
void SetMessageHandler(MessageHandler handler);
@ -115,7 +116,7 @@ class TCPComm {
std::shared_ptr<std::mutex> conn_mutex_;
friend void OnAccept(int server, uint32_t events, void *arg);
friend void DoSend(Connection *conn);
friend int DoSend(Connection *conn);
friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
ConnectionCallBack write_callback, ConnectionCallBack read_callback);
};

View File

@ -69,7 +69,7 @@ void SendActor::SendOutput(OpContext<DeviceTensor> *const context) {
std::string peer_server_url = peer.second;
auto message = BuildRpcMessage(send_output, peer_server_url);
MS_ERROR_IF_NULL_WO_RET_VAL(message);
client_->Send(std::move(message));
client_->SendAsync(std::move(message));
}
}

View File

@ -26,6 +26,7 @@
#define private public
#include "distributed/rpc/tcp/tcp_server.h"
#include "distributed/rpc/tcp/tcp_client.h"
#include "distributed/rpc/tcp/constants.h"
#include "common/common_test.h"
namespace mindspore {
@ -165,7 +166,7 @@ TEST_F(TCPTest, SendOneMessage) {
// Send the message.
client->Connect(server_url);
client->Send(std::move(message));
client->SendAsync(std::move(message));
// Wait timeout: 5s
WaitForDataMsg(1, 5);
@ -182,7 +183,7 @@ TEST_F(TCPTest, SendOneMessage) {
/// Feature: test sending two message continuously.
/// Description: start a socket server and send two normal message to it.
/// Expectation: the server received the two messages sented from client.
TEST_F(TCPTest, sendTwoMessages) {
TEST_F(TCPTest, SendTwoMessages) {
Init();
// Start the tcp server.
@ -205,8 +206,8 @@ TEST_F(TCPTest, sendTwoMessages) {
// Send messages.
client->Connect(server_url);
client->Send(std::move(message1));
client->Send(std::move(message2));
client->SendAsync(std::move(message1));
client->SendAsync(std::move(message2));
// Wait timeout: 5s
WaitForDataMsg(2, 5);
@ -230,6 +231,45 @@ TEST_F(TCPTest, StartServerWithRandomPort) {
EXPECT_LT(0, port);
server->Finalize();
}
/// Feature: test send the message synchronously.
/// Description: start a socket server and send the message synchronously.
/// Expectation: the number of bytes sent could be got synchronously.
TEST_F(TCPTest, SendSyncMessage) {
Init();
// Start the tcp server.
auto server_url = "127.0.0.1:8081";
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>();
ret = client->Initialize();
ASSERT_TRUE(ret);
// Create the message.
auto message = CreateMessage(server_url, client_url);
auto msg_size = GetMessageSize(*message);
// Send the message.
client->Connect(server_url);
auto bytes_num = client->SendSync(std::move(message));
EXPECT_EQ(msg_size, bytes_num);
WaitForDataMsg(1, 5);
EXPECT_EQ(1, GetDataMsgNum());
// Destroy
client->Disconnect(server_url);
client->Finalize();
server->Finalize();
}
} // namespace rpc
} // namespace distributed
} // namespace mindspore