diff --git a/mindspore/ccsrc/ps/core/cluster_config.cc b/mindspore/ccsrc/ps/core/cluster_config.cc index 0b8a00c89a9..2fb9052cd9d 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.cc +++ b/mindspore/ccsrc/ps/core/cluster_config.cc @@ -21,12 +21,16 @@ namespace mindspore { namespace ps { namespace core { - uint32_t ClusterConfig::worker_num_ = 0; uint32_t ClusterConfig::server_num_ = 0; -uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval; std::unique_ptr ClusterConfig::scheduler_host_ = nullptr; uint16_t ClusterConfig::scheduler_port_ = 0; +// The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. +uint32_t ClusterConfig::heartbeat_interval_ = 3; +// The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds. +uint32_t ClusterConfig::heartbeat_timeout_ = 30; +// Timeout period for cluster preparation is 300 seconds. +uint32_t ClusterConfig::cluster_available_timeout_ = 300; void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr scheduler_host, const uint16_t &scheduler_port) { @@ -53,6 +57,18 @@ std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } +uint32_t ClusterConfig::heartbeat_timeout() { return heartbeat_timeout_; } + +void ClusterConfig::set_heartbeat_timeout(const uint32_t &heartbeat_timeout) { + heartbeat_interval_ = heartbeat_timeout; +} + +uint32_t ClusterConfig::cluster_available_timeout() { return cluster_available_timeout_; } + +void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_available_timeout) { + cluster_available_timeout_ = cluster_available_timeout; +} + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/cluster_config.h b/mindspore/ccsrc/ps/core/cluster_config.h index ea7bd68b355..7cf379c6844 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.h +++ b/mindspore/ccsrc/ps/core/cluster_config.h @@ -28,8 +28,6 @@ namespace mindspore { namespace ps { namespace core { -constexpr uint32_t kHeartbeatInterval = 3; - class ClusterConfig { public: static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr scheduler_host, @@ -40,6 +38,10 @@ class ClusterConfig { static void set_heartbeat_interval(const uint32_t &heartbeat_interval); static std::string scheduler_host(); static uint16_t scheduler_port(); + static uint32_t heartbeat_timeout(); + static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout); + static uint32_t cluster_available_timeout(); + static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); private: static uint32_t worker_num_; @@ -47,6 +49,8 @@ class ClusterConfig { static uint32_t heartbeat_interval_; static std::unique_ptr scheduler_host_; static uint16_t scheduler_port_; + static uint32_t heartbeat_timeout_; + static uint32_t cluster_available_timeout_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 71bc9d0d599..28fb5ed658f 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -21,11 +21,17 @@ #include #include #include +#include #include namespace mindspore { namespace ps { namespace core { +std::random_device CommUtil::rd; +std::mt19937_64 CommUtil::gen(rd()); +std::uniform_int_distribution<> CommUtil::dis = std::uniform_int_distribution<>{0, 15}; +std::uniform_int_distribution<> CommUtil::dis2 = std::uniform_int_distribution<>{8, 11}; + bool CommUtil::CheckIpWithRegex(const std::string &ip) { std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); std::smatch res; @@ -75,6 +81,34 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i MS_EXCEPTION_IF_NULL(if_address); freeifaddrs(if_address); } + +std::string CommUtil::GenerateUUID() { + std::stringstream ss; + int i; + ss << std::hex; + for (i = 0; i < kGroup1RandomLength; i++) { + ss << dis(gen); + } + ss << "-"; + for (i = 0; i < kGroup2RandomLength; i++) { + ss << dis(gen); + } + ss << "-4"; + for (i = 0; i < kGroup2RandomLength - 1; i++) { + ss << dis(gen); + } + ss << "-"; + ss << dis2(gen); + for (i = 0; i < kGroup3RandomLength - 1; i++) { + ss << dis(gen); + } + ss << "-"; + for (i = 0; i < kGroup4RandomLength; i++) { + ss << dis(gen); + } + return ss.str(); +} + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index 651ba1c3e8f..62b8b76b254 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -43,17 +43,31 @@ #include #include #include +#include +#include #include "utils/log_adapter.h" namespace mindspore { namespace ps { namespace core { +constexpr int kGroup1RandomLength = 8; +constexpr int kGroup2RandomLength = 4; +constexpr int kGroup3RandomLength = 4; +constexpr int kGroup4RandomLength = 4; +constexpr int kGroup5RandomLength = 12; + class CommUtil { public: static bool CheckIpWithRegex(const std::string &ip); static bool CheckIp(const std::string &ip); static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); + static std::string GenerateUUID(); + + static std::random_device rd; + static std::mt19937_64 gen; + static std::uniform_int_distribution<> dis; + static std::uniform_int_distribution<> dis2; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 9862ab998b4..2b76a8814d7 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -19,36 +19,64 @@ import "google/protobuf/any.proto"; package mindspore.ps.core; option optimize_for = LITE_RUNTIME; -enum ClusterCommand { +enum NodeCommand { TERMINATE = 0; REGISTER = 1; - ACK = 2; - HEARTBEAT = 3; - FETCH_WORKERS = 4; - FETCH_SERVERS = 5; + HEARTBEAT = 2; + SEND_DATA = 3; + FETCH_SERVER = 4; } -enum Role { +enum NodeRole { SERVER = 0; WORKER = 1; SCHEDULER = 2; } message MessageMeta { - // hostname or ip - string hostname = 1; - // the port of this node - int32 port = 2; - // the command of this message,for example: register、heartbeat、data - int32 cmd = 3; - // the timestamp of this message - int32 timestamp = 4; - // data type of message - repeated int32 data_type = 5 [packed = true]; - // message.data_size - int32 data_size = 6; + // the command of this message,for example: register,heartbeat,data + NodeCommand cmd = 1; + // the request id of this message + uint64 request_id = 2; } +message RegisterMessage { + // ip + string ip = 1; + // the port of this node + int32 port = 2; + // the current Node unique id:0,1,2... + string node_id = 3; + // the role of the node: worker,server,scheduler + NodeRole role = 4; +} + +message RegisterRespMessage { + string node_id = 1; + int32 rank_id = 2; +} + +message HeartbeatMessage { + // the current Node unique id:0,1,2... + string node_id = 1; +} + +message HeartbeatRespMessage { + // Is the entire system ready to use. + bool is_cluster_ready = 1; + bool is_cluster_finish = 2; +} + +message FetchServersRespMessage { + repeated ServersMeta servers_meta = 1; +} + +message ServersMeta { + int32 rank_id = 1; + string ip = 2; + int32 port = 3; + +} message CommMessage { MessageMeta pb_meta = 1; diff --git a/mindspore/ccsrc/ps/core/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto index cd5835ed14e..1516af4b087 100644 --- a/mindspore/ccsrc/ps/core/protos/ps.proto +++ b/mindspore/ccsrc/ps/core/protos/ps.proto @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - syntax = "proto3"; package mindspore.ps.core; option optimize_for = LITE_RUNTIME; -message KVMessage { - repeated int32 keys = 1; - repeated float values = 2; +enum PSCommand { + PUSH = 0; + PULL = 1; } -message HeartBeatMessage { - // *.*.*.*:port - repeated string host_and_port = 1; +message KVMessage { + PSCommand command = 1; + repeated int32 keys = 2; + repeated float values = 3; } \ No newline at end of file diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index ece0f3c1d23..fcc0ec44b9d 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -18,8 +18,8 @@ #include #include -#include #include +#include #include #include #include @@ -27,20 +27,23 @@ #include #include #include -#include #include +#include #include "ps/core/comm_util.h" namespace mindspore { namespace ps { namespace core { + +event_base *TcpClient::event_base_ = nullptr; + TcpClient::TcpClient(const std::string &address, std::uint16_t port) - : event_base_(nullptr), - event_timeout_(nullptr), + : event_timeout_(nullptr), buffer_event_(nullptr), server_address_(std::move(address)), - server_port_(port) { + server_port_(port), + is_stop_(true) { message_handler_.SetCallback([this](const CommMessage &message) { if (message_callback_) { message_callback_(*this, message); @@ -61,6 +64,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco } void TcpClient::Init() { + std::lock_guard lock(connection_mutex_); if (buffer_event_) { return; } @@ -68,7 +72,13 @@ void TcpClient::Init() { MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; } - event_base_ = event_base_new(); + int result = evthread_use_pthreads(); + if (result != 0) { + MS_LOG(EXCEPTION) << "Use event pthread failed!"; + } + if (event_base_ == nullptr) { + event_base_ = event_base_new(); + } MS_EXCEPTION_IF_NULL(event_base_); sockaddr_in sin{}; @@ -94,6 +104,7 @@ void TcpClient::Init() { } void TcpClient::StartWithDelay(int seconds) { + std::lock_guard lock(connection_mutex_); if (buffer_event_) { return; } @@ -111,16 +122,28 @@ void TcpClient::StartWithDelay(int seconds) { } void TcpClient::Stop() { - if (buffer_event_) { - bufferevent_free(buffer_event_); - buffer_event_ = nullptr; - } + std::lock_guard lock(connection_mutex_); + MS_LOG(INFO) << "Stop tcp client event buffer!"; + if (!is_stop_.load()) { + if (buffer_event_) { + bufferevent_free(buffer_event_); + buffer_event_ = nullptr; + } - if (event_timeout_) { - event_free(event_timeout_); - event_timeout_ = nullptr; + if (event_timeout_) { + event_free(event_timeout_); + event_timeout_ = nullptr; + } + is_stop_ = true; } +} +void TcpClient::StopEventBase() { + MS_LOG(INFO) << "Stop tcp client event base!"; + int ret = event_base_loopbreak(event_base_); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Event base loop break failed!"; + } if (event_base_) { event_base_free(event_base_); event_base_ = nullptr; @@ -167,21 +190,12 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) { message_handler_.ReceiveMessage(buf, num); } -void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) { +void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { MS_EXCEPTION_IF_NULL(arg); auto tcp_client = reinterpret_cast(arg); - MessageMeta meta; - meta.set_cmd(ClusterCommand::HEARTBEAT); - CommMessage message; - message.set_allocated_pb_meta(&meta); - tcp_client->SendMessage(message); - - struct event *ev; - struct timeval timeout {}; - timeout.tv_sec = ClusterConfig::heartbeat_interval(); - timeout.tv_usec = 0; - ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg); - evtimer_add(ev, &timeout); + if (tcp_client->on_timer_callback_) { + tcp_client->on_timer_callback_(*tcp_client); + } } void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { @@ -211,6 +225,7 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void void TcpClient::Start() { MS_EXCEPTION_IF_NULL(event_base_); + is_stop_ = false; int ret = event_base_dispatch(event_base_); MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) @@ -220,6 +235,7 @@ void TcpClient::Start() { } void TcpClient::StartWithNoBlock() { + std::lock_guard lock(connection_mutex_); MS_LOG(INFO) << "Start tcp client with no block!"; MS_EXCEPTION_IF_NULL(event_base_); int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK); @@ -244,15 +260,24 @@ void TcpClient::SendMessage(const CommMessage &message) const { } } -void TcpClient::SendMessageWithTimer() { - MS_EXCEPTION_IF_NULL(buffer_event_); +void TcpClient::StartTimer(const uint32_t &time) { + MS_EXCEPTION_IF_NULL(event_base_); struct event *ev = nullptr; + if (time == 0) { + MS_LOG(EXCEPTION) << "The time should not be 0!"; + } struct timeval timeout {}; - timeout.tv_sec = 0; + timeout.tv_sec = time; timeout.tv_usec = 0; - ev = evtimer_new(event_base_, SendHeartBeatCallback, this); + ev = event_new(event_base_, -1, EV_PERSIST, TimerCallback, this); + MS_EXCEPTION_IF_NULL(ev); evtimer_add(ev, &timeout); } + +void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } + +const event_base &TcpClient::eventbase() { return *event_base_; } + } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h index 734e9cdbdb6..10c84460a9c 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -21,10 +21,15 @@ #include #include +#include + #include #include #include #include +#include +#include +#include #include "proto/comm.pb.h" #include "ps/core/cluster_config.h" @@ -40,6 +45,7 @@ class TcpClient { using OnRead = std::function; using OnTimeout = std::function; using OnMessage = std::function; + using OnTimer = std::function; explicit TcpClient(const std::string &address, std::uint16_t port); virtual ~TcpClient(); @@ -50,11 +56,14 @@ class TcpClient { void Init(); void StartWithDelay(int seconds); void Stop(); + static void StopEventBase(); void Start(); void StartWithNoBlock(); void SetMessageCallback(const OnMessage &cb); void SendMessage(const CommMessage &message) const; - void SendMessageWithTimer(); + void StartTimer(const uint32_t &time); + void set_timer_callback(const OnTimer &timer); + const event_base &eventbase(); protected: static void SetTcpNoDelay(const evutil_socket_t &fd); @@ -62,7 +71,7 @@ class TcpClient { static void ReadCallback(struct bufferevent *bev, void *ctx); static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); virtual void OnReadHandler(const void *buf, size_t num); - static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg); + static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); private: OnMessage message_callback_; @@ -72,13 +81,16 @@ class TcpClient { OnDisconnected disconnected_callback_; OnRead read_callback_; OnTimeout timeout_callback_; + OnTimer on_timer_callback_; - event_base *event_base_; + static event_base *event_base_; + std::mutex connection_mutex_; event *event_timeout_; bufferevent *buffer_event_; std::string server_address_; std::uint16_t server_port_; + std::atomic is_stop_; }; } // namespace core diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 004142b6bf5..1dcd0048faa 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -18,10 +18,10 @@ #include #include +#include #include #include #include -#include #include #include #include @@ -73,7 +73,8 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port) signal_event_(nullptr), listener_(nullptr), server_address_(std::move(address)), - server_port_(port) {} + server_port_(port), + is_stop_(true) {} TcpServer::~TcpServer() { Stop(); } @@ -84,7 +85,14 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon this->client_accept_ = client_accept; } +void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } + void TcpServer::Init() { + int result = evthread_use_pthreads(); + if (result != 0) { + MS_LOG(EXCEPTION) << "Use event pthread failed!"; + } + base_ = event_base_new(); MS_EXCEPTION_IF_NULL(base_); if (!CommUtil::CheckIp(server_address_)) { @@ -128,6 +136,7 @@ void TcpServer::Start() { std::unique_lock lock(connection_mutex_); MS_LOG(INFO) << "Start tcp server!"; MS_EXCEPTION_IF_NULL(base_); + is_stop_ = false; int ret = event_base_dispatch(base_); MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) @@ -147,21 +156,42 @@ void TcpServer::StartWithNoBlock() { MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; } +void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { + MS_EXCEPTION_IF_NULL(base_); + if (time == 0) { + MS_LOG(EXCEPTION) << "The time should not be 0!"; + } + struct event *ev = nullptr; + struct timeval timeout {}; + timeout.tv_sec = time; + timeout.tv_usec = 0; + ev = evtimer_new(base_, TimerCallback, this); + MS_EXCEPTION_IF_NULL(ev); + evtimer_add(ev, &timeout); +} + void TcpServer::Stop() { MS_LOG(INFO) << "Stop tcp server!"; - if (signal_event_ != nullptr) { - event_free(signal_event_); - signal_event_ = nullptr; - } + if (!is_stop_.load()) { + int ret = event_base_loopbreak(base_); + if (ret != 0) { + MS_LOG(EXCEPTION) << "event base loop break failed!"; + } + if (signal_event_ != nullptr) { + event_free(signal_event_); + signal_event_ = nullptr; + } - if (listener_ != nullptr) { - evconnlistener_free(listener_); - listener_ = nullptr; - } + if (listener_ != nullptr) { + evconnlistener_free(listener_); + listener_ = nullptr; + } - if (base_ != nullptr) { - event_base_free(base_); - base_ = nullptr; + if (base_ != nullptr) { + event_base_free(base_); + base_ = nullptr; + } + is_stop_ = true; } } @@ -287,6 +317,14 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void } } +void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) { + MS_EXCEPTION_IF_NULL(arg); + auto tcp_server = reinterpret_cast(arg); + if (tcp_server->on_timer_callback_) { + tcp_server->on_timer_callback_(*tcp_server); + } +} + void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } void TcpServer::SendMessage(const CommMessage &message) { @@ -299,6 +337,10 @@ void TcpServer::SendMessage(const CommMessage &message) { uint16_t TcpServer::BoundPort() const { return server_port_; } +int TcpServer::ConnectionNum() const { return connections_.size(); } + +const std::map &TcpServer::Connections() const { return connections_; } + void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/tcp_server.h b/mindspore/ccsrc/ps/core/tcp_server.h index e74af03796d..ed986ac2d4f 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.h +++ b/mindspore/ccsrc/ps/core/tcp_server.h @@ -21,17 +21,23 @@ #include #include #include +#include + #include #include #include #include +#include #include #include -#include #include +#include +#include -#include "utils/log_adapter.h" +#include "proto/comm.pb.h" #include "ps/core/tcp_message_handler.h" +#include "ps/core/cluster_config.h" +#include "utils/log_adapter.h" namespace mindspore { namespace ps { @@ -40,7 +46,7 @@ class TcpServer; class TcpConnection { public: explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) - : buffer_event_(bev), fd_(0), server_(server) {} + : buffer_event_(bev), fd_(fd), server_(server) {} virtual ~TcpConnection() = default; virtual void InitConnection(); @@ -65,24 +71,29 @@ class TcpServer { using OnConnected = std::function; using OnDisconnected = std::function; using OnAccepted = std::function; + using OnTimer = std::function; explicit TcpServer(const std::string &address, std::uint16_t port); virtual ~TcpServer(); void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, const OnAccepted &client_accept); + void set_timer_callback(const OnTimer &timer); void Init(); void Start(); void StartWithNoBlock(); + void StartTimerOnlyOnce(const uint32_t &time); void Stop(); void SendToAllClients(const char *data, size_t len); void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); void RemoveConnection(const evutil_socket_t &fd); OnServerReceiveMessage GetServerReceive() const; void SetMessageCallback(const OnServerReceiveMessage &cb); - static void SendMessage(const TcpConnection &conn, const CommMessage &message); + void SendMessage(const TcpConnection &conn, const CommMessage &message); void SendMessage(const CommMessage &message); uint16_t BoundPort() const; + int ConnectionNum() const; + const std::map &Connections() const; protected: static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, @@ -90,6 +101,7 @@ class TcpServer { static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); static void ReadCallback(struct bufferevent *, void *connection); static void EventCallback(struct bufferevent *, std::int16_t events, void *server); + static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); struct event_base *base_; @@ -97,6 +109,7 @@ class TcpServer { struct evconnlistener *listener_; std::string server_address_; std::uint16_t server_port_; + std::atomic is_stop_; std::map connections_; OnConnected client_connection_; @@ -104,6 +117,7 @@ class TcpServer { OnAccepted client_accept_; std::recursive_mutex connection_mutex_; OnServerReceiveMessage message_callback_; + OnTimer on_timer_callback_; }; } // namespace core } // namespace ps diff --git a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc index 360c722abe8..8752cfe0d39 100644 --- a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc +++ b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc @@ -37,7 +37,7 @@ class TestTcpServer : public UT::Common { KVMessage kv_message; kv_message.ParseFromString(message.data()); EXPECT_EQ(2, kv_message.keys_size()); - server.SendMessage(conn, message); + const_cast(server).SendMessage(conn, message); }); server_->Init(); server_->Start();