!7518 Added tcp server based on libevent

Merge pull request !7518 from anancds/tcp-server
This commit is contained in:
mindspore-ci-bot 2020-10-23 09:20:21 +08:00 committed by Gitee
commit 80725441b6
15 changed files with 1026 additions and 62 deletions

View File

@ -7,6 +7,10 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "util.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc")
endif()
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)

View File

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/comm/comm_util.h"
#include <arpa/inet.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <regex>
namespace mindspore {
namespace ps {
namespace comm {
bool CommUtil::CheckIpWithRegex(const std::string &ip) {
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
std::smatch res;
if (regex_match(ip, res, pattern)) {
return true;
}
return false;
}
void CommUtil::CheckIp(const std::string &ip) {
if (!CheckIpWithRegex(ip)) {
MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!";
}
int64_t uAddr = inet_addr(ip.c_str());
if (INADDR_NONE == uAddr) {
MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!";
}
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
#define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
#include <event2/buffer.h>
#include <event2/event.h>
#include <event2/http.h>
#include <event2/keyvalq_struct.h>
#include <event2/listener.h>
#include <event2/util.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <string>
#include <utility>
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace comm {
class CommUtil {
public:
static bool CheckIpWithRegex(const std::string &ip);
static void CheckIp(const std::string &ip);
};
} // namespace comm
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_

View File

@ -38,7 +38,7 @@ namespace mindspore {
namespace ps {
namespace comm {
typedef std::map<std::string, std::list<std::string>> HttpHeaders;
using HttpHeaders = std::map<std::string, std::list<std::string>>;
class HttpMessageHandler {
public:

View File

@ -16,6 +16,7 @@
#include "ps/comm/http_server.h"
#include "ps/comm/http_message_handler.h"
#include "ps/comm/comm_util.h"
#ifdef WIN32
#include <WinSock2.h>
@ -41,28 +42,10 @@ namespace mindspore {
namespace ps {
namespace comm {
HttpServer::~HttpServer() {
if (event_http_) {
evhttp_free(event_http_);
event_http_ = nullptr;
}
if (event_base_) {
event_base_free(event_base_);
event_base_ = nullptr;
}
}
HttpServer::~HttpServer() { Stop(); }
bool HttpServer::InitServer() {
if (!CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "Server address" << server_address_ << " illegal!";
}
int64_t uAddr = inet_addr(server_address_.c_str());
if (INADDR_NONE == uAddr) {
MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!";
}
if (server_port_ <= 0) {
MS_LOG(EXCEPTION) << "Server port:" << server_port_ << " illegal!";
}
CommUtil::CheckIp(server_address_);
event_base_ = event_base_new();
MS_EXCEPTION_IF_NULL(event_base_);
@ -76,15 +59,6 @@ bool HttpServer::InitServer() {
return true;
}
bool HttpServer::CheckIp(const std::string &ip) {
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
std::smatch res;
if (regex_match(ip, res, pattern)) {
return true;
}
return false;
}
void HttpServer::SetTimeOut(int seconds) {
MS_EXCEPTION_IF_NULL(event_http_);
if (seconds < 0) {
@ -93,7 +67,7 @@ void HttpServer::SetTimeOut(int seconds) {
evhttp_set_timeout(event_http_, seconds);
}
void HttpServer::SetAllowedMethod(HttpMethodsSet methods) {
void HttpServer::SetAllowedMethod(u_int16_t methods) {
MS_EXCEPTION_IF_NULL(event_http_);
evhttp_set_allowed_methods(event_http_, methods);
}
@ -114,12 +88,11 @@ void HttpServer::SetMaxBodySize(size_t num) {
evhttp_set_max_body_size(event_http_, num);
}
bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) {
bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *function) {
if ((!is_init_) && (!InitServer())) {
MS_LOG(EXCEPTION) << "Init http server failed!";
}
HandlerFunc func = function;
if (!func) {
if (!function) {
return false;
}
@ -128,15 +101,13 @@ bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) {
MS_EXCEPTION_IF_NULL(arg);
HttpMessageHandler httpReq(req);
httpReq.InitHttpMessage();
handle_t *f = reinterpret_cast<handle_t *>(arg);
f(&httpReq);
OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg);
(*func)(&httpReq);
};
handle_t **pph = func.target<handle_t *>();
MS_EXCEPTION_IF_NULL(pph);
MS_EXCEPTION_IF_NULL(event_http_);
// O SUCCESS,-1 ALREADY_EXIST,-2 FAILURE
int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast<void *>(*pph));
int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast<void *>(function));
if (ret == 0) {
MS_LOG(INFO) << "Ev http register handle of:" << url.c_str() << " success.";
} else if (ret == -1) {

View File

@ -48,26 +48,21 @@ typedef enum eHttpMethod {
HM_PATCH = 1 << 8
} HttpMethod;
typedef u_int16_t HttpMethodsSet;
typedef void(handle_t)(HttpMessageHandler *);
class HttpServer {
public:
// Server address only support IPV4 now, and should be in format of "x.x.x.x"
explicit HttpServer(const std::string &address, std::int16_t port)
explicit HttpServer(const std::string &address, std::uint16_t port)
: server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {}
~HttpServer();
typedef std::function<handle_t> HandlerFunc;
using OnRequestReceive = std::function<void(HttpMessageHandler *)>;
bool InitServer();
static bool CheckIp(const std::string &ip);
void SetTimeOut(int seconds = 5);
// Default allowed methods: GET, POST, HEAD, PUT, DELETE
void SetAllowedMethod(HttpMethodsSet methods);
void SetAllowedMethod(u_int16_t methods);
// Default to ((((unsigned long long)0xffffffffUL) << 32) | 0xffffffffUL)
void SetMaxHeaderSize(std::size_t num);
@ -76,7 +71,7 @@ class HttpServer {
void SetMaxBodySize(std::size_t num);
// Return: true if success, false if failed, check log to find failure reason
bool RegisterRoute(const std::string &url, handle_t *func);
bool RegisterRoute(const std::string &url, OnRequestReceive *func);
bool UnRegisterRoute(const std::string &url);
bool Start();
@ -84,7 +79,7 @@ class HttpServer {
private:
std::string server_address_;
std::int16_t server_port_;
std::uint16_t server_port_;
struct event_base *event_base_;
struct evhttp *event_http_;
bool is_init_;

View File

@ -0,0 +1,220 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/comm/tcp_client.h"
#include <arpa/inet.h>
#include <event2/buffer.h>
#include <event2/bufferevent.h>
#include <event2/buffer_compat.h>
#include <event2/event.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <utility>
#include <string>
#include "ps/comm/comm_util.h"
namespace mindspore {
namespace ps {
namespace comm {
TcpClient::TcpClient(std::string address, std::uint16_t port)
: event_base_(nullptr),
event_timeout_(nullptr),
buffer_event_(nullptr),
server_address_(std::move(address)),
server_port_(port) {
message_handler_.SetCallback([this](const void *buf, size_t num) {
if (buf == nullptr) {
if (disconnected_callback_) disconnected_callback_(*this, 200);
Stop();
}
if (message_callback_) message_callback_(*this, buf, num);
});
}
TcpClient::~TcpClient() { Stop(); }
std::string TcpClient::GetServerAddress() const { return server_address_; }
void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
const OnTimeout &timeout) {
connected_callback_ = conn;
disconnected_callback_ = disconn;
read_callback_ = read;
timeout_callback_ = timeout;
}
void TcpClient::InitTcpClient() {
if (buffer_event_) {
return;
}
CommUtil::CheckIp(server_address_);
event_base_ = event_base_new();
MS_EXCEPTION_IF_NULL(event_base_);
sockaddr_in sin{};
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
}
sin.sin_family = AF_INET;
sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
sin.sin_port = htons(server_port_);
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE);
MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
}
int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
if (result_code < 0) {
MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
}
}
void TcpClient::StartWithDelay(int seconds) {
if (buffer_event_) {
return;
}
event_base_ = event_base_new();
timeval timeout_value{};
timeout_value.tv_sec = seconds;
timeout_value.tv_usec = 0;
event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
if (evtimer_add(event_timeout_, &timeout_value) == -1) {
MS_LOG(EXCEPTION) << "event timeout failed!";
}
}
void TcpClient::Stop() {
if (buffer_event_) {
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}
if (event_timeout_) {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
if (event_base_) {
event_base_free(event_base_);
event_base_ = nullptr;
}
}
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
const int one = 1;
int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
if (ret < 0) {
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
}
}
void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
tcp_client->InitTcpClient();
}
void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(ctx);
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev));
MS_EXCEPTION_IF_NULL(input);
char read_buffer[4096];
int read = 0;
while ((read = EVBUFFER_LENGTH(input)) > 0) {
if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
}
tcp_client->OnReadHandler(read_buffer, read);
}
}
void TcpClient::OnReadHandler(const void *buf, size_t num) {
MS_EXCEPTION_IF_NULL(buf);
if (read_callback_) {
read_callback_(*this, buf, num);
}
message_handler_.ReceiveMessage(buf, num);
}
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(ptr);
auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
if (events & BEV_EVENT_CONNECTED) {
// Connected
if (tcp_client->connected_callback_) {
tcp_client->connected_callback_(*tcp_client);
}
evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev));
SetTcpNoDelay(fd);
MS_LOG(INFO) << "Client connected!";
} else if (events & BEV_EVENT_ERROR) {
MS_LOG(ERROR) << "Client connected error!";
if (tcp_client->disconnected_callback_) {
tcp_client->disconnected_callback_(*tcp_client, errno);
}
} else if (events & BEV_EVENT_EOF) {
MS_LOG(ERROR) << "Client connected end of file";
if (tcp_client->disconnected_callback_) {
tcp_client->disconnected_callback_(*tcp_client, 0);
}
}
}
void TcpClient::Start() {
MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_dispatch(event_base_);
if (ret == 0) {
MS_LOG(INFO) << "Event base dispatch success!";
} else if (ret == 1) {
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
} else if (ret == -1) {
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
} else {
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
}
}
void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; }
void TcpClient::SendMessage(const void *buf, size_t num) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) {
MS_LOG(EXCEPTION) << "event buffer add failed!";
}
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
#define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
#include "ps/comm/tcp_message_handler.h"
#include <event2/event.h>
#include <event2/bufferevent.h>
#include <functional>
#include <string>
namespace mindspore {
namespace ps {
namespace comm {
class TcpClient {
public:
using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>;
using OnConnected = std::function<void(const TcpClient &)>;
using OnDisconnected = std::function<void(const TcpClient &, int)>;
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>;
explicit TcpClient(std::string address, std::uint16_t port);
virtual ~TcpClient();
std::string GetServerAddress() const;
void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
const OnTimeout &timeout);
void InitTcpClient();
void StartWithDelay(int seconds);
void Stop();
void ReceiveMessage(const OnMessage &cb);
void SendMessage(const void *buf, size_t num) const;
void Start();
protected:
static void SetTcpNoDelay(const evutil_socket_t &fd);
static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg);
static void ReadCallback(struct bufferevent *bev, void *ctx);
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
virtual void OnReadHandler(const void *buf, size_t num);
private:
TcpMessageHandler message_handler_;
OnMessage message_callback_;
OnConnected connected_callback_;
OnDisconnected disconnected_callback_;
OnRead read_callback_;
OnTimeout timeout_callback_;
event_base *event_base_;
event *event_timeout_;
bufferevent *buffer_event_;
std::string server_address_;
std::uint16_t server_port_;
};
} // namespace comm
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/comm/tcp_message_handler.h"
#include <iostream>
#include <utility>
namespace mindspore {
namespace ps {
namespace comm {
void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); }
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
MS_EXCEPTION_IF_NULL(buffer);
if (message_callback_) {
message_callback_(buffer, num);
}
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
#define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
#include <functional>
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace comm {
using messageReceive = std::function<void(const void *buffer, size_t len)>;
class TcpMessageHandler {
public:
TcpMessageHandler() = default;
virtual ~TcpMessageHandler() = default;
void SetCallback(messageReceive cb);
void ReceiveMessage(const void *buffer, size_t num);
private:
messageReceive message_callback_;
};
} // namespace comm
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_

View File

@ -0,0 +1,259 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/comm/tcp_server.h"
#include <arpa/inet.h>
#include <event2/buffer.h>
#include <event2/bufferevent.h>
#include <event2/event.h>
#include <event2/listener.h>
#include <event2/buffer_compat.h>
#include <event2/util.h>
#include <sys/socket.h>
#include <csignal>
#include <utility>
#include "ps/comm/comm_util.h"
namespace mindspore {
namespace ps {
namespace comm {
void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(server);
buffer_event_ = const_cast<struct bufferevent *>(bev);
fd_ = fd;
server_ = const_cast<TcpServer *>(server);
tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) {
OnServerReceiveMessage message_callback = server->GetServerReceiveMessage();
if (message_callback) message_callback(*server, *this, buf, num);
});
}
void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); }
void TcpConnection::SendMessage(const void *buffer, size_t num) const {
if (bufferevent_write(buffer_event_, buffer, num) == -1) {
MS_LOG(ERROR) << "Write message to buffer event failed!";
}
}
TcpServer *TcpConnection::GetServer() const { return server_; }
evutil_socket_t TcpConnection::GetFd() const { return fd_; }
TcpServer::TcpServer(std::string address, std::uint16_t port)
: base_(nullptr),
signal_event_(nullptr),
listener_(nullptr),
server_address_(std::move(address)),
server_port_(port) {}
TcpServer::~TcpServer() { Stop(); }
void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
const OnAccepted &client_accept) {
this->client_connection_ = client_conn;
this->client_disconnection_ = client_disconn;
this->client_accept_ = client_accept;
}
void TcpServer::InitServer() {
base_ = event_base_new();
MS_EXCEPTION_IF_NULL(base_);
CommUtil::CheckIp(server_address_);
struct sockaddr_in sin {};
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
}
sin.sin_family = AF_INET;
sin.sin_port = htons(server_port_);
sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
MS_EXCEPTION_IF_NULL(listener_);
signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
MS_EXCEPTION_IF_NULL(signal_event_);
if (event_add(signal_event_, nullptr) < 0) {
MS_LOG(EXCEPTION) << "Cannot create signal event.";
}
}
void TcpServer::Start() {
std::unique_lock<std::recursive_mutex> l(connection_mutex_);
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_dispatch(base_);
if (ret == 0) {
MS_LOG(INFO) << "Event base dispatch success!";
} else if (ret == 1) {
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
} else if (ret == -1) {
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
} else {
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
}
}
void TcpServer::Stop() {
if (signal_event_ != nullptr) {
event_free(signal_event_);
signal_event_ = nullptr;
}
if (listener_ != nullptr) {
evconnlistener_free(listener_);
listener_ = nullptr;
}
if (base_ != nullptr) {
event_base_free(base_);
base_ = nullptr;
}
}
void TcpServer::SendToAllClients(const char *data, size_t len) {
MS_EXCEPTION_IF_NULL(data);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
it->second->SendMessage(data, len);
}
}
void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) {
MS_EXCEPTION_IF_NULL(connection);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
connections_.insert(std::make_pair(fd, connection));
}
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
connections_.erase(fd);
}
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) {
auto server = reinterpret_cast<class TcpServer *>(data);
auto base = reinterpret_cast<struct event_base *>(server->base_);
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(base);
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE);
if (!bev) {
MS_LOG(ERROR) << "Error constructing buffer event!";
event_base_loopbreak(base);
return;
}
TcpConnection *conn = server->onCreateConnection();
MS_EXCEPTION_IF_NULL(conn);
conn->InitConnection(fd, bev, server);
server->AddConnection(fd, conn);
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
}
}
TcpConnection *TcpServer::onCreateConnection() {
TcpConnection *conn = nullptr;
if (client_accept_)
conn = const_cast<TcpConnection *>(client_accept_(this));
else
conn = new TcpConnection();
return conn;
}
OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; }
void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
auto server = reinterpret_cast<class TcpServer *>(data);
MS_EXCEPTION_IF_NULL(server);
struct event_base *base = server->base_;
struct timeval delay = {0, 0};
MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
if (event_base_loopexit(base, &delay) == -1) {
MS_LOG(EXCEPTION) << "event base loop exit failed.";
}
}
void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(connection);
auto conn = static_cast<class TcpConnection *>(connection);
struct evbuffer *buf = bufferevent_get_input(bev);
char read_buffer[4096];
auto read = 0;
while ((read = EVBUFFER_LENGTH(buf)) > 0) {
if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
}
conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
}
}
void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(data);
auto conn = reinterpret_cast<TcpConnection *>(data);
TcpServer *srv = conn->GetServer();
if (events & BEV_EVENT_EOF) {
// Notify about disconnection
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
// Free connection structures
srv->RemoveConnection(conn->GetFd());
bufferevent_free(bev);
} else if (events & BEV_EVENT_ERROR) {
// Free connection structures
srv->RemoveConnection(conn->GetFd());
bufferevent_free(bev);
// Notify about disconnection
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
} else {
MS_LOG(ERROR) << "unhandled event!";
}
}
void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) {
MS_EXCEPTION_IF_NULL(data);
auto mc = const_cast<TcpConnection &>(conn);
mc.SendMessage(data, num);
}
void TcpServer::SendMessage(const void *data, size_t num) {
MS_EXCEPTION_IF_NULL(data);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(*it->second, data, num);
}
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,107 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
#define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
#include <event2/buffer.h>
#include <event2/bufferevent.h>
#include <event2/event.h>
#include <event2/listener.h>
#include <exception>
#include <functional>
#include <iostream>
#include <map>
#include <mutex>
#include <string>
#include "utils/log_adapter.h"
#include "ps/comm/tcp_message_handler.h"
namespace mindspore {
namespace ps {
namespace comm {
class TcpServer;
class TcpConnection {
public:
TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {}
virtual ~TcpConnection() = default;
virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server);
void SendMessage(const void *buffer, size_t num) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const;
evutil_socket_t GetFd() const;
protected:
TcpMessageHandler tcp_message_handler_;
struct bufferevent *buffer_event_;
evutil_socket_t fd_;
TcpServer *server_;
};
using OnServerReceiveMessage =
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>;
class TcpServer {
public:
using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>;
using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>;
using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>;
explicit TcpServer(std::string address, std::uint16_t port);
virtual ~TcpServer();
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
const OnAccepted &client_accept);
void InitServer();
void Start();
void Stop();
void SendToAllClients(const char *data, size_t len);
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
void RemoveConnection(const evutil_socket_t &fd);
void ReceiveMessage(const OnServerReceiveMessage &cb);
static void SendMessage(const TcpConnection &conn, const void *data, size_t num);
void SendMessage(const void *data, size_t num);
OnServerReceiveMessage GetServerReceiveMessage() const;
protected:
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
int socklen, void *server);
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
static void ReadCallback(struct bufferevent *, void *connection);
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
virtual TcpConnection *onCreateConnection();
private:
struct event_base *base_;
struct event *signal_event_;
struct evconnlistener *listener_;
std::string server_address_;
std::uint16_t server_port_;
std::map<evutil_socket_t, const TcpConnection *> connections_;
OnConnected client_connection_;
OnDisconnected client_disconnection_;
OnAccepted client_accept_;
std::recursive_mutex connection_mutex_;
OnServerReceiveMessage message_callback_;
};
} // namespace comm
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_

View File

@ -31,7 +31,7 @@ namespace comm {
class TestHttpServer : public UT::Common {
public:
TestHttpServer() {}
TestHttpServer() = default;
static void testGetHandler(HttpMessageHandler *resp) {
std::string host = resp->GetRequestHost();
@ -58,16 +58,44 @@ class TestHttpServer : public UT::Common {
void SetUp() override {
server_ = new HttpServer("0.0.0.0", 9999);
server_->RegisterRoute("/httpget", [](HttpMessageHandler *resp) {
EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1");
EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1");
EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1");
EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget");
resp->QuickResponse(200, "get request success!\n");
});
server_->RegisterRoute("/handler", TestHttpServer::testGetHandler);
std::function<void(HttpMessageHandler *)> http_get_func = std::bind(
[](HttpMessageHandler *resp) {
EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1");
EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1");
EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1");
EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget");
resp->QuickResponse(200, "get request success!\n");
},
std::placeholders::_1);
std::function<void(HttpMessageHandler *)> http_handler_func = std::bind(
[](HttpMessageHandler *resp) {
std::string host = resp->GetRequestHost();
EXPECT_STREQ(host.c_str(), "127.0.0.1");
std::string path_param = resp->GetPathParam("key1");
std::string header_param = resp->GetHeadParam("headerKey");
std::string post_param = resp->GetPostParam("postKey");
std::string post_message = resp->GetPostMsg();
EXPECT_STREQ(path_param.c_str(), "value1");
EXPECT_STREQ(header_param.c_str(), "headerValue");
EXPECT_STREQ(post_param.c_str(), "postValue");
EXPECT_STREQ(post_message.c_str(), "postKey=postValue");
const std::string rKey("headKey");
const std::string rVal("headValue");
const std::string rBody("post request success!\n");
resp->AddRespHeadParam(rKey, rVal);
resp->AddRespString(rBody);
resp->SetRespCode(200);
resp->SendResponse();
},
std::placeholders::_1);
server_->RegisterRoute("/httpget", &http_get_func);
server_->RegisterRoute("/handler", &http_handler_func);
std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_.reset(new std::thread([&]() { server_->Start(); }));
http_server_thread_ = std::make_unique<std::thread>([&]() { server_->Start(); });
http_server_thread_->detach();
}
@ -110,14 +138,18 @@ TEST_F(TestHttpServer, messageHandler) {
pclose(file);
}
TEST_F(TestHttpServer, portException) {
TEST_F(TestHttpServer, portErrorNoException) {
HttpServer *server_exception = new HttpServer("0.0.0.0", -1);
ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception);
std::function<void(HttpMessageHandler *)> http_handler_func =
std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func));
}
TEST_F(TestHttpServer, addressException) {
HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998);
ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception);
std::function<void(HttpMessageHandler *)> http_handler_func =
std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception);
}
} // namespace comm

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "ps/comm/tcp_client.h"
namespace mindspore {
namespace ps {
namespace comm {
class TestTcpClient : public UT::Common {
public:
TestTcpClient() = default;
};
TEST_F(TestTcpClient, InitClientIPError) {
auto client = new TcpClient("127.0.0.13543", 9000);
client->ReceiveMessage(
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
ASSERT_THROW(client->InitTcpClient(), std::exception);
}
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
auto client = new TcpClient("127.0.0.1", -1);
client->ReceiveMessage(
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
EXPECT_NO_THROW(client->InitTcpClient());
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,71 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/comm/tcp_client.h"
#include "ps/comm/tcp_server.h"
#include "common/common_test.h"
#include <thread>
namespace mindspore {
namespace ps {
namespace comm {
class TestTcpServer : public UT::Common {
public:
TestTcpServer() = default;
void SetUp() override {
server_ = new TcpServer("127.0.0.1", 9000);
std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([&]() {
server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) {
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
server.SendMessage(conn, buffer, num);
});
server_->InitServer();
server_->Start();
});
http_server_thread_->detach();
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
}
void TearDown() override {
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
client_->Stop();
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
server_->Stop();
}
TcpClient *client_;
TcpServer *server_;
const std::string test_message_ = "TCP_MESSAGE";
};
TEST_F(TestTcpServer, ServerSendMessage) {
client_ = new TcpClient("127.0.0.1", 9000);
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {
client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) {
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
});
client_->InitTcpClient();
client_->SendMessage(test_message_.c_str(), test_message_.size());
client_->Start();
});
http_client_thread->detach();
}
} // namespace comm
} // namespace ps
} // namespace mindspore