forked from mindspore-Ecosystem/mindspore
!29786 Add TCP client and TCP server for RPC
Merge pull request !29786 from chengang/add_tcp_server
This commit is contained in:
commit
7b6f7c9c87
|
@ -25,8 +25,6 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
std::mutex Connection::conn_mutex;
|
||||
|
||||
// Handle socket events like read/write.
|
||||
void SocketEventHandler(int fd, uint32_t events, void *context) {
|
||||
Connection *conn = reinterpret_cast<Connection *>(context);
|
||||
|
@ -245,7 +243,7 @@ void Connection::Close() {
|
|||
}
|
||||
}
|
||||
|
||||
int Connection::ReceiveMessage(IOMgr::MessageHandler msgHandler) {
|
||||
int Connection::ReceiveMessage() {
|
||||
bool ok = ParseMessage();
|
||||
// If no message parsed, wait for next read
|
||||
if (!ok) {
|
||||
|
@ -255,26 +253,12 @@ int Connection::ReceiveMessage(IOMgr::MessageHandler msgHandler) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
if (destination.empty()) {
|
||||
std::string fromUrl = recv_message->from;
|
||||
size_t index = fromUrl.find("@");
|
||||
if (index != std::string::npos) {
|
||||
destination = fromUrl.substr(index + 1);
|
||||
MS_LOG(INFO) << "Create new connection fd: " << socket_fd << " to: " << destination.c_str();
|
||||
|
||||
Connection::conn_mutex.lock();
|
||||
ConnectionPool::GetConnectionPool()->SetConnPriority(destination, false, ConnectionPriority::kPriorityLow);
|
||||
state = ConnectionState::kConnected;
|
||||
ConnectionPool::GetConnectionPool()->AddConnection(this);
|
||||
Connection::conn_mutex.unlock();
|
||||
}
|
||||
}
|
||||
std::unique_ptr<MessageBase> msg(recv_message);
|
||||
recv_message = nullptr;
|
||||
|
||||
// Call msg handler if set
|
||||
if (msgHandler != nullptr) {
|
||||
msgHandler(std::move(msg));
|
||||
if (message_handler != nullptr) {
|
||||
message_handler(std::move(msg));
|
||||
} else {
|
||||
MS_LOG(INFO) << "Message handler was not found";
|
||||
}
|
||||
|
@ -282,7 +266,6 @@ int Connection::ReceiveMessage(IOMgr::MessageHandler msgHandler) {
|
|||
}
|
||||
|
||||
void Connection::CheckMessageType() {
|
||||
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
|
||||
if (recv_message_type != ParseType::kUnknown) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
#include <queue>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
|
||||
#include "actor/msg.h"
|
||||
#include "actor/iomgr.h"
|
||||
#include "distributed/rpc/tcp/constants.h"
|
||||
#include "distributed/rpc/tcp/event_loop.h"
|
||||
#include "distributed/rpc/tcp/socket_operation.h"
|
||||
|
@ -118,7 +118,7 @@ struct Connection {
|
|||
// Close this connection.
|
||||
void Close();
|
||||
|
||||
int ReceiveMessage(IOMgr::MessageHandler msgHandler);
|
||||
int ReceiveMessage();
|
||||
void CheckMessageType();
|
||||
|
||||
// Fill the message to be sent based on the input message.
|
||||
|
@ -170,6 +170,9 @@ struct Connection {
|
|||
MessageBase *send_message;
|
||||
MessageBase *recv_message;
|
||||
|
||||
// Owned by the tcp_comm.
|
||||
std::shared_ptr<std::mutex> conn_mutex;
|
||||
|
||||
State recv_state;
|
||||
|
||||
// Total length of received and sent messages.
|
||||
|
@ -201,6 +204,9 @@ struct Connection {
|
|||
ConnectionCallBack write_callback;
|
||||
ConnectionCallBack read_callback;
|
||||
|
||||
// Function for handling received messages.
|
||||
MessageHandler message_handler;
|
||||
|
||||
// Buffer for messages to be sent.
|
||||
std::queue<MessageBase *> send_message_queue;
|
||||
|
||||
|
@ -209,8 +215,6 @@ struct Connection {
|
|||
// The error code when sending or receiving messages.
|
||||
int error_code;
|
||||
|
||||
static std::mutex conn_mutex;
|
||||
|
||||
private:
|
||||
// Add handler for socket connect event.
|
||||
int AddConnnectEventHandler();
|
||||
|
|
|
@ -20,8 +20,6 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
ConnectionPool *ConnectionPool::conn_pool = new ConnectionPool();
|
||||
|
||||
void ConnectionPool::SetLinkPattern(bool linkPattern) { double_link_ = linkPattern; }
|
||||
|
||||
void ConnectionPool::CloseConnection(Connection *conn) {
|
||||
|
@ -36,57 +34,22 @@ void ConnectionPool::CloseConnection(Connection *conn) {
|
|||
}
|
||||
|
||||
if (!conn->destination.empty()) {
|
||||
if (conn->is_remote) {
|
||||
(void)remote_conns_.erase(conn->destination);
|
||||
} else {
|
||||
(void)local_conns_.erase(conn->destination);
|
||||
}
|
||||
(void)connections_.erase(conn->destination);
|
||||
}
|
||||
conn->Close();
|
||||
delete conn;
|
||||
conn = nullptr;
|
||||
}
|
||||
|
||||
Connection *ConnectionPool::FindConnection(const std::string &to, bool remoteLink) {
|
||||
Connection *ConnectionPool::FindConnection(const std::string &dst_url) {
|
||||
Connection *conn = nullptr;
|
||||
if (!remoteLink) {
|
||||
auto iter = local_conns_.find(to);
|
||||
if (iter != local_conns_.end()) {
|
||||
conn = iter->second;
|
||||
return conn;
|
||||
}
|
||||
}
|
||||
auto iter = remote_conns_.find(to);
|
||||
if (iter != remote_conns_.end()) {
|
||||
auto iter = connections_.find(dst_url);
|
||||
if (iter != connections_.end()) {
|
||||
conn = iter->second;
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
Connection *ConnectionPool::ExactFindConnection(const std::string &to, bool remoteLink) {
|
||||
Connection *conn = nullptr;
|
||||
if (!remoteLink) {
|
||||
auto iter = local_conns_.find(to);
|
||||
if (iter != local_conns_.end()) {
|
||||
conn = iter->second;
|
||||
}
|
||||
} else {
|
||||
auto iter = remote_conns_.find(to);
|
||||
if (iter != remote_conns_.end()) {
|
||||
conn = iter->second;
|
||||
}
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
Connection *ConnectionPool::FindConnection(const std::string &to, bool remoteLink, bool exactNotRemote) {
|
||||
if (exactNotRemote) {
|
||||
return ExactFindConnection(to, false);
|
||||
} else {
|
||||
return FindConnection(to, remoteLink);
|
||||
}
|
||||
}
|
||||
|
||||
void ConnectionPool::ResetAllConnMetrics() {
|
||||
for (const auto &iter : local_conns_) {
|
||||
iter.second->send_metrics->Reset();
|
||||
|
@ -96,46 +59,10 @@ void ConnectionPool::ResetAllConnMetrics() {
|
|||
}
|
||||
}
|
||||
|
||||
Connection *ConnectionPool::FindMaxConnection() {
|
||||
Connection *conn = nullptr;
|
||||
size_t count = 0;
|
||||
for (const auto &iter : local_conns_) {
|
||||
if (iter.second->send_metrics->accum_msg_count > count) {
|
||||
count = iter.second->send_metrics->accum_msg_count;
|
||||
conn = iter.second;
|
||||
}
|
||||
}
|
||||
for (const auto &iter : remote_conns_) {
|
||||
if (iter.second->send_metrics->accum_msg_count > count) {
|
||||
count = iter.second->send_metrics->accum_msg_count;
|
||||
conn = iter.second;
|
||||
}
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
Connection *ConnectionPool::FindFastConnection() {
|
||||
Connection *conn = nullptr;
|
||||
size_t size = 0;
|
||||
for (const auto &iter : local_conns_) {
|
||||
if (iter.second->send_metrics->max_msg_size > size) {
|
||||
size = iter.second->send_metrics->max_msg_size;
|
||||
conn = iter.second;
|
||||
}
|
||||
}
|
||||
for (const auto &iter : remote_conns_) {
|
||||
if (iter.second->send_metrics->max_msg_size > size) {
|
||||
size = iter.second->send_metrics->max_msg_size;
|
||||
conn = iter.second;
|
||||
}
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
void ConnectionPool::ExactDeleteConnection(const std::string &to, bool remoteLink) {
|
||||
Connection *conn = ExactFindConnection(to, remoteLink);
|
||||
void ConnectionPool::DeleteConnection(const std::string &dst_url) {
|
||||
Connection *conn = FindConnection(dst_url);
|
||||
if (conn != nullptr) {
|
||||
MS_LOG(INFO) << "unLink fd:" << conn->socket_fd << ",to:" << to.c_str() << ",remote:" << remoteLink;
|
||||
MS_LOG(INFO) << "unLink fd:" << conn->socket_fd << ",to:" << dst_url;
|
||||
CloseConnection(conn);
|
||||
}
|
||||
}
|
||||
|
@ -156,26 +83,15 @@ void ConnectionPool::DeleteAllConnections(std::map<std::string, Connection *> *l
|
|||
|
||||
void ConnectionPool::AddConnection(Connection *conn) {
|
||||
if (conn == nullptr) {
|
||||
MS_LOG(ERROR) << "The connection is null";
|
||||
return;
|
||||
}
|
||||
Connection *tmpConn = ExactFindConnection(conn->destination, conn->is_remote);
|
||||
if (tmpConn != nullptr && tmpConn->is_remote == conn->is_remote) {
|
||||
Connection *tmpConn = FindConnection(conn->destination);
|
||||
if (tmpConn != nullptr) {
|
||||
MS_LOG(INFO) << "unLink fd:" << tmpConn->socket_fd << ",to:" << tmpConn->destination.c_str();
|
||||
CloseConnection(tmpConn);
|
||||
}
|
||||
|
||||
if (conn->is_remote) {
|
||||
(void)remote_conns_.emplace(conn->destination, conn);
|
||||
} else {
|
||||
(void)local_conns_.emplace(conn->destination, conn);
|
||||
}
|
||||
}
|
||||
|
||||
void ConnectionPool::SetConnPriority(const std::string &to, bool remoteLink, ConnectionPriority pri) {
|
||||
Connection *conn = ExactFindConnection(to, remoteLink);
|
||||
if (conn != nullptr && conn->is_remote == remoteLink) {
|
||||
conn->priority = pri;
|
||||
}
|
||||
(void)connections_.emplace(conn->destination, conn);
|
||||
}
|
||||
|
||||
void ConnectionPool::DeleteConnInfo(int fd) {
|
||||
|
@ -207,23 +123,13 @@ void ConnectionPool::DeleteConnInfo(const std::string &to, int fd) {
|
|||
// If run in single link pattern, link fd and send fd may not be the same, we should send Exit message bind
|
||||
// on link fd and remote link fd. Here 'deleted' flag should be set true to avoid duplicate Exit message with
|
||||
// same aid.
|
||||
Connection *nonRemoteConn = ConnectionPool::ExactFindConnection(to, false);
|
||||
if (nonRemoteConn != nullptr) {
|
||||
nonRemoteConn->deleted = true;
|
||||
DeleteConnInfo(nonRemoteConn->socket_fd);
|
||||
Connection *conn = FindConnection(to);
|
||||
if (conn != nullptr) {
|
||||
conn->deleted = true;
|
||||
DeleteConnInfo(conn->socket_fd);
|
||||
|
||||
if (nonRemoteConn->socket_fd != fd) {
|
||||
MS_LOG(INFO) << "delete linker bind on link fd:" << nonRemoteConn->socket_fd << ",delete fd:" << fd;
|
||||
}
|
||||
}
|
||||
|
||||
Connection *remoteConn = ConnectionPool::ExactFindConnection(to, true);
|
||||
if (remoteConn != nullptr) {
|
||||
remoteConn->deleted = true;
|
||||
DeleteConnInfo(remoteConn->socket_fd);
|
||||
|
||||
if (remoteConn->socket_fd != fd) {
|
||||
MS_LOG(INFO) << "delete linker bind on remote link fd:" << remoteConn->socket_fd << ",delete fd:" << fd;
|
||||
if (conn->socket_fd != fd) {
|
||||
MS_LOG(INFO) << "delete linker bind on link fd:" << conn->socket_fd << ",delete fd:" << fd;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -243,7 +149,7 @@ void ConnectionPool::DeleteAllConnInfos() {
|
|||
}
|
||||
}
|
||||
|
||||
ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const AID &sAid, const AID &dAid) {
|
||||
ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const std::string &dst_url) {
|
||||
auto iter = conn_infos_.find(fd);
|
||||
if (iter == conn_infos_.end()) {
|
||||
return nullptr;
|
||||
|
@ -253,7 +159,7 @@ ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const AID &sAid, const AID
|
|||
|
||||
while (iter2 != conn_infos.end()) {
|
||||
auto linkInfo = *iter2;
|
||||
if (AID(linkInfo->from) == sAid && AID(linkInfo->to) == dAid) {
|
||||
if (linkInfo->to == dst_url) {
|
||||
return linkInfo;
|
||||
}
|
||||
++iter2;
|
||||
|
@ -261,19 +167,18 @@ ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const AID &sAid, const AID
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void ConnectionPool::AddConnInfo(int fd, const AID &sAid, const AID &dAid, DeleteCallBack callback) {
|
||||
ConnectionInfo *linker = FindConnInfo(fd, sAid, dAid);
|
||||
void ConnectionPool::AddConnInfo(int fd, const std::string &dst_url, DeleteCallBack callback) {
|
||||
ConnectionInfo *linker = FindConnInfo(fd, dst_url);
|
||||
if (linker != nullptr) {
|
||||
return;
|
||||
}
|
||||
linker = new (std::nothrow) ConnectionInfo();
|
||||
if (linker == nullptr) {
|
||||
MS_LOG(ERROR) << "new ConnectionInfo fail sAid:" << std::string(sAid).c_str()
|
||||
<< ",dAid:" << std::string(dAid).c_str();
|
||||
MS_LOG(ERROR) << "new ConnectionInfo fail dAid:" << dst_url;
|
||||
return;
|
||||
}
|
||||
linker->from = sAid;
|
||||
linker->to = dAid;
|
||||
linker->from = "";
|
||||
linker->to = dst_url;
|
||||
linker->socket_fd = fd;
|
||||
linker->delete_callback = callback;
|
||||
(void)conn_infos_[fd].insert(linker);
|
||||
|
@ -299,8 +204,6 @@ ConnectionPool::~ConnectionPool() {
|
|||
MS_LOG(ERROR) << "Failed to release resource for connection pool.";
|
||||
}
|
||||
}
|
||||
|
||||
ConnectionPool *ConnectionPool::GetConnectionPool() { return ConnectionPool::conn_pool; }
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,13 +42,10 @@ class ConnectionPool {
|
|||
ConnectionPool() : double_link_(false) {}
|
||||
~ConnectionPool();
|
||||
|
||||
// Get the singleton instance of ConnectionPool.
|
||||
static ConnectionPool *GetConnectionPool();
|
||||
|
||||
/*
|
||||
* Operations for ConnectionInfo.
|
||||
*/
|
||||
void AddConnInfo(int socket_fd, const AID &sAid, const AID &dAid, DeleteCallBack delcb);
|
||||
void AddConnInfo(int socket_fd, const std::string &dst_url, DeleteCallBack delcb);
|
||||
bool ReverseConnInfo(int from_socket_fd, int to_socket_fd);
|
||||
|
||||
/*
|
||||
|
@ -58,19 +55,14 @@ class ConnectionPool {
|
|||
void AddConnection(Connection *conn);
|
||||
|
||||
// Find connection.
|
||||
Connection *ExactFindConnection(const std::string &to, bool remoteLink);
|
||||
Connection *FindConnection(const std::string &to, bool remoteLink);
|
||||
Connection *FindConnection(const std::string &to, bool remoteLink, bool exactNotRemote);
|
||||
Connection *FindMaxConnection();
|
||||
Connection *FindFastConnection();
|
||||
Connection *FindConnection(const std::string &dst_url);
|
||||
|
||||
// Delete connection.
|
||||
void ExactDeleteConnection(const std::string &to, bool remoteLink);
|
||||
void DeleteConnection(const std::string &dst_url);
|
||||
void DeleteAllConnections(std::map<std::string, Connection *> *alllinks);
|
||||
|
||||
// Close connection.
|
||||
void CloseConnection(Connection *conn);
|
||||
void SetConnPriority(const std::string &to, bool remoteLink, ConnectionPriority pri);
|
||||
|
||||
// Single link or double link.
|
||||
void SetLinkPattern(bool linkPattern);
|
||||
|
@ -78,7 +70,7 @@ class ConnectionPool {
|
|||
void ResetAllConnMetrics();
|
||||
|
||||
private:
|
||||
ConnectionInfo *FindConnInfo(int socket_fd, const AID &sAid, const AID &dAid);
|
||||
ConnectionInfo *FindConnInfo(int socket_fd, const std::string &dst_url);
|
||||
|
||||
void DeleteConnInfo(int socket_fd);
|
||||
void DeleteConnInfo(const std::string &to, int socket_fd);
|
||||
|
@ -92,11 +84,12 @@ class ConnectionPool {
|
|||
// Maintains the remote connections by remote server addresses.
|
||||
std::map<std::string, Connection *> remote_conns_;
|
||||
|
||||
// Maintains the connections by remote server addresses.
|
||||
std::map<std::string, Connection *> connections_;
|
||||
|
||||
// each to_url has two fds at most, and each fd has multiple linkinfos
|
||||
std::map<int, std::set<ConnectionInfo *>> conn_infos_;
|
||||
|
||||
static ConnectionPool *conn_pool;
|
||||
|
||||
friend class Connection;
|
||||
friend class TCPComm;
|
||||
};
|
||||
|
|
|
@ -20,12 +20,15 @@
|
|||
#include <string>
|
||||
#include <csignal>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
|
||||
#include "actor/log.h"
|
||||
#include "actor/msg.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
using MessageHandler = void (*)(std::unique_ptr<MessageBase> &&msg);
|
||||
using DeleteCallBack = void (*)(const std::string &from, const std::string &to);
|
||||
using ConnectionCallBack = void (*)(void *conn);
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class EventLoop {
|
|||
bool Initialize(const std::string &threadName);
|
||||
void Finalize();
|
||||
|
||||
// Add task (eg. send message, reconnect etc.) to task queue of the event loop by user.
|
||||
// Add task (eg. send message, reconnect etc.) to task queue of the event loop.
|
||||
// These tasks are executed asynchronously.
|
||||
int AddTask(std::function<void()> &&task);
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ int SocketOperation::SetSocketOptions(int sock_fd) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int SocketOperation::CreateServerSocket(sa_family_t family) {
|
||||
int SocketOperation::CreateSocket(sa_family_t family) {
|
||||
int ret = 0;
|
||||
int fd = 0;
|
||||
|
||||
|
@ -292,7 +292,7 @@ int SocketOperation::Listen(const std::string &url) {
|
|||
}
|
||||
|
||||
// create server socket
|
||||
listenFd = CreateServerSocket(addr.sa.sa_family);
|
||||
listenFd = CreateSocket(addr.sa.sa_family);
|
||||
if (listenFd < 0) {
|
||||
MS_LOG(ERROR) << "Failed to create socket, url: " << url.c_str();
|
||||
return -1;
|
||||
|
|
|
@ -48,8 +48,8 @@ class SocketOperation {
|
|||
// Get socket address of the url.
|
||||
static bool GetSockAddr(const std::string &url, SocketAddress *addr);
|
||||
|
||||
// Create a server socket.
|
||||
static int CreateServerSocket(sa_family_t family);
|
||||
// Create a socket.
|
||||
static int CreateSocket(sa_family_t family);
|
||||
|
||||
// Set socket options.
|
||||
static int SetSocketOptions(int sock_fd);
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
bool TCPClient::Initialize() {
|
||||
bool rt = false;
|
||||
if (tcp_comm_ == nullptr) {
|
||||
tcp_comm_ = std::make_unique<TCPComm>();
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm_);
|
||||
rt = tcp_comm_->Initialize();
|
||||
} else {
|
||||
rt = true;
|
||||
}
|
||||
return rt;
|
||||
}
|
||||
|
||||
void TCPClient::Finalize() { tcp_comm_->Finalize(); }
|
||||
|
||||
bool TCPClient::Connect(const std::string &dst_url, size_t timeout_in_sec) {
|
||||
bool rt = false;
|
||||
tcp_comm_->Connect(dst_url);
|
||||
|
||||
int timeout = timeout_in_sec * 1000 * 1000;
|
||||
size_t usleep_count = 100000;
|
||||
|
||||
while (timeout) {
|
||||
if (tcp_comm_->IsConnected(dst_url)) {
|
||||
rt = true;
|
||||
break;
|
||||
}
|
||||
timeout = timeout - usleep_count;
|
||||
usleep(usleep_count);
|
||||
}
|
||||
return rt;
|
||||
}
|
||||
|
||||
bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
|
||||
bool rt = false;
|
||||
tcp_comm_->Disconnect(dst_url);
|
||||
|
||||
int timeout = timeout_in_sec * 1000 * 1000;
|
||||
size_t usleep_count = 100000;
|
||||
|
||||
while (timeout) {
|
||||
if (!tcp_comm_->IsConnected(dst_url)) {
|
||||
rt = true;
|
||||
break;
|
||||
}
|
||||
timeout = timeout - usleep_count;
|
||||
usleep(usleep_count);
|
||||
}
|
||||
return rt;
|
||||
}
|
||||
|
||||
int TCPClient::Send(std::unique_ptr<MessageBase> &&msg) {
|
||||
int rt = -1;
|
||||
rt = tcp_comm_->Send(msg.release());
|
||||
return rt;
|
||||
}
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_CLIENT_H_
|
||||
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_CLIENT_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "distributed/rpc/tcp/tcp_comm.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
class TCPClient {
|
||||
public:
|
||||
TCPClient() = default;
|
||||
~TCPClient() = default;
|
||||
|
||||
// Build or destroy the TCP client.
|
||||
bool Initialize();
|
||||
void Finalize();
|
||||
|
||||
// Connect to the specified server.
|
||||
bool Connect(const std::string &dst_url, size_t timeout_in_sec = 5);
|
||||
|
||||
// Disconnect from the specified server.
|
||||
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5);
|
||||
|
||||
// Send the message from the source to the destination.
|
||||
int Send(std::unique_ptr<MessageBase> &&msg);
|
||||
|
||||
private:
|
||||
// The basic TCP communication component used by the client.
|
||||
std::unique_ptr<TCPComm> tcp_comm_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(TCPClient);
|
||||
};
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -21,47 +21,12 @@
|
|||
#include <memory>
|
||||
|
||||
#include "actor/aid.h"
|
||||
#include "actor/msg.h"
|
||||
#include "distributed/rpc/tcp/constants.h"
|
||||
#include "distributed/rpc/tcp/tcp_socket_operation.h"
|
||||
#include "distributed/rpc/tcp/connection_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
bool TCPComm::is_http_msg_ = false;
|
||||
std::vector<char> TCPComm::advertise_url_;
|
||||
uint64_t TCPComm::output_buf_size_ = 0;
|
||||
|
||||
IOMgr::MessageHandler TCPComm::message_handler_;
|
||||
|
||||
int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
|
||||
ConnectionCallBack write_callback, ConnectionCallBack read_callback) {
|
||||
SocketAddress addr;
|
||||
if (!SocketOperation::GetSockAddr(to, &addr)) {
|
||||
return -1;
|
||||
}
|
||||
int sock_fd = SocketOperation::CreateServerSocket(addr.sa.sa_family);
|
||||
if (sock_fd < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
conn->socket_fd = sock_fd;
|
||||
conn->event_callback = event_callback;
|
||||
conn->write_callback = write_callback;
|
||||
conn->read_callback = read_callback;
|
||||
|
||||
int ret = TCPComm::Connect(conn, (struct sockaddr *)&addr, sizeof(addr));
|
||||
if (ret < 0) {
|
||||
if (close(sock_fd) != 0) {
|
||||
MS_LOG(ERROR) << "Failed to close fd:" << sock_fd;
|
||||
}
|
||||
conn->socket_fd = -1;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError) {
|
||||
if (LOG_CHECK_EVERY_N()) {
|
||||
MS_LOG(INFO) << "Failed to call connect, fd: " << fd << ", to: " << conn->destination.c_str()
|
||||
|
@ -137,13 +102,16 @@ void OnAccept(int server, uint32_t events, void *arg) {
|
|||
}
|
||||
|
||||
conn->socket_fd = acceptFd;
|
||||
conn->source = TCPComm::advertise_url_.data();
|
||||
conn->source = tcpmgr->url_;
|
||||
conn->peer = SocketOperation::GetPeer(acceptFd);
|
||||
|
||||
conn->is_remote = true;
|
||||
conn->recv_event_loop = tcpmgr->recv_event_loop_;
|
||||
conn->send_event_loop = tcpmgr->send_event_loop_;
|
||||
|
||||
conn->conn_mutex = tcpmgr->conn_mutex_;
|
||||
conn->message_handler = tcpmgr->message_handler_;
|
||||
|
||||
conn->event_callback = TCPComm::EventCallBack;
|
||||
conn->write_callback = TCPComm::WriteCallBack;
|
||||
conn->read_callback = TCPComm::ReadCallBack;
|
||||
|
@ -165,7 +133,7 @@ void OnAccept(int server, uint32_t events, void *arg) {
|
|||
void DoSend(Connection *conn) {
|
||||
while (!conn->send_message_queue.empty() || conn->total_send_len != 0) {
|
||||
if (conn->total_send_len == 0) {
|
||||
conn->FillSendMessage(conn->send_message_queue.front(), TCPComm::advertise_url_.data(), TCPComm::IsHttpMsg());
|
||||
conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false);
|
||||
conn->send_message_queue.pop();
|
||||
}
|
||||
|
||||
|
@ -175,7 +143,6 @@ void DoSend(Connection *conn) {
|
|||
// update metrics
|
||||
conn->send_metrics->UpdateError(false);
|
||||
|
||||
TCPComm::output_buf_size_ -= conn->send_message->body.size();
|
||||
conn->output_buffer_size -= conn->send_message->body.size();
|
||||
delete conn->send_message;
|
||||
conn->send_message = nullptr;
|
||||
|
@ -201,25 +168,15 @@ TCPComm::~TCPComm() {
|
|||
}
|
||||
}
|
||||
|
||||
void TCPComm::SendExitMsg(const std::string &from, const std::string &to) {
|
||||
if (message_handler_ != nullptr) {
|
||||
std::unique_ptr<MessageBase> exit_msg = std::make_unique<MessageBase>(MessageBase::Type::KEXIT);
|
||||
MS_EXCEPTION_IF_NULL(exit_msg);
|
||||
|
||||
exit_msg->SetFrom(AID(from));
|
||||
exit_msg->SetTo(AID(to));
|
||||
|
||||
message_handler_(std::move(exit_msg));
|
||||
}
|
||||
}
|
||||
|
||||
void TCPComm::SetMessageHandler(IOMgr::MessageHandler handler) { message_handler_ = handler; }
|
||||
void TCPComm::SetMessageHandler(MessageHandler handler) { message_handler_ = handler; }
|
||||
|
||||
bool TCPComm::Initialize() {
|
||||
if (ConnectionPool::GetConnectionPool() == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create connection pool.";
|
||||
return false;
|
||||
}
|
||||
conn_pool_ = std::make_shared<ConnectionPool>();
|
||||
MS_EXCEPTION_IF_NULL(conn_pool_);
|
||||
|
||||
conn_mutex_ = std::make_shared<std::mutex>();
|
||||
MS_EXCEPTION_IF_NULL(conn_mutex_);
|
||||
|
||||
recv_event_loop_ = new (std::nothrow) EventLoop();
|
||||
if (recv_event_loop_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create recv evLoop.";
|
||||
|
@ -251,62 +208,35 @@ bool TCPComm::Initialize() {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (g_httpKmsgEnable >= 0) {
|
||||
is_http_msg_ = (g_httpKmsgEnable == 0) ? false : true;
|
||||
}
|
||||
|
||||
ConnectionPool::GetConnectionPool()->SetLinkPattern(is_http_msg_);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TCPComm::StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) {
|
||||
bool TCPComm::StartServerSocket(const std::string &url) {
|
||||
server_fd_ = SocketOperation::Listen(url);
|
||||
if (server_fd_ < 0) {
|
||||
MS_LOG(ERROR) << "Failed to call socket listen, url: " << url.c_str()
|
||||
<< ", advertise_url_: " << advertise_url_.data();
|
||||
MS_LOG(ERROR) << "Failed to call socket listen, url: " << url.c_str();
|
||||
return false;
|
||||
}
|
||||
url_ = url;
|
||||
std::string tmp_url;
|
||||
|
||||
if (aAdvertiseUrl.size() > 0) {
|
||||
advertise_url_.resize(aAdvertiseUrl.size());
|
||||
advertise_url_.assign(aAdvertiseUrl.begin(), aAdvertiseUrl.end());
|
||||
tmp_url = aAdvertiseUrl;
|
||||
} else {
|
||||
advertise_url_.resize(url_.size());
|
||||
advertise_url_.assign(url_.begin(), url_.end());
|
||||
tmp_url = url_;
|
||||
}
|
||||
|
||||
size_t index = url.find(URL_PROTOCOL_IP_SEPARATOR);
|
||||
if (index != std::string::npos) {
|
||||
url_ = url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
|
||||
}
|
||||
|
||||
index = tmp_url.find(URL_PROTOCOL_IP_SEPARATOR);
|
||||
if (index != std::string::npos) {
|
||||
tmp_url = tmp_url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
|
||||
advertise_url_.resize(tmp_url.size());
|
||||
advertise_url_.assign(tmp_url.begin(), tmp_url.end());
|
||||
}
|
||||
|
||||
// Register read event callback for server socket
|
||||
int retval = recv_event_loop_->SetEventHandler(server_fd_, EPOLLIN | EPOLLHUP | EPOLLERR, OnAccept,
|
||||
reinterpret_cast<void *>(this));
|
||||
if (retval != RPC_OK) {
|
||||
MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str()
|
||||
<< ", advertise_url_: " << advertise_url_.data();
|
||||
MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str();
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str()
|
||||
<< ", advertise_url_ :" << advertise_url_.data();
|
||||
MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str();
|
||||
return true;
|
||||
}
|
||||
|
||||
void TCPComm::ReadCallBack(void *context) {
|
||||
void TCPComm::ReadCallBack(void *connection) {
|
||||
const int max_recv_count = 3;
|
||||
Connection *conn = reinterpret_cast<Connection *>(context);
|
||||
Connection *conn = reinterpret_cast<Connection *>(connection);
|
||||
int count = 0;
|
||||
int retval = 0;
|
||||
do {
|
||||
|
@ -317,35 +247,35 @@ void TCPComm::ReadCallBack(void *context) {
|
|||
return;
|
||||
}
|
||||
|
||||
void TCPComm::EventCallBack(void *context) {
|
||||
Connection *conn = reinterpret_cast<Connection *>(context);
|
||||
void TCPComm::EventCallBack(void *connection) {
|
||||
Connection *conn = reinterpret_cast<Connection *>(connection);
|
||||
|
||||
if (conn->state == ConnectionState::kConnected) {
|
||||
Connection::conn_mutex.lock();
|
||||
conn->conn_mutex->lock();
|
||||
DoSend(conn);
|
||||
Connection::conn_mutex.unlock();
|
||||
conn->conn_mutex->unlock();
|
||||
} else if (conn->state == ConnectionState::kDisconnecting) {
|
||||
Connection::conn_mutex.lock();
|
||||
output_buf_size_ -= conn->output_buffer_size;
|
||||
ConnectionPool::GetConnectionPool()->CloseConnection(conn);
|
||||
Connection::conn_mutex.unlock();
|
||||
conn->conn_mutex->lock();
|
||||
conn->conn_mutex->unlock();
|
||||
}
|
||||
}
|
||||
|
||||
void TCPComm::WriteCallBack(void *context) {
|
||||
Connection *conn = reinterpret_cast<Connection *>(context);
|
||||
void TCPComm::WriteCallBack(void *connection) {
|
||||
Connection *conn = reinterpret_cast<Connection *>(connection);
|
||||
if (conn->state == ConnectionState::kConnected) {
|
||||
Connection::conn_mutex.lock();
|
||||
conn->conn_mutex->lock();
|
||||
DoSend(conn);
|
||||
Connection::conn_mutex.unlock();
|
||||
conn->conn_mutex->unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/* static method */
|
||||
int TCPComm::ReceiveMessage(Connection *conn) {
|
||||
std::lock_guard<std::mutex> lock(*conn->conn_mutex);
|
||||
conn->CheckMessageType();
|
||||
switch (conn->recv_message_type) {
|
||||
case ParseType::kTcpMsg:
|
||||
return conn->ReceiveMessage(message_handler_);
|
||||
return conn->ReceiveMessage();
|
||||
|
||||
#ifdef HTTP_ENABLED
|
||||
case ParseType::KHTTP_REQ:
|
||||
|
@ -370,6 +300,7 @@ int TCPComm::ReceiveMessage(Connection *conn) {
|
|||
}
|
||||
}
|
||||
|
||||
/* static method */
|
||||
int TCPComm::SetConnectedHandler(Connection *conn) {
|
||||
/* add to epoll */
|
||||
return conn->recv_event_loop->SetEventHandler(conn->socket_fd,
|
||||
|
@ -377,7 +308,8 @@ int TCPComm::SetConnectedHandler(Connection *conn) {
|
|||
ConnectedEventHandler, reinterpret_cast<void *>(conn));
|
||||
}
|
||||
|
||||
int TCPComm::Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
|
||||
/* static method */
|
||||
int TCPComm::DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
|
||||
int retval = 0;
|
||||
uint16_t localPort = 0;
|
||||
|
||||
|
@ -406,195 +338,93 @@ int TCPComm::Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLe
|
|||
return RPC_OK;
|
||||
}
|
||||
|
||||
void TCPComm::Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) {
|
||||
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
|
||||
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, isExactNotRemote);
|
||||
|
||||
// Create a new connection if the connection to target of the message does not existed.
|
||||
if (conn == nullptr) {
|
||||
if (remoteLink && (!isExactNotRemote)) {
|
||||
MS_LOG(ERROR) << "Could not found remote link and send fail name: " << msg->name.c_str()
|
||||
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
|
||||
delete msg;
|
||||
return;
|
||||
}
|
||||
conn = new (std::nothrow) Connection();
|
||||
if (conn == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create new connection and send fail name: " << msg->name.c_str()
|
||||
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
|
||||
delete msg;
|
||||
return;
|
||||
}
|
||||
conn->source = advertise_url_.data();
|
||||
conn->destination = msg->to.Url();
|
||||
conn->recv_event_loop = tcpmgr->recv_event_loop_;
|
||||
conn->send_event_loop = tcpmgr->send_event_loop_;
|
||||
conn->InitSocketOperation();
|
||||
|
||||
int ret = DoConnect(msg->to.Url(), conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
|
||||
if (ret < 0) {
|
||||
MS_LOG(ERROR) << "Failed to do connect and send fail name: " << msg->name.c_str()
|
||||
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
|
||||
if (conn->socket_operation != nullptr) {
|
||||
delete conn->socket_operation;
|
||||
conn->socket_operation = nullptr;
|
||||
}
|
||||
delete conn;
|
||||
delete msg;
|
||||
return;
|
||||
}
|
||||
ConnectionPool::GetConnectionPool()->AddConnection(conn);
|
||||
}
|
||||
|
||||
if (!conn->is_remote && !isExactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) {
|
||||
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true);
|
||||
if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) {
|
||||
conn = remoteConn;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the message.
|
||||
if (conn->total_send_len == 0) {
|
||||
conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_);
|
||||
} else {
|
||||
(void)conn->send_message_queue.emplace(msg);
|
||||
}
|
||||
|
||||
// Send the message.
|
||||
if (conn->state == ConnectionState::kConnected) {
|
||||
DoSend(conn);
|
||||
}
|
||||
/* static method */
|
||||
void TCPComm::DropMessage(MessageBase *msg) {
|
||||
auto *ptr = msg;
|
||||
delete ptr;
|
||||
ptr = nullptr;
|
||||
}
|
||||
|
||||
void TCPComm::SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) {
|
||||
(void)recv_event_loop_->AddTask(
|
||||
[msg, tcpmgr, remoteLink, isExactNotRemote] { TCPComm::Send(msg, tcpmgr, remoteLink, isExactNotRemote); });
|
||||
}
|
||||
|
||||
int TCPComm::Send(MessageBase *msg, bool remoteLink, bool isExactNotRemote) {
|
||||
return send_event_loop_->AddTask([msg, this, remoteLink, isExactNotRemote] {
|
||||
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
|
||||
int TCPComm::Send(MessageBase *msg) {
|
||||
return send_event_loop_->AddTask([msg, this] {
|
||||
std::lock_guard<std::mutex> lock(*conn_mutex_);
|
||||
// Search connection by the target address
|
||||
bool exactNotRemote = is_http_msg_ || isExactNotRemote;
|
||||
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, exactNotRemote);
|
||||
Connection *conn = conn_pool_->FindConnection(msg->to.Url());
|
||||
if (conn == nullptr) {
|
||||
if (remoteLink && (!exactNotRemote)) {
|
||||
MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str()
|
||||
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
|
||||
auto *ptr = msg;
|
||||
delete ptr;
|
||||
ptr = nullptr;
|
||||
|
||||
return;
|
||||
}
|
||||
this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote);
|
||||
MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str()
|
||||
<< ", from: " << msg->from.Url().c_str() << ", to: " << msg->to.Url().c_str();
|
||||
DropMessage(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
if (conn->state != kConnected && conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
|
||||
MS_LOG(WARNING) << "The name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
|
||||
<< ", to: " << conn->destination.c_str() << ", remote: " << conn->is_remote;
|
||||
auto *ptr = msg;
|
||||
delete ptr;
|
||||
ptr = nullptr;
|
||||
|
||||
if (conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
|
||||
MS_LOG(WARNING) << "The message queue is full(max len:" << SENDMSG_QUEUELEN
|
||||
<< ") and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
|
||||
<< ", to: " << conn->destination.c_str();
|
||||
DropMessage(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
if (conn->state == ConnectionState::kClose || conn->state == ConnectionState::kDisconnecting) {
|
||||
this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote);
|
||||
if (conn->state != ConnectionState::kConnected) {
|
||||
MS_LOG(WARNING) << "Invalid connection state " << conn->state
|
||||
<< " and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
|
||||
<< ", to: " << conn->destination.c_str();
|
||||
DropMessage(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!conn->is_remote && !exactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) {
|
||||
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true);
|
||||
if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) {
|
||||
conn = remoteConn;
|
||||
}
|
||||
}
|
||||
|
||||
output_buf_size_ += msg->body.size();
|
||||
if (conn->total_send_len == 0) {
|
||||
conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_);
|
||||
conn->FillSendMessage(msg, url_, false);
|
||||
} else {
|
||||
(void)conn->send_message_queue.emplace(msg);
|
||||
}
|
||||
|
||||
if (conn->state == ConnectionState::kConnected) {
|
||||
DoSend(conn);
|
||||
}
|
||||
DoSend(conn);
|
||||
});
|
||||
}
|
||||
|
||||
void TCPComm::CollectMetrics() {
|
||||
(void)send_event_loop_->AddTask([this] {
|
||||
Connection::conn_mutex.lock();
|
||||
Connection *maxConn = ConnectionPool::GetConnectionPool()->FindMaxConnection();
|
||||
Connection *fastConn = ConnectionPool::GetConnectionPool()->FindFastConnection();
|
||||
|
||||
if (message_handler_ != nullptr) {
|
||||
IntTypeMetrics intMetrics;
|
||||
StringTypeMetrics stringMetrics;
|
||||
|
||||
if (maxConn != nullptr) {
|
||||
intMetrics.push(maxConn->socket_fd);
|
||||
intMetrics.push(maxConn->error_code);
|
||||
intMetrics.push(maxConn->send_metrics->accum_msg_count);
|
||||
intMetrics.push(maxConn->send_metrics->max_msg_size);
|
||||
stringMetrics.push(maxConn->destination);
|
||||
stringMetrics.push(maxConn->send_metrics->last_succ_msg_name);
|
||||
stringMetrics.push(maxConn->send_metrics->last_fail_msg_name);
|
||||
}
|
||||
if (fastConn != nullptr && fastConn->IsSame(maxConn)) {
|
||||
intMetrics.push(fastConn->socket_fd);
|
||||
intMetrics.push(fastConn->error_code);
|
||||
intMetrics.push(fastConn->send_metrics->accum_msg_count);
|
||||
intMetrics.push(fastConn->send_metrics->max_msg_size);
|
||||
stringMetrics.push(fastConn->destination);
|
||||
stringMetrics.push(fastConn->send_metrics->last_succ_msg_name);
|
||||
stringMetrics.push(fastConn->send_metrics->last_fail_msg_name);
|
||||
}
|
||||
}
|
||||
|
||||
ConnectionPool::GetConnectionPool()->ResetAllConnMetrics();
|
||||
Connection::conn_mutex.unlock();
|
||||
});
|
||||
}
|
||||
|
||||
int TCPComm::Send(std::unique_ptr<MessageBase> &&msg, bool remoteLink, bool isExactNotRemote) {
|
||||
return Send(msg.release(), remoteLink, isExactNotRemote);
|
||||
}
|
||||
|
||||
void TCPComm::Link(const AID &source, const AID &destination) {
|
||||
(void)recv_event_loop_->AddTask([source, destination, this] {
|
||||
std::string to = destination.Url();
|
||||
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
|
||||
void TCPComm::Connect(const std::string &dst_url) {
|
||||
(void)recv_event_loop_->AddTask([dst_url, this] {
|
||||
std::lock_guard<std::mutex> lock(*conn_mutex_);
|
||||
|
||||
// Search connection by the target address
|
||||
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
|
||||
Connection *conn = conn_pool_->FindConnection(dst_url);
|
||||
|
||||
if (conn == nullptr) {
|
||||
MS_LOG(INFO) << "Can not found link source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str();
|
||||
MS_LOG(INFO) << "Can not found link destination: " << dst_url;
|
||||
conn = new (std::nothrow) Connection();
|
||||
if (conn == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create new connection and link fail source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str();
|
||||
SendExitMsg(source, destination);
|
||||
MS_LOG(ERROR) << "Failed to create new connection and link fail destination: " << dst_url;
|
||||
return;
|
||||
}
|
||||
conn->source = advertise_url_.data();
|
||||
conn->destination = to;
|
||||
conn->source = url_;
|
||||
conn->destination = dst_url;
|
||||
|
||||
conn->recv_event_loop = this->recv_event_loop_;
|
||||
conn->send_event_loop = this->send_event_loop_;
|
||||
conn->conn_mutex = conn_mutex_;
|
||||
conn->message_handler = message_handler_;
|
||||
conn->InitSocketOperation();
|
||||
|
||||
int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
|
||||
// Create the client socket.
|
||||
SocketAddress addr;
|
||||
if (!SocketOperation::GetSockAddr(dst_url, &addr)) {
|
||||
MS_LOG(ERROR) << "Failed to get socket address to dest url " << dst_url;
|
||||
return;
|
||||
}
|
||||
int sock_fd = SocketOperation::CreateSocket(addr.sa.sa_family);
|
||||
if (sock_fd < 0) {
|
||||
MS_LOG(ERROR) << "Failed to create client tcp socket to dest url " << dst_url;
|
||||
return;
|
||||
}
|
||||
|
||||
conn->socket_fd = sock_fd;
|
||||
conn->event_callback = TCPComm::EventCallBack;
|
||||
conn->write_callback = TCPComm::WriteCallBack;
|
||||
conn->read_callback = TCPComm::ReadCallBack;
|
||||
|
||||
int ret = TCPComm::DoConnect(conn, (struct sockaddr *)&addr, sizeof(addr));
|
||||
if (ret < 0) {
|
||||
MS_LOG(ERROR) << "Failed to do connect and link fail source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str();
|
||||
SendExitMsg(source, destination);
|
||||
MS_LOG(ERROR) << "Failed to do connect and link fail destination: " << dst_url;
|
||||
if (conn->socket_operation != nullptr) {
|
||||
delete conn->socket_operation;
|
||||
conn->socket_operation = nullptr;
|
||||
|
@ -602,76 +432,26 @@ void TCPComm::Link(const AID &source, const AID &destination) {
|
|||
delete conn;
|
||||
return;
|
||||
}
|
||||
ConnectionPool::GetConnectionPool()->AddConnection(conn);
|
||||
conn_pool_->AddConnection(conn);
|
||||
}
|
||||
ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg);
|
||||
MS_LOG(INFO) << "Link fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote;
|
||||
conn_pool_->AddConnInfo(conn->socket_fd, dst_url, nullptr);
|
||||
MS_LOG(INFO) << "Connected to destination: " << dst_url;
|
||||
});
|
||||
}
|
||||
|
||||
void TCPComm::UnLink(const AID &destination) {
|
||||
(void)recv_event_loop_->AddTask([destination] {
|
||||
std::string to = destination.Url();
|
||||
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
|
||||
if (is_http_msg_) {
|
||||
// When application has set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link is in links map
|
||||
// while accepting-link is differently in remoteLinks map. So we only need to delete link in exact links.
|
||||
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false);
|
||||
} else {
|
||||
// When application hasn't set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link and accepting-link is
|
||||
// shared
|
||||
// So we need to delete link in both links map and remote-links map.
|
||||
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false);
|
||||
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, true);
|
||||
}
|
||||
});
|
||||
bool TCPComm::IsConnected(const std::string &dst_url) {
|
||||
Connection *conn = conn_pool_->FindConnection(dst_url);
|
||||
if (conn != nullptr && conn->state == ConnectionState::kConnected) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void TCPComm::DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd) {
|
||||
if (!is_http_msg_ && !conn->is_remote) {
|
||||
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(to, true);
|
||||
// We will close remote link in rare cases where sending-link and accepting link coexists
|
||||
// simultaneously.
|
||||
if (remoteConn != nullptr) {
|
||||
MS_LOG(INFO) << "Reconnect, close remote connect fd :" << remoteConn->socket_fd
|
||||
<< ", source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str() << ", remote: " << remoteConn->is_remote
|
||||
<< ", state: " << remoteConn->state;
|
||||
ConnectionPool::GetConnectionPool()->CloseConnection(remoteConn);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Reconnect, close old connect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote
|
||||
<< ", state: " << conn->state;
|
||||
|
||||
*oldFd = conn->socket_fd;
|
||||
|
||||
if (conn->recv_event_loop->DeleteEpollEvent(conn->socket_fd) == RPC_ERROR) {
|
||||
MS_LOG(ERROR) << "Failed to delete epoll event: " << conn->socket_fd;
|
||||
}
|
||||
conn->socket_operation->Close(conn);
|
||||
|
||||
conn->socket_fd = -1;
|
||||
conn->recv_len = 0;
|
||||
|
||||
conn->total_recv_len = 0;
|
||||
conn->recv_message_type = kUnknown;
|
||||
conn->state = kInit;
|
||||
if (conn->total_send_len != 0 && conn->send_message != nullptr) {
|
||||
delete conn->send_message;
|
||||
}
|
||||
conn->send_message = nullptr;
|
||||
conn->total_send_len = 0;
|
||||
|
||||
if (conn->total_recv_len != 0 && conn->recv_message != nullptr) {
|
||||
delete conn->recv_message;
|
||||
}
|
||||
conn->recv_message = nullptr;
|
||||
conn->total_recv_len = 0;
|
||||
|
||||
conn->recv_state = State::kMsgHeader;
|
||||
void TCPComm::Disconnect(const std::string &dst_url) {
|
||||
(void)recv_event_loop_->AddTask([dst_url, this] {
|
||||
std::lock_guard<std::mutex> lock(*conn_mutex_);
|
||||
conn_pool_->DeleteConnection(dst_url);
|
||||
});
|
||||
}
|
||||
|
||||
Connection *TCPComm::CreateDefaultConn(std::string to) {
|
||||
|
@ -680,66 +460,16 @@ Connection *TCPComm::CreateDefaultConn(std::string to) {
|
|||
MS_LOG(ERROR) << "Failed to create new connection and reconnect fail to: " << to.c_str();
|
||||
return conn;
|
||||
}
|
||||
conn->source = advertise_url_.data();
|
||||
conn->source = url_.data();
|
||||
conn->destination = to;
|
||||
conn->recv_event_loop = this->recv_event_loop_;
|
||||
conn->send_event_loop = this->send_event_loop_;
|
||||
conn->conn_mutex = conn_mutex_;
|
||||
conn->message_handler = message_handler_;
|
||||
conn->InitSocketOperation();
|
||||
return conn;
|
||||
}
|
||||
|
||||
void TCPComm::Reconnect(const AID &source, const AID &destination) {
|
||||
(void)send_event_loop_->AddTask([source, destination, this] {
|
||||
std::string to = destination.Url();
|
||||
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
|
||||
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
|
||||
if (conn != nullptr) {
|
||||
conn->state = ConnectionState::kClose;
|
||||
}
|
||||
|
||||
(void)recv_event_loop_->AddTask([source, destination, this] {
|
||||
std::string to = destination.Url();
|
||||
int oldFd = -1;
|
||||
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
|
||||
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
|
||||
if (conn != nullptr) {
|
||||
// connection already exist
|
||||
DoReConnectConn(conn, to, source, destination, &oldFd);
|
||||
} else {
|
||||
// create default connection
|
||||
conn = CreateDefaultConn(to);
|
||||
if (conn == nullptr) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
|
||||
if (ret < 0) {
|
||||
if (conn->socket_operation != nullptr) {
|
||||
delete conn->socket_operation;
|
||||
conn->socket_operation = nullptr;
|
||||
}
|
||||
if (oldFd != -1) {
|
||||
conn->socket_fd = oldFd;
|
||||
}
|
||||
MS_LOG(ERROR) << "Failed to connect and reconnect fail source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str();
|
||||
ConnectionPool::GetConnectionPool()->CloseConnection(conn);
|
||||
return;
|
||||
}
|
||||
if (oldFd != -1) {
|
||||
if (!ConnectionPool::GetConnectionPool()->ReverseConnInfo(oldFd, conn->socket_fd)) {
|
||||
MS_LOG(ERROR) << "Failed to swap socket for " << oldFd << " and " << conn->socket_fd;
|
||||
}
|
||||
} else {
|
||||
ConnectionPool::GetConnectionPool()->AddConnection(conn);
|
||||
}
|
||||
ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg);
|
||||
MS_LOG(INFO) << "Reconnect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
|
||||
<< ", destination: " << std::string(destination).c_str();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void TCPComm::Finalize() {
|
||||
if (send_event_loop_ != nullptr) {
|
||||
MS_LOG(INFO) << "Delete send event loop";
|
||||
|
@ -762,13 +492,6 @@ void TCPComm::Finalize() {
|
|||
server_fd_ = -1;
|
||||
}
|
||||
}
|
||||
|
||||
// This value is not used for the sub-class TCPComm.
|
||||
uint64_t TCPComm::GetInBufSize() { return 1; }
|
||||
|
||||
uint64_t TCPComm::GetOutBufSize() { return output_buf_size_; }
|
||||
|
||||
bool TCPComm::IsHttpMsg() { return is_http_msg_; }
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,9 +20,11 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
#include "actor/iomgr.h"
|
||||
#include "actor/msg.h"
|
||||
#include "distributed/rpc/tcp/connection.h"
|
||||
#include "distributed/rpc/tcp/connection_pool.h"
|
||||
#include "distributed/rpc/tcp/event_loop.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -34,15 +36,11 @@ void OnAccept(int server, uint32_t events, void *arg);
|
|||
// Send messages buffered in the connection.
|
||||
void DoSend(Connection *conn);
|
||||
|
||||
// Create a server socket and connect to it, this is a local connection..
|
||||
int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack eventCallBack,
|
||||
ConnectionCallBack writeCallBack, ConnectionCallBack readCallBack);
|
||||
|
||||
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
|
||||
|
||||
void ConnectedEventHandler(int fd, uint32_t events, void *context);
|
||||
|
||||
class TCPComm : public IOMgr {
|
||||
class TCPComm {
|
||||
public:
|
||||
TCPComm() : server_fd_(-1), recv_event_loop_(nullptr), send_event_loop_(nullptr) {}
|
||||
TCPComm(const TCPComm &) = delete;
|
||||
|
@ -50,52 +48,46 @@ class TCPComm : public IOMgr {
|
|||
~TCPComm();
|
||||
|
||||
// Init the event loop for reading and writing.
|
||||
bool Initialize() override;
|
||||
bool Initialize();
|
||||
|
||||
// Destroy all the resources.
|
||||
void Finalize() override;
|
||||
void Finalize();
|
||||
|
||||
// Create the server socket represented by url.
|
||||
bool StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) override;
|
||||
bool StartServerSocket(const std::string &url);
|
||||
|
||||
// Build a connection between the source and destination.
|
||||
void Link(const AID &source, const AID &destination) override;
|
||||
void UnLink(const AID &destination) override;
|
||||
// Connection operation for a specified destination.
|
||||
void Connect(const std::string &dst_url);
|
||||
bool IsConnected(const std::string &dst_url);
|
||||
void Disconnect(const std::string &dst_url);
|
||||
|
||||
// Send the message from the source to the destination.
|
||||
int Send(std::unique_ptr<MessageBase> &&msg, bool remoteLink = false, bool isExactNotRemote = false) override;
|
||||
int Send(MessageBase *msg);
|
||||
|
||||
uint64_t GetInBufSize() override;
|
||||
uint64_t GetOutBufSize() override;
|
||||
void CollectMetrics() override;
|
||||
// Set the message processing handler.
|
||||
void SetMessageHandler(MessageHandler handler);
|
||||
|
||||
private:
|
||||
// Build the connection.
|
||||
Connection *CreateDefaultConn(std::string to);
|
||||
void Reconnect(const AID &source, const AID &destination);
|
||||
void DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd);
|
||||
|
||||
// Send a message.
|
||||
int Send(MessageBase *msg, bool remoteLink = false, bool isExactNotRemote = false);
|
||||
static void Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote);
|
||||
void SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote);
|
||||
static void SendExitMsg(const std::string &from, const std::string &to);
|
||||
|
||||
// Called by ReadCallBack when new message arrived.
|
||||
static int ReceiveMessage(Connection *conn);
|
||||
|
||||
void SetMessageHandler(IOMgr::MessageHandler handler);
|
||||
static int SetConnectedHandler(Connection *conn);
|
||||
|
||||
static int Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen);
|
||||
static int DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen);
|
||||
|
||||
static bool IsHttpMsg();
|
||||
static void DropMessage(MessageBase *msg);
|
||||
|
||||
// Read and write events.
|
||||
static void ReadCallBack(void *context);
|
||||
static void WriteCallBack(void *context);
|
||||
static void ReadCallBack(void *conn);
|
||||
static void WriteCallBack(void *conn);
|
||||
// Connected and Disconnected events.
|
||||
static void EventCallBack(void *context);
|
||||
static void EventCallBack(void *conn);
|
||||
|
||||
// The server url.
|
||||
std::string url_;
|
||||
|
@ -103,21 +95,19 @@ class TCPComm : public IOMgr {
|
|||
// The socket of server.
|
||||
int server_fd_;
|
||||
|
||||
// The message size waiting to be sent.
|
||||
static uint64_t output_buf_size_;
|
||||
|
||||
// User defined handler for Handling received messages.
|
||||
static MessageHandler message_handler_;
|
||||
|
||||
// The source url of a message.
|
||||
static std::vector<char> advertise_url_;
|
||||
|
||||
static bool is_http_msg_;
|
||||
MessageHandler message_handler_;
|
||||
|
||||
// All the connections share the same read and write event loop objects.
|
||||
EventLoop *recv_event_loop_;
|
||||
EventLoop *send_event_loop_;
|
||||
|
||||
// The connection pool used to store new connections.
|
||||
std::shared_ptr<ConnectionPool> conn_pool_;
|
||||
|
||||
// The mutex for connection operations.
|
||||
std::shared_ptr<std::mutex> conn_mutex_;
|
||||
|
||||
friend void OnAccept(int server, uint32_t events, void *arg);
|
||||
friend void DoSend(Connection *conn);
|
||||
friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
bool TCPServer::Initialize(const std::string &url) {
|
||||
if (tcp_comm_ == nullptr) {
|
||||
tcp_comm_ = std::make_unique<TCPComm>();
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm_);
|
||||
bool rt = tcp_comm_->Initialize();
|
||||
if (!rt) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp comm";
|
||||
}
|
||||
rt = tcp_comm_->StartServerSocket(url);
|
||||
return rt;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void TCPServer::Finalize() {
|
||||
if (tcp_comm_ != nullptr) {
|
||||
tcp_comm_.reset();
|
||||
tcp_comm_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void TCPServer::SetMessageHandler(MessageHandler handler) { tcp_comm_->SetMessageHandler(handler); }
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_SERVER_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "distributed/rpc/tcp/tcp_comm.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
class TCPServer {
|
||||
public:
|
||||
TCPServer() = default;
|
||||
~TCPServer() = default;
|
||||
|
||||
// Init the tcp server.
|
||||
bool Initialize(const std::string &url);
|
||||
|
||||
// Destroy the tcp server.
|
||||
void Finalize();
|
||||
|
||||
// Set the message processing handler.
|
||||
void SetMessageHandler(MessageHandler handler);
|
||||
|
||||
private:
|
||||
// The basic TCP communication component used by the server.
|
||||
std::unique_ptr<TCPComm> tcp_comm_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(TCPServer);
|
||||
};
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -24,9 +24,8 @@
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
#define private public
|
||||
#include "actor/iomgr.h"
|
||||
#include "async/async.h"
|
||||
#include "distributed/rpc/tcp/tcp_comm.h"
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -35,26 +34,36 @@ namespace rpc {
|
|||
int g_recv_num = 0;
|
||||
int g_exit_msg_num = 0;
|
||||
|
||||
TCPComm *m_io = nullptr;
|
||||
static size_t g_data_msg_num = 0;
|
||||
|
||||
static void Init() { g_data_msg_num = 0; }
|
||||
|
||||
static bool WaitForDataMsg(size_t expected_msg_num, int timeout_in_sec) {
|
||||
bool rt = false;
|
||||
int timeout = timeout_in_sec * 1000 * 1000;
|
||||
int usleepCount = 100000;
|
||||
|
||||
while (timeout) {
|
||||
if (g_data_msg_num == expected_msg_num) {
|
||||
rt = true;
|
||||
break;
|
||||
}
|
||||
timeout = timeout - usleepCount;
|
||||
usleep(usleepCount);
|
||||
}
|
||||
return rt;
|
||||
}
|
||||
|
||||
static void IncrDataMsgNum(size_t number) { g_data_msg_num += number; }
|
||||
|
||||
static size_t GetDataMsgNum() { return g_data_msg_num; }
|
||||
|
||||
std::atomic<int> m_sendNum(0);
|
||||
std::string m_localIP = "127.0.0.1";
|
||||
bool m_notRemote = false;
|
||||
|
||||
void msgHandle(std::unique_ptr<MessageBase> &&msg) {
|
||||
if (msg->GetType() == MessageBase::Type::KEXIT) {
|
||||
g_exit_msg_num++;
|
||||
} else {
|
||||
g_recv_num++;
|
||||
}
|
||||
}
|
||||
|
||||
class TCPTest : public UT::Common {
|
||||
public:
|
||||
static void SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink = false,
|
||||
std::string body = "");
|
||||
|
||||
protected:
|
||||
char *args[4];
|
||||
char *testServerPath;
|
||||
static const size_t pid_num = 100;
|
||||
pid_t pid1;
|
||||
|
@ -62,124 +71,24 @@ class TCPTest : public UT::Common {
|
|||
|
||||
pid_t pids[pid_num];
|
||||
|
||||
void SetUp() {
|
||||
char *localpEnv = getenv("LITEBUS_IP");
|
||||
if (localpEnv != nullptr) {
|
||||
m_localIP = std::string(localpEnv);
|
||||
}
|
||||
void SetUp() {}
|
||||
void TearDown() {}
|
||||
|
||||
char *locaNotRemoteEnv = getenv("LITEBUS_SEND_ON_REMOTE");
|
||||
if (locaNotRemoteEnv != nullptr) {
|
||||
m_notRemote = (std::string(locaNotRemoteEnv) == "true") ? true : false;
|
||||
}
|
||||
|
||||
pid1 = 0;
|
||||
pid2 = 0;
|
||||
pids[pid_num] = {0};
|
||||
size_t size = pid_num * sizeof(pid_t);
|
||||
if (memset_s(&pids, size, 0, size)) {
|
||||
MS_LOG(ERROR) << "Failed to init pid array";
|
||||
}
|
||||
g_recv_num = 0;
|
||||
g_exit_msg_num = 0;
|
||||
m_sendNum = 0;
|
||||
|
||||
m_io = new TCPComm();
|
||||
m_io->Initialize();
|
||||
m_io->SetMessageHandler(msgHandle);
|
||||
m_io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225");
|
||||
}
|
||||
|
||||
void TearDown() {
|
||||
shutdownTcpServer(pid1);
|
||||
shutdownTcpServer(pid2);
|
||||
pid1 = 0;
|
||||
pid2 = 0;
|
||||
int i = 0;
|
||||
for (i = 0; i < pid_num; i++) {
|
||||
shutdownTcpServer(pids[i]);
|
||||
pids[i] = 0;
|
||||
}
|
||||
g_recv_num = 0;
|
||||
g_exit_msg_num = 0;
|
||||
m_sendNum = 0;
|
||||
m_io->Finalize();
|
||||
delete m_io;
|
||||
m_io = nullptr;
|
||||
}
|
||||
std::unique_ptr<MessageBase> CreateMessage(const std::string &serverUrl, const std::string &client_url);
|
||||
|
||||
bool CheckRecvNum(int expectedRecvNum, int _timeout);
|
||||
bool CheckExitNum(int expectedExitNum, int _timeout);
|
||||
pid_t startTcpServer(char **args);
|
||||
void shutdownTcpServer(pid_t pid);
|
||||
void KillTcpServer(pid_t pid);
|
||||
|
||||
void Link(std::string &_localUrl, std::string &_remoteUrl);
|
||||
void Reconnect(std::string &_localUrl, std::string &_remoteUrl);
|
||||
void Unlink(std::string &_remoteUrl);
|
||||
};
|
||||
|
||||
// listening local url and sending msg to remote url,if start succ.
|
||||
pid_t TCPTest::startTcpServer(char **args) {
|
||||
pid_t pid = fork();
|
||||
if (pid == 0) {
|
||||
return -1;
|
||||
} else {
|
||||
return pid;
|
||||
}
|
||||
}
|
||||
void TCPTest::shutdownTcpServer(pid_t pid) {
|
||||
if (pid > 1) {
|
||||
kill(pid, SIGALRM);
|
||||
int status;
|
||||
waitpid(pid, &status, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void TCPTest::KillTcpServer(pid_t pid) {
|
||||
if (pid > 1) {
|
||||
kill(pid, SIGKILL);
|
||||
int status;
|
||||
waitpid(pid, &status, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void TCPTest::SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink, std::string body) {
|
||||
AID from("testserver", _localUrl);
|
||||
AID to("testserver", _remoteUrl);
|
||||
|
||||
std::unique_ptr<MessageBase> TCPTest::CreateMessage(const std::string &serverUrl, const std::string &clientUrl) {
|
||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||
std::string data(msgsize, 'A');
|
||||
size_t len = 100;
|
||||
std::string data(len, 'A');
|
||||
message->name = "testname";
|
||||
message->from = from;
|
||||
message->to = to;
|
||||
message->from = AID("client", clientUrl);
|
||||
message->to = AID("server", serverUrl);
|
||||
message->body = data;
|
||||
if (body != "") {
|
||||
message->body = body;
|
||||
}
|
||||
|
||||
if (m_notRemote) {
|
||||
m_io->Send(std::move(message), remoteLink, true);
|
||||
} else {
|
||||
m_io->Send(std::move(message), remoteLink);
|
||||
}
|
||||
}
|
||||
|
||||
void TCPTest::Link(std::string &_localUrl, std::string &_remoteUrl) {
|
||||
AID from("testserver", _localUrl);
|
||||
AID to("testserver", _remoteUrl);
|
||||
m_io->Link(from, to);
|
||||
}
|
||||
|
||||
void TCPTest::Reconnect(std::string &_localUrl, std::string &_remoteUrl) {
|
||||
AID from("testserver", _localUrl);
|
||||
AID to("testserver", _remoteUrl);
|
||||
m_io->Reconnect(from, to);
|
||||
}
|
||||
|
||||
void TCPTest::Unlink(std::string &_remoteUrl) {
|
||||
AID to("testserver", _remoteUrl);
|
||||
m_io->UnLink(to);
|
||||
return message;
|
||||
}
|
||||
|
||||
bool TCPTest::CheckRecvNum(int expectedRecvNum, int _timeout) {
|
||||
|
@ -215,47 +124,98 @@ bool TCPTest::CheckExitNum(int expectedExitNum, int _timeout) {
|
|||
/// Description: start a socket server with an invalid url.
|
||||
/// Expectation: failed to start the server with invalid url.
|
||||
TEST_F(TCPTest, StartServerFail) {
|
||||
std::unique_ptr<TCPComm> io = std::make_unique<TCPComm>();
|
||||
io->Initialize();
|
||||
|
||||
bool ret = io->StartServerSocket("tcp://0:2225", "tcp://0:2225");
|
||||
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
|
||||
bool ret = server->Initialize("0:2225");
|
||||
ASSERT_FALSE(ret);
|
||||
io->Finalize();
|
||||
server->Finalize();
|
||||
}
|
||||
|
||||
/// Feature: test start a socket server.
|
||||
/// Description: start the socket server with a specified socket.
|
||||
/// Expectation: the socket server is started successfully.
|
||||
TEST_F(TCPTest, StartServer2) {
|
||||
std::unique_ptr<TCPComm> io = std::make_unique<TCPComm>();
|
||||
io->Initialize();
|
||||
io->SetMessageHandler(msgHandle);
|
||||
bool ret = io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225");
|
||||
ASSERT_FALSE(ret);
|
||||
ret = io->StartServerSocket("tcp://" + m_localIP + ":2224", "tcp://" + m_localIP + ":2224");
|
||||
io->Finalize();
|
||||
TEST_F(TCPTest, StartServerSucc) {
|
||||
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
|
||||
bool ret = server->Initialize("127.0.0.1:8081");
|
||||
ASSERT_TRUE(ret);
|
||||
server->Finalize();
|
||||
}
|
||||
|
||||
/// Feature: test normal tcp message sending.
|
||||
/// Description: start a socket server and send a normal message to it.
|
||||
/// Expectation: the server received the message sented from client.
|
||||
TEST_F(TCPTest, send1Msg) {
|
||||
g_recv_num = 0;
|
||||
pid1 = startTcpServer(args);
|
||||
bool ret = CheckRecvNum(1, 5);
|
||||
ASSERT_FALSE(ret);
|
||||
TEST_F(TCPTest, SendOneMessage) {
|
||||
Init();
|
||||
|
||||
std::string from = "tcp://" + m_localIP + ":2223";
|
||||
std::string to = "tcp://" + m_localIP + ":2225";
|
||||
SendMsg(from, to, pid_num);
|
||||
|
||||
ret = CheckRecvNum(1, 5);
|
||||
// Start the tcp server.
|
||||
auto server_url = "127.0.0.1:8081";
|
||||
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
|
||||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
Unlink(to);
|
||||
shutdownTcpServer(pid1);
|
||||
pid1 = 0;
|
||||
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>();
|
||||
ret = client->Initialize();
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
// Create the message.
|
||||
auto message = CreateMessage(server_url, client_url);
|
||||
|
||||
// Send the message.
|
||||
client->Connect(server_url);
|
||||
client->Send(std::move(message));
|
||||
|
||||
// Wait timeout: 5s
|
||||
WaitForDataMsg(1, 5);
|
||||
|
||||
// Check result
|
||||
EXPECT_EQ(1, GetDataMsgNum());
|
||||
|
||||
// Destroy
|
||||
client->Disconnect(server_url);
|
||||
client->Finalize();
|
||||
server->Finalize();
|
||||
}
|
||||
|
||||
/// Feature: test sending two message continuously.
|
||||
/// Description: start a socket server and send two normal message to it.
|
||||
/// Expectation: the server received the two messages sented from client.
|
||||
TEST_F(TCPTest, sendTwoMessages) {
|
||||
Init();
|
||||
|
||||
// Start the tcp server.
|
||||
auto server_url = "127.0.0.1:8081";
|
||||
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
|
||||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>();
|
||||
ret = client->Initialize();
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
// Create messages.
|
||||
auto message1 = CreateMessage(server_url, client_url);
|
||||
auto message2 = CreateMessage(server_url, client_url);
|
||||
|
||||
// Send messages.
|
||||
client->Connect(server_url);
|
||||
client->Send(std::move(message1));
|
||||
client->Send(std::move(message2));
|
||||
|
||||
// Wait timeout: 5s
|
||||
WaitForDataMsg(2, 5);
|
||||
|
||||
// Check result
|
||||
EXPECT_EQ(2, GetDataMsgNum());
|
||||
client->Disconnect(server_url);
|
||||
client->Finalize();
|
||||
server->Finalize();
|
||||
}
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
|
|
Loading…
Reference in New Issue