!40875 Add null pointer checking for the rpc module

Merge pull request !40875 from chengang/code_review_ms_1
This commit is contained in:
i-robot 2022-08-26 07:48:07 +00:00 committed by Gitee
commit 47b9fd0a42
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 146 additions and 18 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,6 +35,9 @@ const int kPrintTimeInterval = 50000;
// Handle socket events like read/write.
void SocketEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn == nullptr) {
return;
}
if (fd != conn->socket_fd) {
MS_LOG(ERROR) << "Failed to reuse connection, delete and close fd: " << fd << ", connfd: " << conn->socket_fd
@ -90,6 +93,9 @@ void SocketEventHandler(int fd, uint32_t events, void *context) {
void NewConnectEventHandler(int fd, uint32_t events, void *context) {
int retval = 0;
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn == nullptr) {
return;
}
conn->socket_operation->NewConnEventHandler(fd, events, context);
if (conn->state == ConnectionState::kDisconnecting) {
@ -191,6 +197,10 @@ void Connection::InitSocketOperation() {
}
bool Connection::ReconnectSourceSocket(int fd, uint32_t events, int *soError, uint32_t error) {
if (soError == nullptr) {
return false;
}
MS_EXCEPTION_IF_NULL(recv_event_loop);
socklen_t len = sizeof(*soError);
int retval = recv_event_loop->DeleteEpollEvent(fd);
@ -309,6 +319,9 @@ void Connection::CheckMessageType() {
}
std::string Connection::GenerateHttpMessage(MessageBase *msg) {
if (msg == nullptr) {
return "";
}
static const std::string postLineBegin = std::string() + "POST /";
static const std::string postLineEnd = std::string() + " HTTP/1.1\r\n";
static const std::string userAgentLineBegin = std::string() + "User-Agent: libprocess/";
@ -340,6 +353,9 @@ std::string Connection::GenerateHttpMessage(MessageBase *msg) {
}
void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg) {
if (msg == nullptr || send_metrics == nullptr) {
return;
}
if (msg->type == MessageBase::Type::KMSG) {
size_t index = 0;
if (!isHttpKmsg) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -65,9 +65,15 @@ void ConnectionPool::DeleteConnection(const std::string &dst_url) {
}
void ConnectionPool::DeleteAllConnections(std::map<std::string, Connection *> *links) const {
if (links == nullptr) {
return;
}
auto iter = links->begin();
while (iter != links->end()) {
Connection *conn = iter->second;
if (conn == nullptr) {
continue;
}
// erase link
if (conn->recv_message != nullptr) {
delete conn->recv_message;
@ -102,6 +108,9 @@ void ConnectionPool::DeleteConnInfo(int fd) {
while (iter2 != conn_infos.end()) {
auto linkInfo = *iter2;
if (linkInfo == nullptr) {
continue;
}
if (linkInfo->delete_callback) {
linkInfo->delete_callback(linkInfo->to, linkInfo->from);
}
@ -112,6 +121,9 @@ void ConnectionPool::DeleteConnInfo(int fd) {
}
void ConnectionPool::DeleteConnInfo(Connection *conn) {
if (conn == nullptr) {
return;
}
int fd = conn->socket_fd;
// If run in double link pattern, link fd and send fd must be the same, send Exit message bind on this fd
if (double_link_) {
@ -157,6 +169,9 @@ ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const std::string &dst_url)
while (iter2 != conn_infos.end()) {
auto linkInfo = *iter2;
if (linkInfo == nullptr) {
continue;
}
if (linkInfo->to == dst_url) {
return linkInfo;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,6 +35,9 @@ namespace mindspore {
namespace distributed {
namespace rpc {
int EventLoopRun(EventLoop *evloop, int timeout) {
if (evloop == nullptr) {
return RPC_ERROR;
}
struct epoll_event *events = nullptr;
(void)sem_post(&evloop->sem_id_);
@ -298,7 +301,7 @@ int EventLoop::SetEventHandler(int fd, uint32_t events, EventHandler handler, vo
}
void EventLoop::AddEvent(Event *event) {
if (!event) {
if (event == nullptr) {
return;
}
DeleteEvent(event->fd);
@ -361,6 +364,9 @@ int EventLoop::UpdateEpollEvent(int fd, uint32_t events) {
}
void EventLoop::AddDeletedEvent(Event *event) {
if (event == nullptr) {
return;
}
// caller need check eventData is not nullptr
std::list<Event *> delete_event_list;
@ -411,6 +417,9 @@ void EventLoop::RemoveDeletedEvents() {
}
int EventLoop::FindDeletedEvent(const Event *tev) {
if (tev == nullptr) {
return 0;
}
std::map<int, std::list<Event *>>::iterator fdIter = deleted_events_.find(tev->fd);
if (fdIter == deleted_events_.end()) {
return 0;
@ -429,6 +438,9 @@ int EventLoop::FindDeletedEvent(const Event *tev) {
}
void EventLoop::HandleEvent(struct epoll_event *events, size_t nevent) {
if (events == nullptr) {
return;
}
int found;
Event *tev = nullptr;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -193,11 +193,17 @@ std::string SocketOperation::GetIP(const std::string &url) {
}
bool SocketOperation::GetSockAddr(const std::string &url, SocketAddress *addr) {
if (addr == nullptr) {
return false;
}
std::string ip;
uint16_t port = 0;
size_t len = sizeof(*addr);
(void)memset_s(addr, len, 0, len);
if (memset_s(addr, len, 0, len) > 0) {
MS_LOG(ERROR) << "Failed to call memset_s.";
return false;
}
size_t index1 = url.find(URL_PROTOCOL_IP_SEPARATOR);
if (index1 == std::string::npos) {
@ -300,6 +306,9 @@ std::string SocketOperation::GetPeer(int sock_fd) {
}
int SocketOperation::Connect(int sock_fd, const struct sockaddr *sa, socklen_t saLen, uint16_t *boundPort) {
if (sa == nullptr || boundPort == nullptr) {
return RPC_ERROR;
}
int retval = 0;
retval = connect(sock_fd, sa, saLen);

View File

@ -31,6 +31,9 @@ bool SSLSocketOperation::Initialize() {
ssize_t SSLSocketOperation::ReceivePeek(Connection *, char *, uint32_t) { return 0; }
int SSLSocketOperation::Receive(Connection *connection, char *recvBuf, size_t totalRecvLen, size_t *recvLen) {
if (connection == nullptr || recvBuf == nullptr || recvLen == nullptr) {
return IO_RW_ERROR;
}
char *curRecvBuf = recvBuf;
*recvLen = 0;
@ -67,6 +70,9 @@ int SSLSocketOperation::Receive(Connection *connection, char *recvBuf, size_t to
int SSLSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *recvMsg, size_t totalRecvLen,
size_t *recvLen) {
if (connection == nullptr || recvMsg == nullptr || recvLen == nullptr) {
return IO_RW_ERROR;
}
if (totalRecvLen == 0) {
return IO_RW_OK;
}
@ -113,6 +119,9 @@ int SSLSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *re
int SSLSocketOperation::SendMessage(Connection *connection, struct msghdr *sendMsg, size_t totalSendLen,
size_t *sendLen) {
if (connection == nullptr || sendMsg == nullptr || sendLen == nullptr) {
return IO_RW_ERROR;
}
*sendLen = 0;
const size_t msg_idx = 0;
@ -161,13 +170,18 @@ void SSLSocketOperation::Close(Connection *connection) {
SSL_free(ssl_);
ssl_ = nullptr;
}
// Close the socket.
(void)close(connection->socket_fd);
connection->socket_fd = -1;
if (connection != nullptr) {
// Close the socket.
(void)close(connection->socket_fd);
connection->socket_fd = -1;
}
}
void SSLSocketOperation::NewConnEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn == nullptr) {
return;
}
uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP);
if (error > 0) {
conn->state = ConnectionState::kDisconnecting;
@ -190,6 +204,9 @@ void SSLSocketOperation::NewConnEventHandler(int fd, uint32_t events, void *cont
void SSLSocketOperation::ConnEstablishedEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn == nullptr) {
return;
}
uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP);
if (error > 0) {
conn->state = ConnectionState::kDisconnecting;
@ -211,7 +228,8 @@ void SSLSocketOperation::ConnEstablishedEventHandler(int fd, uint32_t events, vo
}
void SSLSocketOperation::Handshake(int fd, Connection *conn) const {
if (conn->state == ConnectionState::kConnected) {
if (conn == nullptr || conn->recv_event_loop == nullptr || ssl_ == nullptr ||
conn->state == ConnectionState::kConnected) {
return;
}

View File

@ -66,7 +66,10 @@ bool TCPClient::Connect(const std::string &dst_url, size_t retry_count, const Me
} else {
MS_LOG(WARNING) << "Failed to connect to the tcp server : " << dst_url << ", retry to reconnect(" << (i + 1)
<< "/" << retry_count << ")...";
(void)tcp_comm_->Disconnect(dst_url);
if (!tcp_comm_->Disconnect(dst_url)) {
MS_LOG(ERROR) << "Can not disconnect from the server: " << dst_url;
return false;
}
(void)sleep(interval);
}
}
@ -77,7 +80,10 @@ bool TCPClient::IsConnected(const std::string &dst_url) { return tcp_comm_->IsCo
bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
bool rt = false;
(void)tcp_comm_->Disconnect(dst_url);
if (!tcp_comm_->Disconnect(dst_url)) {
MS_LOG(ERROR) << "Can not disconnect from the server: " << dst_url;
return false;
}
size_t timeout_in_ms = timeout_in_sec * 1000;
size_t sleep_in_ms = 100;

View File

@ -67,7 +67,7 @@ class TCPClient {
private:
// The basic TCP communication component used by the client.
std::unique_ptr<TCPComm> tcp_comm_;
std::unique_ptr<TCPComm> tcp_comm_{nullptr};
// The mutex and condition variable used to synchronize the write and read of the received message returned by calling
// the `ReceiveSync` method.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,9 @@ namespace mindspore {
namespace distributed {
namespace rpc {
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError) {
if (conn == nullptr) {
return;
}
if (LOG_CHECK_EVERY_N()) {
MS_LOG(INFO) << "Failed to call connect, fd: " << fd << ", to: " << conn->destination.c_str()
<< ", events: " << error << ", errno: " << soError;
@ -43,6 +46,9 @@ void ConnectedEventHandler(int fd, uint32_t events, void *context) {
uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP);
int soError = 0;
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn == nullptr || conn->socket_operation == nullptr) {
return;
}
conn->socket_operation->ConnEstablishedEventHandler(fd, events, context);
if (conn->state == ConnectionState::kDisconnecting) {
DoDisconnect(fd, conn, error, soError);
@ -67,6 +73,9 @@ void OnAccept(int server, uint32_t events, void *arg) {
return;
}
TCPComm *tcpmgr = reinterpret_cast<TCPComm *>(arg);
if (tcpmgr == nullptr || tcpmgr->conn_pool_ == nullptr) {
return;
}
if (tcpmgr->recv_event_loop_ == nullptr) {
MS_LOG(ERROR) << "EventLoop is null, server fd: " << server << ", events: " << events;
return;
@ -130,6 +139,7 @@ void OnAccept(int server, uint32_t events, void *arg) {
acceptFd = -1;
delete conn->send_metrics;
delete conn;
conn = nullptr;
return;
}
tcpmgr->conn_pool_->AddConnection(conn);
@ -188,7 +198,10 @@ bool TCPComm::StartServerSocket(const std::string &url, const MemAllocateCallbac
allocate_cb_ = allocate_cb;
size_t index = url.find(URL_PROTOCOL_IP_SEPARATOR);
if (index != std::string::npos) {
url_ = url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
index = index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1;
if (index < url.length()) {
url_ = url.substr(index);
}
}
// Register read event callback for server socket
@ -225,6 +238,9 @@ void TCPComm::SetMessageFreeCallback(const std::string &dst_url, const MemFreeCa
void TCPComm::ReadCallBack(void *connection) {
const int max_recv_count = 3;
Connection *conn = reinterpret_cast<Connection *>(connection);
if (conn == nullptr) {
return;
}
int count = 0;
int retval = 0;
do {
@ -237,7 +253,9 @@ void TCPComm::ReadCallBack(void *connection) {
void TCPComm::EventCallBack(void *connection) {
Connection *conn = reinterpret_cast<Connection *>(connection);
if (conn == nullptr) {
return;
}
if (conn->state == ConnectionState::kConnected) {
conn->conn_mutex->lock();
(void)conn->Flush();
@ -250,6 +268,9 @@ void TCPComm::EventCallBack(void *connection) {
void TCPComm::WriteCallBack(void *connection) {
Connection *conn = reinterpret_cast<Connection *>(connection);
if (conn == nullptr) {
return;
}
if (conn->state == ConnectionState::kConnected) {
conn->conn_mutex->lock();
(void)conn->Flush();
@ -279,6 +300,9 @@ int TCPComm::SetConnectedHandler(Connection *conn) {
/* static method */
int TCPComm::DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
if (conn == nullptr || conn->recv_event_loop == nullptr || sa == nullptr) {
return RPC_ERROR;
}
int retval = 0;
uint16_t localPort = 0;
@ -315,6 +339,9 @@ void TCPComm::DropMessage(MessageBase *msg) {
}
ssize_t TCPComm::Send(MessageBase *msg, bool sync) {
if (msg == nullptr) {
return -1;
}
auto task = [msg, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
// Search connection by the target address
@ -375,6 +402,9 @@ bool TCPComm::Flush(const std::string &dst_url) {
}
bool TCPComm::Connect(const std::string &dst_url) {
MS_EXCEPTION_IF_NULL(conn_mutex_);
MS_EXCEPTION_IF_NULL(conn_pool_);
std::lock_guard<std::mutex> lock(*conn_mutex_);
// Search connection by the target address
@ -419,6 +449,7 @@ bool TCPComm::Connect(const std::string &dst_url) {
conn->socket_operation = nullptr;
}
delete conn;
conn = nullptr;
return false;
}
conn->source = SocketOperation::GetLocalIP() + ":" + std::to_string(SocketOperation::GetPort(sock_fd));
@ -442,6 +473,7 @@ bool TCPComm::Connect(const std::string &dst_url) {
}
bool TCPComm::IsConnected(const std::string &dst_url) {
MS_EXCEPTION_IF_NULL(conn_pool_);
Connection *conn = conn_pool_->FindConnection(dst_url);
if (conn != nullptr && conn->state == ConnectionState::kConnected) {
return true;
@ -450,6 +482,11 @@ bool TCPComm::IsConnected(const std::string &dst_url) {
}
bool TCPComm::Disconnect(const std::string &dst_url) {
MS_EXCEPTION_IF_NULL(conn_mutex_);
MS_EXCEPTION_IF_NULL(conn_pool_);
MS_EXCEPTION_IF_NULL(recv_event_loop_);
MS_EXCEPTION_IF_NULL(send_event_loop_);
unsigned int interval = 100000;
size_t retry = 30;
while (recv_event_loop_->RemainingTaskNum() != 0 && send_event_loop_->RemainingTaskNum() != 0 && retry > 0) {

View File

@ -51,7 +51,7 @@ class TCPServer {
bool InitializeImpl(const std::string &url, const MemAllocateCallback &allocate_cb);
// The basic TCP communication component used by the server.
std::unique_ptr<TCPComm> tcp_comm_;
std::unique_ptr<TCPComm> tcp_comm_{nullptr};
std::string ip_{""};
uint32_t port_{0};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,9 @@ ssize_t TCPSocketOperation::ReceivePeek(Connection *connection, char *recvBuf, u
}
int TCPSocketOperation::Receive(Connection *connection, char *recvBuf, size_t totalRecvLen, size_t *recvLen) {
if (connection == nullptr || recvBuf == nullptr || recvLen == nullptr) {
return IO_RW_ERROR;
}
char *curRecvBuf = recvBuf;
int fd = connection->socket_fd;
@ -106,6 +109,9 @@ int TCPSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *re
int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendMsg, size_t totalSendLen,
size_t *sendLen) {
if (connection == nullptr || sendMsg == nullptr || sendLen == nullptr) {
return IO_RW_ERROR;
}
int eagainCount = 0;
// Print retry log interval.
const int print_interval = 10000;
@ -160,6 +166,9 @@ int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendM
}
void TCPSocketOperation::Close(Connection *connection) {
if (connection == nullptr) {
return;
}
(void)close(connection->socket_fd);
connection->socket_fd = -1;
}
@ -167,12 +176,18 @@ void TCPSocketOperation::Close(Connection *connection) {
// accept new conn event handle
void TCPSocketOperation::NewConnEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn == nullptr) {
return;
}
conn->state = ConnectionState::kConnected;
return;
}
void TCPSocketOperation::ConnEstablishedEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn == nullptr) {
return;
}
conn->state = ConnectionState::kConnected;
return;
}