forked from mindspore-Ecosystem/mindspore
!7911 added key value message
Merge pull request !7911 from anancds/kv-patch
This commit is contained in:
commit
7123a2c1d1
|
@ -100,6 +100,11 @@ message("onnx proto path is :" ${ONNX_PROTO})
|
|||
ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO})
|
||||
list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS})
|
||||
|
||||
include_directories("${CMAKE_BINARY_DIR}/ps/comm")
|
||||
file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto")
|
||||
ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN})
|
||||
list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS})
|
||||
|
||||
if (ENABLE_DEBUGGER)
|
||||
# debugger: compile proto files
|
||||
include_directories("${CMAKE_BINARY_DIR}/debug/debugger")
|
||||
|
@ -290,7 +295,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
|||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
|
||||
else ()
|
||||
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
target_link_libraries(mindspore mindspore::pslite mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
|
||||
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
|
||||
if (${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
target_link_libraries(mindspore ibverbs rdmacm)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
import "google/protobuf/any.proto";
|
||||
package mindspore.ps;
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
||||
message MessageMeta {
|
||||
// hostname or ip
|
||||
string hostname = 1;
|
||||
// the port of this node
|
||||
int32 port = 2;
|
||||
// the command of this message,for example: register、heartbeat、data
|
||||
int32 cmd = 3;
|
||||
// the timestamp of this message
|
||||
int32 timestamp = 4;
|
||||
// data type of message
|
||||
repeated int32 data_type = 5 [packed = true];
|
||||
// message.data_size
|
||||
int32 data_size = 6;
|
||||
}
|
||||
|
||||
|
||||
message CommMessage {
|
||||
MessageMeta pb_meta = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
message KVMessage {
|
||||
repeated int32 keys = 1;
|
||||
repeated float values = 2;
|
||||
}
|
||||
|
||||
message HeartBeatMessage {
|
||||
// *.*.*.*:port
|
||||
repeated string host_and_port = 1;
|
||||
}
|
|
@ -36,18 +36,16 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
TcpClient::TcpClient(std::string address, std::uint16_t port)
|
||||
TcpClient::TcpClient(const std::string &address, std::uint16_t port)
|
||||
: event_base_(nullptr),
|
||||
event_timeout_(nullptr),
|
||||
buffer_event_(nullptr),
|
||||
server_address_(std::move(address)),
|
||||
server_port_(port) {
|
||||
message_handler_.SetCallback([this](const void *buf, size_t num) {
|
||||
if (buf == nullptr) {
|
||||
if (disconnected_callback_) disconnected_callback_(*this, 200);
|
||||
Stop();
|
||||
message_handler_.SetCallback([this](const CommMessage &message) {
|
||||
if (message_callback_) {
|
||||
message_callback_(*this, message);
|
||||
}
|
||||
if (message_callback_) message_callback_(*this, buf, num);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -63,7 +61,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco
|
|||
timeout_callback_ = timeout;
|
||||
}
|
||||
|
||||
void TcpClient::InitTcpClient() {
|
||||
void TcpClient::Init() {
|
||||
if (buffer_event_) {
|
||||
return;
|
||||
}
|
||||
|
@ -139,7 +137,7 @@ void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
|
|||
void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
|
||||
tcp_client->InitTcpClient();
|
||||
tcp_client->Init();
|
||||
}
|
||||
|
||||
void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
|
||||
|
@ -150,10 +148,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
|
|||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
||||
char read_buffer[4096];
|
||||
int read = 0;
|
||||
|
||||
while ((read = EVBUFFER_LENGTH(input)) > 0) {
|
||||
if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) {
|
||||
while (EVBUFFER_LENGTH(input) > 0) {
|
||||
int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer));
|
||||
if (read == -1) {
|
||||
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
|
||||
}
|
||||
tcp_client->OnReadHandler(read_buffer, read);
|
||||
|
@ -196,25 +194,38 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void
|
|||
void TcpClient::Start() {
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
int ret = event_base_dispatch(event_base_);
|
||||
if (ret == 0) {
|
||||
MS_LOG(INFO) << "Event base dispatch success!";
|
||||
} else if (ret == 1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
|
||||
} else if (ret == -1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
|
||||
}
|
||||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
||||
<< "Event base dispatch failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
|
||||
}
|
||||
|
||||
void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; }
|
||||
void TcpClient::StartWithNoBlock() {
|
||||
MS_LOG(INFO) << "Start tcp client with no block!";
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
|
||||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
|
||||
}
|
||||
|
||||
void TcpClient::SendMessage(const void *buf, size_t num) const {
|
||||
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
|
||||
|
||||
void TcpClient::SendMessage(const CommMessage &message) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add failed!";
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
|
||||
}
|
||||
if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,10 @@
|
|||
#include <event2/bufferevent.h>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -30,24 +34,25 @@ namespace comm {
|
|||
|
||||
class TcpClient {
|
||||
public:
|
||||
using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||
using OnConnected = std::function<void(const TcpClient &)>;
|
||||
using OnDisconnected = std::function<void(const TcpClient &, int)>;
|
||||
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||
using OnTimeout = std::function<void(const TcpClient &)>;
|
||||
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
|
||||
|
||||
explicit TcpClient(std::string address, std::uint16_t port);
|
||||
explicit TcpClient(const std::string &address, std::uint16_t port);
|
||||
virtual ~TcpClient();
|
||||
|
||||
std::string GetServerAddress() const;
|
||||
void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
|
||||
const OnTimeout &timeout);
|
||||
void InitTcpClient();
|
||||
void Init();
|
||||
void StartWithDelay(int seconds);
|
||||
void Stop();
|
||||
void ReceiveMessage(const OnMessage &cb);
|
||||
void SendMessage(const void *buf, size_t num) const;
|
||||
void Start();
|
||||
void StartWithNoBlock();
|
||||
void SetMessageCallback(const OnMessage &cb);
|
||||
void SendMessage(const CommMessage &message) const;
|
||||
|
||||
protected:
|
||||
static void SetTcpNoDelay(const evutil_socket_t &fd);
|
||||
|
@ -57,8 +62,9 @@ class TcpClient {
|
|||
virtual void OnReadHandler(const void *buf, size_t num);
|
||||
|
||||
private:
|
||||
TcpMessageHandler message_handler_;
|
||||
OnMessage message_callback_;
|
||||
TcpMessageHandler message_handler_;
|
||||
|
||||
OnConnected connected_callback_;
|
||||
OnDisconnected disconnected_callback_;
|
||||
OnRead read_callback_;
|
||||
|
@ -71,6 +77,7 @@ class TcpClient {
|
|||
std::string server_address_;
|
||||
std::uint16_t server_port_;
|
||||
};
|
||||
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "ps/comm/tcp_message_handler.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
|
@ -22,15 +24,55 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); }
|
||||
void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; }
|
||||
|
||||
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(buffer);
|
||||
auto buffer_data = reinterpret_cast<const unsigned char *>(buffer);
|
||||
|
||||
if (message_callback_) {
|
||||
message_callback_(buffer, num);
|
||||
while (num > 0) {
|
||||
if (remaining_length_ == 0) {
|
||||
for (int i = 0; i < 4 && num > 0; ++i) {
|
||||
header_[++header_index_] = *(buffer_data + i);
|
||||
--num;
|
||||
if (header_index_ == 3) {
|
||||
message_length_ = *reinterpret_cast<const uint32_t *>(header_);
|
||||
message_length_ = ntohl(message_length_);
|
||||
remaining_length_ = message_length_;
|
||||
message_buffer_.reset(new unsigned char[remaining_length_]);
|
||||
buffer_data += i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (remaining_length_ > 0) {
|
||||
uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
|
||||
remaining_length_ -= copy_len;
|
||||
num -= copy_len;
|
||||
|
||||
int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len);
|
||||
last_copy_len_ += copy_len;
|
||||
buffer_data += copy_len;
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
if (remaining_length_ == 0) {
|
||||
CommMessage pb_message;
|
||||
pb_message.ParseFromArray(reinterpret_cast<const void *>(message_buffer_.get()), message_length_);
|
||||
if (message_callback_) {
|
||||
message_callback_(pb_message);
|
||||
}
|
||||
message_buffer_.reset();
|
||||
message_buffer_ = nullptr;
|
||||
header_index_ = 0;
|
||||
last_copy_len_ = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,26 +19,43 @@
|
|||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
using messageReceive = std::function<void(const void *buffer, size_t len)>;
|
||||
using messageReceive = std::function<void(const CommMessage &message)>;
|
||||
|
||||
class TcpMessageHandler {
|
||||
public:
|
||||
TcpMessageHandler() = default;
|
||||
TcpMessageHandler()
|
||||
: is_parsed_(false),
|
||||
message_buffer_(nullptr),
|
||||
message_length_(0),
|
||||
remaining_length_(0),
|
||||
header_index_(-1),
|
||||
last_copy_len_(0) {}
|
||||
virtual ~TcpMessageHandler() = default;
|
||||
|
||||
void SetCallback(messageReceive cb);
|
||||
void SetCallback(const messageReceive &cb);
|
||||
void ReceiveMessage(const void *buffer, size_t num);
|
||||
|
||||
private:
|
||||
messageReceive message_callback_;
|
||||
bool is_parsed_;
|
||||
std::unique_ptr<unsigned char> message_buffer_;
|
||||
size_t message_length_;
|
||||
uint32_t remaining_length_;
|
||||
char header_[4];
|
||||
int header_index_;
|
||||
uint32_t last_copy_len_;
|
||||
};
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
|
|
|
@ -33,16 +33,12 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace comm {
|
||||
|
||||
void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
buffer_event_ = const_cast<struct bufferevent *>(bev);
|
||||
fd_ = fd;
|
||||
server_ = const_cast<TcpServer *>(server);
|
||||
|
||||
tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) {
|
||||
OnServerReceiveMessage message_callback = server->GetServerReceiveMessage();
|
||||
if (message_callback) message_callback(*server, *this, buf, num);
|
||||
void TcpConnection::InitConnection() {
|
||||
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
|
||||
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
|
||||
if (on_server_receive) {
|
||||
on_server_receive(*server_, *this, message);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -54,11 +50,26 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const {
|
|||
}
|
||||
}
|
||||
|
||||
TcpServer *TcpConnection::GetServer() const { return server_; }
|
||||
TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); }
|
||||
|
||||
evutil_socket_t TcpConnection::GetFd() const { return fd_; }
|
||||
const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
|
||||
|
||||
TcpServer::TcpServer(std::string address, std::uint16_t port)
|
||||
void TcpConnection::SendMessage(const CommMessage &message) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
|
||||
sizeof(buf_size)) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
|
||||
}
|
||||
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(),
|
||||
buf_size) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
|
||||
}
|
||||
}
|
||||
|
||||
TcpServer::TcpServer(const std::string &address, std::uint16_t port)
|
||||
: base_(nullptr),
|
||||
signal_event_(nullptr),
|
||||
listener_(nullptr),
|
||||
|
@ -74,7 +85,7 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
|
|||
this->client_accept_ = client_accept;
|
||||
}
|
||||
|
||||
void TcpServer::InitServer() {
|
||||
void TcpServer::Init() {
|
||||
base_ = event_base_new();
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
CommUtil::CheckIp(server_address_);
|
||||
|
@ -101,19 +112,26 @@ void TcpServer::InitServer() {
|
|||
}
|
||||
|
||||
void TcpServer::Start() {
|
||||
std::unique_lock<std::recursive_mutex> l(connection_mutex_);
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
MS_LOG(INFO) << "Start tcp server!";
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
int ret = event_base_dispatch(base_);
|
||||
if (ret == 0) {
|
||||
MS_LOG(INFO) << "Event base dispatch success!";
|
||||
} else if (ret == 1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
|
||||
} else if (ret == -1) {
|
||||
MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
|
||||
}
|
||||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
||||
<< "Event base dispatch failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
|
||||
}
|
||||
|
||||
void TcpServer::StartWithNoBlock() {
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
MS_LOG(INFO) << "Start tcp server with no block!";
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
|
||||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
|
||||
}
|
||||
|
||||
void TcpServer::Stop() {
|
||||
|
@ -150,6 +168,8 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co
|
|||
|
||||
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
|
||||
delete connection;
|
||||
connections_.erase(fd);
|
||||
}
|
||||
|
||||
|
@ -166,10 +186,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|||
return;
|
||||
}
|
||||
|
||||
TcpConnection *conn = server->onCreateConnection();
|
||||
TcpConnection *conn = server->onCreateConnection(bev, fd);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
|
||||
conn->InitConnection(fd, bev, server);
|
||||
conn->InitConnection();
|
||||
server->AddConnection(fd, conn);
|
||||
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
|
||||
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
|
||||
|
@ -177,17 +197,18 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|||
}
|
||||
}
|
||||
|
||||
TcpConnection *TcpServer::onCreateConnection() {
|
||||
TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
|
||||
TcpConnection *conn = nullptr;
|
||||
if (client_accept_)
|
||||
conn = const_cast<TcpConnection *>(client_accept_(this));
|
||||
else
|
||||
conn = new TcpConnection();
|
||||
if (client_accept_) {
|
||||
conn = const_cast<TcpConnection *>(client_accept_(*this));
|
||||
} else {
|
||||
conn = new TcpConnection(bev, fd, this);
|
||||
}
|
||||
|
||||
return conn;
|
||||
}
|
||||
|
||||
OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; }
|
||||
OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
|
||||
|
||||
void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
|
||||
auto server = reinterpret_cast<class TcpServer *>(data);
|
||||
|
@ -207,9 +228,9 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
|
|||
auto conn = static_cast<class TcpConnection *>(connection);
|
||||
struct evbuffer *buf = bufferevent_get_input(bev);
|
||||
char read_buffer[4096];
|
||||
auto read = 0;
|
||||
while ((read = EVBUFFER_LENGTH(buf)) > 0) {
|
||||
if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) {
|
||||
while (EVBUFFER_LENGTH(buf) > 0) {
|
||||
int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
|
||||
if (read == -1) {
|
||||
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
|
||||
}
|
||||
conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
|
||||
|
@ -219,43 +240,47 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
|
|||
void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
struct evbuffer *output = bufferevent_get_output(bev);
|
||||
size_t remain = evbuffer_get_length(output);
|
||||
auto conn = reinterpret_cast<TcpConnection *>(data);
|
||||
TcpServer *srv = conn->GetServer();
|
||||
|
||||
if (events & BEV_EVENT_EOF) {
|
||||
MS_LOG(INFO) << "Event buffer end of file!";
|
||||
// Notify about disconnection
|
||||
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
|
||||
if (srv->client_disconnection_) {
|
||||
srv->client_disconnection_(*srv, *conn);
|
||||
}
|
||||
// Free connection structures
|
||||
srv->RemoveConnection(conn->GetFd());
|
||||
bufferevent_free(bev);
|
||||
} else if (events & BEV_EVENT_ERROR) {
|
||||
MS_LOG(ERROR) << "Event buffer remain data: " << remain;
|
||||
// Free connection structures
|
||||
srv->RemoveConnection(conn->GetFd());
|
||||
bufferevent_free(bev);
|
||||
|
||||
// Notify about disconnection
|
||||
if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
|
||||
if (srv->client_disconnection_) {
|
||||
srv->client_disconnection_(*srv, *conn);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unhandled event!";
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
|
||||
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
|
||||
|
||||
void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto mc = const_cast<TcpConnection &>(conn);
|
||||
mc.SendMessage(data, num);
|
||||
}
|
||||
|
||||
void TcpServer::SendMessage(const void *data, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
void TcpServer::SendMessage(const CommMessage &message) {
|
||||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
|
||||
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
|
||||
SendMessage(*it->second, data, num);
|
||||
SendMessage(*it->second, message);
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
|
||||
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,8 @@
|
|||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/comm/tcp_message_handler.h"
|
||||
|
@ -38,46 +40,49 @@ namespace comm {
|
|||
class TcpServer;
|
||||
class TcpConnection {
|
||||
public:
|
||||
TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {}
|
||||
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server)
|
||||
: buffer_event_(bev), fd_(0), server_(server) {}
|
||||
virtual ~TcpConnection() = default;
|
||||
|
||||
virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server);
|
||||
void SendMessage(const void *buffer, size_t num) const;
|
||||
virtual void InitConnection();
|
||||
virtual void SendMessage(const void *buffer, size_t num) const;
|
||||
void SendMessage(const CommMessage &message) const;
|
||||
virtual void OnReadHandler(const void *buffer, size_t numBytes);
|
||||
TcpServer *GetServer() const;
|
||||
evutil_socket_t GetFd() const;
|
||||
const evutil_socket_t &GetFd() const;
|
||||
|
||||
protected:
|
||||
TcpMessageHandler tcp_message_handler_;
|
||||
struct bufferevent *buffer_event_;
|
||||
evutil_socket_t fd_;
|
||||
TcpServer *server_;
|
||||
const TcpServer *server_;
|
||||
TcpMessageHandler tcp_message_handler_;
|
||||
};
|
||||
|
||||
using OnServerReceiveMessage =
|
||||
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>;
|
||||
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>;
|
||||
|
||||
class TcpServer {
|
||||
public:
|
||||
using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>;
|
||||
using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>;
|
||||
using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>;
|
||||
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
|
||||
|
||||
explicit TcpServer(std::string address, std::uint16_t port);
|
||||
explicit TcpServer(const std::string &address, std::uint16_t port);
|
||||
virtual ~TcpServer();
|
||||
|
||||
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
|
||||
const OnAccepted &client_accept);
|
||||
void InitServer();
|
||||
void Init();
|
||||
void Start();
|
||||
void StartWithNoBlock();
|
||||
void Stop();
|
||||
void SendToAllClients(const char *data, size_t len);
|
||||
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
|
||||
void RemoveConnection(const evutil_socket_t &fd);
|
||||
void ReceiveMessage(const OnServerReceiveMessage &cb);
|
||||
static void SendMessage(const TcpConnection &conn, const void *data, size_t num);
|
||||
void SendMessage(const void *data, size_t num);
|
||||
OnServerReceiveMessage GetServerReceiveMessage() const;
|
||||
OnServerReceiveMessage GetServerReceive() const;
|
||||
void SetMessageCallback(const OnServerReceiveMessage &cb);
|
||||
static void SendMessage(const TcpConnection &conn, const CommMessage &message);
|
||||
void SendMessage(const CommMessage &message);
|
||||
|
||||
protected:
|
||||
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
|
||||
|
@ -85,9 +90,8 @@ class TcpServer {
|
|||
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
|
||||
static void ReadCallback(struct bufferevent *, void *connection);
|
||||
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
|
||||
virtual TcpConnection *onCreateConnection();
|
||||
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
|
||||
|
||||
private:
|
||||
struct event_base *base_;
|
||||
struct event *signal_event_;
|
||||
struct evconnlistener *listener_;
|
||||
|
@ -101,6 +105,7 @@ class TcpServer {
|
|||
std::recursive_mutex connection_mutex_;
|
||||
OnServerReceiveMessage message_callback_;
|
||||
};
|
||||
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -31,7 +32,9 @@ namespace comm {
|
|||
|
||||
class TestHttpServer : public UT::Common {
|
||||
public:
|
||||
TestHttpServer() = default;
|
||||
TestHttpServer() : server_(nullptr) {}
|
||||
|
||||
virtual ~TestHttpServer() = default;
|
||||
|
||||
static void testGetHandler(std::shared_ptr<HttpMessageHandler> resp) {
|
||||
std::string host = resp->GetRequestHost();
|
||||
|
@ -57,7 +60,7 @@ class TestHttpServer : public UT::Common {
|
|||
}
|
||||
|
||||
void SetUp() override {
|
||||
server_ = new HttpServer("0.0.0.0", 9999);
|
||||
server_ = std::make_unique<HttpServer>("0.0.0.0", 9999);
|
||||
OnRequestReceive http_get_func = std::bind(
|
||||
[](std::shared_ptr<HttpMessageHandler> resp) {
|
||||
EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1");
|
||||
|
@ -106,7 +109,7 @@ class TestHttpServer : public UT::Common {
|
|||
}
|
||||
|
||||
private:
|
||||
HttpServer *server_;
|
||||
std::unique_ptr<HttpServer> server_;
|
||||
};
|
||||
|
||||
TEST_F(TestHttpServer, httpGetQequest) {
|
||||
|
@ -143,13 +146,13 @@ TEST_F(TestHttpServer, messageHandler) {
|
|||
}
|
||||
|
||||
TEST_F(TestHttpServer, portErrorNoException) {
|
||||
HttpServer *server_exception = new HttpServer("0.0.0.0", -1);
|
||||
auto server_exception = std::make_unique<HttpServer>("0.0.0.0", -1);
|
||||
OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
|
||||
EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func));
|
||||
}
|
||||
|
||||
TEST_F(TestHttpServer, addressException) {
|
||||
HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998);
|
||||
auto server_exception = std::make_unique<HttpServer>("12344.0.0.0", 9998);
|
||||
OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1);
|
||||
ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ps/comm/tcp_client.h"
|
||||
|
||||
|
@ -26,19 +28,19 @@ class TestTcpClient : public UT::Common {
|
|||
};
|
||||
|
||||
TEST_F(TestTcpClient, InitClientIPError) {
|
||||
auto client = new TcpClient("127.0.0.13543", 9000);
|
||||
client->ReceiveMessage(
|
||||
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
|
||||
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);
|
||||
|
||||
ASSERT_THROW(client->InitTcpClient(), std::exception);
|
||||
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
|
||||
|
||||
ASSERT_THROW(client->Init(), std::exception);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
|
||||
auto client = new TcpClient("127.0.0.1", -1);
|
||||
client->ReceiveMessage(
|
||||
[](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); });
|
||||
auto client = std::make_unique<TcpClient>("127.0.0.1", -1);
|
||||
|
||||
EXPECT_NO_THROW(client->InitTcpClient());
|
||||
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
|
||||
|
||||
EXPECT_NO_THROW(client->Init());
|
||||
}
|
||||
|
||||
} // namespace comm
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "ps/comm/tcp_server.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -25,16 +26,20 @@ namespace ps {
|
|||
namespace comm {
|
||||
class TestTcpServer : public UT::Common {
|
||||
public:
|
||||
TestTcpServer() = default;
|
||||
TestTcpServer() : client_(nullptr), server_(nullptr) {}
|
||||
virtual ~TestTcpServer() = default;
|
||||
|
||||
void SetUp() override {
|
||||
server_ = new TcpServer("127.0.0.1", 9000);
|
||||
server_ = std::make_unique<TcpServer>("127.0.0.1", 9998);
|
||||
std::unique_ptr<std::thread> http_server_thread_(nullptr);
|
||||
http_server_thread_ = std::make_unique<std::thread>([&]() {
|
||||
server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) {
|
||||
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
|
||||
server.SendMessage(conn, buffer, num);
|
||||
server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
KVMessage kv_message;
|
||||
kv_message.ParseFromString(message.data());
|
||||
EXPECT_EQ(2, kv_message.keys_size());
|
||||
server.SendMessage(conn, message);
|
||||
});
|
||||
server_->InitServer();
|
||||
server_->Init();
|
||||
server_->Start();
|
||||
});
|
||||
http_server_thread_->detach();
|
||||
|
@ -47,21 +52,32 @@ class TestTcpServer : public UT::Common {
|
|||
server_->Stop();
|
||||
}
|
||||
|
||||
TcpClient *client_;
|
||||
TcpServer *server_;
|
||||
const std::string test_message_ = "TCP_MESSAGE";
|
||||
std::unique_ptr<TcpClient> client_;
|
||||
std::unique_ptr<TcpServer> server_;
|
||||
};
|
||||
|
||||
TEST_F(TestTcpServer, ServerSendMessage) {
|
||||
client_ = new TcpClient("127.0.0.1", 9000);
|
||||
client_ = std::make_unique<TcpClient>("127.0.0.1", 9998);
|
||||
std::unique_ptr<std::thread> http_client_thread(nullptr);
|
||||
http_client_thread = std::make_unique<std::thread>([&]() {
|
||||
client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) {
|
||||
EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE");
|
||||
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
|
||||
KVMessage kv_message;
|
||||
kv_message.ParseFromString(message.data());
|
||||
EXPECT_EQ(2, kv_message.keys_size());
|
||||
});
|
||||
|
||||
client_->InitTcpClient();
|
||||
client_->SendMessage(test_message_.c_str(), test_message_.size());
|
||||
client_->Init();
|
||||
|
||||
CommMessage comm_message;
|
||||
KVMessage kv_message;
|
||||
std::vector<int> keys{1, 2};
|
||||
std::vector<int> values{3, 4};
|
||||
*kv_message.mutable_keys() = {keys.begin(), keys.end()};
|
||||
*kv_message.mutable_values() = {values.begin(), values.end()};
|
||||
|
||||
comm_message.set_data(kv_message.SerializeAsString());
|
||||
client_->SendMessage(comm_message);
|
||||
|
||||
client_->Start();
|
||||
});
|
||||
http_client_thread->detach();
|
Loading…
Reference in New Issue