!10946 Switch bare pointer to shared_ptr

From: @anancds
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-11 14:12:46 +08:00 committed by Gitee
commit 6eb4634ba2
26 changed files with 261 additions and 176 deletions

View File

@ -75,6 +75,8 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &
auto client = GetOrCreateTcpClient((*it).first.second); auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message); client->SendMessage(comm_message);
} }
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }
@ -126,11 +128,13 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
auto client = GetOrCreateTcpClient(rank_ids.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message); client->SendMessage(comm_message);
} }
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *output, const uint32_t &timeout) { std::string *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) { if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
@ -141,7 +145,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
set_message_callback(request_id, [&]() { set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock(); receive_messages_mutex_.lock();
auto res = receive_messages_[request_id]; auto res = receive_messages_[request_id];
*output = res[rank_id]; *output = res[rank_id].data();
receive_messages_.erase(request_id); receive_messages_.erase(request_id);
receive_messages_mutex_.unlock(); receive_messages_mutex_.unlock();
}); });
@ -157,11 +161,13 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
comm_message.set_data(message); comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id); auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message); client->SendMessage(comm_message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage> *output, const std::vector<std::string> &data, std::vector<std::string> *output,
const uint32_t &timeout) { const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
uint64_t request_id = ++next_request_id_; uint64_t request_id = ++next_request_id_;
@ -177,7 +183,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
receive_messages_mutex_.lock(); receive_messages_mutex_.lock();
auto res = receive_messages_[request_id]; auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) { for (size_t it = 0; it < len; ++it) {
(*output).push_back(res[rank_ids.at(it)]); (*output).push_back(res[rank_ids.at(it)].data());
} }
receive_messages_.erase(request_id); receive_messages_.erase(request_id);
receive_messages_mutex_.unlock(); receive_messages_mutex_.unlock();
@ -201,6 +207,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
auto client = GetOrCreateTcpClient(rank_ids.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message); client->SendMessage(comm_message);
} }
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }
@ -215,7 +223,7 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
} }
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
const std::string &message, const uint32_t &timeout) { const std::string &message) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) { if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }
@ -233,19 +241,19 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
} }
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, CommMessage *output) { const uint32_t &rank_id, std::string *output) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) { if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
*output = received_data_[std::make_pair(rank_id, rank_request_id)]; *output = received_data_[std::make_pair(rank_id, rank_request_id)].data();
received_data_.erase(std::make_pair(rank_id, rank_request_id)); received_data_.erase(std::make_pair(rank_id, rank_request_id));
} else { } else {
set_receive_callback(rank_id, rank_request_id, [=]() { set_receive_callback(rank_id, rank_request_id, [=]() {
receive_callbacks_mutex_.lock(); receive_callbacks_mutex_.lock();
*output = received_data_[std::make_pair(rank_id, 1)]; *output = received_data_[std::make_pair(rank_id, 1)].data();
received_data_.erase(std::make_pair(rank_id, rank_request_id)); received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_callbacks_mutex_.unlock(); receive_callbacks_mutex_.unlock();
}); });
@ -272,13 +280,25 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
<< " begin send heartbeat to the scheduler!"; << " begin send heartbeat to the scheduler!";
heart_beat_thread_ = std::make_unique<std::thread>([&]() { heart_beat_thread_ = std::make_unique<std::thread>([&]() {
while (!is_finish_.load()) { while (!is_finish_.load()) {
Heartbeat(client); if (!Heartbeat(client)) {
MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
if (!CheckSchedulerTimeout() && on_node_event_message_) {
MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
is_finish_ = true;
wait_finish_cond_.notify_all();
on_node_event_message_(NodeEvent::SCHEDULER_TIMEOUT);
}
} else {
UpdateSchedulerTime();
}
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval())); std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
} }
}); });
} }
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
MessageMeta meta; MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT); meta.set_cmd(NodeCommand::HEARTBEAT);
@ -292,11 +312,31 @@ void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n
if (!SendMessageSync(client, message)) { if (!SendMessageSync(client, message)) {
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
} }
return true;
}
void AbstractNode::UpdateSchedulerTime() {
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
scheduler_time_ = current_time;
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
<< " update scheduler time, the current time is: " << current_time.tv_sec;
}
bool AbstractNode::CheckSchedulerTimeout() const {
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
if (scheduler_time_.tv_sec + ClusterConfig::scheduler_timeout() < current_time.tv_sec) {
return true;
}
return false;
} }
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) { void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
HeartbeatRespMessage heartbeat_resp_message; HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromString(message.data()); heartbeat_resp_message.ParseFromString(message.data());
is_ready_ = heartbeat_resp_message.is_cluster_ready(); is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) { if (is_ready_.load()) {
wait_start_cond_.notify_all(); wait_start_cond_.notify_all();
@ -353,9 +393,9 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui
*message.mutable_pb_meta() = {meta}; *message.mutable_pb_meta() = {meta};
message.set_data(finish_message.SerializeAsString()); message.set_data(finish_message.SerializeAsString());
if (!SendMessageSync(client, message)) { if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Disconnect timeout!"; MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
} }
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
return WaitForDisconnect(timeout); return WaitForDisconnect(timeout);
} }
@ -444,6 +484,8 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con
message_tracker_[request_id] = std::make_pair(1, 0); message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message); client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }
@ -452,6 +494,8 @@ uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client
message_tracker_[request_id] = std::make_pair(1, 0); message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message); client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return request_id; return request_id;
} }
@ -460,6 +504,8 @@ void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
const MessageMeta &message_meta = message.pb_meta(); const MessageMeta &message_meta = message.pb_meta();
const uint32_t &rank_id = message_meta.rank_id(); const uint32_t &rank_id = message_meta.rank_id();
const uint64_t request_id = message_meta.request_id(); const uint64_t request_id = message_meta.request_id();
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
auto it = receive_messages_.find(request_id); auto it = receive_messages_.find(request_id);
if (it != receive_messages_.end()) { if (it != receive_messages_.end()) {
it->second[rank_id] = message; it->second[rank_id] = message;

View File

@ -42,23 +42,24 @@ class AbstractNode : public Node {
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output, bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds); std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message);
const uint32_t &timeout = kCommTimeoutInSeconds);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
CommMessage *output); std::string *output);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected: protected:
void Register(const std::shared_ptr<TcpClient> &client); void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message); void ProcessRegisterResp(const CommMessage &message);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const;
void ProcessHeartbeatResp(const CommMessage &message); void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client); void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message); void ProcessFetchServersResp(const CommMessage &message);
@ -113,6 +114,7 @@ class AbstractNode : public Node {
// the key is rank_id, the value is rank_id's actual request_id // the key is rank_id, the value is rank_id's actual request_id
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
std::mutex rank_request_ids_mutex; std::mutex rank_request_ids_mutex;
timeval scheduler_time_;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps

View File

@ -33,15 +33,17 @@ uint32_t ClusterConfig::heartbeat_timeout_ = 30;
uint32_t ClusterConfig::cluster_available_timeout_ = 300; uint32_t ClusterConfig::cluster_available_timeout_ = 300;
// The timeout period for the client to connect to the server is 100ms. // The timeout period for the client to connect to the server is 100ms.
uint32_t ClusterConfig::connect_interval_ = 100; uint32_t ClusterConfig::connect_interval_ = 100;
// When the scheduler exits, the worker and server can continue to work for 5 hours
uint32_t ClusterConfig::scheduler_timeout_ = 3600 * 5;
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { const uint16_t &scheduler_port) {
worker_num_ = worker_num; worker_num_ = worker_num;
server_num_ = server_num; server_num_ = server_num;
if (!CommUtil::CheckIp(*scheduler_host.get())) { if (!CommUtil::CheckIp(scheduler_host)) {
MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!"; MS_LOG(EXCEPTION) << "The scheduler_host:" << scheduler_host << " is illegal!";
} }
scheduler_host_ = std::move(scheduler_host); scheduler_host_ = std::make_unique<std::string>(scheduler_host);
scheduler_port_ = scheduler_port; scheduler_port_ = scheduler_port;
} }
@ -55,7 +57,7 @@ void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) {
heartbeat_interval_ = heartbeat_interval; heartbeat_interval_ = heartbeat_interval;
} }
std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } std::string ClusterConfig::scheduler_host() { return *scheduler_host_; }
uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; }
@ -74,6 +76,10 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa
uint32_t ClusterConfig::connect_interval() { return connect_interval_; } uint32_t ClusterConfig::connect_interval() { return connect_interval_; }
void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; }
uint32_t ClusterConfig::scheduler_timeout() { return scheduler_timeout_; }
void ClusterConfig::set_scheduler_timeout(const uint32_t &scheduler_timeout) { scheduler_timeout_ = scheduler_timeout; }
} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -30,7 +30,7 @@ namespace ps {
namespace core { namespace core {
class ClusterConfig { class ClusterConfig {
public: public:
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
const uint16_t &scheduler_port); const uint16_t &scheduler_port);
static uint32_t worker_num(); static uint32_t worker_num();
static uint32_t server_num(); static uint32_t server_num();
@ -44,6 +44,8 @@ class ClusterConfig {
static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout);
static uint32_t connect_interval(); static uint32_t connect_interval();
static void set_connect_interval(const uint32_t &connect_interval); static void set_connect_interval(const uint32_t &connect_interval);
static uint32_t scheduler_timeout();
static void set_scheduler_timeout(const uint32_t &scheduler_timeout);
private: private:
static uint32_t worker_num_; static uint32_t worker_num_;
@ -54,6 +56,7 @@ class ClusterConfig {
static uint32_t heartbeat_timeout_; static uint32_t heartbeat_timeout_;
static uint32_t cluster_available_timeout_; static uint32_t cluster_available_timeout_;
static uint32_t connect_interval_; static uint32_t connect_interval_;
static uint32_t scheduler_timeout_;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps

View File

@ -21,7 +21,12 @@ namespace ps {
namespace core { namespace core {
std::string Node::node_id() const { return node_info_.node_id_; } std::string Node::node_id() const { return node_info_.node_id_; }
uint32_t Node::rank_id() const { return node_info_.rank_id_; } uint32_t Node::rank_id() const {
if (!is_ready_.load()) {
MS_LOG(EXCEPTION) << "The cluster is not ready yet to get rank id!";
}
return node_info_.rank_id_;
}
NodeRole Node::role() const { return node_info_.node_role_; } NodeRole Node::role() const { return node_info_.node_role_; }

View File

@ -30,8 +30,6 @@
#include <utility> #include <utility>
#include <tuple> #include <tuple>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "ps/core/node_info.h" #include "ps/core/node_info.h"
#include "ps/core/tcp_client.h" #include "ps/core/tcp_client.h"

View File

@ -25,7 +25,7 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT };
struct NodeInfo { struct NodeInfo {
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}

View File

@ -64,8 +64,8 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
struct timeval current_time {}; struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr); (void)gettimeofday(&current_time, nullptr);
heartbeats_[node_id] = current_time; heartbeats_[node_id] = current_time;
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec; << ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
} }
void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); } void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); }

View File

@ -31,8 +31,6 @@
#include <condition_variable> #include <condition_variable>
#include <unordered_set> #include <unordered_set>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/node.h" #include "ps/core/node.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/convert_utils_base.h" #include "utils/convert_utils_base.h"

View File

@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME;
enum PSCommand { enum PSCommand {
PUSH = 0; PUSH = 0;
PULL = 1; PULL = 1;
INIT_EMBEDDING_TABLE = 2;
} }
message KVMessage { message KVMessage {

View File

@ -37,9 +37,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
return true; return true;
} }
void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
HeartbeatMessage heartbeat_message; HeartbeatMessage heartbeat_message;
heartbeat_message.ParseFromString(message.data()); heartbeat_message.ParseFromString(message->data());
node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
@ -59,10 +60,10 @@ void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnectio
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
CommMessage comm_message; std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message.mutable_pb_meta() = {message.pb_meta()}; *comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message.set_data(heartbeat_resp_message.SerializeAsString()); comm_message->set_data(heartbeat_resp_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message); server->SendMessage(conn, comm_message);
} }
void SchedulerNode::Initialize() { void SchedulerNode::Initialize() {
@ -79,23 +80,23 @@ void SchedulerNode::CreateTcpServer() {
std::string scheduler_host = ClusterConfig::scheduler_host(); std::string scheduler_host = ClusterConfig::scheduler_host();
uint32_t scheduler_port = ClusterConfig::scheduler_port(); uint32_t scheduler_port = ClusterConfig::scheduler_port();
server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port); server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
switch (message.pb_meta().cmd()) { switch (message->pb_meta().cmd()) {
case NodeCommand::HEARTBEAT: case NodeCommand::HEARTBEAT:
ProcessHeartbeat(server, conn, message); ProcessHeartbeat(server_, conn, message);
break; break;
case NodeCommand::REGISTER: case NodeCommand::REGISTER:
ProcessRegister(server, conn, message); ProcessRegister(server_, conn, message);
break; break;
case NodeCommand::FINISH: case NodeCommand::FINISH:
ProcessFinish(server, conn, message); ProcessFinish(server_, conn, message);
break; break;
case NodeCommand::FETCH_SERVER: case NodeCommand::FETCH_SERVER:
ProcessFetchServers(server, conn, message); ProcessFetchServers(server_, conn, message);
break; break;
default: default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
} }
}); });
@ -107,10 +108,11 @@ void SchedulerNode::CreateTcpServer() {
}); });
} }
void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
MS_LOG(INFO) << "The scheduler process a register message!"; MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message; RegisterMessage register_message;
register_message.ParseFromString(message.data()); register_message.ParseFromString(message->data());
// assign worker node and server node rank id // assign worker node and server node rank id
int rank_id = node_manager_.NextRankId(register_message); int rank_id = node_manager_.NextRankId(register_message);
@ -124,31 +126,32 @@ void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection
register_resp_message.set_node_id(node_id); register_resp_message.set_node_id(node_id);
register_resp_message.set_rank_id(rank_id); register_resp_message.set_rank_id(rank_id);
CommMessage comm_message; std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message.mutable_pb_meta() = {message.pb_meta()}; *comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message.set_data(register_resp_message.SerializeAsString()); comm_message->set_data(register_resp_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message); server->SendMessage(conn, comm_message);
} }
void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
FinishMessage finish_message; FinishMessage finish_message;
finish_message.ParseFromString(message.data()); finish_message.ParseFromString(message->data());
node_manager_.AddFinishNode(finish_message); node_manager_.AddFinishNode(finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
const_cast<TcpServer &>(server).SendMessage(conn, message); server->SendMessage(conn, message);
} }
void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
const CommMessage &message) { std::shared_ptr<CommMessage> message) {
FetchServersRespMessage fetch_servers_message; FetchServersRespMessage fetch_servers_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
CommMessage comm_message; std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message.mutable_pb_meta() = {message.pb_meta()}; *comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message.set_data(fetch_servers_message.SerializeAsString()); comm_message->set_data(fetch_servers_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message); server->SendMessage(conn, comm_message);
} }
void SchedulerNode::StartUpdateClusterStateTimer() { void SchedulerNode::StartUpdateClusterStateTimer() {

View File

@ -26,8 +26,6 @@
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h" #include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h" #include "ps/core/tcp_server.h"
@ -51,13 +49,17 @@ class SchedulerNode : public Node {
private: private:
void Initialize(); void Initialize();
void CreateTcpServer(); void CreateTcpServer();
void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); std::shared_ptr<CommMessage> message);
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
void StartUpdateClusterStateTimer(); void StartUpdateClusterStateTimer();
void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); std::shared_ptr<CommMessage> message);
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::unique_ptr<TcpServer> server_; std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_; std::unique_ptr<std::thread> scheduler_thread_;
std::unique_ptr<std::thread> update_state_thread_; std::unique_ptr<std::thread> update_state_thread_;

View File

@ -30,7 +30,8 @@ bool ServerNode::Start(const uint32_t &timeout) {
StartHeartbeatTimer(client_to_scheduler_); StartHeartbeatTimer(client_to_scheduler_);
if (!WaitForStart(timeout)) { if (!WaitForStart(timeout)) {
MS_LOG(ERROR) << "Start Server node timeout!"; MS_LOG(ERROR) << "Start server node timeout!";
return false;
} }
MS_LOG(INFO) << "The cluster is ready to use!"; MS_LOG(INFO) << "The cluster is ready to use!";
@ -45,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) {
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
const std::string &message) { MS_EXCEPTION_IF_NULL(conn);
auto &meta = const_cast<MessageMeta &>(message_meta); MS_EXCEPTION_IF_NULL(message);
meta.set_role(node_info_.node_role_); message->mutable_pb_meta()->set_role(node_info_.node_role_);
meta.set_rank_id(node_info_.rank_id_); message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_);
CommMessage comm_message; const MessageMeta &message_meta = message->pb_meta();
*comm_message.mutable_pb_meta() = {meta}; const uint64_t request_id = message_meta.request_id();
comm_message.set_data(message); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
const_cast<TcpServer &>(server).SendMessage(conn, comm_message); server_->SendMessage(conn, message);
} }
void ServerNode::CreateTcpServer() { void ServerNode::CreateTcpServer() {
@ -62,17 +63,17 @@ void ServerNode::CreateTcpServer() {
std::string server_ip; std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0); server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
switch (message.pb_meta().cmd()) { switch (message->pb_meta().cmd()) {
case NodeCommand::SEND_DATA: case NodeCommand::SEND_DATA:
ProcessSendData(server, conn, message); ProcessSendData(conn, message);
break; break;
case NodeCommand::COLLECTIVE_SEND_DATA: case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(server, conn, message); ProcessCollectiveSendData(conn, message);
RunReceiveCallback(message); RunReceiveCallback(*message);
break; break;
default: default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
} }
}); });
server_->Init(); server_->Init();
@ -97,15 +98,18 @@ void ServerNode::Initialize() {
MS_LOG(INFO) << "Server node init client successful!"; MS_LOG(INFO) << "Server node init client successful!";
} }
void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
request_handler_(server, conn, message.pb_meta(), message.data()); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
request_handler_(conn, message);
} }
void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
const CommMessage &message) { MS_EXCEPTION_IF_NULL(conn);
CommMessage comm_message; MS_EXCEPTION_IF_NULL(message);
*comm_message.mutable_pb_meta() = {message.pb_meta()}; std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
const_cast<TcpServer &>(server).SendMessage(conn, comm_message); *comm_message->mutable_pb_meta() = {message->pb_meta()};
server_->SendMessage(conn, comm_message);
} }
bool ServerNode::Stop() { bool ServerNode::Stop() {

View File

@ -44,18 +44,16 @@ class ServerNode : public AbstractNode {
bool Stop() override; bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta, using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
const std::string &message)>;
void set_handler(const RequestHandler &handler); void set_handler(const RequestHandler &handler);
void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta, void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
const std::string &message);
private: private:
void CreateTcpServer(); void CreateTcpServer();
void Initialize(); void Initialize();
void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message); void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
std::shared_ptr<TcpServer> server_; std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_; std::unique_ptr<std::thread> server_thread_;

View File

@ -46,9 +46,9 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
server_port_(port), server_port_(port),
is_stop_(true), is_stop_(true),
is_connected_(false) { is_connected_(false) {
message_handler_.SetCallback([this](const CommMessage &message) { message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) {
if (message_callback_) { if (message_callback_) {
message_callback_(*this, message); message_callback_(*this, *message);
} }
}); });
} }
@ -105,7 +105,7 @@ void TcpClient::Init() {
sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
sin.sin_port = htons(server_port_); sin.sin_port = htons(server_port_);
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
MS_EXCEPTION_IF_NULL(buffer_event_); MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
@ -261,17 +261,23 @@ void TcpClient::StartWithNoBlock() {
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
void TcpClient::SendMessage(const CommMessage &message) const { bool TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_); MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_lock(buffer_event_);
bool res = true;
size_t buf_size = message.ByteSizeLong(); size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size); std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!"; MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
} }
if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) { if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
} }
bufferevent_unlock(buffer_event_);
return res;
} }
void TcpClient::StartTimer(const uint32_t &time) { void TcpClient::StartTimer(const uint32_t &time) {

View File

@ -33,8 +33,6 @@
#include <condition_variable> #include <condition_variable>
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "utils/convert_utils_base.h" #include "utils/convert_utils_base.h"
namespace mindspore { namespace mindspore {
@ -62,7 +60,7 @@ class TcpClient {
void Start(); void Start();
void StartWithNoBlock(); void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb); void SetMessageCallback(const OnMessage &cb);
void SendMessage(const CommMessage &message) const; bool SendMessage(const CommMessage &message) const;
void StartTimer(const uint32_t &time); void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer); void set_timer_callback(const OnTimer &timer);
const event_base &eventbase(); const event_base &eventbase();

View File

@ -57,8 +57,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
} }
if (remaining_length_ == 0) { if (remaining_length_ == 0) {
CommMessage pb_message; std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>();
pb_message.ParseFromArray(message_buffer_.get(), message_length_); pb_message->ParseFromArray(message_buffer_.get(), message_length_);
if (message_callback_) { if (message_callback_) {
message_callback_(pb_message); message_callback_(pb_message);
} }

View File

@ -30,7 +30,7 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
using messageReceive = std::function<void(const CommMessage &message)>; using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>;
constexpr int kHeaderLen = 8; constexpr int kHeaderLen = 8;
class TcpMessageHandler { class TcpMessageHandler {

View File

@ -32,14 +32,7 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
void TcpConnection::InitConnection() { void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
if (on_server_receive) {
on_server_receive(*server_, *this, message);
}
});
}
void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); }
@ -49,23 +42,30 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const {
} }
} }
TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); } TcpServer *TcpConnection::GetServer() const { return server_; }
const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
void TcpConnection::SendMessage(const CommMessage &message) const { void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
MS_EXCEPTION_IF_NULL(buffer_event_); MS_EXCEPTION_IF_NULL(buffer_event_);
size_t buf_size = message.ByteSizeLong(); MS_EXCEPTION_IF_NULL(message);
bufferevent_lock(buffer_event_);
bool res = true;
size_t buf_size = message->ByteSizeLong();
std::vector<unsigned char> serialized(buf_size); std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); message->SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
sizeof(buf_size)) == -1) { MS_LOG(ERROR) << "Event buffer add header failed!";
MS_LOG(EXCEPTION) << "Event buffer add header failed!"; res = false;
} }
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(), if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
buf_size) == -1) { MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; res = false;
} }
bufferevent_unlock(buffer_event_);
return res;
} }
TcpServer::TcpServer(const std::string &address, std::uint16_t port) TcpServer::TcpServer(const std::string &address, std::uint16_t port)
@ -225,7 +225,7 @@ void TcpServer::SendToAllClients(const char *data, size_t len) {
} }
} }
void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) {
MS_EXCEPTION_IF_NULL(connection); MS_EXCEPTION_IF_NULL(connection);
std::lock_guard<std::mutex> lock(connection_mutex_); std::lock_guard<std::mutex> lock(connection_mutex_);
connections_.insert(std::make_pair(fd, connection)); connections_.insert(std::make_pair(fd, connection));
@ -233,11 +233,11 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co
void TcpServer::RemoveConnection(const evutil_socket_t &fd) { void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::lock_guard<std::mutex> lock(connection_mutex_); std::lock_guard<std::mutex> lock(connection_mutex_);
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
delete connection;
connections_.erase(fd); connections_.erase(fd);
} }
std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; }
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
void *data) { void *data) {
auto server = reinterpret_cast<class TcpServer *>(data); auto server = reinterpret_cast<class TcpServer *>(data);
@ -246,7 +246,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
MS_EXCEPTION_IF_NULL(base); MS_EXCEPTION_IF_NULL(base);
MS_EXCEPTION_IF_NULL(sockaddr); MS_EXCEPTION_IF_NULL(sockaddr);
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
if (!bev) { if (!bev) {
MS_LOG(ERROR) << "Error constructing buffer event!"; MS_LOG(ERROR) << "Error constructing buffer event!";
int ret = event_base_loopbreak(base); int ret = event_base_loopbreak(base);
@ -256,23 +256,29 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
return; return;
} }
TcpConnection *conn = server->onCreateConnection(bev, fd); std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd);
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
conn->InitConnection();
server->AddConnection(fd, conn); server->AddConnection(fd, conn);
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn)); conn->InitConnection([=](std::shared_ptr<CommMessage> message) {
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
if (on_server_receive) {
on_server_receive(conn, message);
}
});
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
reinterpret_cast<void *>(conn.get()));
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
} }
} }
TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
TcpConnection *conn = nullptr; std::shared_ptr<TcpConnection> conn = nullptr;
if (client_accept_) { if (client_accept_) {
conn = const_cast<TcpConnection *>(client_accept_(*this)); conn = (client_accept_(*this));
} else { } else {
conn = new TcpConnection(bev, fd, this); conn = std::make_shared<TcpConnection>(bev, fd, this);
} }
return conn; return conn;
@ -312,8 +318,8 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
struct evbuffer *output = bufferevent_get_output(bev); struct evbuffer *output = bufferevent_get_output(bev);
size_t remain = evbuffer_get_length(output); size_t remain = evbuffer_get_length(output);
auto conn = reinterpret_cast<TcpConnection *>(data); auto conn = static_cast<class TcpConnection *>(data);
TcpServer *srv = conn->GetServer(); auto srv = conn->GetServer();
if (events & BEV_EVENT_EOF) { if (events & BEV_EVENT_EOF) {
MS_LOG(INFO) << "Event buffer end of file!"; MS_LOG(INFO) << "Event buffer end of file!";
@ -355,13 +361,18 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
} }
} }
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
return conn->SendMessage(message);
}
void TcpServer::SendMessage(const CommMessage &message) { void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
std::lock_guard<std::mutex> lock(connection_mutex_); std::lock_guard<std::mutex> lock(connection_mutex_);
MS_EXCEPTION_IF_NULL(message);
for (auto it = connections_.begin(); it != connections_.end(); ++it) { for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(*it->second, message); SendMessage(it->second, message);
} }
} }
@ -371,7 +382,7 @@ std::string TcpServer::BoundIp() const { return server_address_; }
int TcpServer::ConnectionNum() const { return connections_.size(); } int TcpServer::ConnectionNum() const { return connections_.size(); }
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }

View File

@ -34,8 +34,6 @@
#include <thread> #include <thread>
#include <atomic> #include <atomic>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/tcp_message_handler.h" #include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -47,36 +45,42 @@ namespace core {
class TcpServer; class TcpServer;
class TcpConnection { class TcpConnection {
public: public:
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, TcpServer *server)
: buffer_event_(bev), fd_(fd), server_(server) {} : buffer_event_(bev), fd_(fd), server_(server) {}
TcpConnection(const TcpConnection &);
virtual ~TcpConnection() = default; virtual ~TcpConnection() = default;
virtual void InitConnection(); using Callback = std::function<void(const std::shared_ptr<CommMessage>)>;
virtual void InitConnection(const messageReceive &callback);
virtual void SendMessage(const void *buffer, size_t num) const; virtual void SendMessage(const void *buffer, size_t num) const;
void SendMessage(const CommMessage &message) const; bool SendMessage(std::shared_ptr<CommMessage> message) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes); virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const; TcpServer *GetServer() const;
const evutil_socket_t &GetFd() const; const evutil_socket_t &GetFd() const;
void set_callback(const Callback &callback);
protected: protected:
struct bufferevent *buffer_event_; struct bufferevent *buffer_event_;
evutil_socket_t fd_; evutil_socket_t fd_;
const TcpServer *server_; TcpServer *server_;
TcpMessageHandler tcp_message_handler_; TcpMessageHandler tcp_message_handler_;
Callback callback_;
}; };
using OnServerReceiveMessage = using OnServerReceiveMessage =
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>; std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
class TcpServer { class TcpServer {
public: public:
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; using OnAccepted = std::function<std::shared_ptr<TcpConnection>(const TcpServer &)>;
using OnTimerOnce = std::function<void(const TcpServer &)>; using OnTimerOnce = std::function<void(const TcpServer &)>;
using OnTimer = std::function<void()>; using OnTimer = std::function<void()>;
explicit TcpServer(const std::string &address, std::uint16_t port); TcpServer(const std::string &address, std::uint16_t port);
TcpServer(const TcpServer &server);
virtual ~TcpServer(); virtual ~TcpServer();
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
@ -90,16 +94,17 @@ class TcpServer {
void StartTimer(const uint32_t &time); void StartTimer(const uint32_t &time);
void Stop(); void Stop();
void SendToAllClients(const char *data, size_t len); void SendToAllClients(const char *data, size_t len);
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); void AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection);
void RemoveConnection(const evutil_socket_t &fd); void RemoveConnection(const evutil_socket_t &fd);
std::shared_ptr<TcpConnection> GetConnectionByFd(const evutil_socket_t &fd);
OnServerReceiveMessage GetServerReceive() const; OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb); void SetMessageCallback(const OnServerReceiveMessage &cb);
void SendMessage(const TcpConnection &conn, const CommMessage &message); bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void SendMessage(const CommMessage &message); void SendMessage(std::shared_ptr<CommMessage> message);
uint16_t BoundPort() const; uint16_t BoundPort() const;
std::string BoundIp() const; std::string BoundIp() const;
int ConnectionNum() const; int ConnectionNum() const;
const std::map<evutil_socket_t, const TcpConnection *> &Connections() const; const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &Connections() const;
protected: protected:
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
@ -109,7 +114,7 @@ class TcpServer {
static void EventCallback(struct bufferevent *, std::int16_t events, void *server); static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg); static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg);
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); std::shared_ptr<TcpConnection> onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
struct event_base *base_; struct event_base *base_;
struct event *signal_event_; struct event *signal_event_;
@ -118,7 +123,7 @@ class TcpServer {
std::uint16_t server_port_; std::uint16_t server_port_;
std::atomic<bool> is_stop_; std::atomic<bool> is_stop_;
std::map<evutil_socket_t, const TcpConnection *> connections_; std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> connections_;
OnConnected client_connection_; OnConnected client_connection_;
OnDisconnected client_disconnection_; OnDisconnected client_disconnection_;
OnAccepted client_accept_; OnAccepted client_accept_;

View File

@ -24,8 +24,6 @@
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h" #include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h" #include "ps/core/tcp_server.h"

View File

@ -31,7 +31,7 @@ class TestClusterAvailableTimeout : public UT::Common {
}; };
TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999); ClusterConfig::Init(1, 1, "127.0.0.1", 9999);
ClusterConfig::set_cluster_available_timeout(3); ClusterConfig::set_cluster_available_timeout(3);
SchedulerNode node; SchedulerNode node;
node.Start(); node.Start();

View File

@ -33,7 +33,7 @@ class TestClusterConfig : public UT::Common {
}; };
TEST_F(TestClusterConfig, HeartbeatInterval) { TEST_F(TestClusterConfig, HeartbeatInterval) {
ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080); ClusterConfig::Init(2, 2, "127.0.0.1", 8080);
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3);
ClusterConfig::set_heartbeat_interval(100); ClusterConfig::set_heartbeat_interval(100);
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100);

View File

@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
} }
TEST_F(TestCommUtil, ValidateRankId) { TEST_F(TestCommUtil, ValidateRankId) {
ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999); ClusterConfig::Init(3, 2, "127.0.0.1", 9999);
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));

View File

@ -35,7 +35,7 @@ class TestTcpMessageHandler : public UT::Common {
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) { TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
std::string data(1000, 'a'); std::string data(1000, 'a');
CommMessage message; CommMessage message;
@ -55,7 +55,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) { TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
std::string data(1000, 'a'); std::string data(1000, 'a');
CommMessage message; CommMessage message;
@ -86,7 +86,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) { TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); }); handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); });
std::string data(4081, 'a'); std::string data(4081, 'a');
CommMessage message; CommMessage message;
@ -126,7 +126,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) { TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); }); handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); });
std::string data(4077, 'a'); std::string data(4077, 'a');
CommMessage message; CommMessage message;

View File

@ -32,12 +32,12 @@ class TestTcpServer : public UT::Common {
void SetUp() override { void SetUp() override {
server_ = std::make_unique<TcpServer>("127.0.0.1", 0); server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
std::unique_ptr<std::thread> http_server_thread_(nullptr); std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([&]() { http_server_thread_ = std::make_unique<std::thread>([=]() {
server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
KVMessage kv_message; KVMessage kv_message;
kv_message.ParseFromString(message.data()); kv_message.ParseFromString(message->data());
EXPECT_EQ(2, kv_message.keys_size()); EXPECT_EQ(2, kv_message.keys_size());
const_cast<TcpServer&>(server).SendMessage(conn, message); server_->SendMessage(conn, message);
}); });
server_->Init(); server_->Init();
server_->Start(); server_->Start();
@ -58,6 +58,7 @@ class TestTcpServer : public UT::Common {
TEST_F(TestTcpServer, ServerSendMessage) { TEST_F(TestTcpServer, ServerSendMessage) {
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort()); client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort());
std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr); std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() { http_client_thread = std::make_unique<std::thread>([&]() {
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {