add random port support for tcp server

This commit is contained in:
Parallels 2022-03-02 10:24:22 +08:00
parent bd0d719fdf
commit 5e09380eb3
7 changed files with 110 additions and 15 deletions

View File

@ -16,6 +16,9 @@
#include "distributed/rpc/tcp/socket_operation.h"
#include <sys/ioctl.h>
#include <net/if.h>
#include <ifaddrs.h>
#include <arpa/inet.h>
#include <securec.h>
#include <netinet/tcp.h>
@ -111,6 +114,38 @@ int SocketOperation::CreateSocket(sa_family_t family) {
return fd;
}
std::string SocketOperation::GetLocalIP() {
// Lookup all the network interfaces on the local machine.
struct ifaddrs *if_addrs;
if (getifaddrs(&if_addrs) != 0) {
MS_LOG(ERROR) << "Failed to lookup local network interfaces.";
freeifaddrs(if_addrs);
return "";
}
// Find the first physical network interface.
struct ifaddrs *if_addr = if_addrs;
MS_EXCEPTION_IF_NULL(if_addr);
while (if_addr != nullptr) {
if (if_addr->ifa_addr == nullptr) continue;
if (if_addr->ifa_addr->sa_family == AF_INET && !(if_addr->ifa_flags & IFF_LOOPBACK)) {
auto sock_addr = reinterpret_cast<struct sockaddr_in *>(if_addr->ifa_addr);
MS_EXCEPTION_IF_NULL(sock_addr);
auto ip_addr = inet_ntoa(sock_addr->sin_addr);
MS_EXCEPTION_IF_NULL(ip_addr);
std::string ip(ip_addr, ip_addr + strlen(ip_addr));
freeifaddrs(if_addrs);
return ip;
} else {
if_addr = if_addr->ifa_next;
}
}
freeifaddrs(if_addrs);
return "";
}
std::string SocketOperation::GetIP(const std::string &url) {
size_t index1 = url.find("[");
if (index1 == std::string::npos) {

View File

@ -39,6 +39,9 @@ class SocketOperation {
SocketOperation() = default;
virtual ~SocketOperation() {}
// Lookup the local IP address of the first available network interface.
static std::string GetLocalIP();
static std::string GetIP(const std::string &url);
static uint16_t GetPort(int sock_fd);

View File

@ -234,6 +234,15 @@ bool TCPComm::StartServerSocket(const std::string &url) {
return true;
}
bool TCPComm::StartServerSocket() {
auto ip = SocketOperation::GetLocalIP();
// The port 0 means that the port will be allocated randomly by the os system.
auto url = ip + ":0";
return StartServerSocket(url);
}
int TCPComm::GetServerFd() { return server_fd_; }
void TCPComm::ReadCallBack(void *connection) {
const int max_recv_count = 3;
Connection *conn = reinterpret_cast<Connection *>(connection);

View File

@ -56,6 +56,9 @@ class TCPComm {
// Create the server socket represented by url.
bool StartServerSocket(const std::string &url);
// Create the server socket with local IP and random port.
bool StartServerSocket();
// Connection operation for a specified destination.
void Connect(const std::string &dst_url);
bool IsConnected(const std::string &dst_url);
@ -67,6 +70,9 @@ class TCPComm {
// Set the message processing handler.
void SetMessageHandler(MessageHandler handler);
// Get the file descriptor of server socket.
int GetServerFd();
private:
// Build the connection.
Connection *CreateDefaultConn(std::string to);

View File

@ -19,20 +19,9 @@
namespace mindspore {
namespace distributed {
namespace rpc {
bool TCPServer::Initialize(const std::string &url) {
if (tcp_comm_ == nullptr) {
tcp_comm_ = std::make_unique<TCPComm>();
MS_EXCEPTION_IF_NULL(tcp_comm_);
bool rt = tcp_comm_->Initialize();
if (!rt) {
MS_LOG(EXCEPTION) << "Failed to initialize tcp comm";
}
rt = tcp_comm_->StartServerSocket(url);
return rt;
} else {
return true;
}
}
bool TCPServer::Initialize(const std::string &url) { return InitializeImpl(url); }
bool TCPServer::Initialize() { return InitializeImpl(""); }
void TCPServer::Finalize() {
if (tcp_comm_ != nullptr) {
@ -42,6 +31,34 @@ void TCPServer::Finalize() {
}
void TCPServer::SetMessageHandler(MessageHandler handler) { tcp_comm_->SetMessageHandler(handler); }
std::string TCPServer::GetIP() { return ip_; }
uint32_t TCPServer::GetPort() { return port_; }
bool TCPServer::InitializeImpl(const std::string &url) {
if (tcp_comm_ == nullptr) {
tcp_comm_ = std::make_unique<TCPComm>();
MS_EXCEPTION_IF_NULL(tcp_comm_);
bool rt = tcp_comm_->Initialize();
if (!rt) {
MS_LOG(EXCEPTION) << "Failed to initialize tcp comm";
}
if (url != "") {
rt = tcp_comm_->StartServerSocket(url);
ip_ = SocketOperation::GetIP(url);
} else {
rt = tcp_comm_->StartServerSocket();
ip_ = SocketOperation::GetLocalIP();
}
auto server_fd = tcp_comm_->GetServerFd();
port_ = SocketOperation::GetPort(server_fd);
return rt;
} else {
return true;
}
}
} // namespace rpc
} // namespace distributed
} // namespace mindspore

View File

@ -31,19 +31,31 @@ class TCPServer {
TCPServer() = default;
~TCPServer() = default;
// Init the tcp server.
// Init the tcp server using the specified url.
bool Initialize(const std::string &url);
// Init the tcp server using local IP and random port.
bool Initialize();
// Destroy the tcp server.
void Finalize();
// Set the message processing handler.
void SetMessageHandler(MessageHandler handler);
// Return the IP and port binded by this server.
std::string GetIP();
uint32_t GetPort();
private:
bool InitializeImpl(const std::string &url);
// The basic TCP communication component used by the server.
std::unique_ptr<TCPComm> tcp_comm_;
std::string ip_{""};
uint32_t port_{0};
DISABLE_COPY_AND_ASSIGN(TCPServer);
};
} // namespace rpc

View File

@ -217,6 +217,19 @@ TEST_F(TCPTest, sendTwoMessages) {
client->Finalize();
server->Finalize();
}
/// Feature: test start the tcp server with random port.
/// Description: start a socket server without specified fixed port.
/// Expectation: the server started successfully.
TEST_F(TCPTest, StartServerWithRandomPort) {
std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>();
bool ret = server->Initialize();
ASSERT_TRUE(ret);
auto port = server->GetPort();
EXPECT_LT(0, port);
server->Finalize();
}
} // namespace rpc
} // namespace distributed
} // namespace mindspore