!29786 Add TCP client and TCP server for RPC

Merge pull request !29786 from chengang/add_tcp_server
This commit is contained in:
i-robot 2022-02-10 07:47:12 +00:00 committed by Gitee
commit 7b6f7c9c87
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 528 additions and 733 deletions

View File

@ -25,8 +25,6 @@
namespace mindspore {
namespace distributed {
namespace rpc {
std::mutex Connection::conn_mutex;
// Handle socket events like read/write.
void SocketEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
@ -245,7 +243,7 @@ void Connection::Close() {
}
}
int Connection::ReceiveMessage(IOMgr::MessageHandler msgHandler) {
int Connection::ReceiveMessage() {
bool ok = ParseMessage();
// If no message parsed, wait for next read
if (!ok) {
@ -255,26 +253,12 @@ int Connection::ReceiveMessage(IOMgr::MessageHandler msgHandler) {
return 0;
}
if (destination.empty()) {
std::string fromUrl = recv_message->from;
size_t index = fromUrl.find("@");
if (index != std::string::npos) {
destination = fromUrl.substr(index + 1);
MS_LOG(INFO) << "Create new connection fd: " << socket_fd << " to: " << destination.c_str();
Connection::conn_mutex.lock();
ConnectionPool::GetConnectionPool()->SetConnPriority(destination, false, ConnectionPriority::kPriorityLow);
state = ConnectionState::kConnected;
ConnectionPool::GetConnectionPool()->AddConnection(this);
Connection::conn_mutex.unlock();
}
}
std::unique_ptr<MessageBase> msg(recv_message);
recv_message = nullptr;
// Call msg handler if set
if (msgHandler != nullptr) {
msgHandler(std::move(msg));
if (message_handler != nullptr) {
message_handler(std::move(msg));
} else {
MS_LOG(INFO) << "Message handler was not found";
}
@ -282,7 +266,6 @@ int Connection::ReceiveMessage(IOMgr::MessageHandler msgHandler) {
}
void Connection::CheckMessageType() {
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
if (recv_message_type != ParseType::kUnknown) {
return;
}

View File

@ -20,9 +20,9 @@
#include <queue>
#include <string>
#include <mutex>
#include <memory>
#include "actor/msg.h"
#include "actor/iomgr.h"
#include "distributed/rpc/tcp/constants.h"
#include "distributed/rpc/tcp/event_loop.h"
#include "distributed/rpc/tcp/socket_operation.h"
@ -118,7 +118,7 @@ struct Connection {
// Close this connection.
void Close();
int ReceiveMessage(IOMgr::MessageHandler msgHandler);
int ReceiveMessage();
void CheckMessageType();
// Fill the message to be sent based on the input message.
@ -170,6 +170,9 @@ struct Connection {
MessageBase *send_message;
MessageBase *recv_message;
// Owned by the tcp_comm.
std::shared_ptr<std::mutex> conn_mutex;
State recv_state;
// Total length of received and sent messages.
@ -201,6 +204,9 @@ struct Connection {
ConnectionCallBack write_callback;
ConnectionCallBack read_callback;
// Function for handling received messages.
MessageHandler message_handler;
// Buffer for messages to be sent.
std::queue<MessageBase *> send_message_queue;
@ -209,8 +215,6 @@ struct Connection {
// The error code when sending or receiving messages.
int error_code;
static std::mutex conn_mutex;
private:
// Add handler for socket connect event.
int AddConnnectEventHandler();

View File

@ -20,8 +20,6 @@
namespace mindspore {
namespace distributed {
namespace rpc {
ConnectionPool *ConnectionPool::conn_pool = new ConnectionPool();
void ConnectionPool::SetLinkPattern(bool linkPattern) { double_link_ = linkPattern; }
void ConnectionPool::CloseConnection(Connection *conn) {
@ -36,57 +34,22 @@ void ConnectionPool::CloseConnection(Connection *conn) {
}
if (!conn->destination.empty()) {
if (conn->is_remote) {
(void)remote_conns_.erase(conn->destination);
} else {
(void)local_conns_.erase(conn->destination);
}
(void)connections_.erase(conn->destination);
}
conn->Close();
delete conn;
conn = nullptr;
}
Connection *ConnectionPool::FindConnection(const std::string &to, bool remoteLink) {
Connection *ConnectionPool::FindConnection(const std::string &dst_url) {
Connection *conn = nullptr;
if (!remoteLink) {
auto iter = local_conns_.find(to);
if (iter != local_conns_.end()) {
conn = iter->second;
return conn;
}
}
auto iter = remote_conns_.find(to);
if (iter != remote_conns_.end()) {
auto iter = connections_.find(dst_url);
if (iter != connections_.end()) {
conn = iter->second;
}
return conn;
}
Connection *ConnectionPool::ExactFindConnection(const std::string &to, bool remoteLink) {
Connection *conn = nullptr;
if (!remoteLink) {
auto iter = local_conns_.find(to);
if (iter != local_conns_.end()) {
conn = iter->second;
}
} else {
auto iter = remote_conns_.find(to);
if (iter != remote_conns_.end()) {
conn = iter->second;
}
}
return conn;
}
Connection *ConnectionPool::FindConnection(const std::string &to, bool remoteLink, bool exactNotRemote) {
if (exactNotRemote) {
return ExactFindConnection(to, false);
} else {
return FindConnection(to, remoteLink);
}
}
void ConnectionPool::ResetAllConnMetrics() {
for (const auto &iter : local_conns_) {
iter.second->send_metrics->Reset();
@ -96,46 +59,10 @@ void ConnectionPool::ResetAllConnMetrics() {
}
}
Connection *ConnectionPool::FindMaxConnection() {
Connection *conn = nullptr;
size_t count = 0;
for (const auto &iter : local_conns_) {
if (iter.second->send_metrics->accum_msg_count > count) {
count = iter.second->send_metrics->accum_msg_count;
conn = iter.second;
}
}
for (const auto &iter : remote_conns_) {
if (iter.second->send_metrics->accum_msg_count > count) {
count = iter.second->send_metrics->accum_msg_count;
conn = iter.second;
}
}
return conn;
}
Connection *ConnectionPool::FindFastConnection() {
Connection *conn = nullptr;
size_t size = 0;
for (const auto &iter : local_conns_) {
if (iter.second->send_metrics->max_msg_size > size) {
size = iter.second->send_metrics->max_msg_size;
conn = iter.second;
}
}
for (const auto &iter : remote_conns_) {
if (iter.second->send_metrics->max_msg_size > size) {
size = iter.second->send_metrics->max_msg_size;
conn = iter.second;
}
}
return conn;
}
void ConnectionPool::ExactDeleteConnection(const std::string &to, bool remoteLink) {
Connection *conn = ExactFindConnection(to, remoteLink);
void ConnectionPool::DeleteConnection(const std::string &dst_url) {
Connection *conn = FindConnection(dst_url);
if (conn != nullptr) {
MS_LOG(INFO) << "unLink fd:" << conn->socket_fd << ",to:" << to.c_str() << ",remote:" << remoteLink;
MS_LOG(INFO) << "unLink fd:" << conn->socket_fd << ",to:" << dst_url;
CloseConnection(conn);
}
}
@ -156,26 +83,15 @@ void ConnectionPool::DeleteAllConnections(std::map<std::string, Connection *> *l
void ConnectionPool::AddConnection(Connection *conn) {
if (conn == nullptr) {
MS_LOG(ERROR) << "The connection is null";
return;
}
Connection *tmpConn = ExactFindConnection(conn->destination, conn->is_remote);
if (tmpConn != nullptr && tmpConn->is_remote == conn->is_remote) {
Connection *tmpConn = FindConnection(conn->destination);
if (tmpConn != nullptr) {
MS_LOG(INFO) << "unLink fd:" << tmpConn->socket_fd << ",to:" << tmpConn->destination.c_str();
CloseConnection(tmpConn);
}
if (conn->is_remote) {
(void)remote_conns_.emplace(conn->destination, conn);
} else {
(void)local_conns_.emplace(conn->destination, conn);
}
}
void ConnectionPool::SetConnPriority(const std::string &to, bool remoteLink, ConnectionPriority pri) {
Connection *conn = ExactFindConnection(to, remoteLink);
if (conn != nullptr && conn->is_remote == remoteLink) {
conn->priority = pri;
}
(void)connections_.emplace(conn->destination, conn);
}
void ConnectionPool::DeleteConnInfo(int fd) {
@ -207,23 +123,13 @@ void ConnectionPool::DeleteConnInfo(const std::string &to, int fd) {
// If run in single link pattern, link fd and send fd may not be the same, we should send Exit message bind
// on link fd and remote link fd. Here 'deleted' flag should be set true to avoid duplicate Exit message with
// same aid.
Connection *nonRemoteConn = ConnectionPool::ExactFindConnection(to, false);
if (nonRemoteConn != nullptr) {
nonRemoteConn->deleted = true;
DeleteConnInfo(nonRemoteConn->socket_fd);
Connection *conn = FindConnection(to);
if (conn != nullptr) {
conn->deleted = true;
DeleteConnInfo(conn->socket_fd);
if (nonRemoteConn->socket_fd != fd) {
MS_LOG(INFO) << "delete linker bind on link fd:" << nonRemoteConn->socket_fd << ",delete fd:" << fd;
}
}
Connection *remoteConn = ConnectionPool::ExactFindConnection(to, true);
if (remoteConn != nullptr) {
remoteConn->deleted = true;
DeleteConnInfo(remoteConn->socket_fd);
if (remoteConn->socket_fd != fd) {
MS_LOG(INFO) << "delete linker bind on remote link fd:" << remoteConn->socket_fd << ",delete fd:" << fd;
if (conn->socket_fd != fd) {
MS_LOG(INFO) << "delete linker bind on link fd:" << conn->socket_fd << ",delete fd:" << fd;
}
}
}
@ -243,7 +149,7 @@ void ConnectionPool::DeleteAllConnInfos() {
}
}
ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const AID &sAid, const AID &dAid) {
ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const std::string &dst_url) {
auto iter = conn_infos_.find(fd);
if (iter == conn_infos_.end()) {
return nullptr;
@ -253,7 +159,7 @@ ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const AID &sAid, const AID
while (iter2 != conn_infos.end()) {
auto linkInfo = *iter2;
if (AID(linkInfo->from) == sAid && AID(linkInfo->to) == dAid) {
if (linkInfo->to == dst_url) {
return linkInfo;
}
++iter2;
@ -261,19 +167,18 @@ ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const AID &sAid, const AID
return nullptr;
}
void ConnectionPool::AddConnInfo(int fd, const AID &sAid, const AID &dAid, DeleteCallBack callback) {
ConnectionInfo *linker = FindConnInfo(fd, sAid, dAid);
void ConnectionPool::AddConnInfo(int fd, const std::string &dst_url, DeleteCallBack callback) {
ConnectionInfo *linker = FindConnInfo(fd, dst_url);
if (linker != nullptr) {
return;
}
linker = new (std::nothrow) ConnectionInfo();
if (linker == nullptr) {
MS_LOG(ERROR) << "new ConnectionInfo fail sAid:" << std::string(sAid).c_str()
<< ",dAid:" << std::string(dAid).c_str();
MS_LOG(ERROR) << "new ConnectionInfo fail dAid:" << dst_url;
return;
}
linker->from = sAid;
linker->to = dAid;
linker->from = "";
linker->to = dst_url;
linker->socket_fd = fd;
linker->delete_callback = callback;
(void)conn_infos_[fd].insert(linker);
@ -299,8 +204,6 @@ ConnectionPool::~ConnectionPool() {
MS_LOG(ERROR) << "Failed to release resource for connection pool.";
}
}
ConnectionPool *ConnectionPool::GetConnectionPool() { return ConnectionPool::conn_pool; }
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -42,13 +42,10 @@ class ConnectionPool {
ConnectionPool() : double_link_(false) {}
~ConnectionPool();
// Get the singleton instance of ConnectionPool.
static ConnectionPool *GetConnectionPool();
/*
* Operations for ConnectionInfo.
*/
void AddConnInfo(int socket_fd, const AID &sAid, const AID &dAid, DeleteCallBack delcb);
void AddConnInfo(int socket_fd, const std::string &dst_url, DeleteCallBack delcb);
bool ReverseConnInfo(int from_socket_fd, int to_socket_fd);
/*
@ -58,19 +55,14 @@ class ConnectionPool {
void AddConnection(Connection *conn);
// Find connection.
Connection *ExactFindConnection(const std::string &to, bool remoteLink);
Connection *FindConnection(const std::string &to, bool remoteLink);
Connection *FindConnection(const std::string &to, bool remoteLink, bool exactNotRemote);
Connection *FindMaxConnection();
Connection *FindFastConnection();
Connection *FindConnection(const std::string &dst_url);
// Delete connection.
void ExactDeleteConnection(const std::string &to, bool remoteLink);
void DeleteConnection(const std::string &dst_url);
void DeleteAllConnections(std::map<std::string, Connection *> *alllinks);
// Close connection.
void CloseConnection(Connection *conn);
void SetConnPriority(const std::string &to, bool remoteLink, ConnectionPriority pri);
// Single link or double link.
void SetLinkPattern(bool linkPattern);
@ -78,7 +70,7 @@ class ConnectionPool {
void ResetAllConnMetrics();
private:
ConnectionInfo *FindConnInfo(int socket_fd, const AID &sAid, const AID &dAid);
ConnectionInfo *FindConnInfo(int socket_fd, const std::string &dst_url);
void DeleteConnInfo(int socket_fd);
void DeleteConnInfo(const std::string &to, int socket_fd);
@ -92,11 +84,12 @@ class ConnectionPool {
// Maintains the remote connections by remote server addresses.
std::map<std::string, Connection *> remote_conns_;
// Maintains the connections by remote server addresses.
std::map<std::string, Connection *> connections_;
// each to_url has two fds at most, and each fd has multiple linkinfos
std::map<int, std::set<ConnectionInfo *>> conn_infos_;
static ConnectionPool *conn_pool;
friend class Connection;
friend class TCPComm;
};

View File

@ -20,12 +20,15 @@
#include <string>
#include <csignal>
#include <queue>
#include <memory>
#include "actor/log.h"
#include "actor/msg.h"
namespace mindspore {
namespace distributed {
namespace rpc {
using MessageHandler = void (*)(std::unique_ptr<MessageBase> &&msg);
using DeleteCallBack = void (*)(const std::string &from, const std::string &to);
using ConnectionCallBack = void (*)(void *conn);

View File

@ -66,7 +66,7 @@ class EventLoop {
bool Initialize(const std::string &threadName);
void Finalize();
// Add task (eg. send message, reconnect etc.) to task queue of the event loop by user.
// Add task (eg. send message, reconnect etc.) to task queue of the event loop.
// These tasks are executed asynchronously.
int AddTask(std::function<void()> &&task);

View File

@ -90,7 +90,7 @@ int SocketOperation::SetSocketOptions(int sock_fd) {
return 0;
}
int SocketOperation::CreateServerSocket(sa_family_t family) {
int SocketOperation::CreateSocket(sa_family_t family) {
int ret = 0;
int fd = 0;
@ -292,7 +292,7 @@ int SocketOperation::Listen(const std::string &url) {
}
// create server socket
listenFd = CreateServerSocket(addr.sa.sa_family);
listenFd = CreateSocket(addr.sa.sa_family);
if (listenFd < 0) {
MS_LOG(ERROR) << "Failed to create socket, url: " << url.c_str();
return -1;

View File

@ -48,8 +48,8 @@ class SocketOperation {
// Get socket address of the url.
static bool GetSockAddr(const std::string &url, SocketAddress *addr);
// Create a server socket.
static int CreateServerSocket(sa_family_t family);
// Create a socket.
static int CreateSocket(sa_family_t family);
// Set socket options.
static int SetSocketOptions(int sock_fd);

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "distributed/rpc/tcp/tcp_client.h"
namespace mindspore {
namespace distributed {
namespace rpc {
bool TCPClient::Initialize() {
bool rt = false;
if (tcp_comm_ == nullptr) {
tcp_comm_ = std::make_unique<TCPComm>();
MS_EXCEPTION_IF_NULL(tcp_comm_);
rt = tcp_comm_->Initialize();
} else {
rt = true;
}
return rt;
}
void TCPClient::Finalize() { tcp_comm_->Finalize(); }
bool TCPClient::Connect(const std::string &dst_url, size_t timeout_in_sec) {
bool rt = false;
tcp_comm_->Connect(dst_url);
int timeout = timeout_in_sec * 1000 * 1000;
size_t usleep_count = 100000;
while (timeout) {
if (tcp_comm_->IsConnected(dst_url)) {
rt = true;
break;
}
timeout = timeout - usleep_count;
usleep(usleep_count);
}
return rt;
}
bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
bool rt = false;
tcp_comm_->Disconnect(dst_url);
int timeout = timeout_in_sec * 1000 * 1000;
size_t usleep_count = 100000;
while (timeout) {
if (!tcp_comm_->IsConnected(dst_url)) {
rt = true;
break;
}
timeout = timeout - usleep_count;
usleep(usleep_count);
}
return rt;
}
int TCPClient::Send(std::unique_ptr<MessageBase> &&msg) {
int rt = -1;
rt = tcp_comm_->Send(msg.release());
return rt;
}
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_CLIENT_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_CLIENT_H_
#include <string>
#include <memory>
#include "distributed/rpc/tcp/tcp_comm.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace distributed {
namespace rpc {
class TCPClient {
public:
TCPClient() = default;
~TCPClient() = default;
// Build or destroy the TCP client.
bool Initialize();
void Finalize();
// Connect to the specified server.
bool Connect(const std::string &dst_url, size_t timeout_in_sec = 5);
// Disconnect from the specified server.
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5);
// Send the message from the source to the destination.
int Send(std::unique_ptr<MessageBase> &&msg);
private:
// The basic TCP communication component used by the client.
std::unique_ptr<TCPComm> tcp_comm_;
DISABLE_COPY_AND_ASSIGN(TCPClient);
};
} // namespace rpc
} // namespace distributed
} // namespace mindspore
#endif

View File

@ -21,47 +21,12 @@
#include <memory>
#include "actor/aid.h"
#include "actor/msg.h"
#include "distributed/rpc/tcp/constants.h"
#include "distributed/rpc/tcp/tcp_socket_operation.h"
#include "distributed/rpc/tcp/connection_pool.h"
namespace mindspore {
namespace distributed {
namespace rpc {
bool TCPComm::is_http_msg_ = false;
std::vector<char> TCPComm::advertise_url_;
uint64_t TCPComm::output_buf_size_ = 0;
IOMgr::MessageHandler TCPComm::message_handler_;
int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
ConnectionCallBack write_callback, ConnectionCallBack read_callback) {
SocketAddress addr;
if (!SocketOperation::GetSockAddr(to, &addr)) {
return -1;
}
int sock_fd = SocketOperation::CreateServerSocket(addr.sa.sa_family);
if (sock_fd < 0) {
return -1;
}
conn->socket_fd = sock_fd;
conn->event_callback = event_callback;
conn->write_callback = write_callback;
conn->read_callback = read_callback;
int ret = TCPComm::Connect(conn, (struct sockaddr *)&addr, sizeof(addr));
if (ret < 0) {
if (close(sock_fd) != 0) {
MS_LOG(ERROR) << "Failed to close fd:" << sock_fd;
}
conn->socket_fd = -1;
return -1;
}
return 0;
}
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError) {
if (LOG_CHECK_EVERY_N()) {
MS_LOG(INFO) << "Failed to call connect, fd: " << fd << ", to: " << conn->destination.c_str()
@ -137,13 +102,16 @@ void OnAccept(int server, uint32_t events, void *arg) {
}
conn->socket_fd = acceptFd;
conn->source = TCPComm::advertise_url_.data();
conn->source = tcpmgr->url_;
conn->peer = SocketOperation::GetPeer(acceptFd);
conn->is_remote = true;
conn->recv_event_loop = tcpmgr->recv_event_loop_;
conn->send_event_loop = tcpmgr->send_event_loop_;
conn->conn_mutex = tcpmgr->conn_mutex_;
conn->message_handler = tcpmgr->message_handler_;
conn->event_callback = TCPComm::EventCallBack;
conn->write_callback = TCPComm::WriteCallBack;
conn->read_callback = TCPComm::ReadCallBack;
@ -165,7 +133,7 @@ void OnAccept(int server, uint32_t events, void *arg) {
void DoSend(Connection *conn) {
while (!conn->send_message_queue.empty() || conn->total_send_len != 0) {
if (conn->total_send_len == 0) {
conn->FillSendMessage(conn->send_message_queue.front(), TCPComm::advertise_url_.data(), TCPComm::IsHttpMsg());
conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false);
conn->send_message_queue.pop();
}
@ -175,7 +143,6 @@ void DoSend(Connection *conn) {
// update metrics
conn->send_metrics->UpdateError(false);
TCPComm::output_buf_size_ -= conn->send_message->body.size();
conn->output_buffer_size -= conn->send_message->body.size();
delete conn->send_message;
conn->send_message = nullptr;
@ -201,25 +168,15 @@ TCPComm::~TCPComm() {
}
}
void TCPComm::SendExitMsg(const std::string &from, const std::string &to) {
if (message_handler_ != nullptr) {
std::unique_ptr<MessageBase> exit_msg = std::make_unique<MessageBase>(MessageBase::Type::KEXIT);
MS_EXCEPTION_IF_NULL(exit_msg);
exit_msg->SetFrom(AID(from));
exit_msg->SetTo(AID(to));
message_handler_(std::move(exit_msg));
}
}
void TCPComm::SetMessageHandler(IOMgr::MessageHandler handler) { message_handler_ = handler; }
void TCPComm::SetMessageHandler(MessageHandler handler) { message_handler_ = handler; }
bool TCPComm::Initialize() {
if (ConnectionPool::GetConnectionPool() == nullptr) {
MS_LOG(ERROR) << "Failed to create connection pool.";
return false;
}
conn_pool_ = std::make_shared<ConnectionPool>();
MS_EXCEPTION_IF_NULL(conn_pool_);
conn_mutex_ = std::make_shared<std::mutex>();
MS_EXCEPTION_IF_NULL(conn_mutex_);
recv_event_loop_ = new (std::nothrow) EventLoop();
if (recv_event_loop_ == nullptr) {
MS_LOG(ERROR) << "Failed to create recv evLoop.";
@ -251,62 +208,35 @@ bool TCPComm::Initialize() {
return false;
}
if (g_httpKmsgEnable >= 0) {
is_http_msg_ = (g_httpKmsgEnable == 0) ? false : true;
}
ConnectionPool::GetConnectionPool()->SetLinkPattern(is_http_msg_);
return true;
}
bool TCPComm::StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) {
bool TCPComm::StartServerSocket(const std::string &url) {
server_fd_ = SocketOperation::Listen(url);
if (server_fd_ < 0) {
MS_LOG(ERROR) << "Failed to call socket listen, url: " << url.c_str()
<< ", advertise_url_: " << advertise_url_.data();
MS_LOG(ERROR) << "Failed to call socket listen, url: " << url.c_str();
return false;
}
url_ = url;
std::string tmp_url;
if (aAdvertiseUrl.size() > 0) {
advertise_url_.resize(aAdvertiseUrl.size());
advertise_url_.assign(aAdvertiseUrl.begin(), aAdvertiseUrl.end());
tmp_url = aAdvertiseUrl;
} else {
advertise_url_.resize(url_.size());
advertise_url_.assign(url_.begin(), url_.end());
tmp_url = url_;
}
size_t index = url.find(URL_PROTOCOL_IP_SEPARATOR);
if (index != std::string::npos) {
url_ = url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
}
index = tmp_url.find(URL_PROTOCOL_IP_SEPARATOR);
if (index != std::string::npos) {
tmp_url = tmp_url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
advertise_url_.resize(tmp_url.size());
advertise_url_.assign(tmp_url.begin(), tmp_url.end());
}
// Register read event callback for server socket
int retval = recv_event_loop_->SetEventHandler(server_fd_, EPOLLIN | EPOLLHUP | EPOLLERR, OnAccept,
reinterpret_cast<void *>(this));
if (retval != RPC_OK) {
MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str()
<< ", advertise_url_: " << advertise_url_.data();
MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str();
return false;
}
MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str()
<< ", advertise_url_ :" << advertise_url_.data();
MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str();
return true;
}
void TCPComm::ReadCallBack(void *context) {
void TCPComm::ReadCallBack(void *connection) {
const int max_recv_count = 3;
Connection *conn = reinterpret_cast<Connection *>(context);
Connection *conn = reinterpret_cast<Connection *>(connection);
int count = 0;
int retval = 0;
do {
@ -317,35 +247,35 @@ void TCPComm::ReadCallBack(void *context) {
return;
}
void TCPComm::EventCallBack(void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
void TCPComm::EventCallBack(void *connection) {
Connection *conn = reinterpret_cast<Connection *>(connection);
if (conn->state == ConnectionState::kConnected) {
Connection::conn_mutex.lock();
conn->conn_mutex->lock();
DoSend(conn);
Connection::conn_mutex.unlock();
conn->conn_mutex->unlock();
} else if (conn->state == ConnectionState::kDisconnecting) {
Connection::conn_mutex.lock();
output_buf_size_ -= conn->output_buffer_size;
ConnectionPool::GetConnectionPool()->CloseConnection(conn);
Connection::conn_mutex.unlock();
conn->conn_mutex->lock();
conn->conn_mutex->unlock();
}
}
void TCPComm::WriteCallBack(void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
void TCPComm::WriteCallBack(void *connection) {
Connection *conn = reinterpret_cast<Connection *>(connection);
if (conn->state == ConnectionState::kConnected) {
Connection::conn_mutex.lock();
conn->conn_mutex->lock();
DoSend(conn);
Connection::conn_mutex.unlock();
conn->conn_mutex->unlock();
}
}
/* static method */
int TCPComm::ReceiveMessage(Connection *conn) {
std::lock_guard<std::mutex> lock(*conn->conn_mutex);
conn->CheckMessageType();
switch (conn->recv_message_type) {
case ParseType::kTcpMsg:
return conn->ReceiveMessage(message_handler_);
return conn->ReceiveMessage();
#ifdef HTTP_ENABLED
case ParseType::KHTTP_REQ:
@ -370,6 +300,7 @@ int TCPComm::ReceiveMessage(Connection *conn) {
}
}
/* static method */
int TCPComm::SetConnectedHandler(Connection *conn) {
/* add to epoll */
return conn->recv_event_loop->SetEventHandler(conn->socket_fd,
@ -377,7 +308,8 @@ int TCPComm::SetConnectedHandler(Connection *conn) {
ConnectedEventHandler, reinterpret_cast<void *>(conn));
}
int TCPComm::Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
/* static method */
int TCPComm::DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
int retval = 0;
uint16_t localPort = 0;
@ -406,195 +338,93 @@ int TCPComm::Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLe
return RPC_OK;
}
void TCPComm::Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) {
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, isExactNotRemote);
// Create a new connection if the connection to target of the message does not existed.
if (conn == nullptr) {
if (remoteLink && (!isExactNotRemote)) {
MS_LOG(ERROR) << "Could not found remote link and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
delete msg;
return;
}
conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
delete msg;
return;
}
conn->source = advertise_url_.data();
conn->destination = msg->to.Url();
conn->recv_event_loop = tcpmgr->recv_event_loop_;
conn->send_event_loop = tcpmgr->send_event_loop_;
conn->InitSocketOperation();
int ret = DoConnect(msg->to.Url(), conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
if (ret < 0) {
MS_LOG(ERROR) << "Failed to do connect and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
if (conn->socket_operation != nullptr) {
delete conn->socket_operation;
conn->socket_operation = nullptr;
}
delete conn;
delete msg;
return;
}
ConnectionPool::GetConnectionPool()->AddConnection(conn);
}
if (!conn->is_remote && !isExactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) {
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true);
if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) {
conn = remoteConn;
}
}
// Prepare the message.
if (conn->total_send_len == 0) {
conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_);
} else {
(void)conn->send_message_queue.emplace(msg);
}
// Send the message.
if (conn->state == ConnectionState::kConnected) {
DoSend(conn);
}
/* static method */
void TCPComm::DropMessage(MessageBase *msg) {
auto *ptr = msg;
delete ptr;
ptr = nullptr;
}
void TCPComm::SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) {
(void)recv_event_loop_->AddTask(
[msg, tcpmgr, remoteLink, isExactNotRemote] { TCPComm::Send(msg, tcpmgr, remoteLink, isExactNotRemote); });
}
int TCPComm::Send(MessageBase *msg, bool remoteLink, bool isExactNotRemote) {
return send_event_loop_->AddTask([msg, this, remoteLink, isExactNotRemote] {
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
int TCPComm::Send(MessageBase *msg) {
return send_event_loop_->AddTask([msg, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
// Search connection by the target address
bool exactNotRemote = is_http_msg_ || isExactNotRemote;
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, exactNotRemote);
Connection *conn = conn_pool_->FindConnection(msg->to.Url());
if (conn == nullptr) {
if (remoteLink && (!exactNotRemote)) {
MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
auto *ptr = msg;
delete ptr;
ptr = nullptr;
return;
}
this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote);
MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str()
<< ", from: " << msg->from.Url().c_str() << ", to: " << msg->to.Url().c_str();
DropMessage(msg);
return;
}
if (conn->state != kConnected && conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
MS_LOG(WARNING) << "The name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
<< ", to: " << conn->destination.c_str() << ", remote: " << conn->is_remote;
auto *ptr = msg;
delete ptr;
ptr = nullptr;
if (conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
MS_LOG(WARNING) << "The message queue is full(max len:" << SENDMSG_QUEUELEN
<< ") and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
<< ", to: " << conn->destination.c_str();
DropMessage(msg);
return;
}
if (conn->state == ConnectionState::kClose || conn->state == ConnectionState::kDisconnecting) {
this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote);
if (conn->state != ConnectionState::kConnected) {
MS_LOG(WARNING) << "Invalid connection state " << conn->state
<< " and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
<< ", to: " << conn->destination.c_str();
DropMessage(msg);
return;
}
if (!conn->is_remote && !exactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) {
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true);
if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) {
conn = remoteConn;
}
}
output_buf_size_ += msg->body.size();
if (conn->total_send_len == 0) {
conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_);
conn->FillSendMessage(msg, url_, false);
} else {
(void)conn->send_message_queue.emplace(msg);
}
if (conn->state == ConnectionState::kConnected) {
DoSend(conn);
}
DoSend(conn);
});
}
void TCPComm::CollectMetrics() {
(void)send_event_loop_->AddTask([this] {
Connection::conn_mutex.lock();
Connection *maxConn = ConnectionPool::GetConnectionPool()->FindMaxConnection();
Connection *fastConn = ConnectionPool::GetConnectionPool()->FindFastConnection();
if (message_handler_ != nullptr) {
IntTypeMetrics intMetrics;
StringTypeMetrics stringMetrics;
if (maxConn != nullptr) {
intMetrics.push(maxConn->socket_fd);
intMetrics.push(maxConn->error_code);
intMetrics.push(maxConn->send_metrics->accum_msg_count);
intMetrics.push(maxConn->send_metrics->max_msg_size);
stringMetrics.push(maxConn->destination);
stringMetrics.push(maxConn->send_metrics->last_succ_msg_name);
stringMetrics.push(maxConn->send_metrics->last_fail_msg_name);
}
if (fastConn != nullptr && fastConn->IsSame(maxConn)) {
intMetrics.push(fastConn->socket_fd);
intMetrics.push(fastConn->error_code);
intMetrics.push(fastConn->send_metrics->accum_msg_count);
intMetrics.push(fastConn->send_metrics->max_msg_size);
stringMetrics.push(fastConn->destination);
stringMetrics.push(fastConn->send_metrics->last_succ_msg_name);
stringMetrics.push(fastConn->send_metrics->last_fail_msg_name);
}
}
ConnectionPool::GetConnectionPool()->ResetAllConnMetrics();
Connection::conn_mutex.unlock();
});
}
int TCPComm::Send(std::unique_ptr<MessageBase> &&msg, bool remoteLink, bool isExactNotRemote) {
return Send(msg.release(), remoteLink, isExactNotRemote);
}
void TCPComm::Link(const AID &source, const AID &destination) {
(void)recv_event_loop_->AddTask([source, destination, this] {
std::string to = destination.Url();
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
void TCPComm::Connect(const std::string &dst_url) {
(void)recv_event_loop_->AddTask([dst_url, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
// Search connection by the target address
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
Connection *conn = conn_pool_->FindConnection(dst_url);
if (conn == nullptr) {
MS_LOG(INFO) << "Can not found link source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
MS_LOG(INFO) << "Can not found link destination: " << dst_url;
conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection and link fail source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
SendExitMsg(source, destination);
MS_LOG(ERROR) << "Failed to create new connection and link fail destination: " << dst_url;
return;
}
conn->source = advertise_url_.data();
conn->destination = to;
conn->source = url_;
conn->destination = dst_url;
conn->recv_event_loop = this->recv_event_loop_;
conn->send_event_loop = this->send_event_loop_;
conn->conn_mutex = conn_mutex_;
conn->message_handler = message_handler_;
conn->InitSocketOperation();
int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
// Create the client socket.
SocketAddress addr;
if (!SocketOperation::GetSockAddr(dst_url, &addr)) {
MS_LOG(ERROR) << "Failed to get socket address to dest url " << dst_url;
return;
}
int sock_fd = SocketOperation::CreateSocket(addr.sa.sa_family);
if (sock_fd < 0) {
MS_LOG(ERROR) << "Failed to create client tcp socket to dest url " << dst_url;
return;
}
conn->socket_fd = sock_fd;
conn->event_callback = TCPComm::EventCallBack;
conn->write_callback = TCPComm::WriteCallBack;
conn->read_callback = TCPComm::ReadCallBack;
int ret = TCPComm::DoConnect(conn, (struct sockaddr *)&addr, sizeof(addr));
if (ret < 0) {
MS_LOG(ERROR) << "Failed to do connect and link fail source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
SendExitMsg(source, destination);
MS_LOG(ERROR) << "Failed to do connect and link fail destination: " << dst_url;
if (conn->socket_operation != nullptr) {
delete conn->socket_operation;
conn->socket_operation = nullptr;
@ -602,76 +432,26 @@ void TCPComm::Link(const AID &source, const AID &destination) {
delete conn;
return;
}
ConnectionPool::GetConnectionPool()->AddConnection(conn);
conn_pool_->AddConnection(conn);
}
ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg);
MS_LOG(INFO) << "Link fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote;
conn_pool_->AddConnInfo(conn->socket_fd, dst_url, nullptr);
MS_LOG(INFO) << "Connected to destination: " << dst_url;
});
}
void TCPComm::UnLink(const AID &destination) {
(void)recv_event_loop_->AddTask([destination] {
std::string to = destination.Url();
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
if (is_http_msg_) {
// When application has set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link is in links map
// while accepting-link is differently in remoteLinks map. So we only need to delete link in exact links.
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false);
} else {
// When application hasn't set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link and accepting-link is
// shared
// So we need to delete link in both links map and remote-links map.
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false);
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, true);
}
});
bool TCPComm::IsConnected(const std::string &dst_url) {
Connection *conn = conn_pool_->FindConnection(dst_url);
if (conn != nullptr && conn->state == ConnectionState::kConnected) {
return true;
}
return false;
}
void TCPComm::DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd) {
if (!is_http_msg_ && !conn->is_remote) {
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(to, true);
// We will close remote link in rare cases where sending-link and accepting link coexists
// simultaneously.
if (remoteConn != nullptr) {
MS_LOG(INFO) << "Reconnect, close remote connect fd :" << remoteConn->socket_fd
<< ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str() << ", remote: " << remoteConn->is_remote
<< ", state: " << remoteConn->state;
ConnectionPool::GetConnectionPool()->CloseConnection(remoteConn);
}
}
MS_LOG(INFO) << "Reconnect, close old connect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote
<< ", state: " << conn->state;
*oldFd = conn->socket_fd;
if (conn->recv_event_loop->DeleteEpollEvent(conn->socket_fd) == RPC_ERROR) {
MS_LOG(ERROR) << "Failed to delete epoll event: " << conn->socket_fd;
}
conn->socket_operation->Close(conn);
conn->socket_fd = -1;
conn->recv_len = 0;
conn->total_recv_len = 0;
conn->recv_message_type = kUnknown;
conn->state = kInit;
if (conn->total_send_len != 0 && conn->send_message != nullptr) {
delete conn->send_message;
}
conn->send_message = nullptr;
conn->total_send_len = 0;
if (conn->total_recv_len != 0 && conn->recv_message != nullptr) {
delete conn->recv_message;
}
conn->recv_message = nullptr;
conn->total_recv_len = 0;
conn->recv_state = State::kMsgHeader;
void TCPComm::Disconnect(const std::string &dst_url) {
(void)recv_event_loop_->AddTask([dst_url, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
conn_pool_->DeleteConnection(dst_url);
});
}
Connection *TCPComm::CreateDefaultConn(std::string to) {
@ -680,66 +460,16 @@ Connection *TCPComm::CreateDefaultConn(std::string to) {
MS_LOG(ERROR) << "Failed to create new connection and reconnect fail to: " << to.c_str();
return conn;
}
conn->source = advertise_url_.data();
conn->source = url_.data();
conn->destination = to;
conn->recv_event_loop = this->recv_event_loop_;
conn->send_event_loop = this->send_event_loop_;
conn->conn_mutex = conn_mutex_;
conn->message_handler = message_handler_;
conn->InitSocketOperation();
return conn;
}
void TCPComm::Reconnect(const AID &source, const AID &destination) {
(void)send_event_loop_->AddTask([source, destination, this] {
std::string to = destination.Url();
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
if (conn != nullptr) {
conn->state = ConnectionState::kClose;
}
(void)recv_event_loop_->AddTask([source, destination, this] {
std::string to = destination.Url();
int oldFd = -1;
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
if (conn != nullptr) {
// connection already exist
DoReConnectConn(conn, to, source, destination, &oldFd);
} else {
// create default connection
conn = CreateDefaultConn(to);
if (conn == nullptr) {
return;
}
}
int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
if (ret < 0) {
if (conn->socket_operation != nullptr) {
delete conn->socket_operation;
conn->socket_operation = nullptr;
}
if (oldFd != -1) {
conn->socket_fd = oldFd;
}
MS_LOG(ERROR) << "Failed to connect and reconnect fail source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
ConnectionPool::GetConnectionPool()->CloseConnection(conn);
return;
}
if (oldFd != -1) {
if (!ConnectionPool::GetConnectionPool()->ReverseConnInfo(oldFd, conn->socket_fd)) {
MS_LOG(ERROR) << "Failed to swap socket for " << oldFd << " and " << conn->socket_fd;
}
} else {
ConnectionPool::GetConnectionPool()->AddConnection(conn);
}
ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg);
MS_LOG(INFO) << "Reconnect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
});
});
}
void TCPComm::Finalize() {
if (send_event_loop_ != nullptr) {
MS_LOG(INFO) << "Delete send event loop";
@ -762,13 +492,6 @@ void TCPComm::Finalize() {
server_fd_ = -1;
}
}
// This value is not used for the sub-class TCPComm.
uint64_t TCPComm::GetInBufSize() { return 1; }
uint64_t TCPComm::GetOutBufSize() { return output_buf_size_; }
bool TCPComm::IsHttpMsg() { return is_http_msg_; }
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -20,9 +20,11 @@
#include <string>
#include <memory>
#include <vector>
#include <mutex>
#include "actor/iomgr.h"
#include "actor/msg.h"
#include "distributed/rpc/tcp/connection.h"
#include "distributed/rpc/tcp/connection_pool.h"
#include "distributed/rpc/tcp/event_loop.h"
namespace mindspore {
@ -34,15 +36,11 @@ void OnAccept(int server, uint32_t events, void *arg);
// Send messages buffered in the connection.
void DoSend(Connection *conn);
// Create a server socket and connect to it, this is a local connection..
int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack eventCallBack,
ConnectionCallBack writeCallBack, ConnectionCallBack readCallBack);
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
void ConnectedEventHandler(int fd, uint32_t events, void *context);
class TCPComm : public IOMgr {
class TCPComm {
public:
TCPComm() : server_fd_(-1), recv_event_loop_(nullptr), send_event_loop_(nullptr) {}
TCPComm(const TCPComm &) = delete;
@ -50,52 +48,46 @@ class TCPComm : public IOMgr {
~TCPComm();
// Init the event loop for reading and writing.
bool Initialize() override;
bool Initialize();
// Destroy all the resources.
void Finalize() override;
void Finalize();
// Create the server socket represented by url.
bool StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) override;
bool StartServerSocket(const std::string &url);
// Build a connection between the source and destination.
void Link(const AID &source, const AID &destination) override;
void UnLink(const AID &destination) override;
// Connection operation for a specified destination.
void Connect(const std::string &dst_url);
bool IsConnected(const std::string &dst_url);
void Disconnect(const std::string &dst_url);
// Send the message from the source to the destination.
int Send(std::unique_ptr<MessageBase> &&msg, bool remoteLink = false, bool isExactNotRemote = false) override;
int Send(MessageBase *msg);
uint64_t GetInBufSize() override;
uint64_t GetOutBufSize() override;
void CollectMetrics() override;
// Set the message processing handler.
void SetMessageHandler(MessageHandler handler);
private:
// Build the connection.
Connection *CreateDefaultConn(std::string to);
void Reconnect(const AID &source, const AID &destination);
void DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd);
// Send a message.
int Send(MessageBase *msg, bool remoteLink = false, bool isExactNotRemote = false);
static void Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote);
void SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote);
static void SendExitMsg(const std::string &from, const std::string &to);
// Called by ReadCallBack when new message arrived.
static int ReceiveMessage(Connection *conn);
void SetMessageHandler(IOMgr::MessageHandler handler);
static int SetConnectedHandler(Connection *conn);
static int Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen);
static int DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen);
static bool IsHttpMsg();
static void DropMessage(MessageBase *msg);
// Read and write events.
static void ReadCallBack(void *context);
static void WriteCallBack(void *context);
static void ReadCallBack(void *conn);
static void WriteCallBack(void *conn);
// Connected and Disconnected events.
static void EventCallBack(void *context);
static void EventCallBack(void *conn);
// The server url.
std::string url_;
@ -103,21 +95,19 @@ class TCPComm : public IOMgr {
// The socket of server.
int server_fd_;
// The message size waiting to be sent.
static uint64_t output_buf_size_;
// User defined handler for Handling received messages.
static MessageHandler message_handler_;
// The source url of a message.
static std::vector<char> advertise_url_;
static bool is_http_msg_;
MessageHandler message_handler_;
// All the connections share the same read and write event loop objects.
EventLoop *recv_event_loop_;
EventLoop *send_event_loop_;
// The connection pool used to store new connections.
std::shared_ptr<ConnectionPool> conn_pool_;
// The mutex for connection operations.
std::shared_ptr<std::mutex> conn_mutex_;
friend void OnAccept(int server, uint32_t events, void *arg);
friend void DoSend(Connection *conn);
friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "distributed/rpc/tcp/tcp_server.h"
namespace mindspore {
namespace distributed {
namespace rpc {
bool TCPServer::Initialize(const std::string &url) {
if (tcp_comm_ == nullptr) {
tcp_comm_ = std::make_unique<TCPComm>();
MS_EXCEPTION_IF_NULL(tcp_comm_);
bool rt = tcp_comm_->Initialize();
if (!rt) {
MS_LOG(EXCEPTION) << "Failed to initialize tcp comm";
}
rt = tcp_comm_->StartServerSocket(url);
return rt;
} else {
return true;
}
}
void TCPServer::Finalize() {
if (tcp_comm_ != nullptr) {
tcp_comm_.reset();
tcp_comm_ = nullptr;
}
}
void TCPServer::SetMessageHandler(MessageHandler handler) { tcp_comm_->SetMessageHandler(handler); }
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_SERVER_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_SERVER_H_
#include <string>
#include <memory>
#include "distributed/rpc/tcp/tcp_comm.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace distributed {
namespace rpc {
class TCPServer {
public:
TCPServer() = default;
~TCPServer() = default;
// Init the tcp server.
bool Initialize(const std::string &url);
// Destroy the tcp server.
void Finalize();
// Set the message processing handler.
void SetMessageHandler(MessageHandler handler);
private:
// The basic TCP communication component used by the server.
std::unique_ptr<TCPComm> tcp_comm_;
DISABLE_COPY_AND_ASSIGN(TCPServer);
};
} // namespace rpc
} // namespace distributed
} // namespace mindspore
#endif

View File

@ -24,9 +24,8 @@
#include <gtest/gtest.h>
#define private public
#include "actor/iomgr.h"
#include "async/async.h"
#include "distributed/rpc/tcp/tcp_comm.h"
#include "distributed/rpc/tcp/tcp_server.h"
#include "distributed/rpc/tcp/tcp_client.h"
#include "common/common_test.h"
namespace mindspore {
@ -35,26 +34,36 @@ namespace rpc {
int g_recv_num = 0;
int g_exit_msg_num = 0;
TCPComm *m_io = nullptr;
static size_t g_data_msg_num = 0;
static void Init() { g_data_msg_num = 0; }
static bool WaitForDataMsg(size_t expected_msg_num, int timeout_in_sec) {
bool rt = false;
int timeout = timeout_in_sec * 1000 * 1000;
int usleepCount = 100000;
while (timeout) {
if (g_data_msg_num == expected_msg_num) {
rt = true;
break;
}
timeout = timeout - usleepCount;
usleep(usleepCount);
}
return rt;
}
static void IncrDataMsgNum(size_t number) { g_data_msg_num += number; }
static size_t GetDataMsgNum() { return g_data_msg_num; }
std::atomic<int> m_sendNum(0);
std::string m_localIP = "127.0.0.1";
bool m_notRemote = false;
void msgHandle(std::unique_ptr<MessageBase> &&msg) {
if (msg->GetType() == MessageBase::Type::KEXIT) {
g_exit_msg_num++;
} else {
g_recv_num++;
}
}
class TCPTest : public UT::Common {
public:
static void SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink = false,
std::string body = "");
protected:
char *args[4];
char *testServerPath;
static const size_t pid_num = 100;
pid_t pid1;
@ -62,124 +71,24 @@ class TCPTest : public UT::Common {
pid_t pids[pid_num];
void SetUp() {
char *localpEnv = getenv("LITEBUS_IP");
if (localpEnv != nullptr) {
m_localIP = std::string(localpEnv);
}
void SetUp() {}
void TearDown() {}
char *locaNotRemoteEnv = getenv("LITEBUS_SEND_ON_REMOTE");
if (locaNotRemoteEnv != nullptr) {
m_notRemote = (std::string(locaNotRemoteEnv) == "true") ? true : false;
}
pid1 = 0;
pid2 = 0;
pids[pid_num] = {0};
size_t size = pid_num * sizeof(pid_t);
if (memset_s(&pids, size, 0, size)) {
MS_LOG(ERROR) << "Failed to init pid array";
}
g_recv_num = 0;
g_exit_msg_num = 0;
m_sendNum = 0;
m_io = new TCPComm();
m_io->Initialize();
m_io->SetMessageHandler(msgHandle);
m_io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225");
}
void TearDown() {
shutdownTcpServer(pid1);
shutdownTcpServer(pid2);
pid1 = 0;
pid2 = 0;
int i = 0;
for (i = 0; i < pid_num; i++) {
shutdownTcpServer(pids[i]);
pids[i] = 0;
}
g_recv_num = 0;
g_exit_msg_num = 0;
m_sendNum = 0;
m_io->Finalize();
delete m_io;
m_io = nullptr;
}
std::unique_ptr<MessageBase> CreateMessage(const std::string &serverUrl, const std::string &client_url);
bool CheckRecvNum(int expectedRecvNum, int _timeout);
bool CheckExitNum(int expectedExitNum, int _timeout);
pid_t startTcpServer(char **args);
void shutdownTcpServer(pid_t pid);
void KillTcpServer(pid_t pid);
void Link(std::string &_localUrl, std::string &_remoteUrl);
void Reconnect(std::string &_localUrl, std::string &_remoteUrl);
void Unlink(std::string &_remoteUrl);
};
// listening local url and sending msg to remote url,if start succ.
pid_t TCPTest::startTcpServer(char **args) {
pid_t pid = fork();
if (pid == 0) {
return -1;
} else {
return pid;
}
}
void TCPTest::shutdownTcpServer(pid_t pid) {
if (pid > 1) {
kill(pid, SIGALRM);
int status;
waitpid(pid, &status, 0);
}
}
void TCPTest::KillTcpServer(pid_t pid) {
if (pid > 1) {
kill(pid, SIGKILL);
int status;
waitpid(pid, &status, 0);
}
}
void TCPTest::SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink, std::string body) {
AID from("testserver", _localUrl);
AID to("testserver", _remoteUrl);
std::unique_ptr<MessageBase> TCPTest::CreateMessage(const std::string &serverUrl, const std::string &clientUrl) {
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
std::string data(msgsize, 'A');
size_t len = 100;
std::string data(len, 'A');
message->name = "testname";
message->from = from;
message->to = to;
message->from = AID("client", clientUrl);
message->to = AID("server", serverUrl);
message->body = data;
if (body != "") {
message->body = body;
}
if (m_notRemote) {
m_io->Send(std::move(message), remoteLink, true);
} else {
m_io->Send(std::move(message), remoteLink);
}
}
void TCPTest::Link(std::string &_localUrl, std::string &_remoteUrl) {
AID from("testserver", _localUrl);
AID to("testserver", _remoteUrl);
m_io->Link(from, to);
}
void TCPTest::Reconnect(std::string &_localUrl, std::string &_remoteUrl) {
AID from("testserver", _localUrl);
AID to("testserver", _remoteUrl);
m_io->Reconnect(from, to);
}
void TCPTest::Unlink(std::string &_remoteUrl) {
AID to("testserver", _remoteUrl);
m_io->UnLink(to);
return message;
}
bool TCPTest::CheckRecvNum(int expectedRecvNum, int _timeout) {
@ -215,47 +124,98 @@ bool TCPTest::CheckExitNum(int expectedExitNum, int _timeout) {
/// Description: start a socket server with an invalid url.
/// Expectation: failed to start the server with invalid url.
TEST_F(TCPTest, StartServerFail) {
std::unique_ptr<TCPComm> io = std::make_unique<TCPComm>();
io->Initialize();
bool ret = io->StartServerSocket("tcp://0:2225", "tcp://0:2225");
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
bool ret = server->Initialize("0:2225");
ASSERT_FALSE(ret);
io->Finalize();
server->Finalize();
}
/// Feature: test start a socket server.
/// Description: start the socket server with a specified socket.
/// Expectation: the socket server is started successfully.
TEST_F(TCPTest, StartServer2) {
std::unique_ptr<TCPComm> io = std::make_unique<TCPComm>();
io->Initialize();
io->SetMessageHandler(msgHandle);
bool ret = io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225");
ASSERT_FALSE(ret);
ret = io->StartServerSocket("tcp://" + m_localIP + ":2224", "tcp://" + m_localIP + ":2224");
io->Finalize();
TEST_F(TCPTest, StartServerSucc) {
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
bool ret = server->Initialize("127.0.0.1:8081");
ASSERT_TRUE(ret);
server->Finalize();
}
/// Feature: test normal tcp message sending.
/// Description: start a socket server and send a normal message to it.
/// Expectation: the server received the message sented from client.
TEST_F(TCPTest, send1Msg) {
g_recv_num = 0;
pid1 = startTcpServer(args);
bool ret = CheckRecvNum(1, 5);
ASSERT_FALSE(ret);
TEST_F(TCPTest, SendOneMessage) {
Init();
std::string from = "tcp://" + m_localIP + ":2223";
std::string to = "tcp://" + m_localIP + ":2225";
SendMsg(from, to, pid_num);
ret = CheckRecvNum(1, 5);
// Start the tcp server.
auto server_url = "127.0.0.1:8081";
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
Unlink(to);
shutdownTcpServer(pid1);
pid1 = 0;
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>();
ret = client->Initialize();
ASSERT_TRUE(ret);
// Create the message.
auto message = CreateMessage(server_url, client_url);
// Send the message.
client->Connect(server_url);
client->Send(std::move(message));
// Wait timeout: 5s
WaitForDataMsg(1, 5);
// Check result
EXPECT_EQ(1, GetDataMsgNum());
// Destroy
client->Disconnect(server_url);
client->Finalize();
server->Finalize();
}
/// Feature: test sending two message continuously.
/// Description: start a socket server and send two normal message to it.
/// Expectation: the server received the two messages sented from client.
TEST_F(TCPTest, sendTwoMessages) {
Init();
// Start the tcp server.
auto server_url = "127.0.0.1:8081";
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>();
ret = client->Initialize();
ASSERT_TRUE(ret);
// Create messages.
auto message1 = CreateMessage(server_url, client_url);
auto message2 = CreateMessage(server_url, client_url);
// Send messages.
client->Connect(server_url);
client->Send(std::move(message1));
client->Send(std::move(message2));
// Wait timeout: 5s
WaitForDataMsg(2, 5);
// Check result
EXPECT_EQ(2, GetDataMsgNum());
client->Disconnect(server_url);
client->Finalize();
server->Finalize();
}
} // namespace rpc
} // namespace distributed