Fix exception when sending large messages

This commit is contained in:
Parallels 2022-03-26 08:21:41 +00:00
parent cd3cfc3320
commit 7bcf3f7b03
6 changed files with 167 additions and 26 deletions

View File

@ -149,6 +149,13 @@ size_t EventLoop::AddTask(std::function<int()> &&task) {
return result;
}
size_t EventLoop::RemainingTaskNum() {
task_queue_mutex_.lock();
auto task_num = task_queue_.size();
task_queue_mutex_.unlock();
return task_num;
}
bool EventLoop::Initialize(const std::string &threadName) {
int retval = InitResource();
if (retval != RPC_OK) {

View File

@ -70,6 +70,9 @@ class EventLoop {
// These tasks are executed asynchronously.
size_t AddTask(std::function<int()> &&task);
// The number of tasks in the pending task queue.
size_t RemainingTaskNum();
// Set event handler for events(read/write/..) occurred on the socket fd.
int SetEventHandler(int sock_fd, uint32_t events, EventHandler handler, void *data);

View File

@ -138,29 +138,40 @@ int DoSend(Connection *conn) {
conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false);
conn->send_message_queue.pop();
}
size_t retryCount = 10;
size_t sendLen = 0;
int retval = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, conn->total_send_len, &sendLen);
if (retval == IO_RW_OK && sendLen > 0) {
total_send_bytes += sendLen;
conn->total_send_len -= sendLen;
if (conn->total_send_len == 0) {
// update metrics
conn->send_metrics->UpdateError(false);
while (retryCount > 0 && sendLen != conn->total_send_len) {
int retval = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, conn->total_send_len, &sendLen);
if (retval == IO_RW_OK && sendLen > 0) {
conn->total_send_len -= sendLen;
if (conn->total_send_len == 0) {
// update metrics
conn->send_metrics->UpdateError(false);
conn->output_buffer_size -= conn->send_message->body.size();
delete conn->send_message;
conn->send_message = nullptr;
conn->output_buffer_size -= conn->send_message->body.size();
total_send_bytes += conn->send_message->body.size();
delete conn->send_message;
conn->send_message = nullptr;
break;
}
} else if (retval == IO_RW_OK && sendLen == 0) {
// EAGAIN
MS_LOG(ERROR) << "Failed to send message and update the epoll event";
(void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
continue;
} else {
if (--retryCount > 0) {
MS_LOG(ERROR) << "Failed to send message and retry(" + std::to_string(retryCount) + ")...";
unsigned int time = 1;
sleep(time);
continue;
} else {
// update metrics
conn->send_metrics->UpdateError(true, conn->error_code);
conn->state = ConnectionState::kDisconnecting;
break;
}
}
} else if (retval == IO_RW_OK && sendLen == 0) {
// EAGAIN
(void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
break;
} else {
// update metrics
conn->send_metrics->UpdateError(true, conn->error_code);
conn->state = ConnectionState::kDisconnecting;
break;
}
}
return total_send_bytes;
@ -445,12 +456,24 @@ bool TCPComm::IsConnected(const std::string &dst_url) {
return false;
}
void TCPComm::Disconnect(const std::string &dst_url) {
bool TCPComm::Disconnect(const std::string &dst_url) {
int interval = 100000;
size_t retry = 30;
while (recv_event_loop_->RemainingTaskNum() != 0 && send_event_loop_->RemainingTaskNum() != 0 && retry > 0) {
usleep(interval);
retry--;
}
if (recv_event_loop_->RemainingTaskNum() > 0 || send_event_loop_->RemainingTaskNum() > 0) {
MS_LOG(ERROR) << "Failed to disconnect from url " << dst_url
<< ", because there are still pending tasks to be executed, please try later.";
return false;
}
(void)recv_event_loop_->AddTask([dst_url, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
conn_pool_->DeleteConnection(dst_url);
return true;
});
return true;
}
Connection *TCPComm::CreateDefaultConn(const std::string &to) {

View File

@ -62,7 +62,7 @@ class TCPComm {
// Connection operation for a specified destination.
void Connect(const std::string &dst_url);
bool IsConnected(const std::string &dst_url);
void Disconnect(const std::string &dst_url);
bool Disconnect(const std::string &dst_url);
// Send the message from the source to the destination.
// The flag sync means if the message is sent directly or added to the task queue.

View File

@ -19,7 +19,7 @@
namespace mindspore {
namespace distributed {
namespace rpc {
constexpr int EAGAIN_RETRY = 2;
constexpr int EAGAIN_RETRY = 100;
ssize_t TCPSocketOperation::ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) {
return recv(connection->socket_fd, recvBuf, recvLen, MSG_PEEK);
@ -107,18 +107,23 @@ int TCPSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *re
int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendMsg, size_t totalSendLen,
size_t *sendLen) {
int eagainCount = EAGAIN_RETRY;
*sendLen = 0;
while (*sendLen != totalSendLen) {
auto retval = sendmsg(connection->socket_fd, sendMsg, MSG_NOSIGNAL);
if (retval < 0) {
--eagainCount;
if (errno != EAGAIN) {
MS_LOG(ERROR) << "Failed to call sendmsg and errno is: " << errno;
connection->error_code = errno;
return IO_RW_ERROR;
} else if (eagainCount == 0) {
MS_LOG(ERROR) << "Failed to call sendmsg after retry " + std::to_string(EAGAIN_RETRY) + " times and errno is: "
<< errno;
*sendLen = 0;
break;
return IO_RW_OK;
}
MS_LOG(ERROR) << "retry(" + std::to_string(eagainCount) + "/" + std::to_string(EAGAIN_RETRY) + ") sending ...";
} else {
*sendLen += retval;
@ -137,7 +142,7 @@ int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendM
reinterpret_cast<char *>(sendMsg->msg_iov[i].iov_base) + static_cast<unsigned int>(retval) - tmpBytes;
sendMsg->msg_iov = &sendMsg->msg_iov[i];
sendMsg->msg_iovlen -= (i + 1);
sendMsg->msg_iovlen -= i;
break;
}
}

View File

@ -255,7 +255,7 @@ TEST_F(TCPTest, SendSyncMessage) {
// Create the message.
auto message = CreateMessage(server_url, client_url);
auto msg_size = GetMessageSize(*message);
auto msg_size = message->body.size();
// Send the message.
client->Connect(server_url);
@ -271,6 +271,109 @@ TEST_F(TCPTest, SendSyncMessage) {
client->Finalize();
server->Finalize();
}
/// Feature: test sending large messages.
/// Description: start a socket server and send several large messages to it.
/// Expectation: the server received these large messages sented from client.
TEST_F(TCPTest, SendLargeMessages) {
Init();
// Start the tcp server.
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
bool ret = server->Initialize();
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>();
ret = client->Initialize();
ASSERT_TRUE(ret);
// Send the message.
auto ip = server->GetIP();
auto port = server->GetPort();
auto server_url = ip + ":" + std::to_string(port);
client->Connect(server_url);
size_t msg_cnt = 5;
size_t large_msg_size = 1024000;
for (int i = 0; i < msg_cnt; ++i) {
auto message = CreateMessage(server_url, client_url, large_msg_size);
client->SendAsync(std::move(message));
}
// Wait timeout: 15s
WaitForDataMsg(msg_cnt, 15);
// Check result
EXPECT_EQ(msg_cnt, GetDataMsgNum());
// Destroy
client->Disconnect(server_url);
client->Finalize();
server->Finalize();
}
/// Feature: test creating many TCP connections.
/// Description: create many servers and clients, then connect each client to a server.
/// Expectation: all the servers and clients are created successfully.
TEST_F(TCPTest, CreateManyConnectionPairs) {
Init();
std::vector<std::shared_ptr<TCPServer>> servers;
std::vector<std::shared_ptr<TCPClient>> clients;
std::vector<std::string> server_urls;
size_t total_connection_num = 10;
for (size_t i = 0; i < total_connection_num; ++i) {
// Start the tcp server.
std::shared_ptr<TCPServer> server = std::make_shared<TCPServer>();
bool ret = server->Initialize();
auto ip = server->GetIP();
auto port = server->GetPort();
ASSERT_TRUE(ret);
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
std::shared_ptr<TCPClient> client = std::make_shared<TCPClient>();
ret = client->Initialize();
ASSERT_TRUE(ret);
// Send the message.
auto server_url = ip + ":" + std::to_string(port);
server_urls.push_back(server_url);
auto success = client->Connect(server_url);
EXPECT_EQ(true, success);
size_t msg_cnt = 100;
size_t large_msg_size = 10240;
for (int i = 0; i < msg_cnt; ++i) {
auto message = CreateMessage(server_url, client_url, large_msg_size);
client->SendAsync(std::move(message));
}
// Check result
servers.push_back(server);
clients.push_back(client);
}
// Check result
EXPECT_EQ(total_connection_num, servers.size());
EXPECT_EQ(total_connection_num, clients.size());
// Destroy
for (size_t i = 0; i < total_connection_num; ++i) {
while (!clients[i]->Disconnect(server_urls[i]))
;
clients[i]->Finalize();
servers[i]->Finalize();
}
}
} // namespace rpc
} // namespace distributed
} // namespace mindspore