forked from mindspore-Ecosystem/mindspore
Custom data transmission format
This commit is contained in:
parent
14a6713d08
commit
c7fe82b43d
|
@ -20,8 +20,9 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace core {
|
||||
void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::REGISTER);
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::REGISTER);
|
||||
|
||||
RegisterMessage register_message;
|
||||
register_message.set_node_id(node_info_.node_id_);
|
||||
|
@ -29,11 +30,8 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|||
register_message.set_ip(node_info_.ip_);
|
||||
register_message.set_port(node_info_.port_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(register_message.SerializeAsString());
|
||||
comm_message.set_user_cmd("");
|
||||
if (!SendMessageSync(client, comm_message)) {
|
||||
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
|
||||
register_message.ByteSizeLong())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " register timeout!";
|
||||
}
|
||||
|
@ -42,9 +40,11 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|||
<< " the node id:" << node_info_.node_id_ << "is registering to scheduler!";
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
|
||||
void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
RegisterRespMessage register_resp_message;
|
||||
register_resp_message.ParseFromString(message.data());
|
||||
register_resp_message.ParseFromArray(data, size);
|
||||
if (register_resp_message.node_id() != node_info_.node_id_) {
|
||||
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
|
||||
<< " is not match the current node id:" << node_info_.node_id_;
|
||||
|
@ -52,28 +52,29 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
|
|||
|
||||
node_info_.rank_id_ = register_resp_message.rank_id();
|
||||
|
||||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
|
||||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_
|
||||
<< " registered scheduler success!";
|
||||
}
|
||||
|
||||
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) {
|
||||
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
|
||||
const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
if (node_role != NodeRole::SERVER) {
|
||||
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
|
||||
uint64_t request_id = AddMessageTrack(nodes_address_.size());
|
||||
|
||||
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta->set_request_id(request_id);
|
||||
message_meta->set_rank_id(node_info_.rank_id_);
|
||||
message_meta->set_role(node_info_.node_role_);
|
||||
message_meta->set_user_cmd(command);
|
||||
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
auto client = GetOrCreateTcpClient((*it).first.second);
|
||||
client->SendMessage(comm_message);
|
||||
client->SendMessage(message_meta, Protos::RAW, message.get(), size);
|
||||
}
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
|
@ -84,28 +85,27 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
|
|||
on_node_event_message_ = on_node_event_message;
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||
int command, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta->set_rank_id(node_info_.rank_id_);
|
||||
message_meta->set_role(node_info_.node_role_);
|
||||
message_meta->set_user_cmd(command);
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
return SendMessageSync(client, comm_message, timeout);
|
||||
return SendMessageSync(client, message_meta, Protos::RAW, data.get(), len, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<CommMessage> &data, const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = AddMessageTrack(data.size());
|
||||
|
||||
if (rank_ids.size() != data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
|
||||
|
@ -115,34 +115,32 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta->set_request_id(request_id);
|
||||
message_meta->set_rank_id(node_info_.rank_id_);
|
||||
message_meta->set_role(node_info_.node_role_);
|
||||
message_meta->set_user_cmd(command);
|
||||
|
||||
auto send = data.at(it);
|
||||
auto len = lens.at(it);
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
client->SendMessage(message_meta, Protos::RAW, send.get(), len);
|
||||
}
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
||||
CommMessage *output, const uint32_t &timeout) {
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
|
||||
int command, VectorPtr *output, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
|
@ -151,59 +149,59 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta->set_request_id(request_id);
|
||||
message_meta->set_rank_id(node_info_.rank_id_);
|
||||
message_meta->set_role(node_info_.node_role_);
|
||||
message_meta->set_user_cmd(command);
|
||||
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
client->SendMessage(comm_message);
|
||||
client->SendMessage(message_meta, Protos::RAW, message.get(), len);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<CommMessage> &data, std::vector<CommMessage> *output,
|
||||
const uint32_t &timeout) {
|
||||
const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
|
||||
std::vector<VectorPtr> *output, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
uint64_t request_id = AddMessageTrack(data.size());
|
||||
|
||||
if (rank_ids.size() != data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
|
||||
}
|
||||
|
||||
size_t len = rank_ids.size();
|
||||
size_t size = rank_ids.size();
|
||||
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
for (size_t it = 0; it < size; ++it) {
|
||||
(*output).push_back(res[rank_ids.at(it)]);
|
||||
}
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
for (size_t it = 0; it < size; ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta->set_request_id(request_id);
|
||||
message_meta->set_rank_id(node_info_.rank_id_);
|
||||
message_meta->set_role(node_info_.node_role_);
|
||||
message_meta->set_user_cmd(command);
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
auto send = data.at(it);
|
||||
auto len = data_lens.at(it);
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
client->SendMessage(message_meta, Protos::RAW, send.get(), len);
|
||||
}
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
|
@ -220,55 +218,61 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
|
|||
return res;
|
||||
}
|
||||
|
||||
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
const CommMessage &message) {
|
||||
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
|
||||
message_meta->set_rank_id(node_info_.rank_id_);
|
||||
message_meta->set_role(node_info_.node_role_);
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
return SendMessageAsync(client, comm_message);
|
||||
return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
|
||||
}
|
||||
|
||||
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
|
||||
const uint32_t &rank_id, CommMessage *output) {
|
||||
const uint32_t &rank_id, void **output,
|
||||
size_t *size) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
MS_EXCEPTION_IF_NULL(size);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
|
||||
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
|
||||
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
|
||||
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
||||
} else {
|
||||
set_receive_callback(rank_id, rank_request_id, [=]() {
|
||||
receive_callbacks_mutex_.lock();
|
||||
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
|
||||
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
|
||||
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = false;
|
||||
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
|
||||
auto res = received_data_[std::make_pair(rank_id, rank_request_id)];
|
||||
*output = res->data();
|
||||
*size = res->size();
|
||||
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
||||
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
|
||||
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
|
||||
} else {
|
||||
receive_callbacks_[std::make_pair(rank_id, rank_request_id)] = [=]() mutable {
|
||||
receive_callbacks_mutex_.lock();
|
||||
auto res = received_data_[std::make_pair(rank_id, rank_request_id)];
|
||||
*output = res->data();
|
||||
*size = res->size();
|
||||
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
||||
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
|
||||
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
|
||||
receive_callbacks_mutex_.unlock();
|
||||
});
|
||||
};
|
||||
}
|
||||
receive_callbacks_mutex_.unlock();
|
||||
return std::make_pair(rank_id, rank_request_id);
|
||||
}
|
||||
|
||||
bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
|
||||
bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
if (actual_rank_request_ids_.count(request_id.first) &&
|
||||
(actual_rank_request_ids_[request_id.first] >= request_id.second)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
bool res =
|
||||
receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -297,17 +301,15 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
|
|||
}
|
||||
|
||||
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::HEARTBEAT);
|
||||
auto meta = std::make_shared<MessageMeta>();
|
||||
meta->set_cmd(NodeCommand::HEARTBEAT);
|
||||
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.set_node_id(node_info_.node_id_);
|
||||
heartbeat_message.set_is_node_finish(is_node_finish);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(heartbeat_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, message)) {
|
||||
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
|
||||
heartbeat_message.ByteSizeLong())) {
|
||||
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||
}
|
||||
return true;
|
||||
|
@ -331,9 +333,11 @@ bool AbstractNode::CheckSchedulerTimeout() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
|
||||
void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.ParseFromString(message.data());
|
||||
heartbeat_resp_message.ParseFromArray(data, size);
|
||||
|
||||
is_ready_ = heartbeat_resp_message.is_cluster_ready();
|
||||
if (is_ready_.load()) {
|
||||
|
@ -359,19 +363,22 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
|
|||
}
|
||||
|
||||
void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::FETCH_SERVER);
|
||||
auto meta = std::make_shared<MessageMeta>();
|
||||
meta->set_cmd(NodeCommand::FETCH_SERVER);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
if (!SendMessageSync(client, message)) {
|
||||
FetchServersMessage fetch_servers;
|
||||
fetch_servers.set_node_id(node_info_.node_id_);
|
||||
if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(),
|
||||
fetch_servers.ByteSizeLong())) {
|
||||
MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
|
||||
void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
FetchServersRespMessage fetch_servers_resp_message;
|
||||
fetch_servers_resp_message.ParseFromString(message.data());
|
||||
fetch_servers_resp_message.ParseFromArray(data, size);
|
||||
|
||||
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
|
||||
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||
|
@ -381,16 +388,14 @@ void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
|
|||
}
|
||||
|
||||
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::FINISH);
|
||||
auto meta = std::make_shared<MessageMeta>();
|
||||
meta->set_cmd(NodeCommand::FINISH);
|
||||
|
||||
FinishMessage finish_message;
|
||||
finish_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(finish_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, message)) {
|
||||
if (!SendMessageSync(client, meta, Protos::PROTOBUF, finish_message.SerializeAsString().data(),
|
||||
finish_message.ByteSizeLong())) {
|
||||
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
|
||||
}
|
||||
|
@ -412,15 +417,16 @@ bool AbstractNode::InitClientToScheduler() {
|
|||
std::string scheduler_host = ClusterConfig::scheduler_host();
|
||||
uint16_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
|
||||
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
if (handlers_.count(message.pb_meta().cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
client_to_scheduler_->SetMessageCallback(
|
||||
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
if (handlers_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
}
|
||||
if (handlers_[message.pb_meta().cmd()] != nullptr) {
|
||||
const auto &handler_ptr = handlers_[message.pb_meta().cmd()];
|
||||
(this->*handler_ptr)(message);
|
||||
if (handlers_[meta->cmd()] != nullptr) {
|
||||
const auto &handler_ptr = handlers_[meta->cmd()];
|
||||
(this->*handler_ptr)(meta, data, size);
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
NotifyMessageArrival(meta);
|
||||
});
|
||||
|
||||
client_to_scheduler_->Init();
|
||||
|
@ -447,19 +453,20 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
|
|||
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
|
||||
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
|
||||
auto client = std::make_shared<TcpClient>(ip, port);
|
||||
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
switch (meta->cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendDataResp(message);
|
||||
RunMessageCallback(message.pb_meta().request_id());
|
||||
ProcessSendDataResp(meta, protos, data, size);
|
||||
RunMessageCallback(meta->request_id());
|
||||
break;
|
||||
case NodeCommand::COLLECTIVE_SEND_DATA:
|
||||
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
|
||||
MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
NotifyMessageArrival(meta);
|
||||
});
|
||||
client->Init();
|
||||
connected_nodes_[rank_id] = client;
|
||||
|
@ -469,8 +476,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
|
|||
|
||||
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
|
@ -478,29 +484,55 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con
|
|||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
meta->set_request_id(request_id);
|
||||
client->SendMessage(meta, protos, data, size);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return request_id;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
|
||||
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
meta->set_request_id(request_id);
|
||||
client->SendMessage(meta, protos, data, size);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
bool res = Wait(request_id, timeout);
|
||||
return res;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
const uint32_t &rank_id = message_meta.rank_id();
|
||||
const uint64_t request_id = message_meta.request_id();
|
||||
const uint32_t &rank_id = meta->rank_id();
|
||||
const uint64_t request_id = meta->request_id();
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
auto it = receive_messages_.find(request_id);
|
||||
VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
|
||||
if (size > 0) {
|
||||
int ret = memcpy_s(received_data.get()->data(), size, data, size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
}
|
||||
if (it != receive_messages_.end()) {
|
||||
it->second[rank_id] = message;
|
||||
it->second[rank_id] = received_data;
|
||||
} else {
|
||||
std::unordered_map<uint32_t, CommMessage> res;
|
||||
res.insert(std::make_pair(rank_id, message));
|
||||
std::unordered_map<uint32_t, VectorPtr> res;
|
||||
res.insert(std::make_pair(rank_id, received_data));
|
||||
receive_messages_[request_id] = res;
|
||||
}
|
||||
}
|
||||
|
@ -509,7 +541,7 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) {
|
|||
message_callbacks_mutex_.lock();
|
||||
// When receiving a message's response, Then compare with the desired number of responses,
|
||||
// If they are equal, then call the callback function
|
||||
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
|
||||
if (CheckMessageTrack(request_id)) {
|
||||
auto it = message_callbacks_.find(request_id);
|
||||
if (it != message_callbacks_.end()) {
|
||||
message_callbacks_mutex_.unlock();
|
||||
|
@ -533,31 +565,31 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag
|
|||
message_callbacks_[request_id] = callback;
|
||||
}
|
||||
|
||||
void AbstractNode::NotifyMessageArrival(const CommMessage &message) {
|
||||
void AbstractNode::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
uint64_t request_id = message_meta.request_id();
|
||||
uint64_t request_id = meta->request_id();
|
||||
|
||||
message_tracker_[request_id].second++;
|
||||
message_tracker_cond_.notify_all();
|
||||
}
|
||||
|
||||
void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id,
|
||||
const MessageCallback &callback) {
|
||||
if (!callback) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(receive_callbacks_mutex_);
|
||||
receive_callbacks_[std::make_pair(rank_id, request_id)] = callback;
|
||||
}
|
||||
|
||||
void AbstractNode::RunReceiveCallback(const CommMessage &message) {
|
||||
void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
receive_callbacks_mutex_.lock();
|
||||
uint32_t rank_id = message.pb_meta().rank_id();
|
||||
uint32_t rank_id = meta->rank_id();
|
||||
// When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
|
||||
// If they are equal, then call the callback function
|
||||
uint64_t rank_request_id = NextActualRankRequestId(rank_id);
|
||||
received_data_[std::make_pair(rank_id, rank_request_id)] = message;
|
||||
std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
|
||||
int ret = memcpy_s(received_data->data(), size, data, size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
|
||||
MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
|
||||
<< ", the send request id is:" << meta->request_id();
|
||||
auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
|
||||
if (it != receive_callbacks_.end()) {
|
||||
receive_callbacks_mutex_.unlock();
|
||||
|
@ -603,6 +635,18 @@ void AbstractNode::InitCommandHandler() {
|
|||
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
|
||||
handlers_[NodeCommand::FINISH] = nullptr;
|
||||
}
|
||||
|
||||
uint64_t AbstractNode::AddMessageTrack(const uint32_t &expected_response) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(expected_response, 0);
|
||||
return request_id;
|
||||
}
|
||||
|
||||
bool AbstractNode::CheckMessageTrack(const uint64_t &request_id) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "ps/core/node.h"
|
||||
#include "ps/core/message.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -34,53 +35,63 @@ class AbstractNode : public Node {
|
|||
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
|
||||
~AbstractNode() override = default;
|
||||
|
||||
typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message);
|
||||
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
bool Broadcast(const enum NodeRole &node_role, const CommMessage &message,
|
||||
using DataPtr = std::shared_ptr<unsigned char>;
|
||||
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
|
||||
|
||||
bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
||||
const std::vector<size_t> &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
|
||||
VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
||||
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
|
||||
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message);
|
||||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
|
||||
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
CommMessage *output);
|
||||
void **output, size_t *size);
|
||||
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
void Register(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessRegisterResp(const CommMessage &message);
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
void UpdateSchedulerTime();
|
||||
bool CheckSchedulerTimeout() const;
|
||||
void ProcessHeartbeatResp(const CommMessage &message);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessFetchServersResp(const CommMessage &message);
|
||||
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
|
||||
bool WaitForDisconnect(const uint32_t &timeout);
|
||||
bool InitClientToScheduler();
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
|
||||
void ProcessSendDataResp(const CommMessage &message);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
|
||||
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
void RunMessageCallback(const uint64_t &request_id);
|
||||
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
|
||||
void NotifyMessageArrival(const CommMessage &message);
|
||||
void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback);
|
||||
void RunReceiveCallback(const CommMessage &message);
|
||||
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
|
||||
void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
|
||||
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
|
||||
void InitCommandHandler();
|
||||
uint64_t AddMessageTrack(const uint32_t &expected_response);
|
||||
bool CheckMessageTrack(const uint64_t &request_id);
|
||||
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||
|
@ -98,15 +109,16 @@ class AbstractNode : public Node {
|
|||
std::mutex message_tracker_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
|
||||
// the key is: request_id, the value is:<rank_id, CommMessage>
|
||||
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
|
||||
// the key is: request_id, the value is: <rank_id, RecvMessage>
|
||||
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;
|
||||
std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_;
|
||||
std::mutex receive_messages_mutex_;
|
||||
// the key is: request_id
|
||||
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
|
||||
std::mutex message_callbacks_mutex_;
|
||||
|
||||
// the key is <rank_id, rank_request_id>
|
||||
std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_;
|
||||
std::map<std::pair<uint32_t, uint64_t>, std::shared_ptr<std::vector<unsigned char>>> received_data_;
|
||||
std::mutex receive_callbacks_mutex_;
|
||||
// the key is <rank_id, rank_request_id>
|
||||
std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
enum class Protos : uint32_t { RAW = 0, PROTOBUF = 1, FLATBUFFERS = 2 };
|
||||
|
||||
enum class Command {
|
||||
TERMINATE = 0,
|
||||
REGISTER = 1,
|
||||
HEARTBEAT = 2,
|
||||
SEND_DATA = 3,
|
||||
FETCH_SERVER = 4,
|
||||
FINISH = 5,
|
||||
COLLECTIVE_SEND_DATA = 6
|
||||
};
|
||||
|
||||
enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 };
|
||||
|
||||
struct MessageHeader {
|
||||
Protos message_proto_ = Protos::RAW;
|
||||
uint32_t message_meta_length_ = 0;
|
||||
uint64_t message_length_ = 0;
|
||||
};
|
||||
|
||||
struct CommandMeta {
|
||||
// the command of this message,for example: register,heartbeat,data
|
||||
Command cmd;
|
||||
// the request id of this message
|
||||
uint64_t request_id;
|
||||
// the role of the current node: worker,server,scheduler
|
||||
Role role;
|
||||
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
||||
int32_t rank_id = 4;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
import "google/protobuf/any.proto";
|
||||
package mindspore.ps.core;
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
||||
|
@ -44,6 +43,8 @@ message MessageMeta {
|
|||
NodeRole role = 3;
|
||||
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
||||
int32 rank_id = 4;
|
||||
// User-defined commands
|
||||
int32 user_cmd = 5;
|
||||
}
|
||||
|
||||
message RegisterMessage {
|
||||
|
@ -76,6 +77,10 @@ message HeartbeatRespMessage {
|
|||
bool is_node_timeout = 4;
|
||||
}
|
||||
|
||||
message FetchServersMessage {
|
||||
string node_id = 1;
|
||||
}
|
||||
|
||||
message FetchServersRespMessage {
|
||||
repeated ServersMeta servers_meta = 1;
|
||||
}
|
||||
|
@ -95,6 +100,4 @@ message FinishMessage {
|
|||
message CommMessage {
|
||||
MessageMeta pb_meta = 1;
|
||||
bytes data = 2;
|
||||
// User-defined commands
|
||||
bytes user_cmd = 3;
|
||||
}
|
||||
|
|
|
@ -38,9 +38,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
}
|
||||
|
||||
void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.ParseFromString(message->data());
|
||||
heartbeat_message.ParseFromArray(data, size);
|
||||
|
||||
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
|
||||
|
||||
|
@ -60,10 +64,8 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
|
|||
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
|
||||
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
|
||||
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
comm_message->set_data(heartbeat_resp_message.SerializeAsString());
|
||||
server->SendMessage(conn, comm_message);
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
|
||||
heartbeat_resp_message.ByteSizeLong());
|
||||
}
|
||||
|
||||
void SchedulerNode::Initialize() {
|
||||
|
@ -89,12 +91,13 @@ void SchedulerNode::CreateTcpServer() {
|
|||
std::string scheduler_host = ClusterConfig::scheduler_host();
|
||||
uint32_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
if (handlers_.count(message->pb_meta().cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
if (handlers_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
}
|
||||
const auto &handler_ptr = handlers_[message->pb_meta().cmd()];
|
||||
(this->*handler_ptr)(server_, conn, message);
|
||||
const auto &handler_ptr = handlers_[meta->cmd()];
|
||||
(this->*handler_ptr)(server_, conn, meta, data, size);
|
||||
});
|
||||
|
||||
server_->Init();
|
||||
|
@ -106,10 +109,14 @@ void SchedulerNode::CreateTcpServer() {
|
|||
}
|
||||
|
||||
void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_LOG(INFO) << "The scheduler process a register message!";
|
||||
RegisterMessage register_message;
|
||||
register_message.ParseFromString(message->data());
|
||||
register_message.ParseFromArray(data, size);
|
||||
|
||||
// assign worker node and server node rank id
|
||||
int rank_id = node_manager_.NextRankId(register_message);
|
||||
|
@ -123,32 +130,32 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
register_resp_message.set_node_id(node_id);
|
||||
register_resp_message.set_rank_id(rank_id);
|
||||
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
comm_message->set_data(register_resp_message.SerializeAsString());
|
||||
server->SendMessage(conn, comm_message);
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
|
||||
register_resp_message.ByteSizeLong());
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
FinishMessage finish_message;
|
||||
finish_message.ParseFromString(message->data());
|
||||
finish_message.ParseFromArray(data, size);
|
||||
node_manager_.AddFinishNode(finish_message);
|
||||
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
|
||||
server->SendMessage(conn, message);
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message) {
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
FetchServersRespMessage fetch_servers_message;
|
||||
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
|
||||
|
||||
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
|
||||
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
comm_message->set_data(fetch_servers_message.SerializeAsString());
|
||||
server->SendMessage(conn, comm_message);
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(),
|
||||
fetch_servers_message.ByteSizeLong());
|
||||
}
|
||||
|
||||
void SchedulerNode::StartUpdateClusterStateTimer() {
|
||||
|
|
|
@ -36,13 +36,14 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
class SchedulerNode : public Node {
|
||||
public:
|
||||
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
|
||||
~SchedulerNode() override;
|
||||
|
||||
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
|
||||
bool Stop() override;
|
||||
|
@ -53,14 +54,14 @@ class SchedulerNode : public Node {
|
|||
void InitCommandHandler();
|
||||
void CreateTcpServer();
|
||||
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void StartUpdateClusterStateTimer();
|
||||
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> scheduler_thread_;
|
||||
|
|
|
@ -46,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) {
|
|||
|
||||
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
|
||||
|
||||
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
message->mutable_pb_meta()->set_role(node_info_.node_role_);
|
||||
message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_);
|
||||
const MessageMeta &message_meta = message->pb_meta();
|
||||
const uint64_t request_id = message_meta.request_id();
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
meta->set_role(node_info_.node_role_);
|
||||
meta->set_rank_id(node_info_.rank_id_);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
server_->SendMessage(conn, message);
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data.get(), size);
|
||||
}
|
||||
|
||||
void ServerNode::CreateTcpServer() {
|
||||
|
@ -63,17 +63,18 @@ void ServerNode::CreateTcpServer() {
|
|||
std::string server_ip;
|
||||
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
|
||||
server_ = std::make_shared<TcpServer>(server_ip, 0);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
switch (message->pb_meta().cmd()) {
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
switch (meta->cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendData(conn, message);
|
||||
ProcessSendData(conn, meta, protos, data, size);
|
||||
break;
|
||||
case NodeCommand::COLLECTIVE_SEND_DATA:
|
||||
ProcessCollectiveSendData(conn, message);
|
||||
RunReceiveCallback(*message);
|
||||
ProcessCollectiveSendData(conn, meta, data, size);
|
||||
RunReceiveCallback(meta, protos, data, size);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
}
|
||||
});
|
||||
server_->Init();
|
||||
|
@ -99,18 +100,24 @@ void ServerNode::Initialize() {
|
|||
MS_LOG(INFO) << "Server node init client successful!";
|
||||
}
|
||||
|
||||
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
request_handler_(conn, message);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
std::shared_ptr<unsigned char> res(new unsigned char[size]);
|
||||
int ret = memcpy_s(res.get(), size, data, size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
request_handler_(conn, meta, res, size);
|
||||
}
|
||||
|
||||
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
|
||||
*comm_message->mutable_pb_meta() = {message->pb_meta()};
|
||||
server_->SendMessage(conn, comm_message);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
}
|
||||
|
||||
bool ServerNode::Stop() {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
|
@ -41,16 +42,19 @@ class ServerNode : public AbstractNode {
|
|||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
|
||||
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
DataPtr data, size_t size)>;
|
||||
|
||||
void set_handler(const RequestHandler &handler);
|
||||
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data, size_t size);
|
||||
|
||||
private:
|
||||
void CreateTcpServer();
|
||||
void Initialize();
|
||||
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const void *data, size_t size);
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
|
|
|
@ -46,9 +46,10 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
|
|||
server_port_(port),
|
||||
is_stop_(true),
|
||||
is_connected_(false) {
|
||||
message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) {
|
||||
message_handler_.SetCallback(
|
||||
[this](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
if (message_callback_) {
|
||||
message_callback_(*this, *message);
|
||||
message_callback_(meta, protos, data, size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -189,7 +190,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
|
|||
void TcpClient::OnReadHandler(const void *buf, size_t num) {
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
if (read_callback_) {
|
||||
read_callback_(*this, buf, num);
|
||||
read_callback_(buf, num);
|
||||
}
|
||||
message_handler_.ReceiveMessage(buf, num);
|
||||
}
|
||||
|
@ -198,7 +199,7 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
|
|||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
|
||||
if (tcp_client->on_timer_callback_) {
|
||||
tcp_client->on_timer_callback_(*tcp_client);
|
||||
tcp_client->on_timer_callback_();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -245,7 +246,7 @@ void TcpClient::Start() {
|
|||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
||||
<< "Event base dispatch failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
|
||||
}
|
||||
|
||||
void TcpClient::StartWithNoBlock() {
|
||||
|
@ -256,7 +257,7 @@ void TcpClient::StartWithNoBlock() {
|
|||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
|
||||
}
|
||||
|
||||
void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
|
||||
|
@ -265,14 +266,49 @@ bool TcpClient::SendMessage(const CommMessage &message) const {
|
|||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
bufferevent_lock(buffer_event_);
|
||||
bool res = true;
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
|
||||
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
|
||||
size_t buf_size = IntToUint(message.ByteSizeLong());
|
||||
uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
|
||||
Messageheader header;
|
||||
header.message_proto_ = Protos::PROTOBUF;
|
||||
header.message_length_ = buf_size;
|
||||
header.message_meta_length_ = meta_size;
|
||||
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add header failed!";
|
||||
res = false;
|
||||
}
|
||||
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
|
||||
if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
bufferevent_unlock(buffer_event_);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
bufferevent_lock(buffer_event_);
|
||||
bool res = true;
|
||||
|
||||
Messageheader header;
|
||||
header.message_proto_ = protos;
|
||||
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
|
||||
header.message_length_ = size + header.message_meta_length_;
|
||||
|
||||
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add header failed!";
|
||||
res = false;
|
||||
}
|
||||
if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
if (bufferevent_write(buffer_event_, data, size) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
|
|
|
@ -42,10 +42,10 @@ class TcpClient {
|
|||
public:
|
||||
using OnConnected = std::function<void()>;
|
||||
using OnDisconnected = std::function<void()>;
|
||||
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
|
||||
using OnTimeout = std::function<void(const TcpClient &)>;
|
||||
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
|
||||
using OnTimer = std::function<void(const TcpClient &)>;
|
||||
using OnRead = std::function<void(const void *, size_t)>;
|
||||
using OnTimeout = std::function<void()>;
|
||||
using OnMessage = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
|
||||
using OnTimer = std::function<void()>;
|
||||
|
||||
explicit TcpClient(const std::string &address, std::uint16_t port);
|
||||
virtual ~TcpClient();
|
||||
|
@ -61,6 +61,7 @@ class TcpClient {
|
|||
void StartWithNoBlock();
|
||||
void SetMessageCallback(const OnMessage &cb);
|
||||
bool SendMessage(const CommMessage &message) const;
|
||||
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
void StartTimer(const uint32_t &time);
|
||||
void set_timer_callback(const OnTimer &timer);
|
||||
const event_base &eventbase();
|
||||
|
|
|
@ -35,8 +35,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
header_[++header_index_] = *(buffer_data + i);
|
||||
--num;
|
||||
if (header_index_ == kHeaderLen - 1) {
|
||||
message_length_ = *reinterpret_cast<const size_t *>(header_);
|
||||
remaining_length_ = message_length_;
|
||||
message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_);
|
||||
message_header_.message_meta_length_ =
|
||||
*reinterpret_cast<const uint32_t *>(header_ + sizeof(message_header_.message_proto_));
|
||||
message_header_.message_length_ = *reinterpret_cast<const size_t *>(
|
||||
header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_));
|
||||
remaining_length_ = message_header_.message_length_;
|
||||
message_buffer_.reset(new unsigned char[remaining_length_]);
|
||||
buffer_data += (i + 1);
|
||||
break;
|
||||
|
@ -57,10 +61,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
}
|
||||
|
||||
if (remaining_length_ == 0) {
|
||||
std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>();
|
||||
pb_message->ParseFromArray(message_buffer_.get(), message_length_);
|
||||
if (message_callback_) {
|
||||
message_callback_(pb_message);
|
||||
std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>();
|
||||
pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_);
|
||||
message_callback_(pb_message, message_header_.message_proto_,
|
||||
message_buffer_.get() + message_header_.message_meta_length_,
|
||||
message_header_.message_length_ - message_header_.message_meta_length_);
|
||||
}
|
||||
message_buffer_.reset();
|
||||
message_buffer_ = nullptr;
|
||||
|
|
|
@ -24,24 +24,20 @@
|
|||
#include <vector>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/core/message.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>;
|
||||
constexpr int kHeaderLen = 8;
|
||||
using messageReceive = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
|
||||
constexpr int kHeaderLen = 16;
|
||||
|
||||
class TcpMessageHandler {
|
||||
public:
|
||||
TcpMessageHandler()
|
||||
: is_parsed_(false),
|
||||
message_buffer_(nullptr),
|
||||
message_length_(0),
|
||||
remaining_length_(0),
|
||||
header_index_(-1),
|
||||
last_copy_len_(0) {}
|
||||
: is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {}
|
||||
virtual ~TcpMessageHandler() = default;
|
||||
|
||||
void SetCallback(const messageReceive &cb);
|
||||
|
@ -51,11 +47,12 @@ class TcpMessageHandler {
|
|||
messageReceive message_callback_;
|
||||
bool is_parsed_;
|
||||
std::unique_ptr<unsigned char> message_buffer_;
|
||||
size_t message_length_;
|
||||
size_t remaining_length_;
|
||||
char header_[8];
|
||||
char header_[16];
|
||||
int header_index_;
|
||||
size_t last_copy_len_;
|
||||
MessageHeader message_header_;
|
||||
std::string mBuffer;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -54,13 +54,39 @@ bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
|
|||
bufferevent_lock(buffer_event_);
|
||||
bool res = true;
|
||||
size_t buf_size = message->ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message->SerializeToArray(serialized.data(), SizeToInt(buf_size));
|
||||
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add header failed!";
|
||||
res = false;
|
||||
}
|
||||
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
|
||||
if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
bufferevent_unlock(buffer_event_);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool TcpConnection::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
bufferevent_lock(buffer_event_);
|
||||
bool res = true;
|
||||
Messageheader header;
|
||||
header.message_proto_ = protos;
|
||||
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
|
||||
header.message_length_ = size + header.message_meta_length_;
|
||||
|
||||
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add header failed!";
|
||||
res = false;
|
||||
}
|
||||
if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
if (bufferevent_write(buffer_event_, data, size) == -1) {
|
||||
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
|
||||
res = false;
|
||||
}
|
||||
|
@ -158,7 +184,7 @@ void TcpServer::Start() {
|
|||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
||||
<< "Event base dispatch failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
|
||||
}
|
||||
|
||||
void TcpServer::StartWithNoBlock() {
|
||||
|
@ -169,7 +195,7 @@ void TcpServer::StartWithNoBlock() {
|
|||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
|
||||
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
|
||||
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
|
||||
}
|
||||
|
||||
void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
|
||||
|
@ -260,10 +286,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|||
MS_EXCEPTION_IF_NULL(conn);
|
||||
|
||||
server->AddConnection(fd, conn);
|
||||
conn->InitConnection([=](std::shared_ptr<CommMessage> message) {
|
||||
conn->InitConnection([=](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
|
||||
if (on_server_receive) {
|
||||
on_server_receive(conn, message);
|
||||
on_server_receive(conn, meta, protos, data, size);
|
||||
}
|
||||
});
|
||||
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
|
||||
|
@ -274,6 +300,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|||
}
|
||||
|
||||
std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
|
||||
MS_EXCEPTION_IF_NULL(bev);
|
||||
std::shared_ptr<TcpConnection> conn = nullptr;
|
||||
if (client_accept_) {
|
||||
conn = (client_accept_(*this));
|
||||
|
@ -367,9 +394,17 @@ bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr
|
|||
return conn->SendMessage(message);
|
||||
}
|
||||
|
||||
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
return conn->SendMessage(meta, protos, data, size);
|
||||
}
|
||||
|
||||
void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
|
||||
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
|
||||
SendMessage(it->second, message);
|
||||
|
|
|
@ -36,7 +36,6 @@
|
|||
|
||||
#include "ps/core/tcp_message_handler.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -55,6 +54,7 @@ class TcpConnection {
|
|||
virtual void InitConnection(const messageReceive &callback);
|
||||
virtual void SendMessage(const void *buffer, size_t num) const;
|
||||
bool SendMessage(std::shared_ptr<CommMessage> message) const;
|
||||
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) const;
|
||||
virtual void OnReadHandler(const void *buffer, size_t numBytes);
|
||||
TcpServer *GetServer() const;
|
||||
const evutil_socket_t &GetFd() const;
|
||||
|
@ -69,7 +69,8 @@ class TcpConnection {
|
|||
};
|
||||
|
||||
using OnServerReceiveMessage =
|
||||
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
|
||||
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size)>;
|
||||
|
||||
class TcpServer {
|
||||
public:
|
||||
|
@ -100,6 +101,8 @@ class TcpServer {
|
|||
OnServerReceiveMessage GetServerReceive() const;
|
||||
void SetMessageCallback(const OnServerReceiveMessage &cb);
|
||||
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t sizee);
|
||||
void SendMessage(std::shared_ptr<CommMessage> message);
|
||||
uint16_t BoundPort() const;
|
||||
std::string BoundIp() const;
|
||||
|
|
|
@ -30,7 +30,12 @@ class TestTcpClient : public UT::Common {
|
|||
TEST_F(TestTcpClient, InitClientIPError) {
|
||||
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);
|
||||
|
||||
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
|
||||
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
|
||||
CommMessage message;
|
||||
message.ParseFromArray(data, size);
|
||||
|
||||
client->SendMessage(message);
|
||||
});
|
||||
|
||||
ASSERT_THROW(client->Init(), std::exception);
|
||||
}
|
||||
|
@ -38,10 +43,15 @@ TEST_F(TestTcpClient, InitClientIPError) {
|
|||
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
|
||||
auto client = std::make_unique<TcpClient>("127.0.0.1", -1);
|
||||
|
||||
client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
|
||||
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
|
||||
CommMessage message;
|
||||
message.ParseFromArray(data, size);
|
||||
client->SendMessage(message);
|
||||
});
|
||||
|
||||
EXPECT_NO_THROW(client->Init());
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -33,131 +33,145 @@ class TestTcpMessageHandler : public UT::Common {
|
|||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
|
||||
TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
|
||||
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
|
||||
EXPECT_EQ(meta->ByteSizeLong(), 2);
|
||||
EXPECT_EQ(size, 1000);
|
||||
});
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[1011];
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
|
||||
char result[1018];
|
||||
|
||||
MessageMeta meta;
|
||||
meta.set_request_id(1);
|
||||
EXPECT_EQ(meta.ByteSizeLong(), 2);
|
||||
|
||||
MessageHeader header;
|
||||
header.message_proto_ = Protos::RAW;
|
||||
header.message_meta_length_ = meta.ByteSizeLong();
|
||||
header.message_length_ = data.length() + meta.ByteSizeLong();
|
||||
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
handler.ReceiveMessage(result, buf_size + kHeaderLen);
|
||||
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
|
||||
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
|
||||
|
||||
handler.ReceiveMessage(result, 1018);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
|
||||
TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data_16Header_2meta_1000Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
|
||||
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
|
||||
EXPECT_EQ(meta->ByteSizeLong(), 2);
|
||||
EXPECT_EQ(size, 1000);
|
||||
});
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[2022] = {0};
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
|
||||
char result[2036];
|
||||
|
||||
MessageMeta meta;
|
||||
meta.set_request_id(1);
|
||||
EXPECT_EQ(meta.ByteSizeLong(), 2);
|
||||
|
||||
MessageHeader header;
|
||||
header.message_proto_ = Protos::RAW;
|
||||
header.message_meta_length_ = meta.ByteSizeLong();
|
||||
header.message_length_ = data.length() + meta.ByteSizeLong();
|
||||
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2);
|
||||
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
|
||||
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
|
||||
|
||||
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);
|
||||
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() + data.length(), meta.ByteSizeLong(),
|
||||
meta.SerializeAsString().data(), meta.ByteSizeLong());
|
||||
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() * 2 + data.length(), data.length(), data.data(),
|
||||
data.length());
|
||||
|
||||
handler.ReceiveMessage(result, 2036);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
|
||||
TEST_F(TestTcpMessageHandler, 16header_2meta_4070data_8header_8header_2meta_4070data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); });
|
||||
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
|
||||
EXPECT_EQ(meta->ByteSizeLong(), 2);
|
||||
EXPECT_EQ(size, 4070);
|
||||
});
|
||||
|
||||
std::string data(4070, 'a');
|
||||
|
||||
std::string data(4081, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[4096] = {0};
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
|
||||
MessageMeta meta;
|
||||
meta.set_request_id(1);
|
||||
EXPECT_EQ(meta.ByteSizeLong(), 2);
|
||||
|
||||
MessageHeader header;
|
||||
header.message_proto_ = Protos::RAW;
|
||||
header.message_meta_length_ = meta.ByteSizeLong();
|
||||
header.message_length_ = data.length() + meta.ByteSizeLong();
|
||||
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
|
||||
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
|
||||
|
||||
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), 8, &header, 8);
|
||||
handler.ReceiveMessage(result, 4096);
|
||||
|
||||
auto temp = reinterpret_cast<char *>(&buf_size);
|
||||
ret = memcpy_s(result, 4, temp + 4, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4088);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); });
|
||||
|
||||
std::string data(4077, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[4096] = {0};
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4096);
|
||||
|
||||
ret = memcpy_s(result, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
auto temp = reinterpret_cast<char *>(&header);
|
||||
memcpy_s(result, 8, temp + 8, 8);
|
||||
memcpy_s(result + 8, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
|
||||
memcpy_s(result + 8 + 2, data.length(), data.data(), data.length());
|
||||
|
||||
handler.ReceiveMessage(result, 4080);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 16Header_2meta_4062Data_16Header_2meta_4062_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
|
||||
EXPECT_EQ(meta->ByteSizeLong(), 2);
|
||||
EXPECT_EQ(size, 4062);
|
||||
});
|
||||
|
||||
std::string data(4062, 'a');
|
||||
|
||||
char result[4096] = {0};
|
||||
|
||||
MessageMeta meta;
|
||||
meta.set_request_id(1);
|
||||
EXPECT_EQ(meta.ByteSizeLong(), 2);
|
||||
|
||||
MessageHeader header;
|
||||
header.message_proto_ = Protos::RAW;
|
||||
header.message_meta_length_ = meta.ByteSizeLong();
|
||||
header.message_length_ = data.length() + meta.ByteSizeLong();
|
||||
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
|
||||
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
|
||||
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);
|
||||
|
||||
handler.ReceiveMessage(result, 4096);
|
||||
|
||||
memcpy_s(result, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
|
||||
memcpy_s(result + meta.ByteSizeLong(), data.length(), data.data(), data.length());
|
||||
|
||||
handler.ReceiveMessage(result, 4064);
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -33,11 +33,12 @@ class TestTcpServer : public UT::Common {
|
|||
server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
|
||||
std::unique_ptr<std::thread> http_server_thread_(nullptr);
|
||||
http_server_thread_ = std::make_unique<std::thread>([=]() {
|
||||
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
KVMessage kv_message;
|
||||
kv_message.ParseFromString(message->data());
|
||||
kv_message.ParseFromArray(data, size);
|
||||
EXPECT_EQ(2, kv_message.keys_size());
|
||||
server_->SendMessage(conn, message);
|
||||
server_->SendMessage(conn, meta, protos, data, size);
|
||||
});
|
||||
server_->Init();
|
||||
server_->Start();
|
||||
|
@ -61,23 +62,24 @@ TEST_F(TestTcpServer, ServerSendMessage) {
|
|||
std::cout << server_->BoundPort() << std::endl;
|
||||
std::unique_ptr<std::thread> http_client_thread(nullptr);
|
||||
http_client_thread = std::make_unique<std::thread>([&]() {
|
||||
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
|
||||
KVMessage kv_message;
|
||||
kv_message.ParseFromString(message.data());
|
||||
EXPECT_EQ(2, kv_message.keys_size());
|
||||
client_->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
|
||||
KVMessage message;
|
||||
message.ParseFromArray(data, size);
|
||||
EXPECT_EQ(2, message.keys_size());
|
||||
});
|
||||
|
||||
client_->Init();
|
||||
|
||||
CommMessage comm_message;
|
||||
KVMessage kv_message;
|
||||
std::vector<int> keys{1, 2};
|
||||
std::vector<int> values{3, 4};
|
||||
*kv_message.mutable_keys() = {keys.begin(), keys.end()};
|
||||
*kv_message.mutable_values() = {values.begin(), values.end()};
|
||||
|
||||
comm_message.set_data(kv_message.SerializeAsString());
|
||||
client_->SendMessage(comm_message);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SEND_DATA);
|
||||
|
||||
client_->SendMessage(message_meta, Protos::RAW, kv_message.SerializeAsString().data(), kv_message.ByteSizeLong());
|
||||
|
||||
client_->Start();
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue