forked from mindspore-Ecosystem/mindspore
added worker node
This commit is contained in:
parent
b82c4cba32
commit
ee4132889e
|
@ -15,6 +15,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc")
|
||||
endif ()
|
||||
|
||||
if (NOT ENABLE_D)
|
||||
|
|
|
@ -94,16 +94,16 @@ std::string CommUtil::GenerateUUID() {
|
|||
ss << dis(gen);
|
||||
}
|
||||
ss << "-4";
|
||||
for (i = 0; i < kGroup2RandomLength - 1; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
ss << "-";
|
||||
ss << dis2(gen);
|
||||
for (i = 0; i < kGroup3RandomLength - 1; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
ss << "-";
|
||||
for (i = 0; i < kGroup4RandomLength; i++) {
|
||||
ss << dis2(gen);
|
||||
for (i = 0; i < kGroup4RandomLength - 1; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
ss << "-";
|
||||
for (i = 0; i < kGroup5RandomLength; i++) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
return ss.str();
|
||||
|
@ -121,7 +121,14 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
|
|||
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
|
||||
}
|
||||
}
|
||||
|
||||
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) {
|
||||
if (node_role == NodeRole::SERVER && (rank_id > ClusterConfig::server_num() - 1)) {
|
||||
return false;
|
||||
} else if (node_role == NodeRole::WORKER && (rank_id > ClusterConfig::worker_num() - 1)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -66,6 +67,7 @@ class CommUtil {
|
|||
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
|
||||
static std::string GenerateUUID();
|
||||
static std::string NodeRoleToString(const NodeRole &role);
|
||||
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id);
|
||||
|
||||
private:
|
||||
static std::random_device rd;
|
||||
|
|
|
@ -47,13 +47,17 @@ void Node::ProcessHeartbeatResp(const CommMessage &message) {
|
|||
is_ready_ = heartbeat_resp_message.is_cluster_ready();
|
||||
if (is_ready_.load()) {
|
||||
wait_start_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
|
||||
}
|
||||
is_finish_ = heartbeat_resp_message.is_cluster_finish();
|
||||
if (is_finish_.load()) {
|
||||
wait_finish_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
|
||||
}
|
||||
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
|
||||
if (is_timeout_ && on_node_event_message_) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
@ -64,7 +68,9 @@ void Node::FetchServers(const std::shared_ptr<TcpClient> &client) {
|
|||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
SendMessageSync(client, message);
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
|
||||
}
|
||||
}
|
||||
|
||||
void Node::ProcessFetchServersResp(const CommMessage &message) {
|
||||
|
@ -72,10 +78,10 @@ void Node::ProcessFetchServersResp(const CommMessage &message) {
|
|||
fetch_servers_resp_message.ParseFromString(message.data());
|
||||
|
||||
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
|
||||
server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port());
|
||||
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size();
|
||||
MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size();
|
||||
}
|
||||
|
||||
std::string Node::node_id() const { return node_info_.node_id_; }
|
||||
|
@ -86,19 +92,128 @@ void Node::set_callback(const OnNodeEventMessage &on_node_event_message) {
|
|||
on_node_event_message_ = on_node_event_message;
|
||||
}
|
||||
|
||||
void Node::Wait(uint64_t request_id) {
|
||||
std::unique_lock<std::mutex> lock(message_mutex_);
|
||||
message_tracker_cond_.wait(lock, [&] {
|
||||
bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
|
||||
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
|
||||
if (ret) {
|
||||
MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id;
|
||||
message_tracker_.erase(request_id);
|
||||
}
|
||||
return ret;
|
||||
});
|
||||
message_tracker_.erase(request_id);
|
||||
return res;
|
||||
}
|
||||
|
||||
void Node::Disconnect(const std::shared_ptr<TcpClient> &client) {
|
||||
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
const uint32_t &timeout) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
return SendMessageSync(client, comm_message);
|
||||
}
|
||||
|
||||
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
|
||||
if (rank_ids.size() != data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
|
||||
}
|
||||
for (size_t it = 0; it < rank_ids.size(); ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
CommMessage *comm_message_resp, const uint32_t &timeout) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
comm_message_resp = &res[rank_id];
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
client->SendMessage(comm_message);
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
|
||||
if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
|
||||
}
|
||||
|
||||
size_t len = rank_ids.size();
|
||||
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
comm_message_resp->at(it) = &res[rank_ids.at(it)];
|
||||
}
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool Node::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::FINISH);
|
||||
|
||||
|
@ -108,36 +223,43 @@ void Node::Disconnect(const std::shared_ptr<TcpClient> &client) {
|
|||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(finish_message.SerializeAsString());
|
||||
SendMessageSync(client, message);
|
||||
WaitForDisconnect();
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(EXCEPTION) << "Disconnect timeout!";
|
||||
}
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
|
||||
return WaitForDisconnect(timeout);
|
||||
}
|
||||
|
||||
void Node::WaitForStart() {
|
||||
bool Node::WaitForStart(const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(wait_start_mutex_);
|
||||
wait_start_cond_.wait(lock, [&] {
|
||||
if (is_ready_.load()) {
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!";
|
||||
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
bool res = is_ready_.load();
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
|
||||
}
|
||||
return is_ready_.load();
|
||||
return res;
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
void Node::WaitForDisconnect() {
|
||||
bool Node::WaitForDisconnect(const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||
wait_finish_cond_.wait(lock, [&] {
|
||||
bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
if (is_finish_.load()) {
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!";
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
|
||||
}
|
||||
return is_finish_.load();
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
void Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
|
||||
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
Wait(request_id);
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
|
||||
|
@ -147,12 +269,83 @@ void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const Comm
|
|||
}
|
||||
|
||||
void Node::NotifyMessageArrival(const CommMessage &message) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
uint64_t request_id = message_meta.request_id();
|
||||
|
||||
message_tracker_[request_id].second++;
|
||||
message_tracker_cond_.notify_all();
|
||||
}
|
||||
|
||||
const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) {
|
||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
|
||||
return connected_nodes_[rank_id];
|
||||
} else {
|
||||
if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!";
|
||||
}
|
||||
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
|
||||
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
|
||||
auto client = std::make_shared<TcpClient>(ip, port);
|
||||
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendDataResp(message);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
});
|
||||
client->Init();
|
||||
connected_nodes_[rank_id] = client;
|
||||
return connected_nodes_[rank_id];
|
||||
}
|
||||
}
|
||||
|
||||
void Node::ProcessSendDataResp(const CommMessage &message) {
|
||||
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
const uint32_t &rank_id = message_meta.rank_id();
|
||||
const uint64_t request_id = message_meta.request_id();
|
||||
auto it = receive_messages_.find(request_id);
|
||||
if (it != receive_messages_.end()) {
|
||||
it->second.insert(std::make_pair(rank_id, message));
|
||||
} else {
|
||||
std::unordered_map<uint32_t, CommMessage> res;
|
||||
res.insert(std::make_pair(rank_id, message));
|
||||
receive_messages_[request_id] = res;
|
||||
}
|
||||
|
||||
RunMessageCallback(request_id);
|
||||
}
|
||||
|
||||
void Node::RunMessageCallback(const uint64_t &request_id) {
|
||||
message_callbacks_mutex_.lock();
|
||||
if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) {
|
||||
auto it = message_callbacks_.find(request_id);
|
||||
if (it != message_callbacks_.end()) {
|
||||
message_callbacks_mutex_.unlock();
|
||||
|
||||
if (it->second) {
|
||||
it->second();
|
||||
}
|
||||
|
||||
message_callbacks_mutex_.lock();
|
||||
message_callbacks_.erase(it);
|
||||
}
|
||||
}
|
||||
message_callbacks_mutex_.unlock();
|
||||
}
|
||||
|
||||
void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) {
|
||||
if (!message_callback) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
|
||||
message_callbacks_[request_id] = message_callback;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,15 +21,15 @@
|
|||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <condition_variable>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <map>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
|
@ -42,6 +42,8 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
constexpr int kTimeoutInSeconds = 30;
|
||||
constexpr int kCommTimeoutInSeconds = 3;
|
||||
class Node {
|
||||
public:
|
||||
Node()
|
||||
|
@ -49,51 +51,83 @@ class Node {
|
|||
is_finish_(false),
|
||||
is_timeout_(false),
|
||||
is_already_stopped_(true),
|
||||
is_already_finished_(false),
|
||||
next_request_id_(0),
|
||||
heart_beat_thread_(nullptr) {}
|
||||
virtual ~Node() = default;
|
||||
|
||||
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
|
||||
void set_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
using MessageCallback = std::function<void()>;
|
||||
|
||||
virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0;
|
||||
virtual bool Stop() = 0;
|
||||
virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0;
|
||||
|
||||
void set_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
std::string node_id() const;
|
||||
uint32_t rank_id() const;
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
void Wait(uint64_t request_id);
|
||||
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
void Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessHeartbeatResp(const CommMessage &message);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessFetchServersResp(const CommMessage &message);
|
||||
void Disconnect(const std::shared_ptr<TcpClient> &client);
|
||||
void WaitForStart();
|
||||
void WaitForDisconnect();
|
||||
void SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
|
||||
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
|
||||
bool WaitForStart(const uint32_t &timeout);
|
||||
bool WaitForDisconnect(const uint32_t &timeout);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
|
||||
void NotifyMessageArrival(const CommMessage &message);
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
|
||||
void ProcessSendDataResp(const CommMessage &message);
|
||||
void RunMessageCallback(const uint64_t &request_id);
|
||||
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
|
||||
|
||||
NodeInfo node_info_;
|
||||
std::atomic<bool> is_ready_;
|
||||
std::atomic<bool> is_finish_;
|
||||
std::atomic<bool> is_timeout_;
|
||||
std::atomic<bool> is_already_stopped_;
|
||||
std::atomic<bool> is_already_finished_;
|
||||
std::atomic_uint64_t next_request_id_;
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
|
||||
OnNodeEventMessage on_node_event_message_;
|
||||
|
||||
// rank_id-><ip, port>
|
||||
std::unordered_map<int, std::pair<std::string, uint16_t>> server_rank_ids_;
|
||||
// <NodeRole,rank_id>-><ip, port>
|
||||
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
|
||||
// rank_id->tcpclient
|
||||
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
|
||||
|
||||
// timestamp-><expected responses, actual responses>
|
||||
// request_id-><expected responses, actual responses>
|
||||
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
|
||||
std::mutex message_mutex_;
|
||||
std::mutex message_tracker_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
std::mutex wait_finish_mutex_;
|
||||
std::condition_variable wait_finish_cond_;
|
||||
std::mutex wait_start_mutex_;
|
||||
std::condition_variable wait_start_cond_;
|
||||
std::mutex finish_mutex_;
|
||||
std::mutex client_mutex_;
|
||||
|
||||
// request_id -> <rank_id, CommMessage>
|
||||
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
|
||||
std::mutex receive_messages_mutex_;
|
||||
// request_id -> MessageCallback
|
||||
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
|
||||
std::mutex message_callbacks_mutex_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -39,6 +39,10 @@ message MessageMeta {
|
|||
NodeCommand cmd = 1;
|
||||
// the request id of this message
|
||||
uint64 request_id = 2;
|
||||
// the role of the current node: worker,server,scheduler
|
||||
NodeRole role = 3;
|
||||
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
||||
int32 rank_id = 4;
|
||||
}
|
||||
|
||||
message RegisterMessage {
|
||||
|
|
|
@ -249,7 +249,7 @@ void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb
|
|||
|
||||
void TcpClient::SendMessage(const CommMessage &message) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; }
|
||||
|
||||
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
||||
|
@ -32,11 +31,11 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
|
||||
while (num > 0) {
|
||||
if (remaining_length_ == 0) {
|
||||
for (int i = 0; i < 4 && num > 0; ++i) {
|
||||
for (int i = 0; i < kHeaderLen && num > 0; ++i) {
|
||||
header_[++header_index_] = *(buffer_data + i);
|
||||
--num;
|
||||
if (header_index_ == 3) {
|
||||
message_length_ = *reinterpret_cast<const uint32_t *>(header_);
|
||||
if (header_index_ == kHeaderLen - 1) {
|
||||
message_length_ = *reinterpret_cast<const size_t *>(header_);
|
||||
remaining_length_ = message_length_;
|
||||
message_buffer_.reset(new unsigned char[remaining_length_]);
|
||||
buffer_data += (i + 1);
|
||||
|
@ -46,7 +45,7 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
}
|
||||
|
||||
if (remaining_length_ > 0 && num > 0) {
|
||||
uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
|
||||
size_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
|
||||
remaining_length_ -= copy_len;
|
||||
num -= copy_len;
|
||||
|
||||
|
@ -71,7 +70,6 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace core {
|
||||
using messageReceive = std::function<void(const CommMessage &message)>;
|
||||
constexpr int kHeaderLen = 8;
|
||||
|
||||
class TcpMessageHandler {
|
||||
public:
|
||||
|
@ -51,10 +52,10 @@ class TcpMessageHandler {
|
|||
bool is_parsed_;
|
||||
std::unique_ptr<unsigned char> message_buffer_;
|
||||
size_t message_length_;
|
||||
uint32_t remaining_length_;
|
||||
char header_[4];
|
||||
size_t remaining_length_;
|
||||
char header_[8];
|
||||
int header_index_;
|
||||
uint32_t last_copy_len_;
|
||||
size_t last_copy_len_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -55,7 +55,7 @@ const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
|
|||
|
||||
void TcpConnection::SendMessage(const CommMessage &message) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
std::vector<unsigned char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/worker_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
WorkerNode::~WorkerNode() {
|
||||
MS_LOG(INFO) << "Stop worker node!";
|
||||
if (!is_already_stopped_.load()) {
|
||||
is_ready_ = true;
|
||||
is_timeout_ = true;
|
||||
client_to_scheduler_->Stop();
|
||||
if (!connected_nodes_.empty()) {
|
||||
for (auto &connected_node : connected_nodes_) {
|
||||
connected_node.second->Stop();
|
||||
}
|
||||
}
|
||||
client_to_scheduler_->StopEventBase();
|
||||
if (worker_thread_->joinable()) {
|
||||
worker_thread_->join();
|
||||
}
|
||||
if (heart_beat_thread_->joinable()) {
|
||||
heart_beat_thread_->join();
|
||||
}
|
||||
is_already_stopped_ = true;
|
||||
}
|
||||
}
|
||||
bool WorkerNode::Start(const uint32_t &timeout) {
|
||||
MS_LOG(INFO) << "Starting worker node!";
|
||||
Initialize();
|
||||
Register();
|
||||
Heartbeat(client_to_scheduler_);
|
||||
|
||||
if (!WaitForStart(timeout)) {
|
||||
MS_LOG(ERROR) << "Start Worker node timeout!";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "The node is ready to fetch servers!";
|
||||
|
||||
if (!is_timeout_.load()) {
|
||||
FetchServers(client_to_scheduler_);
|
||||
MS_LOG(INFO) << "Fetch servers successful!";
|
||||
}
|
||||
MS_LOG(INFO) << "The Worker node has successfully started.";
|
||||
return true;
|
||||
}
|
||||
|
||||
void WorkerNode::Register() {
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::REGISTER);
|
||||
|
||||
RegisterMessage register_message;
|
||||
register_message.set_node_id(node_info_.node_id_);
|
||||
register_message.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(register_message.SerializeAsString());
|
||||
if (!SendMessageSync(client_to_scheduler_, comm_message)) {
|
||||
MS_LOG(EXCEPTION) << "Worker node register timeout!";
|
||||
}
|
||||
MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_
|
||||
<< "is registering to scheduler, the request id is:" << message_meta.request_id();
|
||||
}
|
||||
|
||||
void WorkerNode::ProcessRegisterResp(const CommMessage &message) {
|
||||
RegisterRespMessage register_resp_message;
|
||||
register_resp_message.ParseFromString(message.data());
|
||||
if (register_resp_message.node_id() != node_info_.node_id_) {
|
||||
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
|
||||
<< " is not match the current node id:" << node_info_.node_id_;
|
||||
}
|
||||
|
||||
node_info_.rank_id_ = register_resp_message.rank_id();
|
||||
|
||||
MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
|
||||
}
|
||||
|
||||
void WorkerNode::Initialize() {
|
||||
is_already_stopped_ = false;
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = NodeRole::WORKER;
|
||||
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_;
|
||||
InitClientToScheduler();
|
||||
}
|
||||
|
||||
void WorkerNode::InitClientToScheduler() {
|
||||
std::string scheduler_host = ClusterConfig::scheduler_host();
|
||||
uint16_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
|
||||
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::HEARTBEAT:
|
||||
ProcessHeartbeatResp(message);
|
||||
break;
|
||||
case NodeCommand::REGISTER:
|
||||
ProcessRegisterResp(message);
|
||||
break;
|
||||
case NodeCommand::FETCH_SERVER:
|
||||
ProcessFetchServersResp(message);
|
||||
break;
|
||||
case NodeCommand::FINISH:
|
||||
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
});
|
||||
|
||||
client_to_scheduler_->Init();
|
||||
worker_thread_ = std::make_unique<std::thread>([&]() {
|
||||
MS_LOG(INFO) << "The worker node start a tcp client!";
|
||||
client_to_scheduler_->Start();
|
||||
});
|
||||
worker_thread_->detach();
|
||||
}
|
||||
|
||||
bool WorkerNode::Stop() {
|
||||
MS_LOG(INFO) << "Stop worker node!";
|
||||
if (!is_already_stopped_.load()) {
|
||||
is_ready_ = true;
|
||||
is_timeout_ = true;
|
||||
client_to_scheduler_->Stop();
|
||||
if (!connected_nodes_.empty()) {
|
||||
for (auto &connected_node : connected_nodes_) {
|
||||
connected_node.second->Stop();
|
||||
}
|
||||
}
|
||||
client_to_scheduler_->StopEventBase();
|
||||
if (worker_thread_->joinable()) {
|
||||
worker_thread_->join();
|
||||
}
|
||||
if (heart_beat_thread_->joinable()) {
|
||||
heart_beat_thread_->join();
|
||||
}
|
||||
is_already_stopped_ = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool WorkerNode::Finish(const uint32_t &timeout) {
|
||||
std::lock_guard<std::mutex> lock(finish_mutex_);
|
||||
if (is_already_finished_) {
|
||||
MS_LOG(INFO) << "Worker node already finish!";
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "Finish worker node!";
|
||||
is_already_finished_ = true;
|
||||
return Disconnect(client_to_scheduler_, timeout);
|
||||
}
|
||||
|
||||
bool WorkerNode::BroadcastToServers(const std::string &message) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
|
||||
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient((*it).first.second);
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id);
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <condition_variable>
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class WorkerNode : public Node {
|
||||
public:
|
||||
WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {}
|
||||
~WorkerNode() override;
|
||||
|
||||
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
bool BroadcastToServers(const std::string &message);
|
||||
|
||||
private:
|
||||
void Register();
|
||||
void ProcessRegisterResp(const CommMessage &message);
|
||||
|
||||
void Initialize();
|
||||
void InitClientToScheduler();
|
||||
|
||||
std::shared_ptr<TcpClient> client_to_scheduler_;
|
||||
std::unique_ptr<std::thread> worker_thread_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_
|
|
@ -39,6 +39,14 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
|
|||
EXPECT_TRUE(!interface.empty());
|
||||
EXPECT_TRUE(!ip.empty());
|
||||
}
|
||||
|
||||
TEST_F(TestCommUtil, ValidateRankId) {
|
||||
ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999);
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
|
||||
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
|
||||
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2));
|
||||
}
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -33,117 +33,118 @@ class TestTcpMessageHandler : public UT::Common {
|
|||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) {
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[1007];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[1011];
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
handler.ReceiveMessage(result, buf_size + 4);
|
||||
memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
handler.ReceiveMessage(result, buf_size + kHeaderLen);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) {
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[2014];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[2022] = {0};
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size);
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 2 * buf_size + 4 * 2);
|
||||
handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) {
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); });
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); });
|
||||
|
||||
std::string data(4087, 'a');
|
||||
std::string data(4081, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[4096];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[4096] = {0};
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2);
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4096);
|
||||
|
||||
ret = memcpy_s(result, 2, &buf_size + 2, 2);
|
||||
auto temp = reinterpret_cast<char *>(&buf_size);
|
||||
ret = memcpy_s(result, 4, temp + 4, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4092);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); });
|
||||
|
||||
std::string data(4085, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[4096];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
|
||||
handler.ReceiveMessage(result, 4088);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); });
|
||||
|
||||
std::string data(4077, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
size_t buf_size = message.ByteSizeLong();
|
||||
char result[4096] = {0};
|
||||
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
@ -155,9 +156,8 @@ TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) {
|
|||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4088);
|
||||
handler.ReceiveMessage(result, 4080);
|
||||
}
|
||||
|
||||
} // namespace comm
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue