!7911 added key value message

Merge pull request !7911 from anancds/kv-patch
This commit is contained in:
mindspore-ci-bot 2020-11-05 09:53:00 +08:00 committed by Gitee
commit 7123a2c1d1
12 changed files with 329 additions and 129 deletions

View File

@ -100,6 +100,11 @@ message("onnx proto path is :" ${ONNX_PROTO})
ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO})
list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS})
include_directories("${CMAKE_BINARY_DIR}/ps/comm")
file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto")
ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN})
list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS})
if (ENABLE_DEBUGGER)
# debugger: compile proto files
include_directories("${CMAKE_BINARY_DIR}/debug/debugger")
@ -290,7 +295,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
else ()
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
if (${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(mindspore ibverbs rdmacm)
endif()

View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
import "google/protobuf/any.proto";
package mindspore.ps;
option optimize_for = LITE_RUNTIME;
message MessageMeta {
// hostname or ip
string hostname = 1;
// the port of this node
int32 port = 2;
// the command of this message,for example: registerheartbeatdata
int32 cmd = 3;
// the timestamp of this message
int32 timestamp = 4;
// data type of message
repeated int32 data_type = 5 [packed = true];
// message.data_size
int32 data_size = 6;
}
message CommMessage {
MessageMeta pb_meta = 1;
bytes data = 2;
}

View File

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
message KVMessage {
repeated int32 keys = 1;
repeated float values = 2;
}
message HeartBeatMessage {
// *.*.*.*:port
repeated string host_and_port = 1;
}

View File

@ -36,18 +36,16 @@ namespace mindspore {
namespace ps {
namespace comm {
TcpClient::TcpClient(std::string address, std::uint16_t port)
TcpClient::TcpClient(const std::string &address, std::uint16_t port)
: event_base_(nullptr),
event_timeout_(nullptr),
buffer_event_(nullptr),
server_address_(std::move(address)),
server_port_(port) {
message_handler_.SetCallback([this](const void *buf, size_t num) {
if (buf == nullptr) {
if (disconnected_callback_) disconnected_callback_(*this, 200);
Stop();
message_handler_.SetCallback([this](const CommMessage &message) {
if (message_callback_) {
message_callback_(*this, message);
}
if (message_callback_) message_callback_(*this, buf, num);
});
}
@ -63,7 +61,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco
timeout_callback_ = timeout;
}
void TcpClient::InitTcpClient() {
void TcpClient::Init() {
if (buffer_event_) {
return;
}
@ -139,7 +137,7 @@ void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
tcp_client->InitTcpClient();
tcp_client->Init();
}
void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
@ -150,10 +148,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
MS_EXCEPTION_IF_NULL(input);
char read_buffer[4096];
int read = 0;
while ((read = EVBUFFER_LENGTH(input)) > 0) {
if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) {
while (EVBUFFER_LENGTH(input) > 0) {
int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer));
if (read == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
}
tcp_client->OnReadHandler(read_buffer, read);
@ -196,25 +194,38 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void
void TcpClient::Start() {
MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_dispatch(event_base_);
if (ret == 0) {
MS_LOG(INFO) << "Event base dispatch success!";
} else if (ret == 1) {
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
} else if (ret == -1) {
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
} else {
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
}
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
}
void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; }
void TcpClient::StartWithNoBlock() {
MS_LOG(INFO) << "Start tcp client with no block!";
MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
}
void TcpClient::SendMessage(const void *buf, size_t num) const {
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
void TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add failed!";
uint32_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
}
if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
}
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -23,6 +23,10 @@
#include <event2/bufferevent.h>
#include <functional>
#include <string>
#include <memory>
#include <vector>
#include "proto/comm.pb.h"
namespace mindspore {
namespace ps {
@ -30,24 +34,25 @@ namespace comm {
class TcpClient {
public:
using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>;
using OnConnected = std::function<void(const TcpClient &)>;
using OnDisconnected = std::function<void(const TcpClient &, int)>;
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>;
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
explicit TcpClient(std::string address, std::uint16_t port);
explicit TcpClient(const std::string &address, std::uint16_t port);
virtual ~TcpClient();
std::string GetServerAddress() const;
void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
const OnTimeout &timeout);
void InitTcpClient();
void Init();
void StartWithDelay(int seconds);
void Stop();
void ReceiveMessage(const OnMessage &cb);
void SendMessage(const void *buf, size_t num) const;
void Start();
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
void SendMessage(const CommMessage &message) const;
protected:
static void SetTcpNoDelay(const evutil_socket_t &fd);
@ -57,8 +62,9 @@ class TcpClient {
virtual void OnReadHandler(const void *buf, size_t num);
private:
TcpMessageHandler message_handler_;
OnMessage message_callback_;
TcpMessageHandler message_handler_;
OnConnected connected_callback_;
OnDisconnected disconnected_callback_;
OnRead read_callback_;
@ -71,6 +77,7 @@ class TcpClient {
std::string server_address_;
std::uint16_t server_port_;
};
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -15,6 +15,8 @@
*/
#include "ps/comm/tcp_message_handler.h"
#include <arpa/inet.h>
#include <iostream>
#include <utility>
@ -22,15 +24,55 @@ namespace mindspore {
namespace ps {
namespace comm {
void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); }
void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; }
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
MS_EXCEPTION_IF_NULL(buffer);
auto buffer_data = reinterpret_cast<const unsigned char *>(buffer);
if (message_callback_) {
message_callback_(buffer, num);
while (num > 0) {
if (remaining_length_ == 0) {
for (int i = 0; i < 4 && num > 0; ++i) {
header_[++header_index_] = *(buffer_data + i);
--num;
if (header_index_ == 3) {
message_length_ = *reinterpret_cast<const uint32_t *>(header_);
message_length_ = ntohl(message_length_);
remaining_length_ = message_length_;
message_buffer_.reset(new unsigned char[remaining_length_]);
buffer_data += i;
break;
}
}
}
if (remaining_length_ > 0) {
uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
remaining_length_ -= copy_len;
num -= copy_len;
int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len);
last_copy_len_ += copy_len;
buffer_data += copy_len;
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
if (remaining_length_ == 0) {
CommMessage pb_message;
pb_message.ParseFromArray(reinterpret_cast<const void *>(message_buffer_.get()), message_length_);
if (message_callback_) {
message_callback_(pb_message);
}
message_buffer_.reset();
message_buffer_ = nullptr;
header_index_ = 0;
last_copy_len_ = 0;
}
}
}
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -19,26 +19,43 @@
#include <functional>
#include <iostream>
#include <string>
#include <memory>
#include <vector>
#include "utils/log_adapter.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
namespace mindspore {
namespace ps {
namespace comm {
using messageReceive = std::function<void(const void *buffer, size_t len)>;
using messageReceive = std::function<void(const CommMessage &message)>;
class TcpMessageHandler {
public:
TcpMessageHandler() = default;
TcpMessageHandler()
: is_parsed_(false),
message_buffer_(nullptr),
message_length_(0),
remaining_length_(0),
header_index_(-1),
last_copy_len_(0) {}
virtual ~TcpMessageHandler() = default;
void SetCallback(messageReceive cb);
void SetCallback(const messageReceive &cb);
void ReceiveMessage(const void *buffer, size_t num);
private:
messageReceive message_callback_;
bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_;
size_t message_length_;
uint32_t remaining_length_;
char header_[4];
int header_index_;
uint32_t last_copy_len_;
};
} // namespace comm
} // namespace ps

View File

@ -33,16 +33,12 @@ namespace mindspore {
namespace ps {
namespace comm {
void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(server);
buffer_event_ = const_cast<struct bufferevent *>(bev);
fd_ = fd;
server_ = const_cast<TcpServer *>(server);
tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) {
OnServerReceiveMessage message_callback = server->GetServerReceiveMessage();
if (message_callback) message_callback(*server, *this, buf, num);
void TcpConnection::InitConnection() {
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
if (on_server_receive) {
on_server_receive(*server_, *this, message);
}
});
}
@ -54,11 +50,26 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const {
}
}
TcpServer *TcpConnection::GetServer() const { return server_; }
TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); }
evutil_socket_t TcpConnection::GetFd() const { return fd_; }
const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
TcpServer::TcpServer(std::string address, std::uint16_t port)
void TcpConnection::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
uint32_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
}
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(),
buf_size) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
}
}
TcpServer::TcpServer(const std::string &address, std::uint16_t port)
: base_(nullptr),
signal_event_(nullptr),
listener_(nullptr),
@ -74,7 +85,7 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
this->client_accept_ = client_accept;
}
void TcpServer::InitServer() {
void TcpServer::Init() {
base_ = event_base_new();
MS_EXCEPTION_IF_NULL(base_);
CommUtil::CheckIp(server_address_);
@ -101,19 +112,26 @@ void TcpServer::InitServer() {
}
void TcpServer::Start() {
std::unique_lock<std::recursive_mutex> l(connection_mutex_);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server!";
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_dispatch(base_);
if (ret == 0) {
MS_LOG(INFO) << "Event base dispatch success!";
} else if (ret == 1) {
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
} else if (ret == -1) {
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
} else {
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
}
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
}
void TcpServer::StartWithNoBlock() {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server with no block!";
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
}
void TcpServer::Stop() {
@ -150,6 +168,8 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
delete connection;
connections_.erase(fd);
}
@ -166,10 +186,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
return;
}
TcpConnection *conn = server->onCreateConnection();
TcpConnection *conn = server->onCreateConnection(bev, fd);
MS_EXCEPTION_IF_NULL(conn);
conn->InitConnection(fd, bev, server);
conn->InitConnection();
server->AddConnection(fd, conn);
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
@ -177,17 +197,18 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
}
}
TcpConnection *TcpServer::onCreateConnection() {
TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
TcpConnection *conn = nullptr;
if (client_accept_)
conn = const_cast<TcpConnection *>(client_accept_(this));
else
conn = new TcpConnection();
if (client_accept_) {
conn = const_cast<TcpConnection *>(client_accept_(*this));
} else {
conn = new TcpConnection(bev, fd, this);
}
return conn;
}
OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; }
OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
auto server = reinterpret_cast<class TcpServer *>(data);
@ -207,9 +228,9 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
auto conn = static_cast<class TcpConnection *>(connection);
struct evbuffer *buf = bufferevent_get_input(bev);
char read_buffer[4096];
auto read = 0;
while ((read = EVBUFFER_LENGTH(buf)) > 0) {
if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) {
while (EVBUFFER_LENGTH(buf) > 0) {
int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
if (read == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
}
conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
@ -219,43 +240,47 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(data);
struct evbuffer *output = bufferevent_get_output(bev);
size_t remain = evbuffer_get_length(output);
auto conn = reinterpret_cast<TcpConnection *>(data);
TcpServer *srv = conn->GetServer();
if (events & BEV_EVENT_EOF) {
MS_LOG(INFO) << "Event buffer end of file!";
// Notify about disconnection
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
if (srv->client_disconnection_) {
srv->client_disconnection_(*srv, *conn);
}
// Free connection structures
srv->RemoveConnection(conn->GetFd());
bufferevent_free(bev);
} else if (events & BEV_EVENT_ERROR) {
MS_LOG(ERROR) << "Event buffer remain data: " << remain;
// Free connection structures
srv->RemoveConnection(conn->GetFd());
bufferevent_free(bev);
// Notify about disconnection
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
if (srv->client_disconnection_) {
srv->client_disconnection_(*srv, *conn);
}
} else {
MS_LOG(ERROR) << "Unhandled event!";
}
}
void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) {
MS_EXCEPTION_IF_NULL(data);
auto mc = const_cast<TcpConnection &>(conn);
mc.SendMessage(data, num);
}
void TcpServer::SendMessage(const void *data, size_t num) {
MS_EXCEPTION_IF_NULL(data);
void TcpServer::SendMessage(const CommMessage &message) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(*it->second, data, num);
SendMessage(*it->second, message);
}
}
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -27,6 +27,8 @@
#include <map>
#include <mutex>
#include <string>
#include <memory>
#include <vector>
#include "utils/log_adapter.h"
#include "ps/comm/tcp_message_handler.h"
@ -38,46 +40,49 @@ namespace comm {
class TcpServer;
class TcpConnection {
public:
TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {}
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server)
: buffer_event_(bev), fd_(0), server_(server) {}
virtual ~TcpConnection() = default;
virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server);
void SendMessage(const void *buffer, size_t num) const;
virtual void InitConnection();
virtual void SendMessage(const void *buffer, size_t num) const;
void SendMessage(const CommMessage &message) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const;
evutil_socket_t GetFd() const;
const evutil_socket_t &GetFd() const;
protected:
TcpMessageHandler tcp_message_handler_;
struct bufferevent *buffer_event_;
evutil_socket_t fd_;
TcpServer *server_;
const TcpServer *server_;
TcpMessageHandler tcp_message_handler_;
};
using OnServerReceiveMessage =
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>;
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>;
class TcpServer {
public:
using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>;
using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>;
using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>;
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
explicit TcpServer(std::string address, std::uint16_t port);
explicit TcpServer(const std::string &address, std::uint16_t port);
virtual ~TcpServer();
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
const OnAccepted &client_accept);
void InitServer();
void Init();
void Start();
void StartWithNoBlock();
void Stop();
void SendToAllClients(const char *data, size_t len);
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
void RemoveConnection(const evutil_socket_t &fd);
void ReceiveMessage(const OnServerReceiveMessage &cb);
static void SendMessage(const TcpConnection &conn, const void *data, size_t num);
void SendMessage(const void *data, size_t num);
OnServerReceiveMessage GetServerReceiveMessage() const;
OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb);
static void SendMessage(const TcpConnection &conn, const CommMessage &message);
void SendMessage(const CommMessage &message);
protected:
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
@ -85,9 +90,8 @@ class TcpServer {
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
static void ReadCallback(struct bufferevent *, void *connection);
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
virtual TcpConnection *onCreateConnection();
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
private:
struct event_base *base_;
struct event *signal_event_;
struct evconnlistener *listener_;
@ -101,6 +105,7 @@ class TcpServer {
std::recursive_mutex connection_mutex_;
OnServerReceiveMessage message_callback_;
};
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <iostream>
#include <string>
#include <thread>
#include <memory>
namespace mindspore {
namespace ps {
@ -31,7 +32,9 @@ namespace comm {
class TestHttpServer : public UT::Common {
public:
TestHttpServer() = default;
TestHttpServer() : server_(nullptr) {}
virtual ~TestHttpServer() = default;
static void testGetHandler(std::shared_ptr<HttpMessageHandler> resp) {
std::string host = resp->GetRequestHost();
@ -57,7 +60,7 @@ class TestHttpServer : public UT::Common {
}
void SetUp() override {
server_ = new HttpServer("0.0.0.0", 9999);
server_ = std::make_unique<HttpServer>("0.0.0.0", 9999);
OnRequestReceive http_get_func = std::bind(
[](std::shared_ptr<HttpMessageHandler> resp) {
EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1");
@ -106,7 +109,7 @@ class TestHttpServer : public UT::Common {
}
private:
HttpServer *server_;
std::unique_ptr<HttpServer> server_;
};
TEST_F(TestHttpServer, httpGetQequest) {
@ -143,13 +146,13 @@ TEST_F(TestHttpServer, messageHandler) {
}
TEST_F(TestHttpServer, portErrorNoException) {
HttpServer *server_exception = new HttpServer("0.0.0.0", -1);
auto server_exception = std::make_unique<HttpServer>("0.0.0.0", -1);
OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func));
}
TEST_F(TestHttpServer, addressException) {
HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998);
auto server_exception = std::make_unique<HttpServer>("12344.0.0.0", 9998);
OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception);
}

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "common/common_test.h"
#include "ps/comm/tcp_client.h"
@ -26,19 +28,19 @@ class TestTcpClient : public UT::Common {
};
TEST_F(TestTcpClient, InitClientIPError) {
auto client = new TcpClient("127.0.0.13543", 9000);
client->ReceiveMessage(
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);
ASSERT_THROW(client->InitTcpClient(), std::exception);
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
ASSERT_THROW(client->Init(), std::exception);
}
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
auto client = new TcpClient("127.0.0.1", -1);
client->ReceiveMessage(
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
auto client = std::make_unique<TcpClient>("127.0.0.1", -1);
EXPECT_NO_THROW(client->InitTcpClient());
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
EXPECT_NO_THROW(client->Init());
}
} // namespace comm

View File

@ -18,6 +18,7 @@
#include "ps/comm/tcp_server.h"
#include "common/common_test.h"
#include <memory>
#include <thread>
namespace mindspore {
@ -25,16 +26,20 @@ namespace ps {
namespace comm {
class TestTcpServer : public UT::Common {
public:
TestTcpServer() = default;
TestTcpServer() : client_(nullptr), server_(nullptr) {}
virtual ~TestTcpServer() = default;
void SetUp() override {
server_ = new TcpServer("127.0.0.1", 9000);
server_ = std::make_unique<TcpServer>("127.0.0.1", 9998);
std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([&]() {
server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) {
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
server.SendMessage(conn, buffer, num);
server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
KVMessage kv_message;
kv_message.ParseFromString(message.data());
EXPECT_EQ(2, kv_message.keys_size());
server.SendMessage(conn, message);
});
server_->InitServer();
server_->Init();
server_->Start();
});
http_server_thread_->detach();
@ -47,21 +52,32 @@ class TestTcpServer : public UT::Common {
server_->Stop();
}
TcpClient *client_;
TcpServer *server_;
const std::string test_message_ = "TCP_MESSAGE";
std::unique_ptr<TcpClient> client_;
std::unique_ptr<TcpServer> server_;
};
TEST_F(TestTcpServer, ServerSendMessage) {
client_ = new TcpClient("127.0.0.1", 9000);
client_ = std::make_unique<TcpClient>("127.0.0.1", 9998);
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {
client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) {
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
KVMessage kv_message;
kv_message.ParseFromString(message.data());
EXPECT_EQ(2, kv_message.keys_size());
});
client_->InitTcpClient();
client_->SendMessage(test_message_.c_str(), test_message_.size());
client_->Init();
CommMessage comm_message;
KVMessage kv_message;
std::vector<int> keys{1, 2};
std::vector<int> values{3, 4};
*kv_message.mutable_keys() = {keys.begin(), keys.end()};
*kv_message.mutable_values() = {values.begin(), values.end()};
comm_message.set_data(kv_message.SerializeAsString());
client_->SendMessage(comm_message);
client_->Start();
});
http_client_thread->detach();