diff --git a/mindspore/ccsrc/distributed/CMakeLists.txt b/mindspore/ccsrc/distributed/CMakeLists.txt index fa71f3652f5..efde26e31d4 100644 --- a/mindspore/ccsrc/distributed/CMakeLists.txt +++ b/mindspore/ccsrc/distributed/CMakeLists.txt @@ -6,14 +6,15 @@ else() list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES "cluster/dummy_cluster_context.cc") endif() -set(EXCLUDE_DIR "rpc/") -foreach(RPC_FILE ${_DISTRIBUTED_SRC_FILES}) - string(FIND ${RPC_FILE} ${EXCLUDE_DIR} FOUND) - if(${FOUND} EQUAL 0) - list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES ${RPC_FILE}) - endif() -endforeach() - +if(NOT ENABLE_CPU OR WIN32 OR APPLE) + set(EXCLUDE_DIR "rpc/") + foreach(RPC_FILE ${_DISTRIBUTED_SRC_FILES}) + string(FIND ${RPC_FILE} ${EXCLUDE_DIR} FOUND) + if(${FOUND} EQUAL 0) + list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES ${RPC_FILE}) + endif() + endforeach() +endif() set_property(SOURCE ${_DISTRIBUTED_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DISTRIBUTED) diff --git a/mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.cc b/mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.cc new file mode 100644 index 00000000000..c1c75c9d0c3 --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.cc @@ -0,0 +1,771 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "distributed/rpc/tcp/tcp_comm.h" + +#include +#include +#include + +#include "actor/aid.h" +#include "actor/msg.h" +#include "distributed/rpc/tcp/constants.h" +#include "distributed/rpc/tcp/tcp_socket_operation.h" +#include "distributed/rpc/tcp/connection_pool.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +bool TCPComm::is_http_msg_ = false; +std::vector TCPComm::advertise_url_; +uint64_t TCPComm::output_buf_size_ = 0; + +IOMgr::MessageHandler TCPComm::message_handler_; + +int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback, + ConnectionCallBack write_callback, ConnectionCallBack read_callback) { + SocketAddress addr; + if (!SocketOperation::GetSockAddr(to, &addr)) { + return -1; + } + int sock_fd = SocketOperation::CreateServerSocket(addr.sa.sa_family); + if (sock_fd < 0) { + return -1; + } + + conn->socket_fd = sock_fd; + conn->event_callback = event_callback; + conn->write_callback = write_callback; + conn->read_callback = read_callback; + + int ret = TCPComm::Connect(conn, (struct sockaddr *)&addr, sizeof(addr)); + if (ret < 0) { + close(sock_fd); + conn->socket_fd = -1; + return -1; + } + return 0; +} + +void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError) { + if (LOG_CHECK_EVERY_N()) { + MS_LOG(INFO) << "Failed to call connect, fd: " << fd << ", to: " << conn->destination.c_str() + << ", events: " << error << ", errno: " << soError; + } + + conn->state = ConnectionState::kDisconnecting; + conn->error_code = soError; + conn->event_callback(conn); + return; +} + +void ConnectedEventHandler(int fd, uint32_t events, void *context) { + uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP); + int soError = 0; + Connection *conn = reinterpret_cast(context); + conn->socket_operation->ConnEstablishedEventHandler(fd, events, context); + if (conn->state == ConnectionState::kDisconnecting) { + DoDisconnect(fd, conn, error, soError); + return; + } else if (conn->state != ConnectionState::kConnected) { + return; + } + + if (!conn->ReconnectSourceSocket(fd, events, &soError, error)) { + DoDisconnect(fd, conn, error, soError); + return; + } + if (conn->write_callback) { + conn->write_callback(conn); + } + return; +} + +void OnAccept(int server, uint32_t events, void *arg) { + if (events & (EPOLLHUP | EPOLLERR)) { + MS_LOG(ERROR) << "Invalid error event, server fd: " << server << ", events: " << events; + return; + } + TCPComm *tcpmgr = reinterpret_cast(arg); + if (tcpmgr->recv_event_loop_ == nullptr) { + MS_LOG(ERROR) << "EventLoop is null, server fd: " << server << ", events: " << events; + return; + } + + // accept connection + auto acceptFd = SocketOperation::Accept(server); + if (acceptFd < 0) { + MS_LOG(ERROR) << "Failed to call accept, server fd: " << server << ", events: " << events; + return; + } + + Connection *conn = new (std::nothrow) Connection(); + if (conn == nullptr) { + MS_LOG(ERROR) << "Failed to create new connection, server fd:" << server << ", events: " << events + << ", accept fd: " << acceptFd; + close(acceptFd); + acceptFd = -1; + return; + } + + // init metrics + conn->send_metrics = new (std::nothrow) SendMetrics(); + if (conn->send_metrics == nullptr) { + MS_LOG(ERROR) << "Failed to create connection metrics, server fd: " << server << ", events: " << events + << ", accept fd: " << acceptFd; + close(acceptFd); + acceptFd = -1; + delete conn; + return; + } + + conn->socket_fd = acceptFd; + conn->source = TCPComm::advertise_url_.data(); + conn->peer = SocketOperation::GetPeer(acceptFd); + + conn->is_remote = true; + conn->recv_event_loop = tcpmgr->recv_event_loop_; + conn->send_event_loop = tcpmgr->send_event_loop_; + + conn->event_callback = TCPComm::EventCallBack; + conn->write_callback = TCPComm::WriteCallBack; + conn->read_callback = TCPComm::ReadCallBack; + + int retval = conn->Initialize(); + if (retval != RPC_OK) { + MS_LOG(ERROR) << "Failed to add accept fd event, server fd: " << server << ", events: " << events + << ", accept fd: " << acceptFd; + close(acceptFd); + acceptFd = -1; + delete conn->send_metrics; + delete conn; + return; + } +} + +void DoSend(Connection *conn) { + while (!conn->send_message_queue.empty() || conn->total_send_len != 0) { + if (conn->total_send_len == 0) { + conn->FillSendMessage(conn->send_message_queue.front(), TCPComm::advertise_url_.data(), TCPComm::IsHttpMsg()); + conn->send_message_queue.pop(); + } + + int sendLen = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, &conn->total_send_len); + if (sendLen > 0) { + if (conn->total_send_len == 0) { + // update metrics + conn->send_metrics->UpdateError(false); + + TCPComm::output_buf_size_ -= conn->send_message->body.size(); + conn->output_buffer_size -= conn->send_message->body.size(); + delete conn->send_message; + conn->send_message = nullptr; + } + } else if (sendLen == 0) { + // EAGAIN + (void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR); + break; + } else { + // update metrics + conn->send_metrics->UpdateError(true, conn->error_code); + conn->state = ConnectionState::kDisconnecting; + break; + } + } +} + +TCPComm::~TCPComm() { + try { + Finalize(); + } catch (...) { + MS_LOG(ERROR) << "Failed to finalize tcp communicator."; + } +} + +void TCPComm::SendExitMsg(const std::string &from, const std::string &to) { + if (message_handler_ != nullptr) { + std::unique_ptr exit_msg = std::make_unique(MessageBase::Type::KEXIT); + RPC_OOM_EXIT(exit_msg); + + exit_msg->SetFrom(AID(from)); + exit_msg->SetTo(AID(to)); + + message_handler_(std::move(exit_msg)); + } +} + +void TCPComm::SetMessageHandler(IOMgr::MessageHandler handler) { message_handler_ = handler; } + +bool TCPComm::Initialize() { + if (ConnectionPool::GetConnectionPool() == nullptr) { + MS_LOG(ERROR) << "Failed to create connection pool."; + return false; + } + recv_event_loop_ = new (std::nothrow) EventLoop(); + if (recv_event_loop_ == nullptr) { + MS_LOG(ERROR) << "Failed to create recv evLoop."; + return false; + } + + bool ok = recv_event_loop_->Initialize(TCP_RECV_EVLOOP_THREADNAME); + if (!ok) { + MS_LOG(ERROR) << "Failed to init recv evLoop"; + delete recv_event_loop_; + recv_event_loop_ = nullptr; + return false; + } + + send_event_loop_ = new (std::nothrow) EventLoop(); + if (send_event_loop_ == nullptr) { + MS_LOG(ERROR) << "Failed to create send evLoop."; + delete recv_event_loop_; + recv_event_loop_ = nullptr; + return false; + } + ok = send_event_loop_->Initialize(TCP_SEND_EVLOOP_THREADNAME); + if (!ok) { + MS_LOG(ERROR) << "Failed to init send evLoop"; + delete recv_event_loop_; + recv_event_loop_ = nullptr; + delete send_event_loop_; + send_event_loop_ = nullptr; + return false; + } + + if (g_httpKmsgEnable < 0) { + char *httpKmsgEnv = getenv("LITERPC_HTTPKMSG_ENABLED"); + if (httpKmsgEnv != nullptr) { + if (std::string(httpKmsgEnv) == "true" || std::string(httpKmsgEnv) == "1") { + is_http_msg_ = true; + } + } + } else { + is_http_msg_ = (g_httpKmsgEnable == 0) ? false : true; + } + + ConnectionPool::GetConnectionPool()->SetLinkPattern(is_http_msg_); + return true; +} + +bool TCPComm::StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) { + server_fd_ = SocketOperation::Listen(url); + if (server_fd_ < 0) { + MS_LOG(ERROR) << "Failed to call socket listen, url: " << url.c_str() + << ", advertise_url_: " << advertise_url_.data(); + return false; + } + url_ = url; + std::string tmp_url; + + if (aAdvertiseUrl.size() > 0) { + advertise_url_.resize(aAdvertiseUrl.size()); + advertise_url_.assign(aAdvertiseUrl.begin(), aAdvertiseUrl.end()); + tmp_url = aAdvertiseUrl; + } else { + advertise_url_.resize(url_.size()); + advertise_url_.assign(url_.begin(), url_.end()); + tmp_url = url_; + } + + size_t index = url.find(URL_PROTOCOL_IP_SEPARATOR); + if (index != std::string::npos) { + url_ = url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1); + } + + index = tmp_url.find(URL_PROTOCOL_IP_SEPARATOR); + if (index != std::string::npos) { + tmp_url = tmp_url.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1); + advertise_url_.resize(tmp_url.size()); + advertise_url_.assign(tmp_url.begin(), tmp_url.end()); + } + + // Register read event callback for server socket + int retval = recv_event_loop_->SetEventHandler(server_fd_, EPOLLIN | EPOLLHUP | EPOLLERR, OnAccept, + reinterpret_cast(this)); + if (retval != RPC_OK) { + MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str() + << ", advertise_url_: " << advertise_url_.data(); + return false; + } + MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str() + << ", advertise_url_ :" << advertise_url_.data(); + return true; +} + +void TCPComm::ReadCallBack(void *context) { + const int max_recv_count = 3; + Connection *conn = reinterpret_cast(context); + int count = 0; + int retval = 0; + do { + retval = ReceiveMessage(conn); + ++count; + } while (retval > 0 && count < max_recv_count); + + return; +} + +void TCPComm::EventCallBack(void *context) { + Connection *conn = reinterpret_cast(context); + + if (conn->state == ConnectionState::kConnected) { + Connection::conn_mutex.lock(); + DoSend(conn); + Connection::conn_mutex.unlock(); + } else if (conn->state == ConnectionState::kDisconnecting) { + Connection::conn_mutex.lock(); + output_buf_size_ -= conn->output_buffer_size; + ConnectionPool::GetConnectionPool()->CloseConnection(conn); + Connection::conn_mutex.unlock(); + } +} + +void TCPComm::WriteCallBack(void *context) { + Connection *conn = reinterpret_cast(context); + if (conn->state == ConnectionState::kConnected) { + Connection::conn_mutex.lock(); + DoSend(conn); + Connection::conn_mutex.unlock(); + } +} + +int TCPComm::ReceiveMessage(Connection *conn) { + conn->CheckMessageType(); + switch (conn->recv_message_type) { + case ParseType::kTcpMsg: + return conn->ReceiveMessage(message_handler_); + +#ifdef HTTP_ENABLED + case ParseType::KHTTP_REQ: + if (httpReqCb) { + return httpReqCb(conn, message_handler_); + } else { + conn->state = ConnectionState::kDisconnecting; + return -1; + } + + case ParseType::KHTTP_RSP: + if (httpRspCb) { + return httpRspCb(conn, message_handler_); + } else { + conn->state = ConnectionState::kDisconnecting; + return -1; + } +#endif + + default: + return 0; + } +} + +int TCPComm::SetConnectedHandler(Connection *conn) { + /* add to epoll */ + return conn->recv_event_loop->SetEventHandler(conn->socket_fd, + (uint32_t)(EPOLLOUT | EPOLLHUP | EPOLLRDHUP | EPOLLERR), + ConnectedEventHandler, reinterpret_cast(conn)); +} + +int TCPComm::Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) { + int retval = 0; + uint16_t localPort = -1; + + retval = SocketOperation::Connect(conn->socket_fd, sa, saLen, &localPort); + if (retval != RPC_OK) { + return RPC_ERROR; + } + + // Init connection metrics. + if (conn->send_metrics == nullptr) { + conn->send_metrics = new (std::nothrow) SendMetrics(); + if (conn->send_metrics == nullptr) { + return RPC_ERROR; + } + } + + // Add the socket of this connection to epoll. + retval = SetConnectedHandler(conn); + if (retval != RPC_OK) { + if (conn->send_metrics != nullptr) { + delete conn->send_metrics; + conn->send_metrics = nullptr; + } + return RPC_ERROR; + } + return RPC_OK; +} + +void TCPComm::Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) { + std::lock_guard lock(Connection::conn_mutex); + Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, isExactNotRemote); + + // Create a new connection if the connection to target of the message does not existed. + if (conn == nullptr) { + if (remoteLink && (!isExactNotRemote)) { + MS_LOG(ERROR) << "Could not found remote link and send fail name: " << msg->name.c_str() + << ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str(); + delete msg; + return; + } + conn = new (std::nothrow) Connection(); + if (conn == nullptr) { + MS_LOG(ERROR) << "Failed to create new connection and send fail name: " << msg->name.c_str() + << ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str(); + delete msg; + return; + } + conn->source = advertise_url_.data(); + conn->destination = msg->to.Url(); + conn->recv_event_loop = tcpmgr->recv_event_loop_; + conn->send_event_loop = tcpmgr->send_event_loop_; + conn->InitSocketOperation(); + + int ret = DoConnect(msg->to.Url(), conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack); + if (ret < 0) { + MS_LOG(ERROR) << "Failed to do connect and send fail name: " << msg->name.c_str() + << ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str(); + if (conn->socket_operation != nullptr) { + delete conn->socket_operation; + conn->socket_operation = nullptr; + } + delete conn; + delete msg; + return; + } + ConnectionPool::GetConnectionPool()->AddConnection(conn); + } + + if (!conn->is_remote && !isExactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) { + Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true); + if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) { + conn = remoteConn; + } + } + + // Prepare the message. + if (conn->total_send_len == 0) { + conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_); + } else { + conn->send_message_queue.emplace(msg); + } + + // Send the message. + if (conn->state == ConnectionState::kConnected) { + DoSend(conn); + } +} + +void TCPComm::SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote) { + recv_event_loop_->AddTask( + [msg, tcpmgr, remoteLink, isExactNotRemote] { TCPComm::Send(msg, tcpmgr, remoteLink, isExactNotRemote); }); +} + +int TCPComm::Send(MessageBase *msg, bool remoteLink, bool isExactNotRemote) { + return send_event_loop_->AddTask([msg, this, remoteLink, isExactNotRemote] { + std::lock_guard lock(Connection::conn_mutex); + // Search connection by the target address + bool exactNotRemote = is_http_msg_ || isExactNotRemote; + Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(msg->to.Url(), remoteLink, exactNotRemote); + if (conn == nullptr) { + if (remoteLink && (!exactNotRemote)) { + MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str() + << ", from: " << advertise_url_.data() << ", to: " << msg->to.Url().c_str(); + auto *ptr = msg; + delete ptr; + ptr = nullptr; + + return; + } + this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote); + return; + } + + if (conn->state != kConnected && conn->send_message_queue.size() >= SENDMSG_QUEUELEN) { + MS_LOG(WARNING) << "The name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd + << ", to: " << conn->destination.c_str() << ", remote: " << conn->is_remote; + auto *ptr = msg; + delete ptr; + ptr = nullptr; + + return; + } + + if (conn->state == ConnectionState::kClose || conn->state == ConnectionState::kDisconnecting) { + this->SendByRecvLoop(msg, this, remoteLink, exactNotRemote); + return; + } + + if (!conn->is_remote && !exactNotRemote && conn->priority == ConnectionPriority::kPriorityLow) { + Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(msg->to.Url(), true); + if (remoteConn != nullptr && remoteConn->state == ConnectionState::kConnected) { + conn = remoteConn; + } + } + + output_buf_size_ += msg->body.size(); + if (conn->total_send_len == 0) { + conn->FillSendMessage(msg, advertise_url_.data(), is_http_msg_); + } else { + conn->send_message_queue.emplace(msg); + } + + if (conn->state == ConnectionState::kConnected) { + DoSend(conn); + } + }); +} + +void TCPComm::CollectMetrics() { + send_event_loop_->AddTask([this] { + Connection::conn_mutex.lock(); + Connection *maxConn = ConnectionPool::GetConnectionPool()->FindMaxConnection(); + Connection *fastConn = ConnectionPool::GetConnectionPool()->FindFastConnection(); + + if (message_handler_ != nullptr) { + IntTypeMetrics intMetrics; + StringTypeMetrics stringMetrics; + + if (maxConn != nullptr) { + intMetrics.push(maxConn->socket_fd); + intMetrics.push(maxConn->error_code); + intMetrics.push(maxConn->send_metrics->accum_msg_count); + intMetrics.push(maxConn->send_metrics->max_msg_size); + stringMetrics.push(maxConn->destination); + stringMetrics.push(maxConn->send_metrics->last_succ_msg_name); + stringMetrics.push(maxConn->send_metrics->last_fail_msg_name); + } + if (fastConn != nullptr && fastConn->IsSame(maxConn)) { + intMetrics.push(fastConn->socket_fd); + intMetrics.push(fastConn->error_code); + intMetrics.push(fastConn->send_metrics->accum_msg_count); + intMetrics.push(fastConn->send_metrics->max_msg_size); + stringMetrics.push(fastConn->destination); + stringMetrics.push(fastConn->send_metrics->last_succ_msg_name); + stringMetrics.push(fastConn->send_metrics->last_fail_msg_name); + } + } + + ConnectionPool::GetConnectionPool()->ResetAllConnMetrics(); + Connection::conn_mutex.unlock(); + }); +} + +int TCPComm::Send(std::unique_ptr &&msg, bool remoteLink, bool isExactNotRemote) { + return Send(msg.release(), remoteLink, isExactNotRemote); +} + +void TCPComm::Link(const AID &source, const AID &destination) { + recv_event_loop_->AddTask([source, destination, this] { + std::string to = destination.Url(); + std::lock_guard lock(Connection::conn_mutex); + + // Search connection by the target address + Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_); + + if (conn == nullptr) { + MS_LOG(INFO) << "Can not found link source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str(); + conn = new (std::nothrow) Connection(); + if (conn == nullptr) { + MS_LOG(ERROR) << "Failed to create new connection and link fail source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str(); + SendExitMsg(source, destination); + return; + } + conn->source = advertise_url_.data(); + conn->destination = to; + + conn->recv_event_loop = this->recv_event_loop_; + conn->send_event_loop = this->send_event_loop_; + conn->InitSocketOperation(); + + int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack); + if (ret < 0) { + MS_LOG(ERROR) << "Failed to do connect and link fail source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str(); + SendExitMsg(source, destination); + if (conn->socket_operation != nullptr) { + delete conn->socket_operation; + conn->socket_operation = nullptr; + } + delete conn; + return; + } + ConnectionPool::GetConnectionPool()->AddConnection(conn); + } + ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg); + MS_LOG(INFO) << "Link fd: " << conn->socket_fd << ", source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote; + }); +} + +void TCPComm::UnLink(const AID &destination) { + recv_event_loop_->AddTask([destination] { + std::string to = destination.Url(); + std::lock_guard lock(Connection::conn_mutex); + if (is_http_msg_) { + // When application has set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link is in links map + // while accepting-link is differently in remoteLinks map. So we only need to delete link in exact links. + ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false); + } else { + // When application hasn't set 'LITERPC_HTTPKMSG_ENABLED',it means sending-link and accepting-link is + // shared + // So we need to delete link in both links map and remote-links map. + ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, false); + ConnectionPool::GetConnectionPool()->ExactDeleteConnection(to, true); + } + }); +} + +void TCPComm::DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd) { + if (!is_http_msg_ && !conn->is_remote) { + Connection *remoteConn = ConnectionPool::GetConnectionPool()->ExactFindConnection(to, true); + // We will close remote link in rare cases where sending-link and accepting link coexists + // simultaneously. + if (remoteConn != nullptr) { + MS_LOG(INFO) << "Reconnect, close remote connect fd :" << remoteConn->socket_fd + << ", source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str() << ", remote: " << remoteConn->is_remote + << ", state: " << remoteConn->state; + ConnectionPool::GetConnectionPool()->CloseConnection(remoteConn); + } + } + + MS_LOG(INFO) << "Reconnect, close old connect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str() << ", remote: " << conn->is_remote + << ", state: " << conn->state; + + *oldFd = conn->socket_fd; + + conn->recv_event_loop->DeleteEpollEvent(conn->socket_fd); + conn->socket_operation->Close(conn); + + conn->socket_fd = -1; + conn->recv_len = 0; + + conn->total_recv_len = 0; + conn->recv_message_type = kUnknown; + conn->state = kInit; + if (conn->total_send_len != 0 && conn->send_message != nullptr) { + delete conn->send_message; + } + conn->send_message = nullptr; + conn->total_send_len = 0; + + if (conn->total_recv_len != 0 && conn->recv_message != nullptr) { + delete conn->recv_message; + } + conn->recv_message = nullptr; + conn->total_recv_len = 0; + + conn->recv_state = State::kMsgHeader; +} + +Connection *TCPComm::CreateDefaultConn(std::string to) { + Connection *conn = new (std::nothrow) Connection(); + if (conn == nullptr) { + MS_LOG(ERROR) << "Failed to create new connection and reconnect fail to: " << to.c_str(); + return conn; + } + conn->source = advertise_url_.data(); + conn->destination = to; + conn->recv_event_loop = this->recv_event_loop_; + conn->send_event_loop = this->send_event_loop_; + conn->InitSocketOperation(); + return conn; +} + +void TCPComm::Reconnect(const AID &source, const AID &destination) { + send_event_loop_->AddTask([source, destination, this] { + std::string to = destination.Url(); + std::lock_guard lock(Connection::conn_mutex); + Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_); + if (conn != nullptr) { + conn->state = ConnectionState::kClose; + } + + recv_event_loop_->AddTask([source, destination, this] { + std::string to = destination.Url(); + int oldFd = -1; + std::lock_guard lock(Connection::conn_mutex); + Connection *conn = ConnectionPool::GetConnectionPool()->FindConnection(to, false, is_http_msg_); + if (conn != nullptr) { + // connection already exist + DoReConnectConn(conn, to, source, destination, &oldFd); + } else { + // create default connection + conn = CreateDefaultConn(to); + if (conn == nullptr) { + return; + } + } + int ret = DoConnect(to, conn, TCPComm::EventCallBack, TCPComm::WriteCallBack, TCPComm::ReadCallBack); + if (ret < 0) { + if (conn->socket_operation != nullptr) { + delete conn->socket_operation; + conn->socket_operation = nullptr; + } + if (oldFd != -1) { + conn->socket_fd = oldFd; + } + MS_LOG(ERROR) << "Failed to connect and reconnect fail source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str(); + ConnectionPool::GetConnectionPool()->CloseConnection(conn); + return; + } + if (oldFd != -1) { + if (!ConnectionPool::GetConnectionPool()->ReverseConnInfo(oldFd, conn->socket_fd)) { + MS_LOG(ERROR) << "Failed to swap socket for " << oldFd << " and " << conn->socket_fd; + } + } else { + ConnectionPool::GetConnectionPool()->AddConnection(conn); + } + ConnectionPool::GetConnectionPool()->AddConnInfo(conn->socket_fd, source, destination, SendExitMsg); + MS_LOG(INFO) << "Reconnect fd: " << conn->socket_fd << ", source: " << std::string(source).c_str() + << ", destination: " << std::string(destination).c_str(); + }); + }); +} + +void TCPComm::Finalize() { + if (send_event_loop_ != nullptr) { + MS_LOG(INFO) << "Delete send event loop"; + send_event_loop_->Finalize(); + delete send_event_loop_; + send_event_loop_ = nullptr; + } + + if (recv_event_loop_ != nullptr) { + MS_LOG(INFO) << "Delete recv event loop"; + recv_event_loop_->Finalize(); + delete recv_event_loop_; + recv_event_loop_ = nullptr; + } + + if (server_fd_ > 0) { + close(server_fd_); + server_fd_ = -1; + } +} + +// This value is not used for the sub-class TCPComm. +uint64_t TCPComm::GetInBufSize() { return 1; } + +uint64_t TCPComm::GetOutBufSize() { return output_buf_size_; } + +bool TCPComm::IsHttpMsg() { return is_http_msg_; } +} // namespace rpc +} // namespace distributed +} // namespace mindspore diff --git a/mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.h b/mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.h new file mode 100644 index 00000000000..72bcfe34064 --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.h @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_ + +#include +#include +#include + +#include "actor/iomgr.h" +#include "distributed/rpc/tcp/connection.h" +#include "distributed/rpc/tcp/event_loop.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +// Event handler for new connecting request arrived. +void OnAccept(int server, uint32_t events, void *arg); + +// Send messages buffered in the connection. +void DoSend(Connection *conn); + +// Create a server socket and connect to it, this is a local connection.. +int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack eventCallBack, + ConnectionCallBack writeCallBack, ConnectionCallBack readCallBack); + +void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError); + +void ConnectedEventHandler(int fd, uint32_t events, void *context); + +class TCPComm : public IOMgr { + public: + TCPComm() : server_fd_(-1), recv_event_loop_(nullptr), send_event_loop_(nullptr) {} + TCPComm(const TCPComm &) = delete; + TCPComm &operator=(const TCPComm &) = delete; + ~TCPComm(); + + // Init the event loop for reading and writing. + bool Initialize() override; + + // Destroy all the resources. + void Finalize() override; + + // Create the server socket represented by url. + bool StartServerSocket(const std::string &url, const std::string &aAdvertiseUrl) override; + + // Build a connection between the source and destination. + void Link(const AID &source, const AID &destination) override; + void UnLink(const AID &destination) override; + + // Send the message from the source to the destination. + int Send(std::unique_ptr &&msg, bool remoteLink = false, bool isExactNotRemote = false) override; + + uint64_t GetInBufSize() override; + uint64_t GetOutBufSize() override; + void CollectMetrics() override; + + private: + // Build the connection. + Connection *CreateDefaultConn(std::string to); + void Reconnect(const AID &source, const AID &destination); + void DoReConnectConn(Connection *conn, std::string to, const AID &source, const AID &destination, int *oldFd); + + // Send a message. + int Send(MessageBase *msg, bool remoteLink = false, bool isExactNotRemote = false); + static void Send(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote); + void SendByRecvLoop(MessageBase *msg, const TCPComm *tcpmgr, bool remoteLink, bool isExactNotRemote); + static void SendExitMsg(const std::string &from, const std::string &to); + + // Called by ReadCallBack when new message arrived. + static int ReceiveMessage(Connection *conn); + + void SetMessageHandler(IOMgr::MessageHandler handler); + static int SetConnectedHandler(Connection *conn); + + static int Connect(Connection *conn, const struct sockaddr *sa, socklen_t saLen); + + static bool IsHttpMsg(); + + // Read and write events. + static void ReadCallBack(void *context); + static void WriteCallBack(void *context); + // Connected and Disconnected events. + static void EventCallBack(void *context); + + // The server url. + std::string url_; + + // The socket of server. + int server_fd_; + + // The message size waiting to be sent. + static uint64_t output_buf_size_; + + // User defined handler for Handling received messages. + static MessageHandler message_handler_; + + // The source url of a message. + static std::vector advertise_url_; + + static bool is_http_msg_; + + // All the connections share the same read and write event loop objects. + EventLoop *recv_event_loop_; + EventLoop *send_event_loop_; + + friend void OnAccept(int server, uint32_t events, void *arg); + friend void DoSend(Connection *conn); + friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback, + ConnectionCallBack write_callback, ConnectionCallBack read_callback); +}; +} // namespace rpc +} // namespace distributed +} // namespace mindspore + +#endif diff --git a/mindspore/core/mindrt/src/actor/actormgr.cc b/mindspore/core/mindrt/src/actor/actormgr.cc index 5fde13d4b80..b265e13a5a0 100644 --- a/mindspore/core/mindrt/src/actor/actormgr.cc +++ b/mindspore/core/mindrt/src/actor/actormgr.cc @@ -153,7 +153,7 @@ void ActorMgr::Finalize() { // stop iomgr thread for (auto mgrIt = ioMgrs.begin(); mgrIt != ioMgrs.end(); ++mgrIt) { MS_LOG(INFO) << "finalize IOMgr=" << mgrIt->first.c_str(); - mgrIt->second->Finish(); + mgrIt->second->Finalize(); } // delete actor thread pool if use_inner_pool diff --git a/mindspore/core/mindrt/src/actor/iomgr.h b/mindspore/core/mindrt/src/actor/iomgr.h index e0dd8df395d..4276dc3adb6 100644 --- a/mindspore/core/mindrt/src/actor/iomgr.h +++ b/mindspore/core/mindrt/src/actor/iomgr.h @@ -56,7 +56,7 @@ static const char HTTP_CLIENT_EVLOOP_THREADNAME[] = "MINDRT_Htp"; class IOMgr { public: - using MsgHandler = void (*)(std::unique_ptr &&msg); + using MessageHandler = void (*)(std::unique_ptr &&msg); /** * remoteLink and isExactNotRemote are flags to tell us which link should be used. There are several cases: * 1. remoteLink is false and isExactNotRemote is false : callers can reuse remote link when threr are no links @@ -71,16 +71,15 @@ class IOMgr { // close the socket,and send exitedEvent to all linkers. virtual void UnLink(const AID &dAid) = 0; virtual void Reconnect(const AID &sAid, const AID &dAid) = 0; - virtual void RegisterMsgHandle(MsgHandler handle) = 0; - virtual bool Init() = 0; // once - virtual void Finish() = 0; // once - virtual bool StartIOServer(const std::string &url, const std::string &advertiseUrl) = 0; // multicalledable + virtual void SetMessageHandler(MessageHandler handle) = 0; + virtual bool Initialize() = 0; // once + virtual void Finalize() = 0; // once + virtual bool StartServerSocket(const std::string &url, const std::string &advertiseUrl) = 0; // multicalledable virtual uint64_t GetOutBufSize() = 0; virtual uint64_t GetInBufSize() = 0; virtual void CollectMetrics() = 0; - virtual int AddRuleUdp(std::string peer, int recordNum) = 0; - virtual void DelRuleUdp(std::string peer, bool outputLog) = 0; - virtual void LinkRecycleCheck(int recyclePeroid) = 0; + virtual int AddRuleUdp(std::string peer, int recordNum) { return 0; } + virtual void DelRuleUdp(std::string peer, bool outputLog) { return; } IOMgr() {} virtual ~IOMgr() {} }; diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index d00db33ce6c..cf39cd3e511 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -70,6 +70,7 @@ if(ENABLE_MINDDATA) ./ps/*.cc ./fl/*.cc ./distributed/persistent/*.cc + ./distributed/rpc/tcp/*.cc ./cxx_api/*.cc ./tbe/*.cc ./mindapi/*.cc @@ -175,6 +176,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/ps/*.cc" "../../../mindspore/ccsrc/fl/*.cc" "../../../mindspore/ccsrc/distributed/persistent/*.cc" + "../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc" "../../../mindspore/ccsrc/profiler/device/ascend/*.cc" "../../../mindspore/ccsrc/profiler/device/profiling.cc" "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c" diff --git a/tests/ut/cpp/distributed/rpc/tcp/tcp_test.cc b/tests/ut/cpp/distributed/rpc/tcp/tcp_test.cc new file mode 100644 index 00000000000..0549dc066e8 --- /dev/null +++ b/tests/ut/cpp/distributed/rpc/tcp/tcp_test.cc @@ -0,0 +1,282 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#define private public +#include "actor/iomgr.h" +#include "async/async.h" +#include "distributed/rpc/tcp/tcp_comm.h" +#include "common/common_test.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +int g_recv_num = 0; +int g_exit_msg_num = 0; + +TCPComm *m_io = nullptr; +std::atomic m_sendNum(0); +std::string m_localIP = "127.0.0.1"; +bool m_notRemote = false; + +void msgHandle(std::unique_ptr &&msg) { + if (msg->GetType() == MessageBase::Type::KEXIT) { + g_exit_msg_num++; + } else { + g_recv_num++; + } +} + +class TCPTest : public UT::Common { + public: + static void SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink = false, + std::string body = ""); + + protected: + char *args[4]; + char *testServerPath; + static const size_t pid_num = 100; + pid_t pid1; + pid_t pid2; + + pid_t pids[pid_num]; + + void SetUp() { + char *localpEnv = getenv("LITEBUS_IP"); + if (localpEnv != nullptr) { + m_localIP = std::string(localpEnv); + } + + char *locaNotRemoteEnv = getenv("LITEBUS_SEND_ON_REMOTE"); + if (locaNotRemoteEnv != nullptr) { + m_notRemote = (std::string(locaNotRemoteEnv) == "true") ? true : false; + } + + pid1 = 0; + pid2 = 0; + pids[pid_num] = {0}; + size_t size = pid_num * sizeof(pid_t); + if (memset_s(&pids, size, 0, size)) { + MS_LOG(ERROR) << "Failed to init pid array"; + } + g_recv_num = 0; + g_exit_msg_num = 0; + m_sendNum = 0; + + m_io = new TCPComm(); + m_io->Initialize(); + m_io->SetMessageHandler(msgHandle); + m_io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225"); + } + + void TearDown() { + shutdownTcpServer(pid1); + shutdownTcpServer(pid2); + pid1 = 0; + pid2 = 0; + int i = 0; + for (i = 0; i < pid_num; i++) { + shutdownTcpServer(pids[i]); + pids[i] = 0; + } + g_recv_num = 0; + g_exit_msg_num = 0; + m_sendNum = 0; + m_io->Finalize(); + delete m_io; + m_io = nullptr; + } + + bool CheckRecvNum(int expectedRecvNum, int _timeout); + bool CheckExitNum(int expectedExitNum, int _timeout); + pid_t startTcpServer(char **args); + void shutdownTcpServer(pid_t pid); + void KillTcpServer(pid_t pid); + + void Link(std::string &_localUrl, std::string &_remoteUrl); + void Reconnect(std::string &_localUrl, std::string &_remoteUrl); + void Unlink(std::string &_remoteUrl); +}; + +// listening local url and sending msg to remote url,if start succ. +pid_t TCPTest::startTcpServer(char **args) { + pid_t pid = fork(); + if (pid == 0) { + return -1; + } else { + return pid; + } +} +void TCPTest::shutdownTcpServer(pid_t pid) { + if (pid > 1) { + kill(pid, SIGALRM); + int status; + waitpid(pid, &status, 0); + } +} + +void TCPTest::KillTcpServer(pid_t pid) { + if (pid > 1) { + kill(pid, SIGKILL); + int status; + waitpid(pid, &status, 0); + } +} + +void TCPTest::SendMsg(std::string &_localUrl, std::string &_remoteUrl, int msgsize, bool remoteLink, std::string body) { + AID from("testserver", _localUrl); + AID to("testserver", _remoteUrl); + + std::unique_ptr message = std::make_unique(); + std::string data(msgsize, 'A'); + message->name = "testname"; + message->from = from; + message->to = to; + message->body = data; + if (body != "") { + message->body = body; + } + + if (m_notRemote) { + m_io->Send(std::move(message), remoteLink, true); + } else { + m_io->Send(std::move(message), remoteLink); + } +} + +void TCPTest::Link(std::string &_localUrl, std::string &_remoteUrl) { + AID from("testserver", _localUrl); + AID to("testserver", _remoteUrl); + m_io->Link(from, to); +} + +void TCPTest::Reconnect(std::string &_localUrl, std::string &_remoteUrl) { + AID from("testserver", _localUrl); + AID to("testserver", _remoteUrl); + m_io->Reconnect(from, to); +} + +void TCPTest::Unlink(std::string &_remoteUrl) { + AID to("testserver", _remoteUrl); + m_io->UnLink(to); +} + +bool TCPTest::CheckRecvNum(int expectedRecvNum, int _timeout) { + int timeout = _timeout * 1000 * 1000; // us + int usleepCount = 100000; + + while (timeout) { + usleep(usleepCount); + if (g_recv_num >= expectedRecvNum) { + return true; + } + timeout = timeout - usleepCount; + } + return false; +} + +bool TCPTest::CheckExitNum(int expectedExitNum, int _timeout) { + int timeout = _timeout * 1000 * 1000; + int usleepCount = 100000; + + while (timeout) { + usleep(usleepCount); + if (g_exit_msg_num >= expectedExitNum) { + return true; + } + timeout = timeout - usleepCount; + } + + return false; +} + +/// Feature: test failed to start a socket server. +/// Description: start a socket server with an invalid url. +/// Expectation: failed to start the server with invalid url. +TEST_F(TCPTest, StartServerFail) { + std::unique_ptr io = std::make_unique(); + io->Initialize(); + + bool ret = io->StartServerSocket("tcp://0:2225", "tcp://0:2225"); + ASSERT_FALSE(ret); + io->Finalize(); +} + +/// Feature: test start a socket server. +/// Description: start the socket server with a specified socket. +/// Expectation: the socket server is started successfully. +TEST_F(TCPTest, StartServer2) { + std::unique_ptr io = std::make_unique(); + io->Initialize(); + io->SetMessageHandler(msgHandle); + bool ret = io->StartServerSocket("tcp://" + m_localIP + ":2225", "tcp://" + m_localIP + ":2225"); + ASSERT_FALSE(ret); + ret = io->StartServerSocket("tcp://" + m_localIP + ":2224", "tcp://" + m_localIP + ":2224"); + io->Finalize(); + ASSERT_TRUE(ret); +} + +/// Feature: test normal tcp message sending. +/// Description: start a socket server and send a normal message to it. +/// Expectation: the server received the message sented from client. +TEST_F(TCPTest, send1Msg) { + g_recv_num = 0; + pid1 = startTcpServer(args); + bool ret = CheckRecvNum(1, 5); + ASSERT_FALSE(ret); + + std::string from = "tcp://" + m_localIP + ":2223"; + std::string to = "tcp://" + m_localIP + ":2225"; + SendMsg(from, to, pid_num); + + ret = CheckRecvNum(1, 5); + ASSERT_TRUE(ret); + + Unlink(to); + shutdownTcpServer(pid1); + pid1 = 0; +} + +/// Feature: test max message body check.. +/// Description: send a message with body exceeding the limit of max message body. +/// Expectation: drop the invalid message. +TEST_F(TCPTest, sendInvalidMsg) { + g_recv_num = 0; + pid1 = startTcpServer(args); + bool ret = CheckRecvNum(1, 5); + ASSERT_FALSE(ret); + + std::string from = "tcp://" + m_localIP + ":2223"; + std::string to = "tcp://" + m_localIP + ":2225"; + SendMsg(from, to, 1024 * 1024 * pid_num + 1); + ret = CheckRecvNum(1, 5); + ASSERT_FALSE(ret); + + Unlink(to); + shutdownTcpServer(pid1); + pid1 = 0; +} +} // namespace rpc +} // namespace distributed +} // namespace mindspore