forked from mindspore-Ecosystem/mindspore
updated cluster config and proto
This commit is contained in:
parent
715ca637e1
commit
fc380b8071
|
@ -21,12 +21,16 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
uint32_t ClusterConfig::worker_num_ = 0;
|
||||
uint32_t ClusterConfig::server_num_ = 0;
|
||||
uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval;
|
||||
std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr;
|
||||
uint16_t ClusterConfig::scheduler_port_ = 0;
|
||||
// The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds.
|
||||
uint32_t ClusterConfig::heartbeat_interval_ = 3;
|
||||
// The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds.
|
||||
uint32_t ClusterConfig::heartbeat_timeout_ = 30;
|
||||
// Timeout period for cluster preparation is 300 seconds.
|
||||
uint32_t ClusterConfig::cluster_available_timeout_ = 300;
|
||||
|
||||
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num,
|
||||
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) {
|
||||
|
@ -53,6 +57,18 @@ std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); }
|
|||
|
||||
uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; }
|
||||
|
||||
uint32_t ClusterConfig::heartbeat_timeout() { return heartbeat_timeout_; }
|
||||
|
||||
void ClusterConfig::set_heartbeat_timeout(const uint32_t &heartbeat_timeout) {
|
||||
heartbeat_interval_ = heartbeat_timeout;
|
||||
}
|
||||
|
||||
uint32_t ClusterConfig::cluster_available_timeout() { return cluster_available_timeout_; }
|
||||
|
||||
void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_available_timeout) {
|
||||
cluster_available_timeout_ = cluster_available_timeout;
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,8 +28,6 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
constexpr uint32_t kHeartbeatInterval = 3;
|
||||
|
||||
class ClusterConfig {
|
||||
public:
|
||||
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host,
|
||||
|
@ -40,6 +38,10 @@ class ClusterConfig {
|
|||
static void set_heartbeat_interval(const uint32_t &heartbeat_interval);
|
||||
static std::string scheduler_host();
|
||||
static uint16_t scheduler_port();
|
||||
static uint32_t heartbeat_timeout();
|
||||
static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout);
|
||||
static uint32_t cluster_available_timeout();
|
||||
static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout);
|
||||
|
||||
private:
|
||||
static uint32_t worker_num_;
|
||||
|
@ -47,6 +49,8 @@ class ClusterConfig {
|
|||
static uint32_t heartbeat_interval_;
|
||||
static std::unique_ptr<std::string> scheduler_host_;
|
||||
static uint16_t scheduler_port_;
|
||||
static uint32_t heartbeat_timeout_;
|
||||
static uint32_t cluster_available_timeout_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -21,11 +21,17 @@
|
|||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <regex>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
std::random_device CommUtil::rd;
|
||||
std::mt19937_64 CommUtil::gen(rd());
|
||||
std::uniform_int_distribution<> CommUtil::dis = std::uniform_int_distribution<>{0, 15};
|
||||
std::uniform_int_distribution<> CommUtil::dis2 = std::uniform_int_distribution<>{8, 11};
|
||||
|
||||
bool CommUtil::CheckIpWithRegex(const std::string &ip) {
|
||||
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
|
||||
std::smatch res;
|
||||
|
@ -75,6 +81,34 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i
|
|||
MS_EXCEPTION_IF_NULL(if_address);
|
||||
freeifaddrs(if_address);
|
||||
}
|
||||
|
||||
std::string CommUtil::GenerateUUID() {
|
||||
std::stringstream ss;
|
||||
int i;
|
||||
ss << std::hex;
|
||||
for (i = 0; i < kGroup1RandomLength; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
ss << "-";
|
||||
for (i = 0; i < kGroup2RandomLength; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
ss << "-4";
|
||||
for (i = 0; i < kGroup2RandomLength - 1; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
ss << "-";
|
||||
ss << dis2(gen);
|
||||
for (i = 0; i < kGroup3RandomLength - 1; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
ss << "-";
|
||||
for (i = 0; i < kGroup4RandomLength; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,17 +43,31 @@
|
|||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
constexpr int kGroup1RandomLength = 8;
|
||||
constexpr int kGroup2RandomLength = 4;
|
||||
constexpr int kGroup3RandomLength = 4;
|
||||
constexpr int kGroup4RandomLength = 4;
|
||||
constexpr int kGroup5RandomLength = 12;
|
||||
|
||||
class CommUtil {
|
||||
public:
|
||||
static bool CheckIpWithRegex(const std::string &ip);
|
||||
static bool CheckIp(const std::string &ip);
|
||||
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
|
||||
static std::string GenerateUUID();
|
||||
|
||||
static std::random_device rd;
|
||||
static std::mt19937_64 gen;
|
||||
static std::uniform_int_distribution<> dis;
|
||||
static std::uniform_int_distribution<> dis2;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -19,36 +19,64 @@ import "google/protobuf/any.proto";
|
|||
package mindspore.ps.core;
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
||||
enum ClusterCommand {
|
||||
enum NodeCommand {
|
||||
TERMINATE = 0;
|
||||
REGISTER = 1;
|
||||
ACK = 2;
|
||||
HEARTBEAT = 3;
|
||||
FETCH_WORKERS = 4;
|
||||
FETCH_SERVERS = 5;
|
||||
HEARTBEAT = 2;
|
||||
SEND_DATA = 3;
|
||||
FETCH_SERVER = 4;
|
||||
}
|
||||
|
||||
enum Role {
|
||||
enum NodeRole {
|
||||
SERVER = 0;
|
||||
WORKER = 1;
|
||||
SCHEDULER = 2;
|
||||
}
|
||||
|
||||
message MessageMeta {
|
||||
// hostname or ip
|
||||
string hostname = 1;
|
||||
// the port of this node
|
||||
int32 port = 2;
|
||||
// the command of this message,for example: register、heartbeat、data
|
||||
int32 cmd = 3;
|
||||
// the timestamp of this message
|
||||
int32 timestamp = 4;
|
||||
// data type of message
|
||||
repeated int32 data_type = 5 [packed = true];
|
||||
// message.data_size
|
||||
int32 data_size = 6;
|
||||
// the command of this message,for example: register,heartbeat,data
|
||||
NodeCommand cmd = 1;
|
||||
// the request id of this message
|
||||
uint64 request_id = 2;
|
||||
}
|
||||
|
||||
message RegisterMessage {
|
||||
// ip
|
||||
string ip = 1;
|
||||
// the port of this node
|
||||
int32 port = 2;
|
||||
// the current Node unique id:0,1,2...
|
||||
string node_id = 3;
|
||||
// the role of the node: worker,server,scheduler
|
||||
NodeRole role = 4;
|
||||
}
|
||||
|
||||
message RegisterRespMessage {
|
||||
string node_id = 1;
|
||||
int32 rank_id = 2;
|
||||
}
|
||||
|
||||
message HeartbeatMessage {
|
||||
// the current Node unique id:0,1,2...
|
||||
string node_id = 1;
|
||||
}
|
||||
|
||||
message HeartbeatRespMessage {
|
||||
// Is the entire system ready to use.
|
||||
bool is_cluster_ready = 1;
|
||||
bool is_cluster_finish = 2;
|
||||
}
|
||||
|
||||
message FetchServersRespMessage {
|
||||
repeated ServersMeta servers_meta = 1;
|
||||
}
|
||||
|
||||
message ServersMeta {
|
||||
int32 rank_id = 1;
|
||||
string ip = 2;
|
||||
int32 port = 3;
|
||||
|
||||
}
|
||||
|
||||
message CommMessage {
|
||||
MessageMeta pb_meta = 1;
|
||||
|
|
|
@ -13,17 +13,17 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
package mindspore.ps.core;
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
||||
message KVMessage {
|
||||
repeated int32 keys = 1;
|
||||
repeated float values = 2;
|
||||
enum PSCommand {
|
||||
PUSH = 0;
|
||||
PULL = 1;
|
||||
}
|
||||
|
||||
message HeartBeatMessage {
|
||||
// *.*.*.*:port
|
||||
repeated string host_and_port = 1;
|
||||
message KVMessage {
|
||||
PSCommand command = 1;
|
||||
repeated int32 keys = 2;
|
||||
repeated float values = 3;
|
||||
}
|
|
@ -18,8 +18,8 @@
|
|||
|
||||
#include <arpa/inet.h>
|
||||
#include <event2/buffer.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/buffer_compat.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/event.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
|
@ -27,20 +27,23 @@
|
|||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "ps/core/comm_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
event_base *TcpClient::event_base_ = nullptr;
|
||||
|
||||
TcpClient::TcpClient(const std::string &address, std::uint16_t port)
|
||||
: event_base_(nullptr),
|
||||
event_timeout_(nullptr),
|
||||
: event_timeout_(nullptr),
|
||||
buffer_event_(nullptr),
|
||||
server_address_(std::move(address)),
|
||||
server_port_(port) {
|
||||
server_port_(port),
|
||||
is_stop_(true) {
|
||||
message_handler_.SetCallback([this](const CommMessage &message) {
|
||||
if (message_callback_) {
|
||||
message_callback_(*this, message);
|
||||
|
@ -61,6 +64,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco
|
|||
}
|
||||
|
||||
void TcpClient::Init() {
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
if (buffer_event_) {
|
||||
return;
|
||||
}
|
||||
|
@ -68,7 +72,13 @@ void TcpClient::Init() {
|
|||
MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
|
||||
}
|
||||
|
||||
event_base_ = event_base_new();
|
||||
int result = evthread_use_pthreads();
|
||||
if (result != 0) {
|
||||
MS_LOG(EXCEPTION) << "Use event pthread failed!";
|
||||
}
|
||||
if (event_base_ == nullptr) {
|
||||
event_base_ = event_base_new();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
|
||||
sockaddr_in sin{};
|
||||
|
@ -94,6 +104,7 @@ void TcpClient::Init() {
|
|||
}
|
||||
|
||||
void TcpClient::StartWithDelay(int seconds) {
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
if (buffer_event_) {
|
||||
return;
|
||||
}
|
||||
|
@ -111,16 +122,28 @@ void TcpClient::StartWithDelay(int seconds) {
|
|||
}
|
||||
|
||||
void TcpClient::Stop() {
|
||||
if (buffer_event_) {
|
||||
bufferevent_free(buffer_event_);
|
||||
buffer_event_ = nullptr;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
MS_LOG(INFO) << "Stop tcp client event buffer!";
|
||||
if (!is_stop_.load()) {
|
||||
if (buffer_event_) {
|
||||
bufferevent_free(buffer_event_);
|
||||
buffer_event_ = nullptr;
|
||||
}
|
||||
|
||||
if (event_timeout_) {
|
||||
event_free(event_timeout_);
|
||||
event_timeout_ = nullptr;
|
||||
if (event_timeout_) {
|
||||
event_free(event_timeout_);
|
||||
event_timeout_ = nullptr;
|
||||
}
|
||||
is_stop_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::StopEventBase() {
|
||||
MS_LOG(INFO) << "Stop tcp client event base!";
|
||||
int ret = event_base_loopbreak(event_base_);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Event base loop break failed!";
|
||||
}
|
||||
if (event_base_) {
|
||||
event_base_free(event_base_);
|
||||
event_base_ = nullptr;
|
||||
|
@ -167,21 +190,12 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) {
|
|||
message_handler_.ReceiveMessage(buf, num);
|
||||
}
|
||||
|
||||
void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) {
|
||||
void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(ClusterCommand::HEARTBEAT);
|
||||
CommMessage message;
|
||||
message.set_allocated_pb_meta(&meta);
|
||||
tcp_client->SendMessage(message);
|
||||
|
||||
struct event *ev;
|
||||
struct timeval timeout {};
|
||||
timeout.tv_sec = ClusterConfig::heartbeat_interval();
|
||||
timeout.tv_usec = 0;
|
||||
ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg);
|
||||
evtimer_add(ev, &timeout);
|
||||
if (tcp_client->on_timer_callback_) {
|
||||
tcp_client->on_timer_callback_(*tcp_client);
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
|
||||
|
@ -211,6 +225,7 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void
|
|||
|
||||
void TcpClient::Start() {
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
is_stop_ = false;
|
||||
int ret = event_base_dispatch(event_base_);
|
||||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
||||
|
@ -220,6 +235,7 @@ void TcpClient::Start() {
|
|||
}
|
||||
|
||||
void TcpClient::StartWithNoBlock() {
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
MS_LOG(INFO) << "Start tcp client with no block!";
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
|
||||
|
@ -244,15 +260,24 @@ void TcpClient::SendMessage(const CommMessage &message) const {
|
|||
}
|
||||
}
|
||||
|
||||
void TcpClient::SendMessageWithTimer() {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
void TcpClient::StartTimer(const uint32_t &time) {
|
||||
MS_EXCEPTION_IF_NULL(event_base_);
|
||||
struct event *ev = nullptr;
|
||||
if (time == 0) {
|
||||
MS_LOG(EXCEPTION) << "The time should not be 0!";
|
||||
}
|
||||
struct timeval timeout {};
|
||||
timeout.tv_sec = 0;
|
||||
timeout.tv_sec = time;
|
||||
timeout.tv_usec = 0;
|
||||
ev = evtimer_new(event_base_, SendHeartBeatCallback, this);
|
||||
ev = event_new(event_base_, -1, EV_PERSIST, TimerCallback, this);
|
||||
MS_EXCEPTION_IF_NULL(ev);
|
||||
evtimer_add(ev, &timeout);
|
||||
}
|
||||
|
||||
void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
|
||||
|
||||
const event_base &TcpClient::eventbase() { return *event_base_; }
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,10 +21,15 @@
|
|||
|
||||
#include <event2/event.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/thread.h>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
|
@ -40,6 +45,7 @@ class TcpClient {
|
|||
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||
using OnTimeout = std::function<void(const TcpClient &)>;
|
||||
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
|
||||
using OnTimer = std::function<void(const TcpClient &)>;
|
||||
|
||||
explicit TcpClient(const std::string &address, std::uint16_t port);
|
||||
virtual ~TcpClient();
|
||||
|
@ -50,11 +56,14 @@ class TcpClient {
|
|||
void Init();
|
||||
void StartWithDelay(int seconds);
|
||||
void Stop();
|
||||
static void StopEventBase();
|
||||
void Start();
|
||||
void StartWithNoBlock();
|
||||
void SetMessageCallback(const OnMessage &cb);
|
||||
void SendMessage(const CommMessage &message) const;
|
||||
void SendMessageWithTimer();
|
||||
void StartTimer(const uint32_t &time);
|
||||
void set_timer_callback(const OnTimer &timer);
|
||||
const event_base &eventbase();
|
||||
|
||||
protected:
|
||||
static void SetTcpNoDelay(const evutil_socket_t &fd);
|
||||
|
@ -62,7 +71,7 @@ class TcpClient {
|
|||
static void ReadCallback(struct bufferevent *bev, void *ctx);
|
||||
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
|
||||
virtual void OnReadHandler(const void *buf, size_t num);
|
||||
static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg);
|
||||
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
|
||||
|
||||
private:
|
||||
OnMessage message_callback_;
|
||||
|
@ -72,13 +81,16 @@ class TcpClient {
|
|||
OnDisconnected disconnected_callback_;
|
||||
OnRead read_callback_;
|
||||
OnTimeout timeout_callback_;
|
||||
OnTimer on_timer_callback_;
|
||||
|
||||
event_base *event_base_;
|
||||
static event_base *event_base_;
|
||||
std::mutex connection_mutex_;
|
||||
event *event_timeout_;
|
||||
bufferevent *buffer_event_;
|
||||
|
||||
std::string server_address_;
|
||||
std::uint16_t server_port_;
|
||||
std::atomic<bool> is_stop_;
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -18,10 +18,10 @@
|
|||
|
||||
#include <arpa/inet.h>
|
||||
#include <event2/buffer.h>
|
||||
#include <event2/buffer_compat.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/event.h>
|
||||
#include <event2/listener.h>
|
||||
#include <event2/buffer_compat.h>
|
||||
#include <event2/util.h>
|
||||
#include <sys/socket.h>
|
||||
#include <csignal>
|
||||
|
@ -73,7 +73,8 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port)
|
|||
signal_event_(nullptr),
|
||||
listener_(nullptr),
|
||||
server_address_(std::move(address)),
|
||||
server_port_(port) {}
|
||||
server_port_(port),
|
||||
is_stop_(true) {}
|
||||
|
||||
TcpServer::~TcpServer() { Stop(); }
|
||||
|
||||
|
@ -84,7 +85,14 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
|
|||
this->client_accept_ = client_accept;
|
||||
}
|
||||
|
||||
void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
|
||||
|
||||
void TcpServer::Init() {
|
||||
int result = evthread_use_pthreads();
|
||||
if (result != 0) {
|
||||
MS_LOG(EXCEPTION) << "Use event pthread failed!";
|
||||
}
|
||||
|
||||
base_ = event_base_new();
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
if (!CommUtil::CheckIp(server_address_)) {
|
||||
|
@ -128,6 +136,7 @@ void TcpServer::Start() {
|
|||
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
|
||||
MS_LOG(INFO) << "Start tcp server!";
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
is_stop_ = false;
|
||||
int ret = event_base_dispatch(base_);
|
||||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
||||
|
@ -147,21 +156,42 @@ void TcpServer::StartWithNoBlock() {
|
|||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
|
||||
}
|
||||
|
||||
void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
if (time == 0) {
|
||||
MS_LOG(EXCEPTION) << "The time should not be 0!";
|
||||
}
|
||||
struct event *ev = nullptr;
|
||||
struct timeval timeout {};
|
||||
timeout.tv_sec = time;
|
||||
timeout.tv_usec = 0;
|
||||
ev = evtimer_new(base_, TimerCallback, this);
|
||||
MS_EXCEPTION_IF_NULL(ev);
|
||||
evtimer_add(ev, &timeout);
|
||||
}
|
||||
|
||||
void TcpServer::Stop() {
|
||||
MS_LOG(INFO) << "Stop tcp server!";
|
||||
if (signal_event_ != nullptr) {
|
||||
event_free(signal_event_);
|
||||
signal_event_ = nullptr;
|
||||
}
|
||||
if (!is_stop_.load()) {
|
||||
int ret = event_base_loopbreak(base_);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "event base loop break failed!";
|
||||
}
|
||||
if (signal_event_ != nullptr) {
|
||||
event_free(signal_event_);
|
||||
signal_event_ = nullptr;
|
||||
}
|
||||
|
||||
if (listener_ != nullptr) {
|
||||
evconnlistener_free(listener_);
|
||||
listener_ = nullptr;
|
||||
}
|
||||
if (listener_ != nullptr) {
|
||||
evconnlistener_free(listener_);
|
||||
listener_ = nullptr;
|
||||
}
|
||||
|
||||
if (base_ != nullptr) {
|
||||
event_base_free(base_);
|
||||
base_ = nullptr;
|
||||
if (base_ != nullptr) {
|
||||
event_base_free(base_);
|
||||
base_ = nullptr;
|
||||
}
|
||||
is_stop_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -287,6 +317,14 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
|
|||
}
|
||||
}
|
||||
|
||||
void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto tcp_server = reinterpret_cast<TcpServer *>(arg);
|
||||
if (tcp_server->on_timer_callback_) {
|
||||
tcp_server->on_timer_callback_(*tcp_server);
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
|
||||
|
||||
void TcpServer::SendMessage(const CommMessage &message) {
|
||||
|
@ -299,6 +337,10 @@ void TcpServer::SendMessage(const CommMessage &message) {
|
|||
|
||||
uint16_t TcpServer::BoundPort() const { return server_port_; }
|
||||
|
||||
int TcpServer::ConnectionNum() const { return connections_.size(); }
|
||||
|
||||
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }
|
||||
|
||||
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -21,17 +21,23 @@
|
|||
#include <event2/bufferevent.h>
|
||||
#include <event2/event.h>
|
||||
#include <event2/listener.h>
|
||||
#include <event2/thread.h>
|
||||
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "ps/core/tcp_message_handler.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -40,7 +46,7 @@ class TcpServer;
|
|||
class TcpConnection {
|
||||
public:
|
||||
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server)
|
||||
: buffer_event_(bev), fd_(0), server_(server) {}
|
||||
: buffer_event_(bev), fd_(fd), server_(server) {}
|
||||
virtual ~TcpConnection() = default;
|
||||
|
||||
virtual void InitConnection();
|
||||
|
@ -65,24 +71,29 @@ class TcpServer {
|
|||
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
|
||||
using OnTimer = std::function<void(const TcpServer &)>;
|
||||
|
||||
explicit TcpServer(const std::string &address, std::uint16_t port);
|
||||
virtual ~TcpServer();
|
||||
|
||||
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
|
||||
const OnAccepted &client_accept);
|
||||
void set_timer_callback(const OnTimer &timer);
|
||||
void Init();
|
||||
void Start();
|
||||
void StartWithNoBlock();
|
||||
void StartTimerOnlyOnce(const uint32_t &time);
|
||||
void Stop();
|
||||
void SendToAllClients(const char *data, size_t len);
|
||||
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
|
||||
void RemoveConnection(const evutil_socket_t &fd);
|
||||
OnServerReceiveMessage GetServerReceive() const;
|
||||
void SetMessageCallback(const OnServerReceiveMessage &cb);
|
||||
static void SendMessage(const TcpConnection &conn, const CommMessage &message);
|
||||
void SendMessage(const TcpConnection &conn, const CommMessage &message);
|
||||
void SendMessage(const CommMessage &message);
|
||||
uint16_t BoundPort() const;
|
||||
int ConnectionNum() const;
|
||||
const std::map<evutil_socket_t, const TcpConnection *> &Connections() const;
|
||||
|
||||
protected:
|
||||
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
|
||||
|
@ -90,6 +101,7 @@ class TcpServer {
|
|||
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
|
||||
static void ReadCallback(struct bufferevent *, void *connection);
|
||||
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
|
||||
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
|
||||
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
|
||||
|
||||
struct event_base *base_;
|
||||
|
@ -97,6 +109,7 @@ class TcpServer {
|
|||
struct evconnlistener *listener_;
|
||||
std::string server_address_;
|
||||
std::uint16_t server_port_;
|
||||
std::atomic<bool> is_stop_;
|
||||
|
||||
std::map<evutil_socket_t, const TcpConnection *> connections_;
|
||||
OnConnected client_connection_;
|
||||
|
@ -104,6 +117,7 @@ class TcpServer {
|
|||
OnAccepted client_accept_;
|
||||
std::recursive_mutex connection_mutex_;
|
||||
OnServerReceiveMessage message_callback_;
|
||||
OnTimer on_timer_callback_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -37,7 +37,7 @@ class TestTcpServer : public UT::Common {
|
|||
KVMessage kv_message;
|
||||
kv_message.ParseFromString(message.data());
|
||||
EXPECT_EQ(2, kv_message.keys_size());
|
||||
server.SendMessage(conn, message);
|
||||
const_cast<TcpServer&>(server).SendMessage(conn, message);
|
||||
});
|
||||
server_->Init();
|
||||
server_->Start();
|
||||
|
|
Loading…
Reference in New Issue