forked from mindspore-Ecosystem/mindspore
!10946 Switch bare pointer to shared_ptr
From: @anancds Reviewed-by: Signed-off-by:
This commit is contained in:
commit
6eb4634ba2
|
@ -75,6 +75,8 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &
|
|||
auto client = GetOrCreateTcpClient((*it).first.second);
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
|
@ -126,11 +128,13 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
CommMessage *output, const uint32_t &timeout) {
|
||||
std::string *output, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
|
@ -141,7 +145,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
*output = res[rank_id];
|
||||
*output = res[rank_id].data();
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
@ -157,11 +161,13 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
client->SendMessage(comm_message);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, std::vector<CommMessage> *output,
|
||||
const std::vector<std::string> &data, std::vector<std::string> *output,
|
||||
const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
|
@ -177,7 +183,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
(*output).push_back(res[rank_ids.at(it)]);
|
||||
(*output).push_back(res[rank_ids.at(it)].data());
|
||||
}
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
|
@ -201,6 +207,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
|
@ -215,7 +223,7 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
|
|||
}
|
||||
|
||||
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
const std::string &message, const uint32_t &timeout) {
|
||||
const std::string &message) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
@ -233,19 +241,19 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
|
|||
}
|
||||
|
||||
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
|
||||
const uint32_t &rank_id, CommMessage *output) {
|
||||
const uint32_t &rank_id, std::string *output) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
|
||||
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
|
||||
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
|
||||
*output = received_data_[std::make_pair(rank_id, rank_request_id)].data();
|
||||
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
||||
} else {
|
||||
set_receive_callback(rank_id, rank_request_id, [=]() {
|
||||
receive_callbacks_mutex_.lock();
|
||||
*output = received_data_[std::make_pair(rank_id, 1)];
|
||||
*output = received_data_[std::make_pair(rank_id, 1)].data();
|
||||
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
||||
receive_callbacks_mutex_.unlock();
|
||||
});
|
||||
|
@ -272,13 +280,25 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
|
|||
<< " begin send heartbeat to the scheduler!";
|
||||
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
|
||||
while (!is_finish_.load()) {
|
||||
Heartbeat(client);
|
||||
if (!Heartbeat(client)) {
|
||||
MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||
if (!CheckSchedulerTimeout() && on_node_event_message_) {
|
||||
MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
on_node_event_message_(NodeEvent::SCHEDULER_TIMEOUT);
|
||||
}
|
||||
} else {
|
||||
UpdateSchedulerTime();
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
|
||||
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::HEARTBEAT);
|
||||
|
||||
|
@ -292,11 +312,31 @@ void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n
|
|||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AbstractNode::UpdateSchedulerTime() {
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
scheduler_time_ = current_time;
|
||||
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
|
||||
<< " update scheduler time, the current time is: " << current_time.tv_sec;
|
||||
}
|
||||
|
||||
bool AbstractNode::CheckSchedulerTimeout() const {
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
if (scheduler_time_.tv_sec + ClusterConfig::scheduler_timeout() < current_time.tv_sec) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.ParseFromString(message.data());
|
||||
|
||||
is_ready_ = heartbeat_resp_message.is_cluster_ready();
|
||||
if (is_ready_.load()) {
|
||||
wait_start_cond_.notify_all();
|
||||
|
@ -353,9 +393,9 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui
|
|||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(finish_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(EXCEPTION) << "Disconnect timeout!";
|
||||
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
|
||||
}
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
|
||||
return WaitForDisconnect(timeout);
|
||||
}
|
||||
|
||||
|
@ -444,6 +484,8 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con
|
|||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
|
@ -452,6 +494,8 @@ uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client
|
|||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return request_id;
|
||||
}
|
||||
|
||||
|
@ -460,6 +504,8 @@ void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
|
|||
const MessageMeta &message_meta = message.pb_meta();
|
||||
const uint32_t &rank_id = message_meta.rank_id();
|
||||
const uint64_t request_id = message_meta.request_id();
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
auto it = receive_messages_.find(request_id);
|
||||
if (it != receive_messages_.end()) {
|
||||
it->second[rank_id] = message;
|
||||
|
|
|
@ -42,23 +42,24 @@ class AbstractNode : public Node {
|
|||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output,
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message);
|
||||
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
CommMessage *output);
|
||||
std::string *output);
|
||||
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
void Register(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessRegisterResp(const CommMessage &message);
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
|
||||
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
|
||||
void UpdateSchedulerTime();
|
||||
bool CheckSchedulerTimeout() const;
|
||||
void ProcessHeartbeatResp(const CommMessage &message);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessFetchServersResp(const CommMessage &message);
|
||||
|
@ -113,6 +114,7 @@ class AbstractNode : public Node {
|
|||
// the key is rank_id, the value is rank_id's actual request_id
|
||||
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
|
||||
std::mutex rank_request_ids_mutex;
|
||||
timeval scheduler_time_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -33,15 +33,17 @@ uint32_t ClusterConfig::heartbeat_timeout_ = 30;
|
|||
uint32_t ClusterConfig::cluster_available_timeout_ = 300;
|
||||
// The timeout period for the client to connect to the server is 100ms.
|
||||
uint32_t ClusterConfig::connect_interval_ = 100;
|
||||
// When the scheduler exits, the worker and server can continue to work for 5 hours
|
||||
uint32_t ClusterConfig::scheduler_timeout_ = 3600 * 5;
|
||||
|
||||
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num,
|
||||
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) {
|
||||
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
|
||||
const uint16_t &scheduler_port) {
|
||||
worker_num_ = worker_num;
|
||||
server_num_ = server_num;
|
||||
if (!CommUtil::CheckIp(*scheduler_host.get())) {
|
||||
MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!";
|
||||
if (!CommUtil::CheckIp(scheduler_host)) {
|
||||
MS_LOG(EXCEPTION) << "The scheduler_host:" << scheduler_host << " is illegal!";
|
||||
}
|
||||
scheduler_host_ = std::move(scheduler_host);
|
||||
scheduler_host_ = std::make_unique<std::string>(scheduler_host);
|
||||
scheduler_port_ = scheduler_port;
|
||||
}
|
||||
|
||||
|
@ -55,7 +57,7 @@ void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) {
|
|||
heartbeat_interval_ = heartbeat_interval;
|
||||
}
|
||||
|
||||
std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); }
|
||||
std::string ClusterConfig::scheduler_host() { return *scheduler_host_; }
|
||||
|
||||
uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; }
|
||||
|
||||
|
@ -74,6 +76,10 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa
|
|||
uint32_t ClusterConfig::connect_interval() { return connect_interval_; }
|
||||
|
||||
void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; }
|
||||
|
||||
uint32_t ClusterConfig::scheduler_timeout() { return scheduler_timeout_; }
|
||||
|
||||
void ClusterConfig::set_scheduler_timeout(const uint32_t &scheduler_timeout) { scheduler_timeout_ = scheduler_timeout; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace ps {
|
|||
namespace core {
|
||||
class ClusterConfig {
|
||||
public:
|
||||
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host,
|
||||
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
|
||||
const uint16_t &scheduler_port);
|
||||
static uint32_t worker_num();
|
||||
static uint32_t server_num();
|
||||
|
@ -44,6 +44,8 @@ class ClusterConfig {
|
|||
static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout);
|
||||
static uint32_t connect_interval();
|
||||
static void set_connect_interval(const uint32_t &connect_interval);
|
||||
static uint32_t scheduler_timeout();
|
||||
static void set_scheduler_timeout(const uint32_t &scheduler_timeout);
|
||||
|
||||
private:
|
||||
static uint32_t worker_num_;
|
||||
|
@ -54,6 +56,7 @@ class ClusterConfig {
|
|||
static uint32_t heartbeat_timeout_;
|
||||
static uint32_t cluster_available_timeout_;
|
||||
static uint32_t connect_interval_;
|
||||
static uint32_t scheduler_timeout_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -21,7 +21,12 @@ namespace ps {
|
|||
namespace core {
|
||||
std::string Node::node_id() const { return node_info_.node_id_; }
|
||||
|
||||
uint32_t Node::rank_id() const { return node_info_.rank_id_; }
|
||||
uint32_t Node::rank_id() const {
|
||||
if (!is_ready_.load()) {
|
||||
MS_LOG(EXCEPTION) << "The cluster is not ready yet to get rank id!";
|
||||
}
|
||||
return node_info_.rank_id_;
|
||||
}
|
||||
|
||||
NodeRole Node::role() const { return node_info_.node_role_; }
|
||||
|
||||
|
|
|
@ -30,8 +30,6 @@
|
|||
#include <utility>
|
||||
#include <tuple>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/node_info.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 };
|
||||
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT };
|
||||
|
||||
struct NodeInfo {
|
||||
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}
|
||||
|
|
|
@ -64,8 +64,8 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
|
|||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
heartbeats_[node_id] = current_time;
|
||||
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
|
||||
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
|
||||
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
|
||||
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); }
|
||||
|
|
|
@ -31,8 +31,6 @@
|
|||
#include <condition_variable>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
|
|
@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME;
|
|||
enum PSCommand {
|
||||
PUSH = 0;
|
||||
PULL = 1;
|
||||
INIT_EMBEDDING_TABLE = 2;
|
||||
}
|
||||
|
||||
message KVMessage {
|
||||
|
|
|
@ -37,9 +37,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.ParseFromString(message.data());
|
||||
heartbeat_message.ParseFromString(message->data());
|
||||
|
||||
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
|
||||
|
||||
|
@ -59,10 +60,10 @@ void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnectio
|
|||
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
|
||||
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message.pb_meta()};
|
||||
comm_message.set_data(heartbeat_resp_message.SerializeAsString());
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
comm_message->set_data(heartbeat_resp_message.SerializeAsString());
|
||||
server->SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
void SchedulerNode::Initialize() {
|
||||
|
@ -79,23 +80,23 @@ void SchedulerNode::CreateTcpServer() {
|
|||
|
||||
std::string scheduler_host = ClusterConfig::scheduler_host();
|
||||
uint32_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
switch (message->pb_meta().cmd()) {
|
||||
case NodeCommand::HEARTBEAT:
|
||||
ProcessHeartbeat(server, conn, message);
|
||||
ProcessHeartbeat(server_, conn, message);
|
||||
break;
|
||||
case NodeCommand::REGISTER:
|
||||
ProcessRegister(server, conn, message);
|
||||
ProcessRegister(server_, conn, message);
|
||||
break;
|
||||
case NodeCommand::FINISH:
|
||||
ProcessFinish(server, conn, message);
|
||||
ProcessFinish(server_, conn, message);
|
||||
break;
|
||||
case NodeCommand::FETCH_SERVER:
|
||||
ProcessFetchServers(server, conn, message);
|
||||
ProcessFetchServers(server_, conn, message);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -107,10 +108,11 @@ void SchedulerNode::CreateTcpServer() {
|
|||
});
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
MS_LOG(INFO) << "The scheduler process a register message!";
|
||||
RegisterMessage register_message;
|
||||
register_message.ParseFromString(message.data());
|
||||
register_message.ParseFromString(message->data());
|
||||
|
||||
// assign worker node and server node rank id
|
||||
int rank_id = node_manager_.NextRankId(register_message);
|
||||
|
@ -124,31 +126,32 @@ void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection
|
|||
register_resp_message.set_node_id(node_id);
|
||||
register_resp_message.set_rank_id(rank_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message.pb_meta()};
|
||||
comm_message.set_data(register_resp_message.SerializeAsString());
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
comm_message->set_data(register_resp_message.SerializeAsString());
|
||||
server->SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
FinishMessage finish_message;
|
||||
finish_message.ParseFromString(message.data());
|
||||
finish_message.ParseFromString(message->data());
|
||||
node_manager_.AddFinishNode(finish_message);
|
||||
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, message);
|
||||
server->SendMessage(conn, message);
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn,
|
||||
const CommMessage &message) {
|
||||
void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
FetchServersRespMessage fetch_servers_message;
|
||||
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
|
||||
|
||||
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message.pb_meta()};
|
||||
comm_message.set_data(fetch_servers_message.SerializeAsString());
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
comm_message->set_data(fetch_servers_message.SerializeAsString());
|
||||
server->SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
void SchedulerNode::StartUpdateClusterStateTimer() {
|
||||
|
|
|
@ -26,8 +26,6 @@
|
|||
#include <thread>
|
||||
#include <mutex>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
|
@ -51,13 +49,17 @@ class SchedulerNode : public Node {
|
|||
private:
|
||||
void Initialize();
|
||||
void CreateTcpServer();
|
||||
void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
void StartUpdateClusterStateTimer();
|
||||
void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
|
||||
std::unique_ptr<TcpServer> server_;
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> scheduler_thread_;
|
||||
std::unique_ptr<std::thread> update_state_thread_;
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ bool ServerNode::Start(const uint32_t &timeout) {
|
|||
StartHeartbeatTimer(client_to_scheduler_);
|
||||
|
||||
if (!WaitForStart(timeout)) {
|
||||
MS_LOG(ERROR) << "Start Server node timeout!";
|
||||
MS_LOG(ERROR) << "Start server node timeout!";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "The cluster is ready to use!";
|
||||
|
||||
|
@ -45,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) {
|
|||
|
||||
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
|
||||
|
||||
void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
|
||||
const std::string &message) {
|
||||
auto &meta = const_cast<MessageMeta &>(message_meta);
|
||||
meta.set_role(node_info_.node_role_);
|
||||
meta.set_rank_id(node_info_.rank_id_);
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {meta};
|
||||
comm_message.set_data(message);
|
||||
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
message->mutable_pb_meta()->set_role(node_info_.node_role_);
|
||||
message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_);
|
||||
const MessageMeta &message_meta = message->pb_meta();
|
||||
const uint64_t request_id = message_meta.request_id();
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
server_->SendMessage(conn, message);
|
||||
}
|
||||
|
||||
void ServerNode::CreateTcpServer() {
|
||||
|
@ -62,17 +63,17 @@ void ServerNode::CreateTcpServer() {
|
|||
std::string server_ip;
|
||||
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
|
||||
server_ = std::make_shared<TcpServer>(server_ip, 0);
|
||||
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
switch (message->pb_meta().cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendData(server, conn, message);
|
||||
ProcessSendData(conn, message);
|
||||
break;
|
||||
case NodeCommand::COLLECTIVE_SEND_DATA:
|
||||
ProcessCollectiveSendData(server, conn, message);
|
||||
RunReceiveCallback(message);
|
||||
ProcessCollectiveSendData(conn, message);
|
||||
RunReceiveCallback(*message);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
});
|
||||
server_->Init();
|
||||
|
@ -97,15 +98,18 @@ void ServerNode::Initialize() {
|
|||
MS_LOG(INFO) << "Server node init client successful!";
|
||||
}
|
||||
|
||||
void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
request_handler_(server, conn, message.pb_meta(), message.data());
|
||||
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
request_handler_(conn, message);
|
||||
}
|
||||
|
||||
void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn,
|
||||
const CommMessage &message) {
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message.pb_meta()};
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
server_->SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
bool ServerNode::Stop() {
|
||||
|
|
|
@ -44,18 +44,16 @@ class ServerNode : public AbstractNode {
|
|||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta,
|
||||
const std::string &message)>;
|
||||
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
|
||||
|
||||
void set_handler(const RequestHandler &handler);
|
||||
void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
|
||||
const std::string &message);
|
||||
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
|
||||
private:
|
||||
void CreateTcpServer();
|
||||
void Initialize();
|
||||
void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
|
|
|
@ -46,9 +46,9 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
|
|||
server_port_(port),
|
||||
is_stop_(true),
|
||||
is_connected_(false) {
|
||||
message_handler_.SetCallback([this](const CommMessage &message) {
|
||||
message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) {
|
||||
if (message_callback_) {
|
||||
message_callback_(*this, message);
|
||||
message_callback_(*this, *message);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ void TcpClient::Init() {
|
|||
sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
|
||||
sin.sin_port = htons(server_port_);
|
||||
|
||||
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE);
|
||||
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
|
||||
bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
|
||||
|
@ -261,17 +261,23 @@ void TcpClient::StartWithNoBlock() {
|
|||
|
||||
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
|
||||
|
||||
void TcpClient::SendMessage(const CommMessage &message) const {
|
||||
bool TcpClient::SendMessage(const CommMessage &message) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
bufferevent_lock(buffer_event_);
|
||||
bool res = true;
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
|
||||
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
|
||||
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add header failed!";
|
||||
res = false;
|
||||
}
|
||||
if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
|
||||
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
bufferevent_unlock(buffer_event_);
|
||||
return res;
|
||||
}
|
||||
|
||||
void TcpClient::StartTimer(const uint32_t &time) {
|
||||
|
|
|
@ -33,8 +33,6 @@
|
|||
#include <condition_variable>
|
||||
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -62,7 +60,7 @@ class TcpClient {
|
|||
void Start();
|
||||
void StartWithNoBlock();
|
||||
void SetMessageCallback(const OnMessage &cb);
|
||||
void SendMessage(const CommMessage &message) const;
|
||||
bool SendMessage(const CommMessage &message) const;
|
||||
void StartTimer(const uint32_t &time);
|
||||
void set_timer_callback(const OnTimer &timer);
|
||||
const event_base &eventbase();
|
||||
|
|
|
@ -57,8 +57,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
}
|
||||
|
||||
if (remaining_length_ == 0) {
|
||||
CommMessage pb_message;
|
||||
pb_message.ParseFromArray(message_buffer_.get(), message_length_);
|
||||
std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>();
|
||||
pb_message->ParseFromArray(message_buffer_.get(), message_length_);
|
||||
if (message_callback_) {
|
||||
message_callback_(pb_message);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
using messageReceive = std::function<void(const CommMessage &message)>;
|
||||
using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>;
|
||||
constexpr int kHeaderLen = 8;
|
||||
|
||||
class TcpMessageHandler {
|
||||
|
|
|
@ -32,14 +32,7 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
void TcpConnection::InitConnection() {
|
||||
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
|
||||
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
|
||||
if (on_server_receive) {
|
||||
on_server_receive(*server_, *this, message);
|
||||
}
|
||||
});
|
||||
}
|
||||
void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }
|
||||
|
||||
void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); }
|
||||
|
||||
|
@ -49,23 +42,30 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const {
|
|||
}
|
||||
}
|
||||
|
||||
TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); }
|
||||
TcpServer *TcpConnection::GetServer() const { return server_; }
|
||||
|
||||
const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
|
||||
|
||||
void TcpConnection::SendMessage(const CommMessage &message) const {
|
||||
void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
|
||||
|
||||
bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
bufferevent_lock(buffer_event_);
|
||||
bool res = true;
|
||||
size_t buf_size = message->ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
|
||||
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
|
||||
sizeof(buf_size)) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
|
||||
message->SerializeToArray(serialized.data(), SizeToInt(buf_size));
|
||||
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add header failed!";
|
||||
res = false;
|
||||
}
|
||||
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(),
|
||||
buf_size) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
|
||||
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
bufferevent_unlock(buffer_event_);
|
||||
return res;
|
||||
}
|
||||
|
||||
TcpServer::TcpServer(const std::string &address, std::uint16_t port)
|
||||
|
@ -225,7 +225,7 @@ void TcpServer::SendToAllClients(const char *data, size_t len) {
|
|||
}
|
||||
}
|
||||
|
||||
void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) {
|
||||
void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) {
|
||||
MS_EXCEPTION_IF_NULL(connection);
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
connections_.insert(std::make_pair(fd, connection));
|
||||
|
@ -233,11 +233,11 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co
|
|||
|
||||
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
|
||||
delete connection;
|
||||
connections_.erase(fd);
|
||||
}
|
||||
|
||||
std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; }
|
||||
|
||||
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
|
||||
void *data) {
|
||||
auto server = reinterpret_cast<class TcpServer *>(data);
|
||||
|
@ -246,7 +246,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|||
MS_EXCEPTION_IF_NULL(base);
|
||||
MS_EXCEPTION_IF_NULL(sockaddr);
|
||||
|
||||
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE);
|
||||
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
|
||||
if (!bev) {
|
||||
MS_LOG(ERROR) << "Error constructing buffer event!";
|
||||
int ret = event_base_loopbreak(base);
|
||||
|
@ -256,23 +256,29 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|||
return;
|
||||
}
|
||||
|
||||
TcpConnection *conn = server->onCreateConnection(bev, fd);
|
||||
std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
|
||||
conn->InitConnection();
|
||||
server->AddConnection(fd, conn);
|
||||
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
|
||||
conn->InitConnection([=](std::shared_ptr<CommMessage> message) {
|
||||
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
|
||||
if (on_server_receive) {
|
||||
on_server_receive(conn, message);
|
||||
}
|
||||
});
|
||||
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
|
||||
reinterpret_cast<void *>(conn.get()));
|
||||
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
|
||||
}
|
||||
}
|
||||
|
||||
TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
|
||||
TcpConnection *conn = nullptr;
|
||||
std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
|
||||
std::shared_ptr<TcpConnection> conn = nullptr;
|
||||
if (client_accept_) {
|
||||
conn = const_cast<TcpConnection *>(client_accept_(*this));
|
||||
conn = (client_accept_(*this));
|
||||
} else {
|
||||
conn = new TcpConnection(bev, fd, this);
|
||||
conn = std::make_shared<TcpConnection>(bev, fd, this);
|
||||
}
|
||||
|
||||
return conn;
|
||||
|
@ -312,8 +318,8 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
|
|||
MS_EXCEPTION_IF_NULL(data);
|
||||
struct evbuffer *output = bufferevent_get_output(bev);
|
||||
size_t remain = evbuffer_get_length(output);
|
||||
auto conn = reinterpret_cast<TcpConnection *>(data);
|
||||
TcpServer *srv = conn->GetServer();
|
||||
auto conn = static_cast<class TcpConnection *>(data);
|
||||
auto srv = conn->GetServer();
|
||||
|
||||
if (events & BEV_EVENT_EOF) {
|
||||
MS_LOG(INFO) << "Event buffer end of file!";
|
||||
|
@ -355,13 +361,18 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
|
|||
}
|
||||
}
|
||||
|
||||
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
|
||||
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
return conn->SendMessage(message);
|
||||
}
|
||||
|
||||
void TcpServer::SendMessage(const CommMessage &message) {
|
||||
void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
|
||||
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
|
||||
SendMessage(*it->second, message);
|
||||
SendMessage(it->second, message);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -371,7 +382,7 @@ std::string TcpServer::BoundIp() const { return server_address_; }
|
|||
|
||||
int TcpServer::ConnectionNum() const { return connections_.size(); }
|
||||
|
||||
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }
|
||||
const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
|
||||
|
||||
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
|
||||
|
||||
|
|
|
@ -34,8 +34,6 @@
|
|||
#include <thread>
|
||||
#include <atomic>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/tcp_message_handler.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -47,36 +45,42 @@ namespace core {
|
|||
class TcpServer;
|
||||
class TcpConnection {
|
||||
public:
|
||||
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server)
|
||||
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, TcpServer *server)
|
||||
: buffer_event_(bev), fd_(fd), server_(server) {}
|
||||
TcpConnection(const TcpConnection &);
|
||||
virtual ~TcpConnection() = default;
|
||||
|
||||
virtual void InitConnection();
|
||||
using Callback = std::function<void(const std::shared_ptr<CommMessage>)>;
|
||||
|
||||
virtual void InitConnection(const messageReceive &callback);
|
||||
virtual void SendMessage(const void *buffer, size_t num) const;
|
||||
void SendMessage(const CommMessage &message) const;
|
||||
bool SendMessage(std::shared_ptr<CommMessage> message) const;
|
||||
virtual void OnReadHandler(const void *buffer, size_t numBytes);
|
||||
TcpServer *GetServer() const;
|
||||
const evutil_socket_t &GetFd() const;
|
||||
void set_callback(const Callback &callback);
|
||||
|
||||
protected:
|
||||
struct bufferevent *buffer_event_;
|
||||
evutil_socket_t fd_;
|
||||
const TcpServer *server_;
|
||||
TcpServer *server_;
|
||||
TcpMessageHandler tcp_message_handler_;
|
||||
Callback callback_;
|
||||
};
|
||||
|
||||
using OnServerReceiveMessage =
|
||||
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>;
|
||||
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
|
||||
|
||||
class TcpServer {
|
||||
public:
|
||||
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
|
||||
using OnAccepted = std::function<std::shared_ptr<TcpConnection>(const TcpServer &)>;
|
||||
using OnTimerOnce = std::function<void(const TcpServer &)>;
|
||||
using OnTimer = std::function<void()>;
|
||||
|
||||
explicit TcpServer(const std::string &address, std::uint16_t port);
|
||||
TcpServer(const std::string &address, std::uint16_t port);
|
||||
TcpServer(const TcpServer &server);
|
||||
virtual ~TcpServer();
|
||||
|
||||
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
|
||||
|
@ -90,16 +94,17 @@ class TcpServer {
|
|||
void StartTimer(const uint32_t &time);
|
||||
void Stop();
|
||||
void SendToAllClients(const char *data, size_t len);
|
||||
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
|
||||
void AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection);
|
||||
void RemoveConnection(const evutil_socket_t &fd);
|
||||
std::shared_ptr<TcpConnection> GetConnectionByFd(const evutil_socket_t &fd);
|
||||
OnServerReceiveMessage GetServerReceive() const;
|
||||
void SetMessageCallback(const OnServerReceiveMessage &cb);
|
||||
void SendMessage(const TcpConnection &conn, const CommMessage &message);
|
||||
void SendMessage(const CommMessage &message);
|
||||
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
void SendMessage(std::shared_ptr<CommMessage> message);
|
||||
uint16_t BoundPort() const;
|
||||
std::string BoundIp() const;
|
||||
int ConnectionNum() const;
|
||||
const std::map<evutil_socket_t, const TcpConnection *> &Connections() const;
|
||||
const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &Connections() const;
|
||||
|
||||
protected:
|
||||
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
|
||||
|
@ -109,7 +114,7 @@ class TcpServer {
|
|||
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
|
||||
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
|
||||
static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg);
|
||||
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
|
||||
std::shared_ptr<TcpConnection> onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
|
||||
|
||||
struct event_base *base_;
|
||||
struct event *signal_event_;
|
||||
|
@ -118,7 +123,7 @@ class TcpServer {
|
|||
std::uint16_t server_port_;
|
||||
std::atomic<bool> is_stop_;
|
||||
|
||||
std::map<evutil_socket_t, const TcpConnection *> connections_;
|
||||
std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> connections_;
|
||||
OnConnected client_connection_;
|
||||
OnDisconnected client_disconnection_;
|
||||
OnAccepted client_accept_;
|
||||
|
|
|
@ -24,8 +24,6 @@
|
|||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
|
|
|
@ -31,7 +31,7 @@ class TestClusterAvailableTimeout : public UT::Common {
|
|||
};
|
||||
|
||||
TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
|
||||
ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999);
|
||||
ClusterConfig::Init(1, 1, "127.0.0.1", 9999);
|
||||
ClusterConfig::set_cluster_available_timeout(3);
|
||||
SchedulerNode node;
|
||||
node.Start();
|
||||
|
|
|
@ -33,7 +33,7 @@ class TestClusterConfig : public UT::Common {
|
|||
};
|
||||
|
||||
TEST_F(TestClusterConfig, HeartbeatInterval) {
|
||||
ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080);
|
||||
ClusterConfig::Init(2, 2, "127.0.0.1", 8080);
|
||||
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3);
|
||||
ClusterConfig::set_heartbeat_interval(100);
|
||||
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100);
|
||||
|
|
|
@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
|
|||
}
|
||||
|
||||
TEST_F(TestCommUtil, ValidateRankId) {
|
||||
ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999);
|
||||
ClusterConfig::Init(3, 2, "127.0.0.1", 9999);
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
|
||||
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
|
||||
|
|
|
@ -35,7 +35,7 @@ class TestTcpMessageHandler : public UT::Common {
|
|||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
|
@ -55,7 +55,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
|
|||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
|
@ -86,7 +86,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
|
|||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); });
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); });
|
||||
|
||||
std::string data(4081, 'a');
|
||||
CommMessage message;
|
||||
|
@ -126,7 +126,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
|
|||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); });
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); });
|
||||
|
||||
std::string data(4077, 'a');
|
||||
CommMessage message;
|
||||
|
|
|
@ -32,12 +32,12 @@ class TestTcpServer : public UT::Common {
|
|||
void SetUp() override {
|
||||
server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
|
||||
std::unique_ptr<std::thread> http_server_thread_(nullptr);
|
||||
http_server_thread_ = std::make_unique<std::thread>([&]() {
|
||||
server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
http_server_thread_ = std::make_unique<std::thread>([=]() {
|
||||
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
KVMessage kv_message;
|
||||
kv_message.ParseFromString(message.data());
|
||||
kv_message.ParseFromString(message->data());
|
||||
EXPECT_EQ(2, kv_message.keys_size());
|
||||
const_cast<TcpServer&>(server).SendMessage(conn, message);
|
||||
server_->SendMessage(conn, message);
|
||||
});
|
||||
server_->Init();
|
||||
server_->Start();
|
||||
|
@ -58,6 +58,7 @@ class TestTcpServer : public UT::Common {
|
|||
|
||||
TEST_F(TestTcpServer, ServerSendMessage) {
|
||||
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort());
|
||||
std::cout << server_->BoundPort() << std::endl;
|
||||
std::unique_ptr<std::thread> http_client_thread(nullptr);
|
||||
http_client_thread = std::make_unique<std::thread>([&]() {
|
||||
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
|
||||
|
|
Loading…
Reference in New Issue