Add tcp communicator for RPC

This commit is contained in:
cristoval 2021-12-12 18:09:04 +08:00
parent 4ac1ba29a1
commit bb5798c391
7 changed files with 1202 additions and 17 deletions

View File

@ -6,14 +6,15 @@ else()
list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES "cluster/dummy_cluster_context.cc")
endif()
set(EXCLUDE_DIR "rpc/")
foreach(RPC_FILE ${_DISTRIBUTED_SRC_FILES})
string(FIND ${RPC_FILE} ${EXCLUDE_DIR} FOUND)
if(${FOUND} EQUAL 0)
list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES ${RPC_FILE})
endif()
endforeach()
if(NOT ENABLE_CPU OR WIN32 OR APPLE)
set(EXCLUDE_DIR "rpc/")
foreach(RPC_FILE ${_DISTRIBUTED_SRC_FILES})
string(FIND ${RPC_FILE} ${EXCLUDE_DIR} FOUND)
if(${FOUND} EQUAL 0)
list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES ${RPC_FILE})
endif()
endforeach()
endif()
set_property(SOURCE ${_DISTRIBUTED_SRC_FILES} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_DISTRIBUTED)

View File

@ -0,0 +1,771 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "distributed/rpc/tcp/tcp_comm.h"
#include <mutex>
#include <utility>
#include <memory>
#include "actor/aid.h"
#include "actor/msg.h"
#include "distributed/rpc/tcp/constants.h"
#include "distributed/rpc/tcp/tcp_socket_operation.h"
#include "distributed/rpc/tcp/connection_pool.h"
namespace mindspore {
namespace distributed {
namespace rpc {
bool TCPComm::is_http_msg_ = false;
std::vector<char> TCPComm::advertise_url_;
uint64_t TCPComm::output_buf_size_ = 0;
IOMgr::MessageHandler TCPComm::message_handler_;
int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
ConnectionCallBack write_callback, ConnectionCallBack read_callback) {
SocketAddress addr;
if (!SocketOperation::GetSockAddr(to, &addr)) {
return -1;
}
int sock_fd = SocketOperation::CreateServerSocket(addr.sa.sa_family);
if (sock_fd < 0) {
return -1;
}
conn->socket_fd = sock_fd;
conn->event_callback = event_callback;
conn->write_callback = write_callback;
conn->read_callback = read_callback;
int ret = TCPComm::Connect(conn, (struct sockaddr *)&addr, sizeof(addr));
if (ret < 0) {
close(sock_fd);
conn->socket_fd = -1;
return -1;
}
return 0;
}
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError) {
if (LOG_CHECK_EVERY_N()) {
MS_LOG(INFO) << "Failed to call connect, fd: " << fd << ", to: " << conn->destination.c_str()
<< ", events: " << error << ", errno: " << soError;
}
conn->state = ConnectionState::kDisconnecting;
conn->error_code = soError;
conn->event_callback(conn);
return;
}
void ConnectedEventHandler(int fd, uint32_t events, void *context) {
uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP);
int soError = 0;
Connection *conn = reinterpret_cast<Connection *>(context);
conn->socket_operation->ConnEstablishedEventHandler(fd, events, context);
if (conn->state == ConnectionState::kDisconnecting) {
DoDisconnect(fd, conn, error, soError);
return;
} else if (conn->state != ConnectionState::kConnected) {
return;
}
if (!conn->ReconnectSourceSocket(fd, events, &soError, error)) {
DoDisconnect(fd, conn, error, soError);
return;
}
if (conn->write_callback) {
conn->write_callback(conn);
}
return;
}
void OnAccept(int server, uint32_t events, void *arg) {
if (events & (EPOLLHUP | EPOLLERR)) {
MS_LOG(ERROR) << "Invalid error event, server fd: " << server << ", events: " << events;
return;
}
TCPComm *tcpmgr = reinterpret_cast<TCPComm *>(arg);
if (tcpmgr->recv_event_loop_ == nullptr) {
MS_LOG(ERROR) << "EventLoop is null, server fd: " << server << ", events: " << events;
return;
}
// accept connection
auto acceptFd = SocketOperation::Accept(server);
if (acceptFd < 0) {
MS_LOG(ERROR) << "Failed to call accept, server fd: " << server << ", events: " << events;
return;
}
Connection *conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection, server fd:" << server << ", events: " << events
<< ", accept fd: " << acceptFd;
close(acceptFd);
acceptFd = -1;
return;
}
// init metrics
conn->send_metrics = new (std::nothrow) SendMetrics();
if (conn->send_metrics == nullptr) {
MS_LOG(ERROR) << "Failed to create connection metrics, server fd: " << server << ", events: " << events
<< ", accept fd: " << acceptFd;
close(acceptFd);
acceptFd = -1;
delete conn;
return;
}
conn->socket_fd = acceptFd;
conn->source = TCPComm::advertise_url_.data();
conn->peer = SocketOperation::GetPeer(acceptFd);
conn->is_remote = true;
conn->recv_event_loop = tcpmgr->recv_event_loop_;
conn->send_event_loop = tcpmgr->send_event_loop_;
conn->event_callback = TCPComm::EventCallBack;
conn->write_callback = TCPComm::WriteCallBack;
conn->read_callback = TCPComm::ReadCallBack;
int retval = conn->Initialize();
if (retval != RPC_OK) {
MS_LOG(ERROR) << "Failed to add accept fd event, server fd: " << server << ", events: " << events
<< ", accept fd: " << acceptFd;
close(acceptFd);
acceptFd = -1;
delete conn->send_metrics;
delete conn;
return;
}
}
void DoSend(Connection *conn) {
while (!conn->send_message_queue.empty() || conn->total_send_len != 0) {
if (conn->total_send_len == 0) {
conn->FillSendMessage(conn->send_message_queue.front(), TCPComm::advertise_url_.data(), TCPComm::IsHttpMsg());
conn->send_message_queue.pop();
}
int sendLen = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, &conn->total_send_len);
if (sendLen > 0) {
if (conn->total_send_len == 0) {
// update metrics
conn->send_metrics->UpdateError(false);
TCPComm::output_buf_size_ -= conn->send_message->body.size();
conn->output_buffer_size -= conn->send_message->body.size();
delete conn->send_message;
conn->send_message = nullptr;
}
} else if (sendLen == 0) {
// EAGAIN
(void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
break;
} else {
// update metrics
conn->send_metrics->UpdateError(true, conn->error_code);
conn->state = ConnectionState::kDisconnecting;
break;
}
}
}
TCPComm::~TCPComm() {
try {
Finalize();
} catch (...) {
MS_LOG(ERROR) << "Failed to finalize tcp communicator.";
}
}
void TCPComm::SendExitMsg(const std::string &from, const std::string &to) {
if (message_handler_ != nullptr) {
std::unique_ptr<MessageBase> exit_msg = std::make_unique<MessageBase>(MessageBase::Type::KEXIT);
RPC_OOM_EXIT(exit_msg);
exit_msg->SetFrom(AID(from));
exit_msg->SetTo(AID(to));
message_handler_(std::move(exit_msg));
}
}
void TCPComm::SetMessageHandler(IOMgr::MessageHandler handler) { message_handler_ = handler; }
bool TCPComm::Initialize() {
if (ConnectionPool::GetConnectionPool() == nullptr) {
MS_LOG(ERROR) << "Failed to create connection pool.";
return false;
}
recv_event_loop_ = new (std::nothrow) EventLoop();
if (recv_event_loop_ == nullptr) {
MS_LOG(ERROR) << "Failed to create recv evLoop.";
return false;
}
bool ok = recv_event_loop_->Initialize(TCP_RECV_EVLOOP_THREADNAME);
if (!ok) {
MS_LOG(ERROR) << "Failed to init recv evLoop";
delete recv_event_loop_;
recv_event_loop_ = nullptr;
return false;
}
send_event_loop_ = new (std::nothrow) EventLoop();
if (send_event_loop_ == nullptr) {
MS_LOG(ERROR) << "Failed to create send evLoop.";
delete recv_event_loop_;
recv_event_loop_ = nullptr;
return false;
}
ok = send_event_loop_->Initialize(TCP_SEND_EVLOOP_THREADNAME);
if (!ok) {
MS_LOG(ERROR) << "Failed to init send evLoop";
delete recv_event_loop_;
recv_event_loop_ = nullptr;
delete send_event_loop_;
send_event_loop_ = nullptr;
return false;
}
if (g_httpKmsgEnable < 0) {
char *httpKmsgEnv = getenv("LITERPC_HTTPKMSG_ENABLED");
if (httpKmsgEnv != nullptr) {
if (std::string(httpKmsgEnv) == "true" || std::string(httpKmsgEnv) == "1") {
is_http_msg_ = true;
}
}
} else {
is_http_msg_ = (g_httpKmsgEnable == 0) ? false : true;
}
ConnectionPool::GetConnectionPool()->SetLinkPattern(is_http_msg_);
return true;
}
bool TCPComm::StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) {
server_fd_ = SocketOperation::Listen(url);
if (server_fd_ < 0) {
MS_LOG(ERROR) << "Failed to call socket listen, url: " << url.c_str()
<< ", advertise_url_: " << advertise_url_.data();
return false;
}
url_ = url;
std::string tmp_url;
if (aAdvertiseUrl.size() > 0) {
advertise_url_.resize(aAdvertiseUrl.size());
advertise_url_.assign(aAdvertiseUrl.begin(), aAdvertiseUrl.end());
tmp_url = aAdvertiseUrl;
} else {
advertise_url_.resize(url_.size());
advertise_url_.assign(url_.begin(), url_.end());
tmp_url = url_;
}
size_t index = url.find(URL_PROTOCOL_IP_SEPARATOR);
if (index != std::string::npos) {
url_ = url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
}
index = tmp_url.find(URL_PROTOCOL_IP_SEPARATOR);
if (index != std::string::npos) {
tmp_url = tmp_url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
advertise_url_.resize(tmp_url.size());
advertise_url_.assign(tmp_url.begin(), tmp_url.end());
}
// Register read event callback for server socket
int retval = recv_event_loop_->SetEventHandler(server_fd_, EPOLLIN | EPOLLHUP | EPOLLERR, OnAccept,
reinterpret_cast<void *>(this));
if (retval != RPC_OK) {
MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str()
<< ", advertise_url_: " << advertise_url_.data();
return false;
}
MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str()
<< ", advertise_url_ :" << advertise_url_.data();
return true;
}
void TCPComm::ReadCallBack(void *context) {
const int max_recv_count = 3;
Connection *conn = reinterpret_cast<Connection *>(context);
int count = 0;
int retval = 0;
do {
retval = ReceiveMessage(conn);
++count;
} while (retval > 0 && count < max_recv_count);
return;
}
void TCPComm::EventCallBack(void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn->state == ConnectionState::kConnected) {
Connection::conn_mutex.lock();
DoSend(conn);
Connection::conn_mutex.unlock();
} else if (conn->state == ConnectionState::kDisconnecting) {
Connection::conn_mutex.lock();
output_buf_size_ -= conn->output_buffer_size;
ConnectionPool::GetConnectionPool()->CloseConnection(conn);
Connection::conn_mutex.unlock();
}
}
void TCPComm::WriteCallBack(void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
if (conn->state == ConnectionState::kConnected) {
Connection::conn_mutex.lock();
DoSend(conn);
Connection::conn_mutex.unlock();
}
}
int TCPComm::ReceiveMessage(Connection *conn) {
conn->CheckMessageType();
switch (conn->recv_message_type) {
case ParseType::kTcpMsg:
return conn->ReceiveMessage(message_handler_);
#ifdef HTTP_ENABLED
case ParseType::KHTTP_REQ:
if (httpReqCb) {
return httpReqCb(conn, message_handler_);
} else {
conn->state = ConnectionState::kDisconnecting;
return -1;
}
case ParseType::KHTTP_RSP:
if (httpRspCb) {
return httpRspCb(conn, message_handler_);
} else {
conn->state = ConnectionState::kDisconnecting;
return -1;
}
#endif
default:
return 0;
}
}
int TCPComm::SetConnectedHandler(Connection *conn) {
/* add to epoll */
return conn->recv_event_loop->SetEventHandler(conn->socket_fd,
(uint32_t)(EPOLLOUT | EPOLLHUP | EPOLLRDHUP | EPOLLERR),
ConnectedEventHandler, reinterpret_cast<void *>(conn));
}
int TCPComm::Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
int retval = 0;
uint16_t localPort = -1;
retval = SocketOperation::Connect(conn->socket_fd, sa, saLen, &localPort);
if (retval != RPC_OK) {
return RPC_ERROR;
}
// Init connection metrics.
if (conn->send_metrics == nullptr) {
conn->send_metrics = new (std::nothrow) SendMetrics();
if (conn->send_metrics == nullptr) {
return RPC_ERROR;
}
}
// Add the socket of this connection to epoll.
retval = SetConnectedHandler(conn);
if (retval != RPC_OK) {
if (conn->send_metrics != nullptr) {
delete conn->send_metrics;
conn->send_metrics = nullptr;
}
return RPC_ERROR;
}
return RPC_OK;
}
void TCPComm::Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) {
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, isExactNotRemote);
// Create a new connection if the connection to target of the message does not existed.
if (conn == nullptr) {
if (remoteLink && (!isExactNotRemote)) {
MS_LOG(ERROR) << "Could not found remote link and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
delete msg;
return;
}
conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
delete msg;
return;
}
conn->source = advertise_url_.data();
conn->destination = msg->to.Url();
conn->recv_event_loop = tcpmgr->recv_event_loop_;
conn->send_event_loop = tcpmgr->send_event_loop_;
conn->InitSocketOperation();
int ret = DoConnect(msg->to.Url(), conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
if (ret < 0) {
MS_LOG(ERROR) << "Failed to do connect and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
if (conn->socket_operation != nullptr) {
delete conn->socket_operation;
conn->socket_operation = nullptr;
}
delete conn;
delete msg;
return;
}
ConnectionPool::GetConnectionPool()->AddConnection(conn);
}
if (!conn->is_remote && !isExactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) {
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true);
if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) {
conn = remoteConn;
}
}
// Prepare the message.
if (conn->total_send_len == 0) {
conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_);
} else {
conn->send_message_queue.emplace(msg);
}
// Send the message.
if (conn->state == ConnectionState::kConnected) {
DoSend(conn);
}
}
void TCPComm::SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) {
recv_event_loop_->AddTask(
[msg, tcpmgr, remoteLink, isExactNotRemote] { TCPComm::Send(msg, tcpmgr, remoteLink, isExactNotRemote); });
}
int TCPComm::Send(MessageBase *msg, bool remoteLink, bool isExactNotRemote) {
return send_event_loop_->AddTask([msg, this, remoteLink, isExactNotRemote] {
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
// Search connection by the target address
bool exactNotRemote = is_http_msg_ || isExactNotRemote;
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, exactNotRemote);
if (conn == nullptr) {
if (remoteLink && (!exactNotRemote)) {
MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str()
<< ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str();
auto *ptr = msg;
delete ptr;
ptr = nullptr;
return;
}
this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote);
return;
}
if (conn->state != kConnected && conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
MS_LOG(WARNING) << "The name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
<< ", to: " << conn->destination.c_str() << ", remote: " << conn->is_remote;
auto *ptr = msg;
delete ptr;
ptr = nullptr;
return;
}
if (conn->state == ConnectionState::kClose || conn->state == ConnectionState::kDisconnecting) {
this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote);
return;
}
if (!conn->is_remote && !exactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) {
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true);
if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) {
conn = remoteConn;
}
}
output_buf_size_ += msg->body.size();
if (conn->total_send_len == 0) {
conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_);
} else {
conn->send_message_queue.emplace(msg);
}
if (conn->state == ConnectionState::kConnected) {
DoSend(conn);
}
});
}
void TCPComm::CollectMetrics() {
send_event_loop_->AddTask([this] {
Connection::conn_mutex.lock();
Connection *maxConn = ConnectionPool::GetConnectionPool()->FindMaxConnection();
Connection *fastConn = ConnectionPool::GetConnectionPool()->FindFastConnection();
if (message_handler_ != nullptr) {
IntTypeMetrics intMetrics;
StringTypeMetrics stringMetrics;
if (maxConn != nullptr) {
intMetrics.push(maxConn->socket_fd);
intMetrics.push(maxConn->error_code);
intMetrics.push(maxConn->send_metrics->accum_msg_count);
intMetrics.push(maxConn->send_metrics->max_msg_size);
stringMetrics.push(maxConn->destination);
stringMetrics.push(maxConn->send_metrics->last_succ_msg_name);
stringMetrics.push(maxConn->send_metrics->last_fail_msg_name);
}
if (fastConn != nullptr && fastConn->IsSame(maxConn)) {
intMetrics.push(fastConn->socket_fd);
intMetrics.push(fastConn->error_code);
intMetrics.push(fastConn->send_metrics->accum_msg_count);
intMetrics.push(fastConn->send_metrics->max_msg_size);
stringMetrics.push(fastConn->destination);
stringMetrics.push(fastConn->send_metrics->last_succ_msg_name);
stringMetrics.push(fastConn->send_metrics->last_fail_msg_name);
}
}
ConnectionPool::GetConnectionPool()->ResetAllConnMetrics();
Connection::conn_mutex.unlock();
});
}
int TCPComm::Send(std::unique_ptr<MessageBase> &&msg, bool remoteLink, bool isExactNotRemote) {
return Send(msg.release(), remoteLink, isExactNotRemote);
}
void TCPComm::Link(const AID &source, const AID &destination) {
recv_event_loop_->AddTask([source, destination, this] {
std::string to = destination.Url();
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
// Search connection by the target address
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
if (conn == nullptr) {
MS_LOG(INFO) << "Can not found link source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection and link fail source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
SendExitMsg(source, destination);
return;
}
conn->source = advertise_url_.data();
conn->destination = to;
conn->recv_event_loop = this->recv_event_loop_;
conn->send_event_loop = this->send_event_loop_;
conn->InitSocketOperation();
int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
if (ret < 0) {
MS_LOG(ERROR) << "Failed to do connect and link fail source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
SendExitMsg(source, destination);
if (conn->socket_operation != nullptr) {
delete conn->socket_operation;
conn->socket_operation = nullptr;
}
delete conn;
return;
}
ConnectionPool::GetConnectionPool()->AddConnection(conn);
}
ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg);
MS_LOG(INFO) << "Link fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote;
});
}
void TCPComm::UnLink(const AID &destination) {
recv_event_loop_->AddTask([destination] {
std::string to = destination.Url();
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
if (is_http_msg_) {
// When application has set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link is in links map
// while accepting-link is differently in remoteLinks map. So we only need to delete link in exact links.
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false);
} else {
// When application hasn't set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link and accepting-link is
// shared
// So we need to delete link in both links map and remote-links map.
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false);
ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, true);
}
});
}
void TCPComm::DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd) {
if (!is_http_msg_ && !conn->is_remote) {
Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(to, true);
// We will close remote link in rare cases where sending-link and accepting link coexists
// simultaneously.
if (remoteConn != nullptr) {
MS_LOG(INFO) << "Reconnect, close remote connect fd :" << remoteConn->socket_fd
<< ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str() << ", remote: " << remoteConn->is_remote
<< ", state: " << remoteConn->state;
ConnectionPool::GetConnectionPool()->CloseConnection(remoteConn);
}
}
MS_LOG(INFO) << "Reconnect, close old connect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote
<< ", state: " << conn->state;
*oldFd = conn->socket_fd;
conn->recv_event_loop->DeleteEpollEvent(conn->socket_fd);
conn->socket_operation->Close(conn);
conn->socket_fd = -1;
conn->recv_len = 0;
conn->total_recv_len = 0;
conn->recv_message_type = kUnknown;
conn->state = kInit;
if (conn->total_send_len != 0 && conn->send_message != nullptr) {
delete conn->send_message;
}
conn->send_message = nullptr;
conn->total_send_len = 0;
if (conn->total_recv_len != 0 && conn->recv_message != nullptr) {
delete conn->recv_message;
}
conn->recv_message = nullptr;
conn->total_recv_len = 0;
conn->recv_state = State::kMsgHeader;
}
Connection *TCPComm::CreateDefaultConn(std::string to) {
Connection *conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection and reconnect fail to: " << to.c_str();
return conn;
}
conn->source = advertise_url_.data();
conn->destination = to;
conn->recv_event_loop = this->recv_event_loop_;
conn->send_event_loop = this->send_event_loop_;
conn->InitSocketOperation();
return conn;
}
void TCPComm::Reconnect(const AID &source, const AID &destination) {
send_event_loop_->AddTask([source, destination, this] {
std::string to = destination.Url();
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
if (conn != nullptr) {
conn->state = ConnectionState::kClose;
}
recv_event_loop_->AddTask([source, destination, this] {
std::string to = destination.Url();
int oldFd = -1;
std::lock_guard<std::mutex> lock(Connection::conn_mutex);
Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_);
if (conn != nullptr) {
// connection already exist
DoReConnectConn(conn, to, source, destination, &oldFd);
} else {
// create default connection
conn = CreateDefaultConn(to);
if (conn == nullptr) {
return;
}
}
int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack);
if (ret < 0) {
if (conn->socket_operation != nullptr) {
delete conn->socket_operation;
conn->socket_operation = nullptr;
}
if (oldFd != -1) {
conn->socket_fd = oldFd;
}
MS_LOG(ERROR) << "Failed to connect and reconnect fail source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
ConnectionPool::GetConnectionPool()->CloseConnection(conn);
return;
}
if (oldFd != -1) {
if (!ConnectionPool::GetConnectionPool()->ReverseConnInfo(oldFd, conn->socket_fd)) {
MS_LOG(ERROR) << "Failed to swap socket for " << oldFd << " and " << conn->socket_fd;
}
} else {
ConnectionPool::GetConnectionPool()->AddConnection(conn);
}
ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg);
MS_LOG(INFO) << "Reconnect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str()
<< ", destination: " << std::string(destination).c_str();
});
});
}
void TCPComm::Finalize() {
if (send_event_loop_ != nullptr) {
MS_LOG(INFO) << "Delete send event loop";
send_event_loop_->Finalize();
delete send_event_loop_;
send_event_loop_ = nullptr;
}
if (recv_event_loop_ != nullptr) {
MS_LOG(INFO) << "Delete recv event loop";
recv_event_loop_->Finalize();
delete recv_event_loop_;
recv_event_loop_ = nullptr;
}
if (server_fd_ > 0) {
close(server_fd_);
server_fd_ = -1;
}
}
// This value is not used for the sub-class TCPComm.
uint64_t TCPComm::GetInBufSize() { return 1; }
uint64_t TCPComm::GetOutBufSize() { return output_buf_size_; }
bool TCPComm::IsHttpMsg() { return is_http_msg_; }
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_
#include <string>
#include <memory>
#include <vector>
#include "actor/iomgr.h"
#include "distributed/rpc/tcp/connection.h"
#include "distributed/rpc/tcp/event_loop.h"
namespace mindspore {
namespace distributed {
namespace rpc {
// Event handler for new connecting request arrived.
void OnAccept(int server, uint32_t events, void *arg);
// Send messages buffered in the connection.
void DoSend(Connection *conn);
// Create a server socket and connect to it, this is a local connection..
int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack eventCallBack,
ConnectionCallBack writeCallBack, ConnectionCallBack readCallBack);
void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
void ConnectedEventHandler(int fd, uint32_t events, void *context);
class TCPComm : public IOMgr {
public:
TCPComm() : server_fd_(-1), recv_event_loop_(nullptr), send_event_loop_(nullptr) {}
TCPComm(const TCPComm &) = delete;
TCPComm &operator=(const TCPComm &) = delete;
~TCPComm();
// Init the event loop for reading and writing.
bool Initialize() override;
// Destroy all the resources.
void Finalize() override;
// Create the server socket represented by url.
bool StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) override;
// Build a connection between the source and destination.
void Link(const AID &source, const AID &destination) override;
void UnLink(const AID &destination) override;
// Send the message from the source to the destination.
int Send(std::unique_ptr<MessageBase> &&msg, bool remoteLink = false, bool isExactNotRemote = false) override;
uint64_t GetInBufSize() override;
uint64_t GetOutBufSize() override;
void CollectMetrics() override;
private:
// Build the connection.
Connection *CreateDefaultConn(std::string to);
void Reconnect(const AID &source, const AID &destination);
void DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd);
// Send a message.
int Send(MessageBase *msg, bool remoteLink = false, bool isExactNotRemote = false);
static void Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote);
void SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote);
static void SendExitMsg(const std::string &from, const std::string &to);
// Called by ReadCallBack when new message arrived.
static int ReceiveMessage(Connection *conn);
void SetMessageHandler(IOMgr::MessageHandler handler);
static int SetConnectedHandler(Connection *conn);
static int Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen);
static bool IsHttpMsg();
// Read and write events.
static void ReadCallBack(void *context);
static void WriteCallBack(void *context);
// Connected and Disconnected events.
static void EventCallBack(void *context);
// The server url.
std::string url_;
// The socket of server.
int server_fd_;
// The message size waiting to be sent.
static uint64_t output_buf_size_;
// User defined handler for Handling received messages.
static MessageHandler message_handler_;
// The source url of a message.
static std::vector<char> advertise_url_;
static bool is_http_msg_;
// All the connections share the same read and write event loop objects.
EventLoop *recv_event_loop_;
EventLoop *send_event_loop_;
friend void OnAccept(int server, uint32_t events, void *arg);
friend void DoSend(Connection *conn);
friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
ConnectionCallBack write_callback, ConnectionCallBack read_callback);
};
} // namespace rpc
} // namespace distributed
} // namespace mindspore
#endif

View File

@ -153,7 +153,7 @@ void ActorMgr::Finalize() {
// stop iomgr thread
for (auto mgrIt = ioMgrs.begin(); mgrIt != ioMgrs.end(); ++mgrIt) {
MS_LOG(INFO) << "finalize IOMgr=" << mgrIt->first.c_str();
mgrIt->second->Finish();
mgrIt->second->Finalize();
}
// delete actor thread pool if use_inner_pool

View File

@ -56,7 +56,7 @@ static const char HTTP_CLIENT_EVLOOP_THREADNAME[] = "MINDRT_Htp";
class IOMgr {
public:
using MsgHandler = void (*)(std::unique_ptr<MessageBase> &&msg);
using MessageHandler = void (*)(std::unique_ptr<MessageBase> &&msg);
/**
* remoteLink and isExactNotRemote are flags to tell us which link should be used. There are several cases:
* 1. remoteLink is false and isExactNotRemote is false : callers can reuse remote link when threr are no links
@ -71,16 +71,15 @@ class IOMgr {
// close the socket,and send exitedEvent to all linkers.
virtual void UnLink(const AID &dAid) = 0;
virtual void Reconnect(const AID &sAid, const AID &dAid) = 0;
virtual void RegisterMsgHandle(MsgHandler handle) = 0;
virtual bool Init() = 0; // once
virtual void Finish() = 0; // once
virtual bool StartIOServer(const std::string &url, const std::string &advertiseUrl) = 0; // multicalledable
virtual void SetMessageHandler(MessageHandler handle) = 0;
virtual bool Initialize() = 0; // once
virtual void Finalize() = 0; // once
virtual bool StartServerSocket(const std::string &url, const std::string &advertiseUrl) = 0; // multicalledable
virtual uint64_t GetOutBufSize() = 0;
virtual uint64_t GetInBufSize() = 0;
virtual void CollectMetrics() = 0;
virtual int AddRuleUdp(std::string peer, int recordNum) = 0;
virtual void DelRuleUdp(std::string peer, bool outputLog) = 0;
virtual void LinkRecycleCheck(int recyclePeroid) = 0;
virtual int AddRuleUdp(std::string peer, int recordNum) { return 0; }
virtual void DelRuleUdp(std::string peer, bool outputLog) { return; }
IOMgr() {}
virtual ~IOMgr() {}
};

View File

@ -70,6 +70,7 @@ if(ENABLE_MINDDATA)
./ps/*.cc
./fl/*.cc
./distributed/persistent/*.cc
./distributed/rpc/tcp/*.cc
./cxx_api/*.cc
./tbe/*.cc
./mindapi/*.cc
@ -175,6 +176,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/ps/*.cc"
"../../../mindspore/ccsrc/fl/*.cc"
"../../../mindspore/ccsrc/distributed/persistent/*.cc"
"../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc"
"../../../mindspore/ccsrc/profiler/device/ascend/*.cc"
"../../../mindspore/ccsrc/profiler/device/profiling.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"

View File

@ -0,0 +1,282 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/resource.h>
#include <sys/types.h>
#include <dirent.h>
#include <atomic>
#include <string>
#include <thread>
#include <csignal>
#include <gtest/gtest.h>
#define private public
#include "actor/iomgr.h"
#include "async/async.h"
#include "distributed/rpc/tcp/tcp_comm.h"
#include "common/common_test.h"
namespace mindspore {
namespace distributed {
namespace rpc {
int g_recv_num = 0;
int g_exit_msg_num = 0;
TCPComm *m_io = nullptr;
std::atomic<int> m_sendNum(0);
std::string m_localIP = "127.0.0.1";
bool m_notRemote = false;
void msgHandle(std::unique_ptr<MessageBase> &&msg) {
if (msg->GetType() == MessageBase::Type::KEXIT) {
g_exit_msg_num++;
} else {
g_recv_num++;
}
}
class TCPTest : public UT::Common {
public:
static void SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink = false,
std::string body = "");
protected:
char *args[4];
char *testServerPath;
static const size_t pid_num = 100;
pid_t pid1;
pid_t pid2;
pid_t pids[pid_num];
void SetUp() {
char *localpEnv = getenv("LITEBUS_IP");
if (localpEnv != nullptr) {
m_localIP = std::string(localpEnv);
}
char *locaNotRemoteEnv = getenv("LITEBUS_SEND_ON_REMOTE");
if (locaNotRemoteEnv != nullptr) {
m_notRemote = (std::string(locaNotRemoteEnv) == "true") ? true : false;
}
pid1 = 0;
pid2 = 0;
pids[pid_num] = {0};
size_t size = pid_num * sizeof(pid_t);
if (memset_s(&pids, size, 0, size)) {
MS_LOG(ERROR) << "Failed to init pid array";
}
g_recv_num = 0;
g_exit_msg_num = 0;
m_sendNum = 0;
m_io = new TCPComm();
m_io->Initialize();
m_io->SetMessageHandler(msgHandle);
m_io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225");
}
void TearDown() {
shutdownTcpServer(pid1);
shutdownTcpServer(pid2);
pid1 = 0;
pid2 = 0;
int i = 0;
for (i = 0; i < pid_num; i++) {
shutdownTcpServer(pids[i]);
pids[i] = 0;
}
g_recv_num = 0;
g_exit_msg_num = 0;
m_sendNum = 0;
m_io->Finalize();
delete m_io;
m_io = nullptr;
}
bool CheckRecvNum(int expectedRecvNum, int _timeout);
bool CheckExitNum(int expectedExitNum, int _timeout);
pid_t startTcpServer(char **args);
void shutdownTcpServer(pid_t pid);
void KillTcpServer(pid_t pid);
void Link(std::string &_localUrl, std::string &_remoteUrl);
void Reconnect(std::string &_localUrl, std::string &_remoteUrl);
void Unlink(std::string &_remoteUrl);
};
// listening local url and sending msg to remote url,if start succ.
pid_t TCPTest::startTcpServer(char **args) {
pid_t pid = fork();
if (pid == 0) {
return -1;
} else {
return pid;
}
}
void TCPTest::shutdownTcpServer(pid_t pid) {
if (pid > 1) {
kill(pid, SIGALRM);
int status;
waitpid(pid, &status, 0);
}
}
void TCPTest::KillTcpServer(pid_t pid) {
if (pid > 1) {
kill(pid, SIGKILL);
int status;
waitpid(pid, &status, 0);
}
}
void TCPTest::SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink, std::string body) {
AID from("testserver", _localUrl);
AID to("testserver", _remoteUrl);
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
std::string data(msgsize, 'A');
message->name = "testname";
message->from = from;
message->to = to;
message->body = data;
if (body != "") {
message->body = body;
}
if (m_notRemote) {
m_io->Send(std::move(message), remoteLink, true);
} else {
m_io->Send(std::move(message), remoteLink);
}
}
void TCPTest::Link(std::string &_localUrl, std::string &_remoteUrl) {
AID from("testserver", _localUrl);
AID to("testserver", _remoteUrl);
m_io->Link(from, to);
}
void TCPTest::Reconnect(std::string &_localUrl, std::string &_remoteUrl) {
AID from("testserver", _localUrl);
AID to("testserver", _remoteUrl);
m_io->Reconnect(from, to);
}
void TCPTest::Unlink(std::string &_remoteUrl) {
AID to("testserver", _remoteUrl);
m_io->UnLink(to);
}
bool TCPTest::CheckRecvNum(int expectedRecvNum, int _timeout) {
int timeout = _timeout * 1000 * 1000; // us
int usleepCount = 100000;
while (timeout) {
usleep(usleepCount);
if (g_recv_num >= expectedRecvNum) {
return true;
}
timeout = timeout - usleepCount;
}
return false;
}
bool TCPTest::CheckExitNum(int expectedExitNum, int _timeout) {
int timeout = _timeout * 1000 * 1000;
int usleepCount = 100000;
while (timeout) {
usleep(usleepCount);
if (g_exit_msg_num >= expectedExitNum) {
return true;
}
timeout = timeout - usleepCount;
}
return false;
}
/// Feature: test failed to start a socket server.
/// Description: start a socket server with an invalid url.
/// Expectation: failed to start the server with invalid url.
TEST_F(TCPTest, StartServerFail) {
std::unique_ptr<TCPComm> io = std::make_unique<TCPComm>();
io->Initialize();
bool ret = io->StartServerSocket("tcp://0:2225", "tcp://0:2225");
ASSERT_FALSE(ret);
io->Finalize();
}
/// Feature: test start a socket server.
/// Description: start the socket server with a specified socket.
/// Expectation: the socket server is started successfully.
TEST_F(TCPTest, StartServer2) {
std::unique_ptr<TCPComm> io = std::make_unique<TCPComm>();
io->Initialize();
io->SetMessageHandler(msgHandle);
bool ret = io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225");
ASSERT_FALSE(ret);
ret = io->StartServerSocket("tcp://" + m_localIP + ":2224", "tcp://" + m_localIP + ":2224");
io->Finalize();
ASSERT_TRUE(ret);
}
/// Feature: test normal tcp message sending.
/// Description: start a socket server and send a normal message to it.
/// Expectation: the server received the message sented from client.
TEST_F(TCPTest, send1Msg) {
g_recv_num = 0;
pid1 = startTcpServer(args);
bool ret = CheckRecvNum(1, 5);
ASSERT_FALSE(ret);
std::string from = "tcp://" + m_localIP + ":2223";
std::string to = "tcp://" + m_localIP + ":2225";
SendMsg(from, to, pid_num);
ret = CheckRecvNum(1, 5);
ASSERT_TRUE(ret);
Unlink(to);
shutdownTcpServer(pid1);
pid1 = 0;
}
/// Feature: test max message body check..
/// Description: send a message with body exceeding the limit of max message body.
/// Expectation: drop the invalid message.
TEST_F(TCPTest, sendInvalidMsg) {
g_recv_num = 0;
pid1 = startTcpServer(args);
bool ret = CheckRecvNum(1, 5);
ASSERT_FALSE(ret);
std::string from = "tcp://" + m_localIP + ":2223";
std::string to = "tcp://" + m_localIP + ":2225";
SendMsg(from, to, 1024 * 1024 * pid_num + 1);
ret = CheckRecvNum(1, 5);
ASSERT_FALSE(ret);
Unlink(to);
shutdownTcpServer(pid1);
pid1 = 0;
}
} // namespace rpc
} // namespace distributed
} // namespace mindspore