forked from mindspore-Ecosystem/mindspore
!7518 Added tcp server based on libevent
Merge pull request !7518 from anancds/tcp-server
This commit is contained in:
commit
80725441b6
|
@ -7,6 +7,10 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "util.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc")
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/comm/comm_util.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <regex>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
bool CommUtil::CheckIpWithRegex(const std::string &ip) {
|
||||
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
|
||||
std::smatch res;
|
||||
if (regex_match(ip, res, pattern)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CommUtil::CheckIp(const std::string &ip) {
|
||||
if (!CheckIpWithRegex(ip)) {
|
||||
MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!";
|
||||
}
|
||||
int64_t uAddr = inet_addr(ip.c_str());
|
||||
if (INADDR_NONE == uAddr) {
|
||||
MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!";
|
||||
}
|
||||
}
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
|
||||
#define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
|
||||
|
||||
#include <event2/buffer.h>
|
||||
#include <event2/event.h>
|
||||
#include <event2/http.h>
|
||||
#include <event2/keyvalq_struct.h>
|
||||
#include <event2/listener.h>
|
||||
#include <event2/util.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
class CommUtil {
|
||||
public:
|
||||
static bool CheckIpWithRegex(const std::string &ip);
|
||||
static void CheckIp(const std::string &ip);
|
||||
};
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_
|
|
@ -38,7 +38,7 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
typedef std::map<std::string, std::list<std::string>> HttpHeaders;
|
||||
using HttpHeaders = std::map<std::string, std::list<std::string>>;
|
||||
|
||||
class HttpMessageHandler {
|
||||
public:
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "ps/comm/http_server.h"
|
||||
#include "ps/comm/http_message_handler.h"
|
||||
#include "ps/comm/comm_util.h"
|
||||
|
||||
#ifdef WIN32
|
||||
#include <WinSock2.h>
|
||||
|
@ -41,28 +42,10 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
HttpServer::~HttpServer() {
|
||||
if (event_http_) {
|
||||
evhttp_free(event_http_);
|
||||
event_http_ = nullptr;
|
||||
}
|
||||
if (event_base_) {
|
||||
event_base_free(event_base_);
|
||||
event_base_ = nullptr;
|
||||
}
|
||||
}
|
||||
HttpServer::~HttpServer() { Stop(); }
|
||||
|
||||
bool HttpServer::InitServer() {
|
||||
if (!CheckIp(server_address_)) {
|
||||
MS_LOG(EXCEPTION) << "Server address" << server_address_ << " illegal!";
|
||||
}
|
||||
int64_t uAddr = inet_addr(server_address_.c_str());
|
||||
if (INADDR_NONE == uAddr) {
|
||||
MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!";
|
||||
}
|
||||
if (server_port_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Server port:" << server_port_ << " illegal!";
|
||||
}
|
||||
CommUtil::CheckIp(server_address_);
|
||||
|
||||
event_base_ = event_base_new();
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
|
@ -76,15 +59,6 @@ bool HttpServer::InitServer() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HttpServer::CheckIp(const std::string &ip) {
|
||||
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
|
||||
std::smatch res;
|
||||
if (regex_match(ip, res, pattern)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void HttpServer::SetTimeOut(int seconds) {
|
||||
MS_EXCEPTION_IF_NULL(event_http_);
|
||||
if (seconds < 0) {
|
||||
|
@ -93,7 +67,7 @@ void HttpServer::SetTimeOut(int seconds) {
|
|||
evhttp_set_timeout(event_http_, seconds);
|
||||
}
|
||||
|
||||
void HttpServer::SetAllowedMethod(HttpMethodsSet methods) {
|
||||
void HttpServer::SetAllowedMethod(u_int16_t methods) {
|
||||
MS_EXCEPTION_IF_NULL(event_http_);
|
||||
evhttp_set_allowed_methods(event_http_, methods);
|
||||
}
|
||||
|
@ -114,12 +88,11 @@ void HttpServer::SetMaxBodySize(size_t num) {
|
|||
evhttp_set_max_body_size(event_http_, num);
|
||||
}
|
||||
|
||||
bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) {
|
||||
bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *function) {
|
||||
if ((!is_init_) && (!InitServer())) {
|
||||
MS_LOG(EXCEPTION) << "Init http server failed!";
|
||||
}
|
||||
HandlerFunc func = function;
|
||||
if (!func) {
|
||||
if (!function) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -128,15 +101,13 @@ bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) {
|
|||
MS_EXCEPTION_IF_NULL(arg);
|
||||
HttpMessageHandler httpReq(req);
|
||||
httpReq.InitHttpMessage();
|
||||
handle_t *f = reinterpret_cast<handle_t *>(arg);
|
||||
f(&httpReq);
|
||||
OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg);
|
||||
(*func)(&httpReq);
|
||||
};
|
||||
handle_t **pph = func.target<handle_t *>();
|
||||
MS_EXCEPTION_IF_NULL(pph);
|
||||
MS_EXCEPTION_IF_NULL(event_http_);
|
||||
|
||||
// O SUCCESS,-1 ALREADY_EXIST,-2 FAILURE
|
||||
int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast<void *>(*pph));
|
||||
int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast<void *>(function));
|
||||
if (ret == 0) {
|
||||
MS_LOG(INFO) << "Ev http register handle of:" << url.c_str() << " success.";
|
||||
} else if (ret == -1) {
|
||||
|
|
|
@ -48,26 +48,21 @@ typedef enum eHttpMethod {
|
|||
HM_PATCH = 1 << 8
|
||||
} HttpMethod;
|
||||
|
||||
typedef u_int16_t HttpMethodsSet;
|
||||
|
||||
typedef void(handle_t)(HttpMessageHandler *);
|
||||
|
||||
class HttpServer {
|
||||
public:
|
||||
// Server address only support IPV4 now, and should be in format of "x.x.x.x"
|
||||
explicit HttpServer(const std::string &address, std::int16_t port)
|
||||
explicit HttpServer(const std::string &address, std::uint16_t port)
|
||||
: server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {}
|
||||
|
||||
~HttpServer();
|
||||
|
||||
typedef std::function<handle_t> HandlerFunc;
|
||||
using OnRequestReceive = std::function<void(HttpMessageHandler *)>;
|
||||
|
||||
bool InitServer();
|
||||
static bool CheckIp(const std::string &ip);
|
||||
void SetTimeOut(int seconds = 5);
|
||||
|
||||
// Default allowed methods: GET, POST, HEAD, PUT, DELETE
|
||||
void SetAllowedMethod(HttpMethodsSet methods);
|
||||
void SetAllowedMethod(u_int16_t methods);
|
||||
|
||||
// Default to ((((unsigned long long)0xffffffffUL) << 32) | 0xffffffffUL)
|
||||
void SetMaxHeaderSize(std::size_t num);
|
||||
|
@ -76,7 +71,7 @@ class HttpServer {
|
|||
void SetMaxBodySize(std::size_t num);
|
||||
|
||||
// Return: true if success, false if failed, check log to find failure reason
|
||||
bool RegisterRoute(const std::string &url, handle_t *func);
|
||||
bool RegisterRoute(const std::string &url, OnRequestReceive *func);
|
||||
bool UnRegisterRoute(const std::string &url);
|
||||
|
||||
bool Start();
|
||||
|
@ -84,7 +79,7 @@ class HttpServer {
|
|||
|
||||
private:
|
||||
std::string server_address_;
|
||||
std::int16_t server_port_;
|
||||
std::uint16_t server_port_;
|
||||
struct event_base *event_base_;
|
||||
struct evhttp *event_http_;
|
||||
bool is_init_;
|
||||
|
|
|
@ -0,0 +1,220 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/comm/tcp_client.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <event2/buffer.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/buffer_compat.h>
|
||||
#include <event2/event.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
#include "ps/comm/comm_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
TcpClient::TcpClient(std::string address, std::uint16_t port)
|
||||
: event_base_(nullptr),
|
||||
event_timeout_(nullptr),
|
||||
buffer_event_(nullptr),
|
||||
server_address_(std::move(address)),
|
||||
server_port_(port) {
|
||||
message_handler_.SetCallback([this](const void *buf, size_t num) {
|
||||
if (buf == nullptr) {
|
||||
if (disconnected_callback_) disconnected_callback_(*this, 200);
|
||||
Stop();
|
||||
}
|
||||
if (message_callback_) message_callback_(*this, buf, num);
|
||||
});
|
||||
}
|
||||
|
||||
TcpClient::~TcpClient() { Stop(); }
|
||||
|
||||
std::string TcpClient::GetServerAddress() const { return server_address_; }
|
||||
|
||||
void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
|
||||
const OnTimeout &timeout) {
|
||||
connected_callback_ = conn;
|
||||
disconnected_callback_ = disconn;
|
||||
read_callback_ = read;
|
||||
timeout_callback_ = timeout;
|
||||
}
|
||||
|
||||
void TcpClient::InitTcpClient() {
|
||||
if (buffer_event_) {
|
||||
return;
|
||||
}
|
||||
CommUtil::CheckIp(server_address_);
|
||||
|
||||
event_base_ = event_base_new();
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
|
||||
sockaddr_in sin{};
|
||||
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
|
||||
}
|
||||
sin.sin_family = AF_INET;
|
||||
sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
|
||||
sin.sin_port = htons(server_port_);
|
||||
|
||||
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE);
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
|
||||
bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
|
||||
if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
|
||||
MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
|
||||
}
|
||||
|
||||
int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
|
||||
if (result_code < 0) {
|
||||
MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::StartWithDelay(int seconds) {
|
||||
if (buffer_event_) {
|
||||
return;
|
||||
}
|
||||
|
||||
event_base_ = event_base_new();
|
||||
|
||||
timeval timeout_value{};
|
||||
timeout_value.tv_sec = seconds;
|
||||
timeout_value.tv_usec = 0;
|
||||
|
||||
event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
|
||||
if (evtimer_add(event_timeout_, &timeout_value) == -1) {
|
||||
MS_LOG(EXCEPTION) << "event timeout failed!";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::Stop() {
|
||||
if (buffer_event_) {
|
||||
bufferevent_free(buffer_event_);
|
||||
buffer_event_ = nullptr;
|
||||
}
|
||||
|
||||
if (event_timeout_) {
|
||||
event_free(event_timeout_);
|
||||
event_timeout_ = nullptr;
|
||||
}
|
||||
|
||||
if (event_base_) {
|
||||
event_base_free(event_base_);
|
||||
event_base_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
|
||||
const int one = 1;
|
||||
int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
|
||||
if (ret < 0) {
|
||||
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
|
||||
tcp_client->InitTcpClient();
|
||||
}
|
||||
|
||||
void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
MS_EXCEPTION_IF_NULL(ctx);
|
||||
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
|
||||
struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev));
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
||||
char read_buffer[4096];
|
||||
int read = 0;
|
||||
|
||||
while ((read = EVBUFFER_LENGTH(input)) > 0) {
|
||||
if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
|
||||
}
|
||||
tcp_client->OnReadHandler(read_buffer, read);
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::OnReadHandler(const void *buf, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
if (read_callback_) {
|
||||
read_callback_(*this, buf, num);
|
||||
}
|
||||
message_handler_.ReceiveMessage(buf, num);
|
||||
}
|
||||
|
||||
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
|
||||
if (events & BEV_EVENT_CONNECTED) {
|
||||
// Connected
|
||||
if (tcp_client->connected_callback_) {
|
||||
tcp_client->connected_callback_(*tcp_client);
|
||||
}
|
||||
evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev));
|
||||
SetTcpNoDelay(fd);
|
||||
MS_LOG(INFO) << "Client connected!";
|
||||
} else if (events & BEV_EVENT_ERROR) {
|
||||
MS_LOG(ERROR) << "Client connected error!";
|
||||
if (tcp_client->disconnected_callback_) {
|
||||
tcp_client->disconnected_callback_(*tcp_client, errno);
|
||||
}
|
||||
} else if (events & BEV_EVENT_EOF) {
|
||||
MS_LOG(ERROR) << "Client connected end of file";
|
||||
if (tcp_client->disconnected_callback_) {
|
||||
tcp_client->disconnected_callback_(*tcp_client, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::Start() {
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
int ret = event_base_dispatch(event_base_);
|
||||
if (ret == 0) {
|
||||
MS_LOG(INFO) << "Event base dispatch success!";
|
||||
} else if (ret == 1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
|
||||
} else if (ret == -1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; }
|
||||
|
||||
void TcpClient::SendMessage(const void *buf, size_t num) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) {
|
||||
MS_LOG(EXCEPTION) << "event buffer add failed!";
|
||||
}
|
||||
}
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
|
||||
#define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
|
||||
|
||||
#include "ps/comm/tcp_message_handler.h"
|
||||
|
||||
#include <event2/event.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
class TcpClient {
|
||||
public:
|
||||
using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||
using OnConnected = std::function<void(const TcpClient &)>;
|
||||
using OnDisconnected = std::function<void(const TcpClient &, int)>;
|
||||
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||
using OnTimeout = std::function<void(const TcpClient &)>;
|
||||
|
||||
explicit TcpClient(std::string address, std::uint16_t port);
|
||||
virtual ~TcpClient();
|
||||
|
||||
std::string GetServerAddress() const;
|
||||
void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
|
||||
const OnTimeout &timeout);
|
||||
void InitTcpClient();
|
||||
void StartWithDelay(int seconds);
|
||||
void Stop();
|
||||
void ReceiveMessage(const OnMessage &cb);
|
||||
void SendMessage(const void *buf, size_t num) const;
|
||||
void Start();
|
||||
|
||||
protected:
|
||||
static void SetTcpNoDelay(const evutil_socket_t &fd);
|
||||
static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg);
|
||||
static void ReadCallback(struct bufferevent *bev, void *ctx);
|
||||
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
|
||||
virtual void OnReadHandler(const void *buf, size_t num);
|
||||
|
||||
private:
|
||||
TcpMessageHandler message_handler_;
|
||||
OnMessage message_callback_;
|
||||
OnConnected connected_callback_;
|
||||
OnDisconnected disconnected_callback_;
|
||||
OnRead read_callback_;
|
||||
OnTimeout timeout_callback_;
|
||||
|
||||
event_base *event_base_;
|
||||
event *event_timeout_;
|
||||
bufferevent *buffer_event_;
|
||||
|
||||
std::string server_address_;
|
||||
std::uint16_t server_port_;
|
||||
};
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/comm/tcp_message_handler.h"
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); }
|
||||
|
||||
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(buffer);
|
||||
|
||||
if (message_callback_) {
|
||||
message_callback_(buffer, num);
|
||||
}
|
||||
}
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
|
||||
#define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
using messageReceive = std::function<void(const void *buffer, size_t len)>;
|
||||
|
||||
class TcpMessageHandler {
|
||||
public:
|
||||
TcpMessageHandler() = default;
|
||||
virtual ~TcpMessageHandler() = default;
|
||||
|
||||
void SetCallback(messageReceive cb);
|
||||
void ReceiveMessage(const void *buffer, size_t num);
|
||||
|
||||
private:
|
||||
messageReceive message_callback_;
|
||||
};
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_
|
|
@ -0,0 +1,259 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/comm/tcp_server.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <event2/buffer.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/event.h>
|
||||
#include <event2/listener.h>
|
||||
#include <event2/buffer_compat.h>
|
||||
#include <event2/util.h>
|
||||
#include <sys/socket.h>
|
||||
#include <csignal>
|
||||
#include <utility>
|
||||
|
||||
#include "ps/comm/comm_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
buffer_event_ = const_cast<struct bufferevent *>(bev);
|
||||
fd_ = fd;
|
||||
server_ = const_cast<TcpServer *>(server);
|
||||
|
||||
tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) {
|
||||
OnServerReceiveMessage message_callback = server->GetServerReceiveMessage();
|
||||
if (message_callback) message_callback(*server, *this, buf, num);
|
||||
});
|
||||
}
|
||||
|
||||
void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); }
|
||||
|
||||
void TcpConnection::SendMessage(const void *buffer, size_t num) const {
|
||||
if (bufferevent_write(buffer_event_, buffer, num) == -1) {
|
||||
MS_LOG(ERROR) << "Write message to buffer event failed!";
|
||||
}
|
||||
}
|
||||
|
||||
TcpServer *TcpConnection::GetServer() const { return server_; }
|
||||
|
||||
evutil_socket_t TcpConnection::GetFd() const { return fd_; }
|
||||
|
||||
TcpServer::TcpServer(std::string address, std::uint16_t port)
|
||||
: base_(nullptr),
|
||||
signal_event_(nullptr),
|
||||
listener_(nullptr),
|
||||
server_address_(std::move(address)),
|
||||
server_port_(port) {}
|
||||
|
||||
TcpServer::~TcpServer() { Stop(); }
|
||||
|
||||
void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
|
||||
const OnAccepted &client_accept) {
|
||||
this->client_connection_ = client_conn;
|
||||
this->client_disconnection_ = client_disconn;
|
||||
this->client_accept_ = client_accept;
|
||||
}
|
||||
|
||||
void TcpServer::InitServer() {
|
||||
base_ = event_base_new();
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
CommUtil::CheckIp(server_address_);
|
||||
|
||||
struct sockaddr_in sin {};
|
||||
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
|
||||
}
|
||||
sin.sin_family = AF_INET;
|
||||
sin.sin_port = htons(server_port_);
|
||||
sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
|
||||
|
||||
listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
|
||||
LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
|
||||
reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(listener_);
|
||||
|
||||
signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
|
||||
MS_EXCEPTION_IF_NULL(signal_event_);
|
||||
if (event_add(signal_event_, nullptr) < 0) {
|
||||
MS_LOG(EXCEPTION) << "Cannot create signal event.";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::Start() {
|
||||
std::unique_lock<std::recursive_mutex> l(connection_mutex_);
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
int ret = event_base_dispatch(base_);
|
||||
if (ret == 0) {
|
||||
MS_LOG(INFO) << "Event base dispatch success!";
|
||||
} else if (ret == 1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
|
||||
} else if (ret == -1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::Stop() {
|
||||
if (signal_event_ != nullptr) {
|
||||
event_free(signal_event_);
|
||||
signal_event_ = nullptr;
|
||||
}
|
||||
|
||||
if (listener_ != nullptr) {
|
||||
evconnlistener_free(listener_);
|
||||
listener_ = nullptr;
|
||||
}
|
||||
|
||||
if (base_ != nullptr) {
|
||||
event_base_free(base_);
|
||||
base_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::SendToAllClients(const char *data, size_t len) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
|
||||
it->second->SendMessage(data, len);
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) {
|
||||
MS_EXCEPTION_IF_NULL(connection);
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
connections_.insert(std::make_pair(fd, connection));
|
||||
}
|
||||
|
||||
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
connections_.erase(fd);
|
||||
}
|
||||
|
||||
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) {
|
||||
auto server = reinterpret_cast<class TcpServer *>(data);
|
||||
auto base = reinterpret_cast<struct event_base *>(server->base_);
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(base);
|
||||
|
||||
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE);
|
||||
if (!bev) {
|
||||
MS_LOG(ERROR) << "Error constructing buffer event!";
|
||||
event_base_loopbreak(base);
|
||||
return;
|
||||
}
|
||||
|
||||
TcpConnection *conn = server->onCreateConnection();
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
|
||||
conn->InitConnection(fd, bev, server);
|
||||
server->AddConnection(fd, conn);
|
||||
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
|
||||
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
|
||||
MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
|
||||
}
|
||||
}
|
||||
|
||||
TcpConnection *TcpServer::onCreateConnection() {
|
||||
TcpConnection *conn = nullptr;
|
||||
if (client_accept_)
|
||||
conn = const_cast<TcpConnection *>(client_accept_(this));
|
||||
else
|
||||
conn = new TcpConnection();
|
||||
|
||||
return conn;
|
||||
}
|
||||
|
||||
OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; }
|
||||
|
||||
void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
|
||||
auto server = reinterpret_cast<class TcpServer *>(data);
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
struct event_base *base = server->base_;
|
||||
struct timeval delay = {0, 0};
|
||||
MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
|
||||
if (event_base_loopexit(base, &delay) == -1) {
|
||||
MS_LOG(EXCEPTION) << "event base loop exit failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
MS_EXCEPTION_IF_NULL(connection);
|
||||
|
||||
auto conn = static_cast<class TcpConnection *>(connection);
|
||||
struct evbuffer *buf = bufferevent_get_input(bev);
|
||||
char read_buffer[4096];
|
||||
auto read = 0;
|
||||
while ((read = EVBUFFER_LENGTH(buf)) > 0) {
|
||||
if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
|
||||
}
|
||||
conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto conn = reinterpret_cast<TcpConnection *>(data);
|
||||
TcpServer *srv = conn->GetServer();
|
||||
|
||||
if (events & BEV_EVENT_EOF) {
|
||||
// Notify about disconnection
|
||||
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
|
||||
// Free connection structures
|
||||
srv->RemoveConnection(conn->GetFd());
|
||||
bufferevent_free(bev);
|
||||
} else if (events & BEV_EVENT_ERROR) {
|
||||
// Free connection structures
|
||||
srv->RemoveConnection(conn->GetFd());
|
||||
bufferevent_free(bev);
|
||||
|
||||
// Notify about disconnection
|
||||
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unhandled event!";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
|
||||
|
||||
void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto mc = const_cast<TcpConnection &>(conn);
|
||||
mc.SendMessage(data, num);
|
||||
}
|
||||
|
||||
void TcpServer::SendMessage(const void *data, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
|
||||
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
|
||||
SendMessage(*it->second, data, num);
|
||||
}
|
||||
}
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
|
||||
|
||||
#include <event2/buffer.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/event.h>
|
||||
#include <event2/listener.h>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/comm/tcp_message_handler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
class TcpServer;
|
||||
class TcpConnection {
|
||||
public:
|
||||
TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {}
|
||||
virtual ~TcpConnection() = default;
|
||||
|
||||
virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server);
|
||||
void SendMessage(const void *buffer, size_t num) const;
|
||||
virtual void OnReadHandler(const void *buffer, size_t numBytes);
|
||||
TcpServer *GetServer() const;
|
||||
evutil_socket_t GetFd() const;
|
||||
|
||||
protected:
|
||||
TcpMessageHandler tcp_message_handler_;
|
||||
struct bufferevent *buffer_event_;
|
||||
evutil_socket_t fd_;
|
||||
TcpServer *server_;
|
||||
};
|
||||
|
||||
using OnServerReceiveMessage =
|
||||
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>;
|
||||
|
||||
class TcpServer {
|
||||
public:
|
||||
using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>;
|
||||
using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>;
|
||||
using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>;
|
||||
|
||||
explicit TcpServer(std::string address, std::uint16_t port);
|
||||
virtual ~TcpServer();
|
||||
|
||||
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
|
||||
const OnAccepted &client_accept);
|
||||
void InitServer();
|
||||
void Start();
|
||||
void Stop();
|
||||
void SendToAllClients(const char *data, size_t len);
|
||||
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
|
||||
void RemoveConnection(const evutil_socket_t &fd);
|
||||
void ReceiveMessage(const OnServerReceiveMessage &cb);
|
||||
static void SendMessage(const TcpConnection &conn, const void *data, size_t num);
|
||||
void SendMessage(const void *data, size_t num);
|
||||
OnServerReceiveMessage GetServerReceiveMessage() const;
|
||||
|
||||
protected:
|
||||
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
|
||||
int socklen, void *server);
|
||||
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
|
||||
static void ReadCallback(struct bufferevent *, void *connection);
|
||||
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
|
||||
virtual TcpConnection *onCreateConnection();
|
||||
|
||||
private:
|
||||
struct event_base *base_;
|
||||
struct event *signal_event_;
|
||||
struct evconnlistener *listener_;
|
||||
std::string server_address_;
|
||||
std::uint16_t server_port_;
|
||||
|
||||
std::map<evutil_socket_t, const TcpConnection *> connections_;
|
||||
OnConnected client_connection_;
|
||||
OnDisconnected client_disconnection_;
|
||||
OnAccepted client_accept_;
|
||||
std::recursive_mutex connection_mutex_;
|
||||
OnServerReceiveMessage message_callback_;
|
||||
};
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_
|
|
@ -31,7 +31,7 @@ namespace comm {
|
|||
|
||||
class TestHttpServer : public UT::Common {
|
||||
public:
|
||||
TestHttpServer() {}
|
||||
TestHttpServer() = default;
|
||||
|
||||
static void testGetHandler(HttpMessageHandler *resp) {
|
||||
std::string host = resp->GetRequestHost();
|
||||
|
@ -58,16 +58,44 @@ class TestHttpServer : public UT::Common {
|
|||
|
||||
void SetUp() override {
|
||||
server_ = new HttpServer("0.0.0.0", 9999);
|
||||
server_->RegisterRoute("/httpget", [](HttpMessageHandler *resp) {
|
||||
EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1");
|
||||
EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1");
|
||||
EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1");
|
||||
EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget");
|
||||
resp->QuickResponse(200, "get request success!\n");
|
||||
});
|
||||
server_->RegisterRoute("/handler", TestHttpServer::testGetHandler);
|
||||
std::function<void(HttpMessageHandler *)> http_get_func = std::bind(
|
||||
[](HttpMessageHandler *resp) {
|
||||
EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1");
|
||||
EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1");
|
||||
EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1");
|
||||
EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget");
|
||||
resp->QuickResponse(200, "get request success!\n");
|
||||
},
|
||||
std::placeholders::_1);
|
||||
|
||||
std::function<void(HttpMessageHandler *)> http_handler_func = std::bind(
|
||||
[](HttpMessageHandler *resp) {
|
||||
std::string host = resp->GetRequestHost();
|
||||
EXPECT_STREQ(host.c_str(), "127.0.0.1");
|
||||
|
||||
std::string path_param = resp->GetPathParam("key1");
|
||||
std::string header_param = resp->GetHeadParam("headerKey");
|
||||
std::string post_param = resp->GetPostParam("postKey");
|
||||
std::string post_message = resp->GetPostMsg();
|
||||
EXPECT_STREQ(path_param.c_str(), "value1");
|
||||
EXPECT_STREQ(header_param.c_str(), "headerValue");
|
||||
EXPECT_STREQ(post_param.c_str(), "postValue");
|
||||
EXPECT_STREQ(post_message.c_str(), "postKey=postValue");
|
||||
|
||||
const std::string rKey("headKey");
|
||||
const std::string rVal("headValue");
|
||||
const std::string rBody("post request success!\n");
|
||||
resp->AddRespHeadParam(rKey, rVal);
|
||||
resp->AddRespString(rBody);
|
||||
|
||||
resp->SetRespCode(200);
|
||||
resp->SendResponse();
|
||||
},
|
||||
std::placeholders::_1);
|
||||
server_->RegisterRoute("/httpget", &http_get_func);
|
||||
server_->RegisterRoute("/handler", &http_handler_func);
|
||||
std::unique_ptr<std::thread> http_server_thread_(nullptr);
|
||||
http_server_thread_.reset(new std::thread([&]() { server_->Start(); }));
|
||||
http_server_thread_ = std::make_unique<std::thread>([&]() { server_->Start(); });
|
||||
http_server_thread_->detach();
|
||||
}
|
||||
|
||||
|
@ -110,14 +138,18 @@ TEST_F(TestHttpServer, messageHandler) {
|
|||
pclose(file);
|
||||
}
|
||||
|
||||
TEST_F(TestHttpServer, portException) {
|
||||
TEST_F(TestHttpServer, portErrorNoException) {
|
||||
HttpServer *server_exception = new HttpServer("0.0.0.0", -1);
|
||||
ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception);
|
||||
std::function<void(HttpMessageHandler *)> http_handler_func =
|
||||
std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
|
||||
EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func));
|
||||
}
|
||||
|
||||
TEST_F(TestHttpServer, addressException) {
|
||||
HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998);
|
||||
ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception);
|
||||
std::function<void(HttpMessageHandler *)> http_handler_func =
|
||||
std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
|
||||
ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception);
|
||||
}
|
||||
|
||||
} // namespace comm
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ps/comm/tcp_client.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
class TestTcpClient : public UT::Common {
|
||||
public:
|
||||
TestTcpClient() = default;
|
||||
};
|
||||
|
||||
TEST_F(TestTcpClient, InitClientIPError) {
|
||||
auto client = new TcpClient("127.0.0.13543", 9000);
|
||||
client->ReceiveMessage(
|
||||
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
|
||||
|
||||
ASSERT_THROW(client->InitTcpClient(), std::exception);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
|
||||
auto client = new TcpClient("127.0.0.1", -1);
|
||||
client->ReceiveMessage(
|
||||
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
|
||||
|
||||
EXPECT_NO_THROW(client->InitTcpClient());
|
||||
}
|
||||
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/comm/tcp_client.h"
|
||||
#include "ps/comm/tcp_server.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
#include <thread>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
class TestTcpServer : public UT::Common {
|
||||
public:
|
||||
TestTcpServer() = default;
|
||||
void SetUp() override {
|
||||
server_ = new TcpServer("127.0.0.1", 9000);
|
||||
std::unique_ptr<std::thread> http_server_thread_(nullptr);
|
||||
http_server_thread_ = std::make_unique<std::thread>([&]() {
|
||||
server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) {
|
||||
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
|
||||
server.SendMessage(conn, buffer, num);
|
||||
});
|
||||
server_->InitServer();
|
||||
server_->Start();
|
||||
});
|
||||
http_server_thread_->detach();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||
}
|
||||
void TearDown() override {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||
client_->Stop();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||
server_->Stop();
|
||||
}
|
||||
|
||||
TcpClient *client_;
|
||||
TcpServer *server_;
|
||||
const std::string test_message_ = "TCP_MESSAGE";
|
||||
};
|
||||
|
||||
TEST_F(TestTcpServer, ServerSendMessage) {
|
||||
client_ = new TcpClient("127.0.0.1", 9000);
|
||||
std::unique_ptr<std::thread> http_client_thread(nullptr);
|
||||
http_client_thread = std::make_unique<std::thread>([&]() {
|
||||
client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) {
|
||||
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
|
||||
});
|
||||
|
||||
client_->InitTcpClient();
|
||||
client_->SendMessage(test_message_.c_str(), test_message_.size());
|
||||
client_->Start();
|
||||
});
|
||||
http_client_thread->detach();
|
||||
}
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue