forked from mindspore-Ecosystem/mindspore
support sync and async message sending mode for tcp client
This commit is contained in:
parent
b3a30e2a68
commit
6f31f270e9
|
@ -62,7 +62,7 @@ bool ComputeGraphNode::Register() {
|
|||
auto message = CreateMessage(server_url, content);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
|
||||
tcp_client_->Send(std::move(message));
|
||||
tcp_client_->SendSync(std::move(message));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ bool ComputeGraphNode::Heartbeat() {
|
|||
auto message = CreateMessage(server_url, content);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
|
||||
tcp_client_->Send(std::move(message));
|
||||
tcp_client_->SendSync(std::move(message));
|
||||
return true;
|
||||
}
|
||||
} // namespace topology
|
||||
|
|
|
@ -326,12 +326,9 @@ void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseU
|
|||
if (msg->type == MessageBase::Type::KMSG) {
|
||||
if (!isHttpKmsg) {
|
||||
send_to = msg->to;
|
||||
send_from = msg->from.Name() + "@" + advertiseUrl;
|
||||
send_from = msg->from;
|
||||
|
||||
send_msg_header.name_len = htonl(static_cast<uint32_t>(msg->name.size()));
|
||||
send_msg_header.to_len = htonl(static_cast<uint32_t>(send_to.size()));
|
||||
send_msg_header.from_len = htonl(static_cast<uint32_t>(send_from.size()));
|
||||
send_msg_header.body_len = htonl(static_cast<uint32_t>(msg->body.size()));
|
||||
FillMessageHeader(*msg, &send_msg_header);
|
||||
|
||||
send_io_vec[index].iov_base = &send_msg_header;
|
||||
send_io_vec[index].iov_len = sizeof(send_msg_header);
|
||||
|
@ -431,7 +428,6 @@ int Connection::AddConnnectEventHandler() {
|
|||
}
|
||||
|
||||
bool Connection::ParseMessage() {
|
||||
std::string magic_id = "";
|
||||
int retval = 0;
|
||||
uint32_t recvLen = 0;
|
||||
char *recvBuf = nullptr;
|
||||
|
@ -454,7 +450,7 @@ bool Connection::ParseMessage() {
|
|||
|
||||
if (strncmp(recv_msg_header.magic, RPC_MAGICID, sizeof(RPC_MAGICID) - 1) != 0) {
|
||||
MS_LOG(ERROR) << "Failed to check magicid, RPC_MAGICID: " << RPC_MAGICID
|
||||
<< ", recv magic_id: " << magic_id.c_str();
|
||||
<< ", recv magic_id: " << recv_msg_header.magic;
|
||||
state = ConnectionState::kDisconnecting;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -30,27 +30,6 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
/*
|
||||
* The MessageHeader contains the stats info about the message body.
|
||||
*/
|
||||
struct MessageHeader {
|
||||
MessageHeader() {
|
||||
for (unsigned int i = 0; i < BUSMAGIC_LEN; ++i) {
|
||||
if (i < sizeof(RPC_MAGICID) - 1) {
|
||||
magic[i] = RPC_MAGICID[i];
|
||||
} else {
|
||||
magic[i] = '\0';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
char magic[BUSMAGIC_LEN];
|
||||
uint32_t name_len{0};
|
||||
uint32_t to_len{0};
|
||||
uint32_t from_len{0};
|
||||
uint32_t body_len{0};
|
||||
};
|
||||
|
||||
/*
|
||||
* The SendMetrics is responsible for collecting metrics when sending data through a connection.
|
||||
*/
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONSTANTS_H_
|
||||
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONSTANTS_H_
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <string>
|
||||
#include <csignal>
|
||||
#include <queue>
|
||||
|
@ -85,6 +86,45 @@ constexpr int IP_LEN_MAX = 128;
|
|||
// Kill the process for safe exiting.
|
||||
inline void KillProcess(const std::string &ret) { raise(SIGKILL); }
|
||||
|
||||
/*
|
||||
* The MessageHeader contains the stats info about the message body.
|
||||
*/
|
||||
struct MessageHeader {
|
||||
MessageHeader() {
|
||||
for (unsigned int i = 0; i < BUSMAGIC_LEN; ++i) {
|
||||
if (i < sizeof(RPC_MAGICID) - 1) {
|
||||
magic[i] = RPC_MAGICID[i];
|
||||
} else {
|
||||
magic[i] = '\0';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
char magic[BUSMAGIC_LEN];
|
||||
uint32_t name_len{0};
|
||||
uint32_t to_len{0};
|
||||
uint32_t from_len{0};
|
||||
uint32_t body_len{0};
|
||||
};
|
||||
|
||||
// Fill the message header using the given message.
|
||||
__attribute__((unused)) static void FillMessageHeader(const MessageBase &message, MessageHeader *header) {
|
||||
std::string send_to = message.to;
|
||||
std::string send_from = message.from;
|
||||
header->name_len = htonl(static_cast<uint32_t>(message.name.size()));
|
||||
header->to_len = htonl(static_cast<uint32_t>(send_to.size()));
|
||||
header->from_len = htonl(static_cast<uint32_t>(send_from.size()));
|
||||
header->body_len = htonl(static_cast<uint32_t>(message.body.size()));
|
||||
}
|
||||
|
||||
// Compute and return the byte size of the whole message.
|
||||
__attribute__((unused)) static size_t GetMessageSize(const MessageBase &message) {
|
||||
std::string send_to = message.to;
|
||||
std::string send_from = message.from;
|
||||
size_t size = message.name.size() + send_to.size() + send_from.size() + message.body.size() + sizeof(MessageHeader);
|
||||
return size;
|
||||
}
|
||||
|
||||
#define RPC_ASSERT(expression) \
|
||||
do { \
|
||||
if (!(expression)) { \
|
||||
|
|
|
@ -129,7 +129,7 @@ void EventLoop::ReleaseResource() {
|
|||
}
|
||||
}
|
||||
|
||||
int EventLoop::AddTask(std::function<void()> &&task) {
|
||||
int EventLoop::AddTask(std::function<int()> &&task) {
|
||||
// put func to the queue
|
||||
task_queue_mutex_.lock();
|
||||
(void)task_queue_.emplace(std::move(task));
|
||||
|
|
|
@ -68,7 +68,7 @@ class EventLoop {
|
|||
|
||||
// Add task (eg. send message, reconnect etc.) to task queue of the event loop.
|
||||
// These tasks are executed asynchronously.
|
||||
int AddTask(std::function<void()> &&task);
|
||||
int AddTask(std::function<int()> &&task);
|
||||
|
||||
// Set event handler for events(read/write/..) occurred on the socket fd.
|
||||
int SetEventHandler(int sock_fd, uint32_t events, EventHandler handler, void *data);
|
||||
|
|
|
@ -69,11 +69,13 @@ bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
|
|||
return rt;
|
||||
}
|
||||
|
||||
int TCPClient::Send(std::unique_ptr<MessageBase> &&msg) {
|
||||
int TCPClient::SendSync(std::unique_ptr<MessageBase> &&msg) {
|
||||
int rt = -1;
|
||||
rt = tcp_comm_->Send(msg.release());
|
||||
rt = tcp_comm_->Send(msg.release(), true);
|
||||
return rt;
|
||||
}
|
||||
|
||||
void TCPClient::SendAsync(std::unique_ptr<MessageBase> &&msg) { (void)tcp_comm_->Send(msg.release(), false); }
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,8 +41,11 @@ class TCPClient {
|
|||
// Disconnect from the specified server.
|
||||
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5);
|
||||
|
||||
// Send the message from the source to the destination.
|
||||
int Send(std::unique_ptr<MessageBase> &&msg);
|
||||
// Send the message from the source to the destination synchronously and return the byte size by this method call.
|
||||
int SendSync(std::unique_ptr<MessageBase> &&msg);
|
||||
|
||||
// Send the message from the source to the destination asynchronously.
|
||||
void SendAsync(std::unique_ptr<MessageBase> &&msg);
|
||||
|
||||
private:
|
||||
// The basic TCP communication component used by the client.
|
||||
|
|
|
@ -130,7 +130,8 @@ void OnAccept(int server, uint32_t events, void *arg) {
|
|||
}
|
||||
}
|
||||
|
||||
void DoSend(Connection *conn) {
|
||||
int DoSend(Connection *conn) {
|
||||
int total_send_bytes = 0;
|
||||
while (!conn->send_message_queue.empty() || conn->total_send_len != 0) {
|
||||
if (conn->total_send_len == 0) {
|
||||
conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false);
|
||||
|
@ -139,6 +140,7 @@ void DoSend(Connection *conn) {
|
|||
|
||||
int sendLen = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, &conn->total_send_len);
|
||||
if (sendLen > 0) {
|
||||
total_send_bytes += sendLen;
|
||||
if (conn->total_send_len == 0) {
|
||||
// update metrics
|
||||
conn->send_metrics->UpdateError(false);
|
||||
|
@ -158,6 +160,7 @@ void DoSend(Connection *conn) {
|
|||
break;
|
||||
}
|
||||
}
|
||||
return total_send_bytes;
|
||||
}
|
||||
|
||||
TCPComm::~TCPComm() {
|
||||
|
@ -354,8 +357,8 @@ void TCPComm::DropMessage(MessageBase *msg) {
|
|||
ptr = nullptr;
|
||||
}
|
||||
|
||||
int TCPComm::Send(MessageBase *msg) {
|
||||
return send_event_loop_->AddTask([msg, this] {
|
||||
int TCPComm::Send(MessageBase *msg, bool sync) {
|
||||
auto task = [msg, this] {
|
||||
std::lock_guard<std::mutex> lock(*conn_mutex_);
|
||||
// Search connection by the target address
|
||||
Connection *conn = conn_pool_->FindConnection(msg->to.Url());
|
||||
|
@ -363,7 +366,8 @@ int TCPComm::Send(MessageBase *msg) {
|
|||
MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str()
|
||||
<< ", from: " << msg->from.Url().c_str() << ", to: " << msg->to.Url().c_str();
|
||||
DropMessage(msg);
|
||||
return;
|
||||
int error_no = -1;
|
||||
return error_no;
|
||||
}
|
||||
|
||||
if (conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
|
||||
|
@ -371,7 +375,8 @@ int TCPComm::Send(MessageBase *msg) {
|
|||
<< ") and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
|
||||
<< ", to: " << conn->destination.c_str();
|
||||
DropMessage(msg);
|
||||
return;
|
||||
int error_no = -1;
|
||||
return error_no;
|
||||
}
|
||||
|
||||
if (conn->state != ConnectionState::kConnected) {
|
||||
|
@ -379,7 +384,8 @@ int TCPComm::Send(MessageBase *msg) {
|
|||
<< " and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
|
||||
<< ", to: " << conn->destination.c_str();
|
||||
DropMessage(msg);
|
||||
return;
|
||||
int error_no = -1;
|
||||
return error_no;
|
||||
}
|
||||
|
||||
if (conn->total_send_len == 0) {
|
||||
|
@ -387,8 +393,13 @@ int TCPComm::Send(MessageBase *msg) {
|
|||
} else {
|
||||
(void)conn->send_message_queue.emplace(msg);
|
||||
}
|
||||
DoSend(conn);
|
||||
});
|
||||
return DoSend(conn);
|
||||
};
|
||||
if (sync) {
|
||||
return task();
|
||||
} else {
|
||||
return send_event_loop_->AddTask(task);
|
||||
}
|
||||
}
|
||||
|
||||
void TCPComm::Connect(const std::string &dst_url) {
|
||||
|
@ -403,7 +414,7 @@ void TCPComm::Connect(const std::string &dst_url) {
|
|||
conn = new (std::nothrow) Connection();
|
||||
if (conn == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create new connection and link fail destination: " << dst_url;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
conn->source = url_;
|
||||
conn->destination = dst_url;
|
||||
|
@ -418,12 +429,12 @@ void TCPComm::Connect(const std::string &dst_url) {
|
|||
SocketAddress addr;
|
||||
if (!SocketOperation::GetSockAddr(dst_url, &addr)) {
|
||||
MS_LOG(ERROR) << "Failed to get socket address to dest url " << dst_url;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
int sock_fd = SocketOperation::CreateSocket(addr.sa.sa_family);
|
||||
if (sock_fd < 0) {
|
||||
MS_LOG(ERROR) << "Failed to create client tcp socket to dest url " << dst_url;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
conn->socket_fd = sock_fd;
|
||||
|
@ -439,12 +450,13 @@ void TCPComm::Connect(const std::string &dst_url) {
|
|||
conn->socket_operation = nullptr;
|
||||
}
|
||||
delete conn;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
conn_pool_->AddConnection(conn);
|
||||
}
|
||||
conn_pool_->AddConnInfo(conn->socket_fd, dst_url, nullptr);
|
||||
MS_LOG(INFO) << "Connected to destination: " << dst_url;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -460,6 +472,7 @@ void TCPComm::Disconnect(const std::string &dst_url) {
|
|||
(void)recv_event_loop_->AddTask([dst_url, this] {
|
||||
std::lock_guard<std::mutex> lock(*conn_mutex_);
|
||||
conn_pool_->DeleteConnection(dst_url);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace rpc {
|
|||
void OnAccept(int server, uint32_t events, void *arg);
|
||||
|
||||
// Send messages buffered in the connection.
|
||||
void DoSend(Connection *conn);
|
||||
int DoSend(Connection *conn);
|
||||
|
||||
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
|
||||
|
||||
|
@ -65,7 +65,8 @@ class TCPComm {
|
|||
void Disconnect(const std::string &dst_url);
|
||||
|
||||
// Send the message from the source to the destination.
|
||||
int Send(MessageBase *msg);
|
||||
// The flag sync means if the message is sent directly or added to the task queue.
|
||||
int Send(MessageBase *msg, bool sync = false);
|
||||
|
||||
// Set the message processing handler.
|
||||
void SetMessageHandler(MessageHandler handler);
|
||||
|
@ -115,7 +116,7 @@ class TCPComm {
|
|||
std::shared_ptr<std::mutex> conn_mutex_;
|
||||
|
||||
friend void OnAccept(int server, uint32_t events, void *arg);
|
||||
friend void DoSend(Connection *conn);
|
||||
friend int DoSend(Connection *conn);
|
||||
friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
|
||||
ConnectionCallBack write_callback, ConnectionCallBack read_callback);
|
||||
};
|
||||
|
|
|
@ -69,7 +69,7 @@ void SendActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
std::string peer_server_url = peer.second;
|
||||
auto message = BuildRpcMessage(send_output, peer_server_url);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
client_->Send(std::move(message));
|
||||
client_->SendAsync(std::move(message));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#define private public
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
#include "distributed/rpc/tcp/constants.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -165,7 +166,7 @@ TEST_F(TCPTest, SendOneMessage) {
|
|||
|
||||
// Send the message.
|
||||
client->Connect(server_url);
|
||||
client->Send(std::move(message));
|
||||
client->SendAsync(std::move(message));
|
||||
|
||||
// Wait timeout: 5s
|
||||
WaitForDataMsg(1, 5);
|
||||
|
@ -182,7 +183,7 @@ TEST_F(TCPTest, SendOneMessage) {
|
|||
/// Feature: test sending two message continuously.
|
||||
/// Description: start a socket server and send two normal message to it.
|
||||
/// Expectation: the server received the two messages sented from client.
|
||||
TEST_F(TCPTest, sendTwoMessages) {
|
||||
TEST_F(TCPTest, SendTwoMessages) {
|
||||
Init();
|
||||
|
||||
// Start the tcp server.
|
||||
|
@ -205,8 +206,8 @@ TEST_F(TCPTest, sendTwoMessages) {
|
|||
|
||||
// Send messages.
|
||||
client->Connect(server_url);
|
||||
client->Send(std::move(message1));
|
||||
client->Send(std::move(message2));
|
||||
client->SendAsync(std::move(message1));
|
||||
client->SendAsync(std::move(message2));
|
||||
|
||||
// Wait timeout: 5s
|
||||
WaitForDataMsg(2, 5);
|
||||
|
@ -230,6 +231,45 @@ TEST_F(TCPTest, StartServerWithRandomPort) {
|
|||
EXPECT_LT(0, port);
|
||||
server->Finalize();
|
||||
}
|
||||
|
||||
/// Feature: test send the message synchronously.
|
||||
/// Description: start a socket server and send the message synchronously.
|
||||
/// Expectation: the number of bytes sent could be got synchronously.
|
||||
TEST_F(TCPTest, SendSyncMessage) {
|
||||
Init();
|
||||
|
||||
// Start the tcp server.
|
||||
auto server_url = "127.0.0.1:8081";
|
||||
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
|
||||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>();
|
||||
ret = client->Initialize();
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
// Create the message.
|
||||
auto message = CreateMessage(server_url, client_url);
|
||||
auto msg_size = GetMessageSize(*message);
|
||||
|
||||
// Send the message.
|
||||
client->Connect(server_url);
|
||||
auto bytes_num = client->SendSync(std::move(message));
|
||||
|
||||
EXPECT_EQ(msg_size, bytes_num);
|
||||
|
||||
WaitForDataMsg(1, 5);
|
||||
EXPECT_EQ(1, GetDataMsgNum());
|
||||
|
||||
// Destroy
|
||||
client->Disconnect(server_url);
|
||||
client->Finalize();
|
||||
server->Finalize();
|
||||
}
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue