!10946 Switch bare pointer to shared_ptr

From: @anancds
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-11 14:12:46 +08:00 committed by Gitee
commit 6eb4634ba2
26 changed files with 261 additions and 176 deletions

View File

@ -75,6 +75,8 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &
auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message);
}
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
@ -126,11 +128,13 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *output, const uint32_t &timeout) {
std::string *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
@ -141,7 +145,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
*output = res[rank_id];
*output = res[rank_id].data();
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
@ -157,11 +161,13 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage> *output,
const std::vector<std::string> &data, std::vector<std::string> *output,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output);
uint64_t request_id = ++next_request_id_;
@ -177,7 +183,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
(*output).push_back(res[rank_ids.at(it)]);
(*output).push_back(res[rank_ids.at(it)].data());
}
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
@ -201,6 +207,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
@ -215,7 +223,7 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
}
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
const std::string &message, const uint32_t &timeout) {
const std::string &message) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -233,19 +241,19 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
}
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, CommMessage *output) {
const uint32_t &rank_id, std::string *output) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
*output = received_data_[std::make_pair(rank_id, rank_request_id)].data();
received_data_.erase(std::make_pair(rank_id, rank_request_id));
} else {
set_receive_callback(rank_id, rank_request_id, [=]() {
receive_callbacks_mutex_.lock();
*output = received_data_[std::make_pair(rank_id, 1)];
*output = received_data_[std::make_pair(rank_id, 1)].data();
received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_callbacks_mutex_.unlock();
});
@ -272,13 +280,25 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
<< " begin send heartbeat to the scheduler!";
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
while (!is_finish_.load()) {
Heartbeat(client);
if (!Heartbeat(client)) {
MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
if (!CheckSchedulerTimeout() && on_node_event_message_) {
MS_LOG(ERROR) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
is_finish_ = true;
wait_finish_cond_.notify_all();
on_node_event_message_(NodeEvent::SCHEDULER_TIMEOUT);
}
} else {
UpdateSchedulerTime();
}
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
}
});
}
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT);
@ -292,11 +312,31 @@ void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n
if (!SendMessageSync(client, message)) {
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
}
return true;
}
void AbstractNode::UpdateSchedulerTime() {
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
scheduler_time_ = current_time;
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
<< " update scheduler time, the current time is: " << current_time.tv_sec;
}
bool AbstractNode::CheckSchedulerTimeout() const {
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
if (scheduler_time_.tv_sec + ClusterConfig::scheduler_timeout() < current_time.tv_sec) {
return true;
}
return false;
}
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromString(message.data());
is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) {
wait_start_cond_.notify_all();
@ -353,9 +393,9 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui
*message.mutable_pb_meta() = {meta};
message.set_data(finish_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Disconnect timeout!";
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
}
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
return WaitForDisconnect(timeout);
}
@ -444,6 +484,8 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
@ -452,6 +494,8 @@ uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return request_id;
}
@ -460,6 +504,8 @@ void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
const MessageMeta &message_meta = message.pb_meta();
const uint32_t &rank_id = message_meta.rank_id();
const uint64_t request_id = message_meta.request_id();
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
auto it = receive_messages_.find(request_id);
if (it != receive_messages_.end()) {
it->second[rank_id] = message;

View File

@ -42,23 +42,24 @@ class AbstractNode : public Node {
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output,
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
CommMessage *output);
std::string *output);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const;
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
@ -113,6 +114,7 @@ class AbstractNode : public Node {
// the key is rank_id, the value is rank_id's actual request_id
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
std::mutex rank_request_ids_mutex;
timeval scheduler_time_;
};
} // namespace core
} // namespace ps

View File

@ -33,15 +33,17 @@ uint32_t ClusterConfig::heartbeat_timeout_ = 30;
uint32_t ClusterConfig::cluster_available_timeout_ = 300;
// The timeout period for the client to connect to the server is 100ms.
uint32_t ClusterConfig::connect_interval_ = 100;
// When the scheduler exits, the worker and server can continue to work for 5 hours
uint32_t ClusterConfig::scheduler_timeout_ = 3600 * 5;
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num,
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) {
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
const uint16_t &scheduler_port) {
worker_num_ = worker_num;
server_num_ = server_num;
if (!CommUtil::CheckIp(*scheduler_host.get())) {
MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!";
if (!CommUtil::CheckIp(scheduler_host)) {
MS_LOG(EXCEPTION) << "The scheduler_host:" << scheduler_host << " is illegal!";
}
scheduler_host_ = std::move(scheduler_host);
scheduler_host_ = std::make_unique<std::string>(scheduler_host);
scheduler_port_ = scheduler_port;
}
@ -55,7 +57,7 @@ void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) {
heartbeat_interval_ = heartbeat_interval;
}
std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); }
std::string ClusterConfig::scheduler_host() { return *scheduler_host_; }
uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; }
@ -74,6 +76,10 @@ void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_availa
uint32_t ClusterConfig::connect_interval() { return connect_interval_; }
void ClusterConfig::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; }
uint32_t ClusterConfig::scheduler_timeout() { return scheduler_timeout_; }
void ClusterConfig::set_scheduler_timeout(const uint32_t &scheduler_timeout) { scheduler_timeout_ = scheduler_timeout; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -30,7 +30,7 @@ namespace ps {
namespace core {
class ClusterConfig {
public:
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host,
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
const uint16_t &scheduler_port);
static uint32_t worker_num();
static uint32_t server_num();
@ -44,6 +44,8 @@ class ClusterConfig {
static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout);
static uint32_t connect_interval();
static void set_connect_interval(const uint32_t &connect_interval);
static uint32_t scheduler_timeout();
static void set_scheduler_timeout(const uint32_t &scheduler_timeout);
private:
static uint32_t worker_num_;
@ -54,6 +56,7 @@ class ClusterConfig {
static uint32_t heartbeat_timeout_;
static uint32_t cluster_available_timeout_;
static uint32_t connect_interval_;
static uint32_t scheduler_timeout_;
};
} // namespace core
} // namespace ps

View File

@ -21,7 +21,12 @@ namespace ps {
namespace core {
std::string Node::node_id() const { return node_info_.node_id_; }
uint32_t Node::rank_id() const { return node_info_.rank_id_; }
uint32_t Node::rank_id() const {
if (!is_ready_.load()) {
MS_LOG(EXCEPTION) << "The cluster is not ready yet to get rank id!";
}
return node_info_.rank_id_;
}
NodeRole Node::role() const { return node_info_.node_role_; }

View File

@ -30,8 +30,6 @@
#include <utility>
#include <tuple>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/node_info.h"
#include "ps/core/tcp_client.h"

View File

@ -25,7 +25,7 @@
namespace mindspore {
namespace ps {
namespace core {
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 };
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT };
struct NodeInfo {
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}

View File

@ -64,7 +64,7 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
heartbeats_[node_id] = current_time;
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
}

View File

@ -31,8 +31,6 @@
#include <condition_variable>
#include <unordered_set>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/node.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"

View File

@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME;
enum PSCommand {
PUSH = 0;
PULL = 1;
INIT_EMBEDDING_TABLE = 2;
}
message KVMessage {

View File

@ -37,9 +37,10 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
return true;
}
void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
HeartbeatMessage heartbeat_message;
heartbeat_message.ParseFromString(message.data());
heartbeat_message.ParseFromString(message->data());
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
@ -59,10 +60,10 @@ void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnectio
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
comm_message.set_data(heartbeat_resp_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(heartbeat_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
}
void SchedulerNode::Initialize() {
@ -79,23 +80,23 @@ void SchedulerNode::CreateTcpServer() {
std::string scheduler_host = ClusterConfig::scheduler_host();
uint32_t scheduler_port = ClusterConfig::scheduler_port();
server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
switch (message->pb_meta().cmd()) {
case NodeCommand::HEARTBEAT:
ProcessHeartbeat(server, conn, message);
ProcessHeartbeat(server_, conn, message);
break;
case NodeCommand::REGISTER:
ProcessRegister(server, conn, message);
ProcessRegister(server_, conn, message);
break;
case NodeCommand::FINISH:
ProcessFinish(server, conn, message);
ProcessFinish(server_, conn, message);
break;
case NodeCommand::FETCH_SERVER:
ProcessFetchServers(server, conn, message);
ProcessFetchServers(server_, conn, message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
}
});
@ -107,10 +108,11 @@ void SchedulerNode::CreateTcpServer() {
});
}
void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message;
register_message.ParseFromString(message.data());
register_message.ParseFromString(message->data());
// assign worker node and server node rank id
int rank_id = node_manager_.NextRankId(register_message);
@ -124,31 +126,32 @@ void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection
register_resp_message.set_node_id(node_id);
register_resp_message.set_rank_id(rank_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
comm_message.set_data(register_resp_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(register_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
}
void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
FinishMessage finish_message;
finish_message.ParseFromString(message.data());
finish_message.ParseFromString(message->data());
node_manager_.AddFinishNode(finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
const_cast<TcpServer &>(server).SendMessage(conn, message);
server->SendMessage(conn, message);
}
void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn,
const CommMessage &message) {
void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
FetchServersRespMessage fetch_servers_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
comm_message.set_data(fetch_servers_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(fetch_servers_message.SerializeAsString());
server->SendMessage(conn, comm_message);
}
void SchedulerNode::StartUpdateClusterStateTimer() {

View File

@ -26,8 +26,6 @@
#include <thread>
#include <mutex>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
@ -51,13 +49,17 @@ class SchedulerNode : public Node {
private:
void Initialize();
void CreateTcpServer();
void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
void StartUpdateClusterStateTimer();
void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::unique_ptr<TcpServer> server_;
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_;
std::unique_ptr<std::thread> update_state_thread_;

View File

@ -30,7 +30,8 @@ bool ServerNode::Start(const uint32_t &timeout) {
StartHeartbeatTimer(client_to_scheduler_);
if (!WaitForStart(timeout)) {
MS_LOG(ERROR) << "Start Server node timeout!";
MS_LOG(ERROR) << "Start server node timeout!";
return false;
}
MS_LOG(INFO) << "The cluster is ready to use!";
@ -45,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) {
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
void ServerNode::Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
const std::string &message) {
auto &meta = const_cast<MessageMeta &>(message_meta);
meta.set_role(node_info_.node_role_);
meta.set_rank_id(node_info_.rank_id_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {meta};
comm_message.set_data(message);
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
message->mutable_pb_meta()->set_role(node_info_.node_role_);
message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_);
const MessageMeta &message_meta = message->pb_meta();
const uint64_t request_id = message_meta.request_id();
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
server_->SendMessage(conn, message);
}
void ServerNode::CreateTcpServer() {
@ -62,17 +63,17 @@ void ServerNode::CreateTcpServer() {
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
switch (message->pb_meta().cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendData(server, conn, message);
ProcessSendData(conn, message);
break;
case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(server, conn, message);
RunReceiveCallback(message);
ProcessCollectiveSendData(conn, message);
RunReceiveCallback(*message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
}
});
server_->Init();
@ -97,15 +98,18 @@ void ServerNode::Initialize() {
MS_LOG(INFO) << "Server node init client successful!";
}
void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
request_handler_(server, conn, message.pb_meta(), message.data());
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
request_handler_(conn, message);
}
void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn,
const CommMessage &message) {
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
server_->SendMessage(conn, comm_message);
}
bool ServerNode::Stop() {

View File

@ -44,18 +44,16 @@ class ServerNode : public AbstractNode {
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta,
const std::string &message)>;
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
void set_handler(const RequestHandler &handler);
void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
const std::string &message);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
private:
void CreateTcpServer();
void Initialize();
void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;

View File

@ -46,9 +46,9 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
server_port_(port),
is_stop_(true),
is_connected_(false) {
message_handler_.SetCallback([this](const CommMessage &message) {
message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) {
if (message_callback_) {
message_callback_(*this, message);
message_callback_(*this, *message);
}
});
}
@ -105,7 +105,7 @@ void TcpClient::Init() {
sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
sin.sin_port = htons(server_port_);
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE);
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
@ -261,17 +261,23 @@ void TcpClient::StartWithNoBlock() {
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
void TcpClient::SendMessage(const CommMessage &message) const {
bool TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_lock(buffer_event_);
bool res = true;
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}
void TcpClient::StartTimer(const uint32_t &time) {

View File

@ -33,8 +33,6 @@
#include <condition_variable>
#include "ps/core/cluster_config.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
@ -62,7 +60,7 @@ class TcpClient {
void Start();
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
void SendMessage(const CommMessage &message) const;
bool SendMessage(const CommMessage &message) const;
void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer);
const event_base &eventbase();

View File

@ -57,8 +57,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
}
if (remaining_length_ == 0) {
CommMessage pb_message;
pb_message.ParseFromArray(message_buffer_.get(), message_length_);
std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>();
pb_message->ParseFromArray(message_buffer_.get(), message_length_);
if (message_callback_) {
message_callback_(pb_message);
}

View File

@ -30,7 +30,7 @@
namespace mindspore {
namespace ps {
namespace core {
using messageReceive = std::function<void(const CommMessage &message)>;
using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>;
constexpr int kHeaderLen = 8;
class TcpMessageHandler {

View File

@ -32,14 +32,7 @@
namespace mindspore {
namespace ps {
namespace core {
void TcpConnection::InitConnection() {
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
if (on_server_receive) {
on_server_receive(*server_, *this, message);
}
});
}
void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }
void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); }
@ -49,23 +42,30 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const {
}
}
TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); }
TcpServer *TcpConnection::GetServer() const { return server_; }
const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
void TcpConnection::SendMessage(const CommMessage &message) const {
void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
size_t buf_size = message.ByteSizeLong();
MS_EXCEPTION_IF_NULL(message);
bufferevent_lock(buffer_event_);
bool res = true;
size_t buf_size = message->ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
message->SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(),
buf_size) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!";
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}
TcpServer::TcpServer(const std::string &address, std::uint16_t port)
@ -225,7 +225,7 @@ void TcpServer::SendToAllClients(const char *data, size_t len) {
}
}
void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) {
void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) {
MS_EXCEPTION_IF_NULL(connection);
std::lock_guard<std::mutex> lock(connection_mutex_);
connections_.insert(std::make_pair(fd, connection));
@ -233,11 +233,11 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co
void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::lock_guard<std::mutex> lock(connection_mutex_);
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
delete connection;
connections_.erase(fd);
}
std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; }
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
void *data) {
auto server = reinterpret_cast<class TcpServer *>(data);
@ -246,7 +246,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
MS_EXCEPTION_IF_NULL(base);
MS_EXCEPTION_IF_NULL(sockaddr);
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE);
struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
if (!bev) {
MS_LOG(ERROR) << "Error constructing buffer event!";
int ret = event_base_loopbreak(base);
@ -256,23 +256,29 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
return;
}
TcpConnection *conn = server->onCreateConnection(bev, fd);
std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd);
MS_EXCEPTION_IF_NULL(conn);
conn->InitConnection();
server->AddConnection(fd, conn);
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
conn->InitConnection([=](std::shared_ptr<CommMessage> message) {
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
if (on_server_receive) {
on_server_receive(conn, message);
}
});
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
reinterpret_cast<void *>(conn.get()));
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
}
}
TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
TcpConnection *conn = nullptr;
std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
std::shared_ptr<TcpConnection> conn = nullptr;
if (client_accept_) {
conn = const_cast<TcpConnection *>(client_accept_(*this));
conn = (client_accept_(*this));
} else {
conn = new TcpConnection(bev, fd, this);
conn = std::make_shared<TcpConnection>(bev, fd, this);
}
return conn;
@ -312,8 +318,8 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
MS_EXCEPTION_IF_NULL(data);
struct evbuffer *output = bufferevent_get_output(bev);
size_t remain = evbuffer_get_length(output);
auto conn = reinterpret_cast<TcpConnection *>(data);
TcpServer *srv = conn->GetServer();
auto conn = static_cast<class TcpConnection *>(data);
auto srv = conn->GetServer();
if (events & BEV_EVENT_EOF) {
MS_LOG(INFO) << "Event buffer end of file!";
@ -355,13 +361,18 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
}
}
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
return conn->SendMessage(message);
}
void TcpServer::SendMessage(const CommMessage &message) {
void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_EXCEPTION_IF_NULL(message);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(*it->second, message);
SendMessage(it->second, message);
}
}
@ -371,7 +382,7 @@ std::string TcpServer::BoundIp() const { return server_address_; }
int TcpServer::ConnectionNum() const { return connections_.size(); }
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }
const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }

View File

@ -34,8 +34,6 @@
#include <thread>
#include <atomic>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
@ -47,36 +45,42 @@ namespace core {
class TcpServer;
class TcpConnection {
public:
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server)
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, TcpServer *server)
: buffer_event_(bev), fd_(fd), server_(server) {}
TcpConnection(const TcpConnection &);
virtual ~TcpConnection() = default;
virtual void InitConnection();
using Callback = std::function<void(const std::shared_ptr<CommMessage>)>;
virtual void InitConnection(const messageReceive &callback);
virtual void SendMessage(const void *buffer, size_t num) const;
void SendMessage(const CommMessage &message) const;
bool SendMessage(std::shared_ptr<CommMessage> message) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const;
const evutil_socket_t &GetFd() const;
void set_callback(const Callback &callback);
protected:
struct bufferevent *buffer_event_;
evutil_socket_t fd_;
const TcpServer *server_;
TcpServer *server_;
TcpMessageHandler tcp_message_handler_;
Callback callback_;
};
using OnServerReceiveMessage =
std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>;
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
class TcpServer {
public:
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
using OnAccepted = std::function<std::shared_ptr<TcpConnection>(const TcpServer &)>;
using OnTimerOnce = std::function<void(const TcpServer &)>;
using OnTimer = std::function<void()>;
explicit TcpServer(const std::string &address, std::uint16_t port);
TcpServer(const std::string &address, std::uint16_t port);
TcpServer(const TcpServer &server);
virtual ~TcpServer();
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
@ -90,16 +94,17 @@ class TcpServer {
void StartTimer(const uint32_t &time);
void Stop();
void SendToAllClients(const char *data, size_t len);
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
void AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection);
void RemoveConnection(const evutil_socket_t &fd);
std::shared_ptr<TcpConnection> GetConnectionByFd(const evutil_socket_t &fd);
OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb);
void SendMessage(const TcpConnection &conn, const CommMessage &message);
void SendMessage(const CommMessage &message);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void SendMessage(std::shared_ptr<CommMessage> message);
uint16_t BoundPort() const;
std::string BoundIp() const;
int ConnectionNum() const;
const std::map<evutil_socket_t, const TcpConnection *> &Connections() const;
const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &Connections() const;
protected:
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
@ -109,7 +114,7 @@ class TcpServer {
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg);
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
std::shared_ptr<TcpConnection> onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
struct event_base *base_;
struct event *signal_event_;
@ -118,7 +123,7 @@ class TcpServer {
std::uint16_t server_port_;
std::atomic<bool> is_stop_;
std::map<evutil_socket_t, const TcpConnection *> connections_;
std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> connections_;
OnConnected client_connection_;
OnDisconnected client_disconnection_;
OnAccepted client_accept_;

View File

@ -24,8 +24,6 @@
#include <utility>
#include <algorithm>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"

View File

@ -31,7 +31,7 @@ class TestClusterAvailableTimeout : public UT::Common {
};
TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999);
ClusterConfig::Init(1, 1, "127.0.0.1", 9999);
ClusterConfig::set_cluster_available_timeout(3);
SchedulerNode node;
node.Start();

View File

@ -33,7 +33,7 @@ class TestClusterConfig : public UT::Common {
};
TEST_F(TestClusterConfig, HeartbeatInterval) {
ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080);
ClusterConfig::Init(2, 2, "127.0.0.1", 8080);
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3);
ClusterConfig::set_heartbeat_interval(100);
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100);

View File

@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
}
TEST_F(TestCommUtil, ValidateRankId) {
ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999);
ClusterConfig::Init(3, 2, "127.0.0.1", 9999);
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));

View File

@ -35,7 +35,7 @@ class TestTcpMessageHandler : public UT::Common {
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
std::string data(1000, 'a');
CommMessage message;
@ -55,7 +55,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
std::string data(1000, 'a');
CommMessage message;
@ -86,7 +86,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); });
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); });
std::string data(4081, 'a');
CommMessage message;
@ -126,7 +126,7 @@ TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); });
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); });
std::string data(4077, 'a');
CommMessage message;

View File

@ -32,12 +32,12 @@ class TestTcpServer : public UT::Common {
void SetUp() override {
server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([&]() {
server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
http_server_thread_ = std::make_unique<std::thread>([=]() {
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
KVMessage kv_message;
kv_message.ParseFromString(message.data());
kv_message.ParseFromString(message->data());
EXPECT_EQ(2, kv_message.keys_size());
const_cast<TcpServer&>(server).SendMessage(conn, message);
server_->SendMessage(conn, message);
});
server_->Init();
server_->Start();
@ -58,6 +58,7 @@ class TestTcpServer : public UT::Common {
TEST_F(TestTcpServer, ServerSendMessage) {
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort());
std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {