!27566 Add TCP connection pool for distributed runtime

Merge pull request !27566 from chengang/add_rpc_connection
This commit is contained in:
i-robot 2021-12-16 01:28:21 +00:00 committed by Gitee
commit 3a17c70deb
4 changed files with 1149 additions and 0 deletions

View File

@ -0,0 +1,504 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "distributed/rpc/tcp/connection.h"
#include <memory>
#include <utility>
#include "distributed/rpc/tcp/tcp_socket_operation.h"
#include "distributed/rpc/tcp/connection_pool.h"
namespace mindspore {
namespace distributed {
namespace rpc {
std::mutex Connection::conn_mutex;
// Handle socket events like read/write.
void SocketEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (fd != conn->socket_fd) {
MS_LOG(ERROR) << "Failed to reuse connection, delete and close fd: " << fd << ", connfd: " << conn->socket_fd
<< ", event: " << events;
conn->recv_event_loop->DeleteEpollEvent(fd);
conn->state = ConnectionState::kDisconnecting;
if (conn->event_callback != nullptr) {
conn->event_callback(conn);
} else {
MS_LOG(ERROR) << "No event_callback found for fd: " << fd << ", events: " << events;
}
return;
}
// Handle write event.
if (events & EPOLLOUT) {
(void)conn->recv_event_loop->UpdateEpollEvent(fd, EPOLLIN | EPOLLHUP | EPOLLERR);
if (conn->write_callback != nullptr) {
conn->write_callback(conn);
}
}
// Handle read event.
if (events & EPOLLIN) {
if (conn->read_callback != nullptr) {
conn->read_callback(conn);
}
}
// Handle disconnect event.
if (conn->state == ConnectionState::kDisconnecting ||
(conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp &&
(events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) {
if (conn->recv_message_type == ParseType::kTcpMsg) {
MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
<< ", errcode: " << conn->error_code << ", errno: " << errno << ", to: " << conn->destination.c_str()
<< ", type:" << conn->recv_message_type << ", remote: " << conn->is_remote;
}
conn->state = ConnectionState::kDisconnecting;
if (conn->event_callback != nullptr) {
conn->event_callback(conn);
} else {
MS_LOG(ERROR) << "No event_callback found for fd: " << fd << ", events: " << events;
}
}
}
// Handle new connect event.
void NewConnectEventHandler(int fd, uint32_t events, void *context) {
int retval = 0;
Connection *conn = reinterpret_cast<Connection *>(context);
conn->socket_operation->NewConnEventHandler(fd, events, context);
if (conn->state == ConnectionState::kDisconnecting) {
conn->Disconnect(fd);
return;
} else if (conn->state != ConnectionState::kConnected) {
// The handshake is not complete
return;
}
retval = conn->recv_event_loop->DeleteEpollEvent(fd);
if (retval) {
MS_LOG(ERROR) << "Failed to remove epoll remove connect handler for fd: " << fd;
return;
}
retval = conn->recv_event_loop->SetEventHandler(conn->socket_fd, EPOLLIN | EPOLLHUP | EPOLLRDHUP | EPOLLERR,
SocketEventHandler, reinterpret_cast<void *>(conn));
if (retval != RPC_OK) {
MS_LOG(ERROR) << "Failed to add socket event handler for fd: " << fd << ", events: " << events;
conn->Disconnect(fd);
return;
}
conn->write_callback(conn);
SocketEventHandler(fd, events, context);
return;
}
Connection::Connection()
: socket_fd(-1),
deleted(false),
is_remote(false),
type(kTcp),
socket_operation(nullptr),
state(kInit),
send_event_loop(nullptr),
recv_event_loop(nullptr),
send_metrics(new SendMetrics()),
send_message(nullptr),
recv_message(nullptr),
recv_state(kMsgHeader),
total_recv_len(0),
total_send_len(0),
recv_len(0),
event_callback(nullptr),
succ_callback(nullptr),
write_callback(nullptr),
read_callback(nullptr),
output_buffer_size(0),
error_code(0) {
// Initialize the recv kernel message structure.
recv_kernel_msg.msg_control = nullptr;
recv_kernel_msg.msg_controllen = 0;
recv_kernel_msg.msg_flags = 0;
recv_kernel_msg.msg_name = nullptr;
recv_kernel_msg.msg_namelen = 0;
recv_kernel_msg.msg_iov = recv_io_vec;
recv_kernel_msg.msg_iovlen = RECV_MSG_IO_VEC_LEN;
// Initialize the send message header.
for (unsigned int i = 0; i < BUSMAGIC_LEN; i++) {
if (i < sizeof(RPC_MAGICID) - 1) {
send_msg_header.magic[i] = RPC_MAGICID[i];
} else {
send_msg_header.magic[i] = '\0';
}
}
// Initialize the send kernel message structure.
send_kernel_msg.msg_control = nullptr;
send_kernel_msg.msg_controllen = 0;
send_kernel_msg.msg_flags = 0;
send_kernel_msg.msg_name = nullptr;
send_kernel_msg.msg_namelen = 0;
send_kernel_msg.msg_iov = send_io_vec;
send_kernel_msg.msg_iovlen = SEND_MSG_IO_VEC_LEN;
}
int Connection::Initialize() {
InitSocketOperation();
return AddConnnectEventHandler();
}
void Connection::InitSocketOperation() {
if (socket_operation != nullptr) {
return;
}
socket_operation = new (std::nothrow) TCPSocketOperation();
RPC_OOM_EXIT(socket_operation);
}
bool Connection::ReconnectSourceSocket(int fd, uint32_t events, int *soError, uint32_t error) {
int retval = 0;
socklen_t len = sizeof(*soError);
retval = recv_event_loop->DeleteEpollEvent(fd);
if (retval) {
MS_LOG(ERROR) << "Failed to delete event for fd: " << fd << ", event: " << events;
return false;
}
retval = getsockopt(fd, SOL_SOCKET, SO_ERROR, soError, &len);
if (retval) {
*soError = errno;
}
if (*soError || error) {
return false;
}
retval = recv_event_loop->SetEventHandler(socket_fd, EPOLLIN | EPOLLHUP | EPOLLRDHUP | EPOLLERR, SocketEventHandler,
reinterpret_cast<void *>(this));
if (retval != RPC_OK) {
MS_LOG(ERROR) << "Failed to add socket event handler for fd: " << fd << ", events: " << events;
return false;
}
return true;
}
void Connection::Disconnect(int fd) {
if (LOG_CHECK_EVERY_N()) {
MS_LOG(INFO) << "New connection fail fd: " << fd << ", state: " << state << ", errno: " << errno
<< ", to: " << destination.c_str() << ", type: " << recv_message_type;
}
state = ConnectionState::kDisconnecting;
event_callback(this);
return;
}
void Connection::Close() {
if (recv_event_loop != nullptr) {
recv_event_loop->DeleteEpollEvent(socket_fd);
}
if (!destination.empty()) {
if (recv_message != nullptr) {
delete recv_message;
}
}
if (total_send_len != 0 && send_message != nullptr) {
delete send_message;
}
MessageBase *tmpMsg = nullptr;
while (!send_message_queue.empty()) {
tmpMsg = send_message_queue.front();
send_message_queue.pop();
delete tmpMsg;
}
if (socket_operation != nullptr) {
socket_operation->Close(this);
delete socket_operation;
}
if (send_metrics != nullptr) {
delete send_metrics;
}
}
int Connection::ReceiveMessage(IOMgr::MessageHandler msgHandler) {
bool ok = ParseMessage();
// If no message parsed, wait for next read
if (!ok) {
if (state == ConnectionState::kDisconnecting) {
return -1;
}
return 0;
}
if (destination.empty()) {
std::string fromUrl = recv_message->from;
size_t index = fromUrl.find("@");
if (index != std::string::npos) {
destination = fromUrl.substr(index + 1);
MS_LOG(INFO) << "Create new connection fd: " << socket_fd << " to: " << destination.c_str();
Connection::conn_mutex.lock();
ConnectionPool::GetConnectionPool()->SetConnPriority(destination, false, ConnectionPriority::kPriorityLow);
state = ConnectionState::kConnected;
ConnectionPool::GetConnectionPool()->AddConnection(this);
Connection::conn_mutex.unlock();
}
}
std::unique_ptr<MessageBase> msg(recv_message);
recv_message = nullptr;
// Call msg handler if set
if (msgHandler != nullptr) {
msgHandler(std::move(msg));
} else {
MS_LOG(INFO) << "Message handler was not found";
}
return 1;
}
void Connection::CheckMessageType() {
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
if (recv_message_type != ParseType::kUnknown) {
return;
}
std::string magic_id = "";
magic_id.resize(sizeof(RPC_MAGICID) - 1);
char *buf = const_cast<char *>(magic_id.data());
int size = socket_operation->ReceivePeek(this, buf, sizeof(RPC_MAGICID) - 1);
if (size < static_cast<int>(sizeof(RPC_MAGICID) - 1)) {
if (size == 0) {
MS_LOG(INFO) << "Set connection disconnecting for fd: " << socket_fd << ", size: " << size
<< ", magic size: " << static_cast<int>(sizeof(RPC_MAGICID) - 1) << ", errno: " << errno;
state = ConnectionState::kDisconnecting;
}
return;
}
if (strncmp(RPC_MAGICID, magic_id.c_str(), sizeof(RPC_MAGICID) - 1) == 0) {
recv_state = State::kMsgHeader;
recv_message_type = ParseType::kTcpMsg;
}
return;
}
std::string Connection::GenerateHttpMessage(MessageBase *msg) {
static const std::string postLineBegin = std::string() + "POST /";
static const std::string postLineEnd = std::string() + " HTTP/1.1\r\n";
static const std::string userAgentLineBegin = std::string() + "User-Agent: libprocess/";
static const std::string fromLineBegin = std::string() + "Libprocess-From: ";
static const std::string connectLine = std::string() + "Connection: Keep-Alive\r\n";
static const std::string hostLine = std::string() + "Host: \r\n";
static const std::string chunkedBeginLine = std::string() + "Transfer-Encoding: chunked\r\n\r\n";
static const std::string chunkedEndLine = std::string() + "\r\n" + "0\r\n" + "\r\n";
static const std::string commonEndLine = std::string() + "\r\n";
std::string postLine;
if (msg->To().Name() != "") {
postLine = postLineBegin + msg->To().Name() + "/" + msg->Name() + postLineEnd;
} else {
postLine = postLineBegin + msg->Name() + postLineEnd;
}
std::string userAgentLine = userAgentLineBegin + msg->From().Name() + "@" + advertise_addr_ + commonEndLine;
std::string fromLine = fromLineBegin + msg->From().Name() + "@" + advertise_addr_ + commonEndLine;
if (msg->Body().size() > 0) {
std::ostringstream bodyLine;
bodyLine << std::hex << msg->Body().size() << "\r\n";
bodyLine.write(msg->Body().data(), msg->Body().size());
return postLine + userAgentLine + fromLine + connectLine + hostLine + chunkedBeginLine + bodyLine.str() +
chunkedEndLine;
}
return postLine + userAgentLine + fromLine + connectLine + hostLine + commonEndLine;
}
void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg, int index) {
index = 0;
if (msg->type == MessageBase::Type::KMSG) {
if (!isHttpKmsg) {
send_to = msg->to;
send_from = msg->from.Name() + "@" + advertiseUrl;
send_msg_header.name_len = htonl(static_cast<uint32_t>(msg->name.size()));
send_msg_header.to_len = htonl(static_cast<uint32_t>(send_to.size()));
send_msg_header.from_len = htonl(static_cast<uint32_t>(send_from.size()));
send_msg_header.body_len = htonl(static_cast<uint32_t>(msg->body.size()));
send_io_vec[index].iov_base = &send_msg_header;
send_io_vec[index].iov_len = sizeof(send_msg_header);
++index;
send_io_vec[index].iov_base = const_cast<char *>(msg->name.data());
send_io_vec[index].iov_len = msg->name.size();
++index;
send_io_vec[index].iov_base = const_cast<char *>(send_to.data());
send_io_vec[index].iov_len = send_to.size();
++index;
send_io_vec[index].iov_base = const_cast<char *>(send_from.data());
send_io_vec[index].iov_len = send_from.size();
++index;
send_io_vec[index].iov_base = const_cast<char *>(msg->body.data());
send_io_vec[index].iov_len = msg->body.size();
++index;
send_kernel_msg.msg_iov = send_io_vec;
send_kernel_msg.msg_iovlen = index;
total_send_len =
sizeof(send_msg_header) + msg->name.size() + send_to.size() + send_from.size() + msg->body.size();
send_message = msg;
// update metrics
send_metrics->UpdateMax(msg->body.size());
send_metrics->last_send_msg_name = msg->name;
return;
} else {
if (advertise_addr_.empty()) {
size_t index = advertiseUrl.find(URL_PROTOCOL_IP_SEPARATOR);
if (index == std::string::npos) {
advertise_addr_ = advertiseUrl;
} else {
advertise_addr_ = advertiseUrl.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
}
}
msg->body = GenerateHttpMessage(msg);
}
send_io_vec[index].iov_base = const_cast<char *>(msg->body.data());
send_io_vec[index].iov_len = msg->body.size();
++index;
send_kernel_msg.msg_iov = send_io_vec;
send_kernel_msg.msg_iovlen = index;
total_send_len = msg->body.size();
send_message = msg;
// update metrics
send_metrics->UpdateMax(msg->body.size());
send_metrics->last_send_msg_name = msg->name;
}
}
void Connection::FillRecvMessage() {
size_t recvNameLen = static_cast<size_t>(recv_msg_header.name_len);
size_t recvToLen = static_cast<size_t>(recv_msg_header.to_len);
size_t recvFromLen = static_cast<size_t>(recv_msg_header.from_len);
size_t recvBodyLen = static_cast<size_t>(recv_msg_header.body_len);
if (recvNameLen > MAX_KMSG_NAME_LEN || recvToLen > MAX_KMSG_TO_LEN || recvFromLen > MAX_KMSG_FROM_LEN ||
recvBodyLen > MAX_KMSG_BODY_LEN) {
MS_LOG(ERROR) << "Drop invalid tcp data.";
state = ConnectionState::kDisconnecting;
return;
}
int i = 0;
MessageBase *msg = new (std::nothrow) MessageBase();
RPC_OOM_EXIT(msg);
msg->name.resize(recvNameLen);
recv_to.resize(recvToLen);
recv_from.resize(recvFromLen);
msg->body.resize(recvBodyLen);
recv_io_vec[i].iov_base = const_cast<char *>(msg->name.data());
recv_io_vec[i].iov_len = msg->name.size();
++i;
recv_io_vec[i].iov_base = const_cast<char *>(recv_to.data());
recv_io_vec[i].iov_len = recv_to.size();
++i;
recv_io_vec[i].iov_base = const_cast<char *>(recv_from.data());
recv_io_vec[i].iov_len = recv_from.size();
++i;
recv_io_vec[i].iov_base = const_cast<char *>(msg->body.data());
recv_io_vec[i].iov_len = msg->body.size();
++i;
recv_kernel_msg.msg_iov = recv_io_vec;
recv_kernel_msg.msg_iovlen = i;
total_recv_len = msg->name.size() + recv_to.size() + recv_from.size() + msg->body.size();
recv_message = msg;
}
int Connection::AddConnnectEventHandler() {
return recv_event_loop->SetEventHandler(socket_fd, EPOLLIN | EPOLLHUP | EPOLLERR, NewConnectEventHandler,
reinterpret_cast<void *>(this));
}
bool Connection::ParseMessage() {
std::string magic_id = "";
int retval = 0;
uint32_t recvLen = 0;
char *recvBuf = nullptr;
switch (recv_state) {
// Parse message header.
case State::kMsgHeader:
recvBuf = reinterpret_cast<char *>(&recv_msg_header) + recv_len;
retval = socket_operation->Receive(this, recvBuf, sizeof(MessageHeader) - recv_len, &recvLen);
if (retval < 0) {
state = ConnectionState::kDisconnecting;
recv_len += recvLen;
return false;
}
if ((recvLen + recv_len) != sizeof(MessageHeader)) {
recv_len += recvLen;
return false;
}
recv_len = 0;
if (strncmp(recv_msg_header.magic, RPC_MAGICID, sizeof(RPC_MAGICID) - 1) != 0) {
MS_LOG(ERROR) << "Failed to check magicid, RPC_MAGICID: " << RPC_MAGICID
<< ", recv magic_id: " << magic_id.c_str();
state = ConnectionState::kDisconnecting;
return false;
}
ReorderHeader(&recv_msg_header);
FillRecvMessage();
if (state == ConnectionState::kDisconnecting) {
return false;
}
recv_state = State::kBody;
// Parse message body.
case State::kBody:
retval = socket_operation->ReceiveMessage(this, &recv_kernel_msg, total_recv_len);
if (retval != static_cast<int>(total_recv_len)) {
if (retval < 0) {
state = ConnectionState::kDisconnecting;
return false;
}
total_recv_len -= retval;
return false;
}
recv_state = State::kMsgHeader;
break;
default:
return false;
}
return true;
}
void Connection::ReorderHeader(MessageHeader *header) {
header->name_len = ntohl(header->name_len);
header->to_len = ntohl(header->to_len);
header->from_len = ntohl(header->from_len);
header->body_len = ntohl(header->body_len);
}
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -0,0 +1,233 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_H_
#include <queue>
#include <string>
#include <mutex>
#include "actor/msg.h"
#include "actor/iomgr.h"
#include "distributed/rpc/tcp/constants.h"
#include "distributed/rpc/tcp/event_loop.h"
#include "distributed/rpc/tcp/socket_operation.h"
namespace mindspore {
namespace distributed {
namespace rpc {
/*
* The MessageHeader contains the stats info about the message body.
*/
struct MessageHeader {
MessageHeader() {
for (unsigned int i = 0; i < BUSMAGIC_LEN; ++i) {
if (i < sizeof(RPC_MAGICID) - 1) {
magic[i] = RPC_MAGICID[i];
} else {
magic[i] = '\0';
}
}
}
char magic[BUSMAGIC_LEN];
uint32_t name_len{0};
uint32_t to_len{0};
uint32_t from_len{0};
uint32_t body_len{0};
};
/*
* The SendMetrics is responsible for collecting metrics when sending data through a connection.
*/
struct SendMetrics {
// Records the message number and max body size.
void UpdateMax(int size) {
accum_msg_count++;
if (size > max_msg_size) {
max_msg_size = size;
}
}
// Records the latest error message.
void UpdateError(bool fail, int err = 0) {
if (fail) {
last_fail_msg_name = last_send_msg_name;
error_code = err;
} else {
last_succ_msg_name = last_send_msg_name;
}
}
// Reset all the metrics info.
void Reset() {
accum_msg_count = 0;
max_msg_size = 0;
error_code = 0;
last_succ_msg_name = "";
last_fail_msg_name = "";
last_send_msg_name = "";
}
// The total number of bytes sent already.
int accum_msg_count{0};
// The max message body size sent in bytes.
int max_msg_size{0};
int error_code{0};
std::string last_succ_msg_name;
std::string last_fail_msg_name;
std::string last_send_msg_name;
};
/*
* Represents a TCP or SSL connection.
*/
struct Connection {
public:
Connection();
~Connection() = default;
// Initialize the connection(eg. add some socket event handlers).
int Initialize();
// Create a new socket operation if needed.
void InitSocketOperation();
// Delete this socket fd(source client socket) and add back to the connection.
bool ReconnectSourceSocket(int fd, uint32_t events, int *soError, uint32_t error);
// Disconnect the socket fd from source.
void Disconnect(int fd);
// Close this connection.
void Close();
int ReceiveMessage(IOMgr::MessageHandler msgHandler);
void CheckMessageType();
// Fill the message to be sent based on the input message.
void FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg, int index = 0);
void FillRecvMessage();
bool IsSame(const Connection *that) {
return !(that != nullptr && that->destination == destination && that->is_remote == is_remote);
}
// The socket used by this connection.
int socket_fd;
// Indicates whether this connection is deleted from link manager.
bool deleted;
// Indicates the priority of this connection.
ConnectionPriority priority{ConnectionPriority::kPriorityHigh};
// Indicates whether this connection is connected from remote client.
// A connection is remote only when the connection is created by the `OnAccept` callback.
bool is_remote;
// TCP or SSL.
ConnectionType type;
// The socket address(ip:port) of client and server of this connection.
std::string source;
std::string destination;
// Peer address.
std::string peer;
// Specific operations for the socket in this connection.
SocketOperation *socket_operation;
// The state of this connection(eg. kInit/kConnecting/..)
ConnectionState state{kInit};
// The threads for handling the receive and send requsets on this connection.
EventLoop *send_event_loop;
EventLoop *recv_event_loop;
// Collects data sending metrics.
SendMetrics *send_metrics;
// The message data waiting to be sent and receive through this connection..
MessageBase *send_message;
MessageBase *recv_message;
State recv_state;
// Total length of received and sent messages.
uint32_t total_recv_len;
uint32_t total_send_len;
uint32_t recv_len;
std::string send_to;
std::string send_from;
std::string recv_to;
std::string recv_from;
// Message header.
MessageHeader send_msg_header;
MessageHeader recv_msg_header;
// The message structure of kernel.
struct msghdr send_kernel_msg;
struct msghdr recv_kernel_msg;
struct iovec recv_io_vec[RECV_MSG_IO_VEC_LEN];
struct iovec send_io_vec[SEND_MSG_IO_VEC_LEN];
ParseType recv_message_type{kUnknown};
// Callbacks for io events
ConnectionCallBack event_callback;
ConnectionCallBack succ_callback;
ConnectionCallBack write_callback;
ConnectionCallBack read_callback;
// Buffer for messages to be sent.
std::queue<MessageBase *> send_message_queue;
uint64_t output_buffer_size;
// The error code when sending or receiving messages.
int error_code;
static std::mutex conn_mutex;
private:
// Add handler for socket connect event.
int AddConnnectEventHandler();
// Parse message from socket recv buffer.
bool ParseMessage();
// Make a http message based on given input message.
std::string GenerateHttpMessage(MessageBase *msg);
// Change the header body from network byte order to host byte order.
void ReorderHeader(MessageHeader *header);
std::string advertise_addr_;
};
} // namespace rpc
} // namespace distributed
} // namespace mindspore
#endif

View File

@ -0,0 +1,306 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <mutex>
#include "distributed/rpc/tcp/connection_pool.h"
namespace mindspore {
namespace distributed {
namespace rpc {
ConnectionPool *ConnectionPool::conn_pool = new ConnectionPool();
void ConnectionPool::SetLinkPattern(bool linkPattern) { double_link_ = linkPattern; }
void ConnectionPool::CloseConnection(Connection *conn) {
if (conn == nullptr) {
return;
}
// Trigger Exit message note that this should be called before erasing link. Because we may chang deleted flag
// by to in this fun. And if deleted has been set to true, it means Exit message has been send before, do nothing.
if (!conn->deleted) {
DeleteConnInfo(conn->destination, conn->socket_fd);
}
if (!conn->destination.empty()) {
if (conn->is_remote) {
remote_conns_.erase(conn->destination);
} else {
local_conns_.erase(conn->destination);
}
}
conn->Close();
delete conn;
conn = nullptr;
}
Connection *ConnectionPool::FindConnection(const std::string &to, bool remoteLink) {
Connection *conn = nullptr;
if (!remoteLink) {
auto iter = local_conns_.find(to);
if (iter != local_conns_.end()) {
conn = iter->second;
return conn;
}
}
auto iter = remote_conns_.find(to);
if (iter != remote_conns_.end()) {
conn = iter->second;
}
return conn;
}
Connection *ConnectionPool::ExactFindConnection(const std::string &to, bool remoteLink) {
Connection *conn = nullptr;
if (!remoteLink) {
auto iter = local_conns_.find(to);
if (iter != local_conns_.end()) {
conn = iter->second;
}
} else {
auto iter = remote_conns_.find(to);
if (iter != remote_conns_.end()) {
conn = iter->second;
}
}
return conn;
}
Connection *ConnectionPool::FindConnection(const std::string &to, bool remoteLink, bool exactNotRemote) {
if (exactNotRemote) {
return ExactFindConnection(to, false);
} else {
return FindConnection(to, remoteLink);
}
}
void ConnectionPool::ResetAllConnMetrics() {
for (const auto &iter : local_conns_) {
iter.second->send_metrics->Reset();
}
for (const auto &iter : remote_conns_) {
iter.second->send_metrics->Reset();
}
}
Connection *ConnectionPool::FindMaxConnection() {
Connection *conn = nullptr;
int count = 0;
for (const auto &iter : local_conns_) {
if (iter.second->send_metrics->accum_msg_count > count) {
count = iter.second->send_metrics->accum_msg_count;
conn = iter.second;
}
}
for (const auto &iter : remote_conns_) {
if (iter.second->send_metrics->accum_msg_count > count) {
count = iter.second->send_metrics->accum_msg_count;
conn = iter.second;
}
}
return conn;
}
Connection *ConnectionPool::FindFastConnection() {
Connection *conn = nullptr;
int size = 0;
for (const auto &iter : local_conns_) {
if (iter.second->send_metrics->max_msg_size > size) {
size = iter.second->send_metrics->max_msg_size;
conn = iter.second;
}
}
for (const auto &iter : remote_conns_) {
if (iter.second->send_metrics->max_msg_size > size) {
size = iter.second->send_metrics->max_msg_size;
conn = iter.second;
}
}
return conn;
}
void ConnectionPool::ExactDeleteConnection(const std::string &to, bool remoteLink) {
Connection *conn = ExactFindConnection(to, remoteLink);
if (conn != nullptr) {
MS_LOG(INFO) << "unLink fd:" << conn->socket_fd << ",to:" << to.c_str() << ",remote:" << remoteLink;
CloseConnection(conn);
}
}
void ConnectionPool::DeleteAllConnections(std::map<std::string, Connection *> *links) {
auto iter = links->begin();
while (iter != links->end()) {
Connection *conn = iter->second;
// erase link
if (conn->recv_message != nullptr) {
delete conn->recv_message;
}
iter = links->erase(iter);
delete conn;
conn = nullptr;
}
}
void ConnectionPool::AddConnection(Connection *conn) {
if (conn == nullptr) {
return;
}
Connection *tmpConn = ExactFindConnection(conn->destination, conn->is_remote);
if (tmpConn != nullptr && tmpConn->is_remote == conn->is_remote) {
MS_LOG(INFO) << "unLink fd:" << tmpConn->socket_fd << ",to:" << tmpConn->destination.c_str();
CloseConnection(tmpConn);
}
if (conn->is_remote) {
remote_conns_.emplace(conn->destination, conn);
} else {
local_conns_.emplace(conn->destination, conn);
}
}
void ConnectionPool::SetConnPriority(const std::string &to, bool remoteLink, ConnectionPriority pri) {
Connection *conn = ExactFindConnection(to, remoteLink);
if (conn != nullptr && conn->is_remote == remoteLink) {
conn->priority = pri;
}
}
void ConnectionPool::DeleteConnInfo(int fd) {
auto iter = conn_infos_.find(fd);
if (iter == conn_infos_.end()) {
return;
}
auto conn_infos = iter->second;
auto iter2 = conn_infos.begin();
while (iter2 != conn_infos.end()) {
auto linkInfo = *iter2;
if (linkInfo->delete_callback) {
linkInfo->delete_callback(linkInfo->to, linkInfo->from);
}
iter2 = conn_infos.erase(iter2);
delete linkInfo;
}
conn_infos_.erase(fd);
}
void ConnectionPool::DeleteConnInfo(const std::string &to, int fd) {
// If run in double link pattern, link fd and send fd must be the same, send Exit message bind on this fd
if (double_link_) {
DeleteConnInfo(fd);
return;
}
// If run in single link pattern, link fd and send fd may not be the same, we should send Exit message bind
// on link fd and remote link fd. Here 'deleted' flag should be set true to avoid duplicate Exit message with
// same aid.
Connection *nonRemoteConn = ConnectionPool::ExactFindConnection(to, false);
if (nonRemoteConn != nullptr) {
nonRemoteConn->deleted = true;
DeleteConnInfo(nonRemoteConn->socket_fd);
if (nonRemoteConn->socket_fd != fd) {
MS_LOG(INFO) << "delete linker bind on link fd:" << nonRemoteConn->socket_fd << ",delete fd:" << fd;
}
}
Connection *remoteConn = ConnectionPool::ExactFindConnection(to, true);
if (remoteConn != nullptr) {
remoteConn->deleted = true;
DeleteConnInfo(remoteConn->socket_fd);
if (remoteConn->socket_fd != fd) {
MS_LOG(INFO) << "delete linker bind on remote link fd:" << remoteConn->socket_fd << ",delete fd:" << fd;
}
}
}
void ConnectionPool::DeleteAllConnInfos() {
auto iter = conn_infos_.begin();
while (iter != conn_infos_.end()) {
auto conn_infos = iter->second;
auto iter2 = conn_infos.begin();
while (iter2 != conn_infos.end()) {
auto linkInfo = *iter2;
iter2 = conn_infos.erase(iter2);
delete linkInfo;
}
iter = conn_infos_.erase(iter);
}
}
ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const AID &sAid, const AID &dAid) {
auto iter = conn_infos_.find(fd);
if (iter == conn_infos_.end()) {
return nullptr;
}
auto conn_infos = iter->second;
auto iter2 = conn_infos.begin();
while (iter2 != conn_infos.end()) {
auto linkInfo = *iter2;
if (AID(linkInfo->from) == sAid && AID(linkInfo->to) == dAid) {
return linkInfo;
}
++iter2;
}
return nullptr;
}
void ConnectionPool::AddConnInfo(int fd, const AID &sAid, const AID &dAid, DeleteCallBack callback) {
ConnectionInfo *linker = FindConnInfo(fd, sAid, dAid);
if (linker != nullptr) {
return;
}
linker = new (std::nothrow) ConnectionInfo();
if (linker == nullptr) {
MS_LOG(ERROR) << "new ConnectionInfo fail sAid:" << std::string(sAid).c_str()
<< ",dAid:" << std::string(dAid).c_str();
return;
}
linker->from = sAid;
linker->to = dAid;
linker->socket_fd = fd;
linker->delete_callback = callback;
conn_infos_[fd].insert(linker);
}
bool ConnectionPool::ReverseConnInfo(int fromFd, int toFd) {
auto iter = conn_infos_.find(fromFd);
if (iter == conn_infos_.end()) {
return false;
}
auto conn_infos = iter->second;
conn_infos_.erase(fromFd);
conn_infos_[toFd] = conn_infos;
return true;
}
ConnectionPool::~ConnectionPool() {
try {
DeleteAllConnections(&local_conns_);
DeleteAllConnections(&remote_conns_);
DeleteAllConnInfos();
} catch (...) {
MS_LOG(ERROR) << "Failed to release resource for connection pool.";
}
}
ConnectionPool *ConnectionPool::GetConnectionPool() { return ConnectionPool::conn_pool; }
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -0,0 +1,106 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_POOL_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_POOL_H_
#include <map>
#include <set>
#include <string>
#include "distributed/rpc/tcp/constants.h"
#include "distributed/rpc/tcp/connection.h"
namespace mindspore {
namespace distributed {
namespace rpc {
struct ConnectionInfo {
int socket_fd;
std::string from;
std::string to;
DeleteCallBack delete_callback;
};
/*
* Maintains a collection of reusable connections.
*/
class ConnectionPool {
public:
ConnectionPool() : double_link_(false) {}
~ConnectionPool();
// Get the singleton instance of ConnectionPool.
static ConnectionPool *GetConnectionPool();
/*
* Operations for ConnectionInfo.
*/
void AddConnInfo(int socket_fd, const AID &sAid, const AID &dAid, DeleteCallBack delcb);
bool ReverseConnInfo(int from_socket_fd, int to_socket_fd);
/*
* Operations for Connection.
*/
// Add a connection.
void AddConnection(Connection *conn);
// Find connection.
Connection *ExactFindConnection(const std::string &to, bool remoteLink);
Connection *FindConnection(const std::string &to, bool remoteLink);
Connection *FindConnection(const std::string &to, bool remoteLink, bool exactNotRemote);
Connection *FindMaxConnection();
Connection *FindFastConnection();
// Delete connection.
void ExactDeleteConnection(const std::string &to, bool remoteLink);
void DeleteAllConnections(std::map<std::string, Connection *> *alllinks);
// Close connection.
void CloseConnection(Connection *conn);
void SetConnPriority(const std::string &to, bool remoteLink, ConnectionPriority pri);
// Single link or double link.
void SetLinkPattern(bool linkPattern);
void ResetAllConnMetrics();
private:
ConnectionInfo *FindConnInfo(int socket_fd, const AID &sAid, const AID &dAid);
void DeleteConnInfo(int socket_fd);
void DeleteConnInfo(const std::string &to, int socket_fd);
void DeleteAllConnInfos();
bool double_link_;
// to_url=tcp@ip:port, event struct
std::map<std::string, Connection *> local_conns_;
// Maintains the remote connections by remote server addresses.
std::map<std::string, Connection *> remote_conns_;
// each to_url has two fds at most, and each fd has multiple linkinfos
std::map<int, std::set<ConnectionInfo *>> conn_infos_;
static ConnectionPool *conn_pool;
friend class Connection;
friend class TCPComm;
};
} // namespace rpc
} // namespace distributed
} // namespace mindspore
#endif