Custom data transmission format

This commit is contained in:
chendongsheng 2021-01-21 14:51:17 +08:00
parent 14a6713d08
commit c7fe82b43d
17 changed files with 634 additions and 393 deletions

View File

@ -20,8 +20,9 @@ namespace mindspore {
namespace ps {
namespace core {
void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::REGISTER);
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::REGISTER);
RegisterMessage register_message;
register_message.set_node_id(node_info_.node_id_);
@ -29,11 +30,8 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
register_message.set_ip(node_info_.ip_);
register_message.set_port(node_info_.port_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
comm_message.set_user_cmd("");
if (!SendMessageSync(client, comm_message)) {
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
}
@ -42,9 +40,11 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
<< " the node id:" << node_info_.node_id_ << "is registering to scheduler!";
}
void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
RegisterRespMessage register_resp_message;
register_resp_message.ParseFromString(message.data());
register_resp_message.ParseFromArray(data, size);
if (register_resp_message.node_id() != node_info_.node_id_) {
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
<< " is not match the current node id:" << node_info_.node_id_;
@ -52,28 +52,29 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
node_info_.rank_id_ = register_resp_message.rank_id();
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_
<< " registered scheduler success!";
}
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) {
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(message);
if (node_role != NodeRole::SERVER) {
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
uint64_t request_id = AddMessageTrack(nodes_address_.size());
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);
*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, message.get(), size);
}
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
@ -84,28 +85,27 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
on_node_event_message_ = on_node_event_message;
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
const uint32_t &timeout) {
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageSync(client, comm_message, timeout);
return SendMessageSync(client, message_meta, Protos::RAW, data.get(), len, timeout);
}
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<CommMessage> &data, const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
const uint32_t &timeout) {
uint64_t request_id = AddMessageTrack(data.size());
if (rank_ids.size() != data.size()) {
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
@ -115,34 +115,32 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
*comm_message.mutable_pb_meta() = {message_meta};
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);
auto send = data.at(it);
auto len = lens.at(it);
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, send.get(), len);
}
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
CommMessage *output, const uint32_t &timeout) {
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
int command, VectorPtr *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(message);
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
uint64_t request_id = AddMessageTrack(1);
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
@ -151,59 +149,59 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
receive_messages_mutex_.unlock();
});
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);
*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, message.get(), len);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<CommMessage> &data, std::vector<CommMessage> *output,
const uint32_t &timeout) {
const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
std::vector<VectorPtr> *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
uint64_t request_id = AddMessageTrack(data.size());
if (rank_ids.size() != data.size()) {
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
}
size_t len = rank_ids.size();
size_t size = rank_ids.size();
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
for (size_t it = 0; it < size; ++it) {
(*output).push_back(res[rank_ids.at(it)]);
}
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
for (size_t it = 0; it < len; ++it) {
for (size_t it = 0; it < size; ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
*comm_message.mutable_pb_meta() = {message_meta};
auto send = data.at(it);
auto len = data_lens.at(it);
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, send.get(), len);
}
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
@ -220,55 +218,61 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
return res;
}
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
const CommMessage &message) {
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageAsync(client, comm_message);
return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
}
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, CommMessage *output) {
const uint32_t &rank_id, void **output,
size_t *size) {
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(size);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
receive_callbacks_mutex_.lock();
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = false;
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
auto res = received_data_[std::make_pair(rank_id, rank_request_id)];
*output = res->data();
*size = res->size();
received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
} else {
set_receive_callback(rank_id, rank_request_id, [=]() {
receive_callbacks_[std::make_pair(rank_id, rank_request_id)] = [=]() mutable {
receive_callbacks_mutex_.lock();
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
auto res = received_data_[std::make_pair(rank_id, rank_request_id)];
*output = res->data();
*size = res->size();
received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
receive_callbacks_mutex_.unlock();
});
};
}
receive_callbacks_mutex_.unlock();
return std::make_pair(rank_id, rank_request_id);
}
bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
if (actual_rank_request_ids_.count(request_id.first) &&
(actual_rank_request_ids_[request_id.first] >= request_id.second)) {
return true;
} else {
return false;
}
});
bool res =
receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
return res;
}
@ -297,17 +301,15 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
}
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT);
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::HEARTBEAT);
HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_);
heartbeat_message.set_is_node_finish(is_node_finish);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(heartbeat_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
heartbeat_message.ByteSizeLong())) {
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
}
return true;
@ -331,9 +333,11 @@ bool AbstractNode::CheckSchedulerTimeout() const {
return false;
}
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromString(message.data());
heartbeat_resp_message.ParseFromArray(data, size);
is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) {
@ -359,19 +363,22 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
}
void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FETCH_SERVER);
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::FETCH_SERVER);
CommMessage message;
*message.mutable_pb_meta() = {meta};
if (!SendMessageSync(client, message)) {
FetchServersMessage fetch_servers;
fetch_servers.set_node_id(node_info_.node_id_);
if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(),
fetch_servers.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
}
}
void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
FetchServersRespMessage fetch_servers_resp_message;
fetch_servers_resp_message.ParseFromString(message.data());
fetch_servers_resp_message.ParseFromArray(data, size);
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
@ -381,16 +388,14 @@ void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
}
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FINISH);
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::FINISH);
FinishMessage finish_message;
finish_message.set_node_id(node_info_.node_id_);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(finish_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
if (!SendMessageSync(client, meta, Protos::PROTOBUF, finish_message.SerializeAsString().data(),
finish_message.ByteSizeLong())) {
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
}
@ -412,16 +417,17 @@ bool AbstractNode::InitClientToScheduler() {
std::string scheduler_host = ClusterConfig::scheduler_host();
uint16_t scheduler_port = ClusterConfig::scheduler_port();
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
if (handlers_.count(message.pb_meta().cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
if (handlers_[message.pb_meta().cmd()] != nullptr) {
const auto &handler_ptr = handlers_[message.pb_meta().cmd()];
(this->*handler_ptr)(message);
}
NotifyMessageArrival(message);
});
client_to_scheduler_->SetMessageCallback(
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
if (handlers_[meta->cmd()] != nullptr) {
const auto &handler_ptr = handlers_[meta->cmd()];
(this->*handler_ptr)(meta, data, size);
}
NotifyMessageArrival(meta);
});
client_to_scheduler_->Init();
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
@ -447,19 +453,20 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
auto client = std::make_shared<TcpClient>(ip, port);
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
switch (meta->cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendDataResp(message);
RunMessageCallback(message.pb_meta().request_id());
ProcessSendDataResp(meta, protos, data, size);
RunMessageCallback(meta->request_id());
break;
case NodeCommand::COLLECTIVE_SEND_DATA:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
NotifyMessageArrival(message);
NotifyMessageArrival(meta);
});
client->Init();
connected_nodes_[rank_id] = client;
@ -469,8 +476,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
uint64_t request_id = AddMessageTrack(1);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@ -478,29 +484,55 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con
return Wait(request_id, timeout);
}
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return request_id;
}
void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
bool res = Wait(request_id, timeout);
return res;
}
void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
const MessageMeta &message_meta = message.pb_meta();
const uint32_t &rank_id = message_meta.rank_id();
const uint64_t request_id = message_meta.request_id();
const uint32_t &rank_id = meta->rank_id();
const uint64_t request_id = meta->request_id();
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
auto it = receive_messages_.find(request_id);
VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
if (size > 0) {
int ret = memcpy_s(received_data.get()->data(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
}
if (it != receive_messages_.end()) {
it->second[rank_id] = message;
it->second[rank_id] = received_data;
} else {
std::unordered_map<uint32_t, CommMessage> res;
res.insert(std::make_pair(rank_id, message));
std::unordered_map<uint32_t, VectorPtr> res;
res.insert(std::make_pair(rank_id, received_data));
receive_messages_[request_id] = res;
}
}
@ -509,7 +541,7 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) {
message_callbacks_mutex_.lock();
// When receiving a message's response, Then compare with the desired number of responses,
// If they are equal, then call the callback function
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
if (CheckMessageTrack(request_id)) {
auto it = message_callbacks_.find(request_id);
if (it != message_callbacks_.end()) {
message_callbacks_mutex_.unlock();
@ -533,31 +565,31 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag
message_callbacks_[request_id] = callback;
}
void AbstractNode::NotifyMessageArrival(const CommMessage &message) {
void AbstractNode::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
const MessageMeta &message_meta = message.pb_meta();
uint64_t request_id = message_meta.request_id();
uint64_t request_id = meta->request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id,
const MessageCallback &callback) {
if (!callback) {
return;
}
std::lock_guard<std::mutex> lock(receive_callbacks_mutex_);
receive_callbacks_[std::make_pair(rank_id, request_id)] = callback;
}
void AbstractNode::RunReceiveCallback(const CommMessage &message) {
void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
receive_callbacks_mutex_.lock();
uint32_t rank_id = message.pb_meta().rank_id();
uint32_t rank_id = meta->rank_id();
// When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
// If they are equal, then call the callback function
uint64_t rank_request_id = NextActualRankRequestId(rank_id);
received_data_[std::make_pair(rank_id, rank_request_id)] = message;
std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
int ret = memcpy_s(received_data->data(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
<< ", the send request id is:" << meta->request_id();
auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
if (it != receive_callbacks_.end()) {
receive_callbacks_mutex_.unlock();
@ -603,6 +635,18 @@ void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
handlers_[NodeCommand::FINISH] = nullptr;
}
uint64_t AbstractNode::AddMessageTrack(const uint32_t &expected_response) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(expected_response, 0);
return request_id;
}
bool AbstractNode::CheckMessageTrack(const uint64_t &request_id) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -25,6 +25,7 @@
#include <unordered_map>
#include "ps/core/node.h"
#include "ps/core/message.h"
namespace mindspore {
namespace ps {
@ -34,53 +35,63 @@ class AbstractNode : public Node {
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
~AbstractNode() override = default;
typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message);
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
bool Broadcast(const enum NodeRole &node_role, const CommMessage &message,
using DataPtr = std::shared_ptr<unsigned char>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
CommMessage *output);
void **output, size_t *size);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const;
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
void ProcessSendDataResp(const CommMessage &message);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size);
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
void NotifyMessageArrival(const CommMessage &message);
void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback);
void RunReceiveCallback(const CommMessage &message);
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler();
uint64_t AddMessageTrack(const uint32_t &expected_response);
bool CheckMessageTrack(const uint64_t &request_id);
std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
@ -98,15 +109,16 @@ class AbstractNode : public Node {
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
// the key is: request_id, the value is:<rank_id, CommMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
// the key is: request_id, the value is: <rank_id, RecvMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;
std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_;
std::mutex receive_messages_mutex_;
// the key is: request_id
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_;
// the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_;
std::map<std::pair<uint32_t, uint64_t>, std::shared_ptr<std::vector<unsigned char>>> received_data_;
std::mutex receive_callbacks_mutex_;
// the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
#define MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
#include <string>
#include <memory>
namespace mindspore {
namespace ps {
namespace core {
enum class Protos : uint32_t { RAW = 0, PROTOBUF = 1, FLATBUFFERS = 2 };
enum class Command {
TERMINATE = 0,
REGISTER = 1,
HEARTBEAT = 2,
SEND_DATA = 3,
FETCH_SERVER = 4,
FINISH = 5,
COLLECTIVE_SEND_DATA = 6
};
enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 };
struct MessageHeader {
Protos message_proto_ = Protos::RAW;
uint32_t message_meta_length_ = 0;
uint64_t message_length_ = 0;
};
struct CommandMeta {
// the command of this message,for example: register,heartbeat,data
Command cmd;
// the request id of this message
uint64_t request_id;
// the role of the current node: worker,server,scheduler
Role role;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32_t rank_id = 4;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_

View File

@ -15,7 +15,6 @@
*/
syntax = "proto3";
import "google/protobuf/any.proto";
package mindspore.ps.core;
option optimize_for = LITE_RUNTIME;
@ -44,6 +43,8 @@ message MessageMeta {
NodeRole role = 3;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32 rank_id = 4;
// User-defined commands
int32 user_cmd = 5;
}
message RegisterMessage {
@ -76,6 +77,10 @@ message HeartbeatRespMessage {
bool is_node_timeout = 4;
}
message FetchServersMessage {
string node_id = 1;
}
message FetchServersRespMessage {
repeated ServersMeta servers_meta = 1;
}
@ -95,6 +100,4 @@ message FinishMessage {
message CommMessage {
MessageMeta pb_meta = 1;
bytes data = 2;
// User-defined commands
bytes user_cmd = 3;
}

View File

@ -38,9 +38,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
}
void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
HeartbeatMessage heartbeat_message;
heartbeat_message.ParseFromString(message->data());
heartbeat_message.ParseFromArray(data, size);
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
@ -60,10 +64,8 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(heartbeat_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
heartbeat_resp_message.ByteSizeLong());
}
void SchedulerNode::Initialize() {
@ -89,12 +91,13 @@ void SchedulerNode::CreateTcpServer() {
std::string scheduler_host = ClusterConfig::scheduler_host();
uint32_t scheduler_port = ClusterConfig::scheduler_port();
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
if (handlers_.count(message->pb_meta().cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
const auto &handler_ptr = handlers_[message->pb_meta().cmd()];
(this->*handler_ptr)(server_, conn, message);
const auto &handler_ptr = handlers_[meta->cmd()];
(this->*handler_ptr)(server_, conn, meta, data, size);
});
server_->Init();
@ -106,10 +109,14 @@ void SchedulerNode::CreateTcpServer() {
}
void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message;
register_message.ParseFromString(message->data());
register_message.ParseFromArray(data, size);
// assign worker node and server node rank id
int rank_id = node_manager_.NextRankId(register_message);
@ -123,32 +130,32 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
register_resp_message.set_node_id(node_id);
register_resp_message.set_rank_id(rank_id);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(register_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
register_resp_message.ByteSizeLong());
}
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
FinishMessage finish_message;
finish_message.ParseFromString(message->data());
finish_message.ParseFromArray(data, size);
node_manager_.AddFinishNode(finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
server->SendMessage(conn, message);
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
}
void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
FetchServersRespMessage fetch_servers_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(fetch_servers_message.SerializeAsString());
server->SendMessage(conn, comm_message);
server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(),
fetch_servers_message.ByteSizeLong());
}
void SchedulerNode::StartUpdateClusterStateTimer() {

View File

@ -36,13 +36,14 @@
namespace mindspore {
namespace ps {
namespace core {
class SchedulerNode : public Node {
public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
~SchedulerNode() override;
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
@ -53,14 +54,14 @@ class SchedulerNode : public Node {
void InitCommandHandler();
void CreateTcpServer();
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void StartUpdateClusterStateTimer();
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_;

View File

@ -46,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) {
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data,
size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
message->mutable_pb_meta()->set_role(node_info_.node_role_);
message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_);
const MessageMeta &message_meta = message->pb_meta();
const uint64_t request_id = message_meta.request_id();
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
meta->set_role(node_info_.node_role_);
meta->set_rank_id(node_info_.rank_id_);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
server_->SendMessage(conn, message);
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
server_->SendMessage(conn, meta, Protos::RAW, data.get(), size);
}
void ServerNode::CreateTcpServer() {
@ -63,17 +63,18 @@ void ServerNode::CreateTcpServer() {
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
switch (message->pb_meta().cmd()) {
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
switch (meta->cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendData(conn, message);
ProcessSendData(conn, meta, protos, data, size);
break;
case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(conn, message);
RunReceiveCallback(*message);
ProcessCollectiveSendData(conn, meta, data, size);
RunReceiveCallback(meta, protos, data, size);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
});
server_->Init();
@ -99,18 +100,24 @@ void ServerNode::Initialize() {
MS_LOG(INFO) << "Server node init client successful!";
}
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
request_handler_(conn, message);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::shared_ptr<unsigned char> res(new unsigned char[size]);
int ret = memcpy_s(res.get(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
request_handler_(conn, meta, res, size);
}
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
server_->SendMessage(conn, comm_message);
MS_EXCEPTION_IF_NULL(meta);
server_->SendMessage(conn, meta, Protos::RAW, data, size);
}
bool ServerNode::Stop() {

View File

@ -23,6 +23,7 @@
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
@ -41,16 +42,19 @@ class ServerNode : public AbstractNode {
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
DataPtr data, size_t size)>;
void set_handler(const RequestHandler &handler);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data, size_t size);
private:
void CreateTcpServer();
void Initialize();
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size);
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;

View File

@ -46,11 +46,12 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
server_port_(port),
is_stop_(true),
is_connected_(false) {
message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) {
if (message_callback_) {
message_callback_(*this, *message);
}
});
message_handler_.SetCallback(
[this](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
if (message_callback_) {
message_callback_(meta, protos, data, size);
}
});
}
TcpClient::~TcpClient() {
@ -189,7 +190,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
void TcpClient::OnReadHandler(const void *buf, size_t num) {
MS_EXCEPTION_IF_NULL(buf);
if (read_callback_) {
read_callback_(*this, buf, num);
read_callback_(buf, num);
}
message_handler_.ReceiveMessage(buf, num);
}
@ -198,7 +199,7 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
if (tcp_client->on_timer_callback_) {
tcp_client->on_timer_callback_(*tcp_client);
tcp_client->on_timer_callback_();
}
}
@ -245,7 +246,7 @@ void TcpClient::Start() {
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
}
void TcpClient::StartWithNoBlock() {
@ -256,7 +257,7 @@ void TcpClient::StartWithNoBlock() {
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
}
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
@ -265,14 +266,49 @@ bool TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_lock(buffer_event_);
bool res = true;
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
size_t buf_size = IntToUint(message.ByteSizeLong());
uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
Messageheader header;
header.message_proto_ = Protos::PROTOBUF;
header.message_length_ = buf_size;
header.message_meta_length_ = meta_size;
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}
bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
bufferevent_lock(buffer_event_);
bool res = true;
Messageheader header;
header.message_proto_ = protos;
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
header.message_length_ = size + header.message_meta_length_;
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, data, size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}

View File

@ -42,10 +42,10 @@ class TcpClient {
public:
using OnConnected = std::function<void()>;
using OnDisconnected = std::function<void()>;
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>;
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
using OnTimer = std::function<void(const TcpClient &)>;
using OnRead = std::function<void(const void *, size_t)>;
using OnTimeout = std::function<void()>;
using OnMessage = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
using OnTimer = std::function<void()>;
explicit TcpClient(const std::string &address, std::uint16_t port);
virtual ~TcpClient();
@ -61,6 +61,7 @@ class TcpClient {
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
bool SendMessage(const CommMessage &message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer);
const event_base &eventbase();

View File

@ -35,8 +35,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
header_[++header_index_] = *(buffer_data + i);
--num;
if (header_index_ == kHeaderLen - 1) {
message_length_ = *reinterpret_cast<const size_t *>(header_);
remaining_length_ = message_length_;
message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_);
message_header_.message_meta_length_ =
*reinterpret_cast<const uint32_t *>(header_ + sizeof(message_header_.message_proto_));
message_header_.message_length_ = *reinterpret_cast<const size_t *>(
header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_));
remaining_length_ = message_header_.message_length_;
message_buffer_.reset(new unsigned char[remaining_length_]);
buffer_data += (i + 1);
break;
@ -57,10 +61,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
}
if (remaining_length_ == 0) {
std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>();
pb_message->ParseFromArray(message_buffer_.get(), message_length_);
if (message_callback_) {
message_callback_(pb_message);
std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>();
pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_);
message_callback_(pb_message, message_header_.message_proto_,
message_buffer_.get() + message_header_.message_meta_length_,
message_header_.message_length_ - message_header_.message_meta_length_);
}
message_buffer_.reset();
message_buffer_ = nullptr;

View File

@ -24,24 +24,20 @@
#include <vector>
#include "utils/log_adapter.h"
#include "ps/core/message.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
namespace mindspore {
namespace ps {
namespace core {
using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>;
constexpr int kHeaderLen = 8;
using messageReceive = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
constexpr int kHeaderLen = 16;
class TcpMessageHandler {
public:
TcpMessageHandler()
: is_parsed_(false),
message_buffer_(nullptr),
message_length_(0),
remaining_length_(0),
header_index_(-1),
last_copy_len_(0) {}
: is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {}
virtual ~TcpMessageHandler() = default;
void SetCallback(const messageReceive &cb);
@ -51,11 +47,12 @@ class TcpMessageHandler {
messageReceive message_callback_;
bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_;
size_t message_length_;
size_t remaining_length_;
char header_[8];
char header_[16];
int header_index_;
size_t last_copy_len_;
MessageHeader message_header_;
std::string mBuffer;
};
} // namespace core
} // namespace ps

View File

@ -54,13 +54,39 @@ bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
bufferevent_lock(buffer_event_);
bool res = true;
size_t buf_size = message->ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message->SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}
bool TcpConnection::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
bufferevent_lock(buffer_event_);
bool res = true;
Messageheader header;
header.message_proto_ = protos;
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
header.message_length_ = size + header.message_meta_length_;
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, data, size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
@ -158,7 +184,7 @@ void TcpServer::Start() {
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
}
void TcpServer::StartWithNoBlock() {
@ -169,7 +195,7 @@ void TcpServer::StartWithNoBlock() {
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
}
void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
@ -260,10 +286,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
MS_EXCEPTION_IF_NULL(conn);
server->AddConnection(fd, conn);
conn->InitConnection([=](std::shared_ptr<CommMessage> message) {
conn->InitConnection([=](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
if (on_server_receive) {
on_server_receive(conn, message);
on_server_receive(conn, meta, protos, data, size);
}
});
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
@ -274,6 +300,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
}
std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
MS_EXCEPTION_IF_NULL(bev);
std::shared_ptr<TcpConnection> conn = nullptr;
if (client_accept_) {
conn = (client_accept_(*this));
@ -367,9 +394,17 @@ bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr
return conn->SendMessage(message);
}
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
return conn->SendMessage(meta, protos, data, size);
}
void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_EXCEPTION_IF_NULL(message);
std::lock_guard<std::mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(it->second, message);

View File

@ -36,7 +36,6 @@
#include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
@ -55,6 +54,7 @@ class TcpConnection {
virtual void InitConnection(const messageReceive &callback);
virtual void SendMessage(const void *buffer, size_t num) const;
bool SendMessage(std::shared_ptr<CommMessage> message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const;
const evutil_socket_t &GetFd() const;
@ -69,7 +69,8 @@ class TcpConnection {
};
using OnServerReceiveMessage =
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size)>;
class TcpServer {
public:
@ -100,6 +101,8 @@ class TcpServer {
OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t sizee);
void SendMessage(std::shared_ptr<CommMessage> message);
uint16_t BoundPort() const;
std::string BoundIp() const;

View File

@ -30,7 +30,12 @@ class TestTcpClient : public UT::Common {
TEST_F(TestTcpClient, InitClientIPError) {
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
message.ParseFromArray(data, size);
client->SendMessage(message);
});
ASSERT_THROW(client->Init(), std::exception);
}
@ -38,10 +43,15 @@ TEST_F(TestTcpClient, InitClientIPError) {
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
auto client = std::make_unique<TcpClient>("127.0.0.1", -1);
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
message.ParseFromArray(data, size);
client->SendMessage(message);
});
EXPECT_NO_THROW(client->Init());
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -33,131 +33,145 @@ class TestTcpMessageHandler : public UT::Common {
void TearDown() override {}
};
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data) {
TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 1000);
});
std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[1011];
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
char result[1018];
MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);
MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
handler.ReceiveMessage(result, buf_size + kHeaderLen);
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
handler.ReceiveMessage(result, 1018);
}
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data_16Header_2meta_1000Data) {
TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 1000);
});
std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[2022] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size);
char result[2036];
MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);
MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2);
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() + data.length(), meta.ByteSizeLong(),
meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() * 2 + data.length(), data.length(), data.data(),
data.length());
handler.ReceiveMessage(result, 2036);
}
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
TEST_F(TestTcpMessageHandler, 16header_2meta_4070data_8header_8header_2meta_4070data) {
TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); });
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 4070);
});
std::string data(4070, 'a');
std::string data(4081, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);
MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), 8, &header, 8);
handler.ReceiveMessage(result, 4096);
auto temp = reinterpret_cast<char *>(&buf_size);
ret = memcpy_s(result, 4, temp + 4, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4088);
}
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); });
std::string data(4077, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4096);
ret = memcpy_s(result, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
auto temp = reinterpret_cast<char *>(&header);
memcpy_s(result, 8, temp + 8, 8);
memcpy_s(result + 8, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + 8 + 2, data.length(), data.data(), data.length());
handler.ReceiveMessage(result, 4080);
}
TEST_F(TestTcpMessageHandler, 16Header_2meta_4062Data_16Header_2meta_4062_data) {
TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 4062);
});
std::string data(4062, 'a');
char result[4096] = {0};
MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);
MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);
handler.ReceiveMessage(result, 4096);
memcpy_s(result, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + meta.ByteSizeLong(), data.length(), data.data(), data.length());
handler.ReceiveMessage(result, 4064);
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -33,11 +33,12 @@ class TestTcpServer : public UT::Common {
server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([=]() {
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
KVMessage kv_message;
kv_message.ParseFromString(message->data());
kv_message.ParseFromArray(data, size);
EXPECT_EQ(2, kv_message.keys_size());
server_->SendMessage(conn, message);
server_->SendMessage(conn, meta, protos, data, size);
});
server_->Init();
server_->Start();
@ -61,23 +62,24 @@ TEST_F(TestTcpServer, ServerSendMessage) {
std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
KVMessage kv_message;
kv_message.ParseFromString(message.data());
EXPECT_EQ(2, kv_message.keys_size());
client_->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
KVMessage message;
message.ParseFromArray(data, size);
EXPECT_EQ(2, message.keys_size());
});
client_->Init();
CommMessage comm_message;
KVMessage kv_message;
std::vector<int> keys{1, 2};
std::vector<int> values{3, 4};
*kv_message.mutable_keys() = {keys.begin(), keys.end()};
*kv_message.mutable_values() = {values.begin(), values.end()};
comm_message.set_data(kv_message.SerializeAsString());
client_->SendMessage(comm_message);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
client_->SendMessage(message_meta, Protos::RAW, kv_message.SerializeAsString().data(), kv_message.ByteSizeLong());
client_->Start();
});