added worker node

This commit is contained in:
chendongsheng 2020-12-08 21:30:32 +08:00
parent b82c4cba32
commit ee4132889e
14 changed files with 607 additions and 103 deletions

View File

@ -15,6 +15,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc")
endif ()
if (NOT ENABLE_D)

View File

@ -94,16 +94,16 @@ std::string CommUtil::GenerateUUID() {
ss << dis(gen);
}
ss << "-4";
for (i = 0; i < kGroup2RandomLength - 1; i++) {
ss << dis(gen);
}
ss << "-";
ss << dis2(gen);
for (i = 0; i < kGroup3RandomLength - 1; i++) {
ss << dis(gen);
}
ss << "-";
for (i = 0; i < kGroup4RandomLength; i++) {
ss << dis2(gen);
for (i = 0; i < kGroup4RandomLength - 1; i++) {
ss << dis(gen);
}
ss << "-";
for (i = 0; i < kGroup5RandomLength; i++) {
ss << dis(gen);
}
return ss.str();
@ -121,7 +121,14 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
}
}
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) {
if (node_role == NodeRole::SERVER && (rank_id > ClusterConfig::server_num() - 1)) {
return false;
} else if (node_role == NodeRole::WORKER && (rank_id > ClusterConfig::worker_num() - 1)) {
return false;
}
return true;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -48,6 +48,7 @@
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
namespace mindspore {
@ -66,6 +67,7 @@ class CommUtil {
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
static std::string GenerateUUID();
static std::string NodeRoleToString(const NodeRole &role);
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id);
private:
static std::random_device rd;

View File

@ -47,13 +47,17 @@ void Node::ProcessHeartbeatResp(const CommMessage &message) {
is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) {
wait_start_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
}
is_finish_ = heartbeat_resp_message.is_cluster_finish();
if (is_finish_.load()) {
wait_finish_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
}
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
if (is_timeout_ && on_node_event_message_) {
is_ready_ = true;
wait_start_cond_.notify_all();
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
}
}
@ -64,7 +68,9 @@ void Node::FetchServers(const std::shared_ptr<TcpClient> &client) {
CommMessage message;
*message.mutable_pb_meta() = {meta};
SendMessageSync(client, message);
if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
}
}
void Node::ProcessFetchServersResp(const CommMessage &message) {
@ -72,10 +78,10 @@ void Node::ProcessFetchServersResp(const CommMessage &message) {
fetch_servers_resp_message.ParseFromString(message.data());
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port());
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
}
MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size();
MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size();
}
std::string Node::node_id() const { return node_info_.node_id_; }
@ -86,19 +92,128 @@ void Node::set_callback(const OnNodeEventMessage &on_node_event_message) {
on_node_event_message_ = on_node_event_message;
}
void Node::Wait(uint64_t request_id) {
std::unique_lock<std::mutex> lock(message_mutex_);
message_tracker_cond_.wait(lock, [&] {
bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
if (ret) {
MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id;
message_tracker_.erase(request_id);
}
return ret;
});
message_tracker_.erase(request_id);
return res;
}
void Node::Disconnect(const std::shared_ptr<TcpClient> &client) {
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageSync(client, comm_message);
}
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
if (rank_ids.size() != data.size()) {
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
}
for (size_t it = 0; it < rank_ids.size(); ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
return Wait(request_id, timeout);
}
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
comm_message_resp = &res[rank_id];
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message);
return Wait(request_id, timeout);
}
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) {
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
}
size_t len = rank_ids.size();
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
comm_message_resp->at(it) = &res[rank_ids.at(it)];
}
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
for (size_t it = 0; it < len; ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
return Wait(request_id, timeout);
}
bool Node::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FINISH);
@ -108,36 +223,43 @@ void Node::Disconnect(const std::shared_ptr<TcpClient> &client) {
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(finish_message.SerializeAsString());
SendMessageSync(client, message);
WaitForDisconnect();
if (!SendMessageSync(client, message)) {
MS_LOG(EXCEPTION) << "Disconnect timeout!";
}
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " send finish message!";
return WaitForDisconnect(timeout);
}
void Node::WaitForStart() {
bool Node::WaitForStart(const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(wait_start_mutex_);
wait_start_cond_.wait(lock, [&] {
if (is_ready_.load()) {
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!";
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
bool res = is_ready_.load();
if (res) {
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
}
return is_ready_.load();
return res;
});
return res;
}
void Node::WaitForDisconnect() {
bool Node::WaitForDisconnect(const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
wait_finish_cond_.wait(lock, [&] {
bool res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
if (is_finish_.load()) {
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!";
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
}
return is_finish_.load();
});
return res;
}
void Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
Wait(request_id);
return Wait(request_id, timeout);
}
void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
@ -147,12 +269,83 @@ void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const Comm
}
void Node::NotifyMessageArrival(const CommMessage &message) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
const MessageMeta &message_meta = message.pb_meta();
uint64_t request_id = message_meta.request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) {
std::lock_guard<std::mutex> lock(client_mutex_);
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
return connected_nodes_[rank_id];
} else {
if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!";
}
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
auto client = std::make_shared<TcpClient>(ip, port);
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendDataResp(message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
NotifyMessageArrival(message);
});
client->Init();
connected_nodes_[rank_id] = client;
return connected_nodes_[rank_id];
}
}
void Node::ProcessSendDataResp(const CommMessage &message) {
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
const MessageMeta &message_meta = message.pb_meta();
const uint32_t &rank_id = message_meta.rank_id();
const uint64_t request_id = message_meta.request_id();
auto it = receive_messages_.find(request_id);
if (it != receive_messages_.end()) {
it->second.insert(std::make_pair(rank_id, message));
} else {
std::unordered_map<uint32_t, CommMessage> res;
res.insert(std::make_pair(rank_id, message));
receive_messages_[request_id] = res;
}
RunMessageCallback(request_id);
}
void Node::RunMessageCallback(const uint64_t &request_id) {
message_callbacks_mutex_.lock();
if (message_tracker_[request_id].first == message_tracker_[request_id].second - 1) {
auto it = message_callbacks_.find(request_id);
if (it != message_callbacks_.end()) {
message_callbacks_mutex_.unlock();
if (it->second) {
it->second();
}
message_callbacks_mutex_.lock();
message_callbacks_.erase(it);
}
}
message_callbacks_mutex_.unlock();
}
void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) {
if (!message_callback) {
return;
}
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
message_callbacks_[request_id] = message_callback;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -21,15 +21,15 @@
#include <cstdlib>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
#include <condition_variable>
#include <utility>
#include <tuple>
#include <map>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@ -42,6 +42,8 @@
namespace mindspore {
namespace ps {
namespace core {
constexpr int kTimeoutInSeconds = 30;
constexpr int kCommTimeoutInSeconds = 3;
class Node {
public:
Node()
@ -49,51 +51,83 @@ class Node {
is_finish_(false),
is_timeout_(false),
is_already_stopped_(true),
is_already_finished_(false),
next_request_id_(0),
heart_beat_thread_(nullptr) {}
virtual ~Node() = default;
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
void set_callback(const OnNodeEventMessage &on_node_event_message);
using MessageCallback = std::function<void()>;
virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0;
virtual bool Stop() = 0;
virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0;
void set_callback(const OnNodeEventMessage &on_node_event_message);
std::string node_id() const;
uint32_t rank_id() const;
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
void Wait(uint64_t request_id);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *const comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
void Heartbeat(const std::shared_ptr<TcpClient> &client);
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
void Disconnect(const std::shared_ptr<TcpClient> &client);
void WaitForStart();
void WaitForDisconnect();
void SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForStart(const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
void NotifyMessageArrival(const CommMessage &message);
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
void ProcessSendDataResp(const CommMessage &message);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
NodeInfo node_info_;
std::atomic<bool> is_ready_;
std::atomic<bool> is_finish_;
std::atomic<bool> is_timeout_;
std::atomic<bool> is_already_stopped_;
std::atomic<bool> is_already_finished_;
std::atomic_uint64_t next_request_id_;
std::unique_ptr<std::thread> heart_beat_thread_;
OnNodeEventMessage on_node_event_message_;
// rank_id-><ip, port>
std::unordered_map<int, std::pair<std::string, uint16_t>> server_rank_ids_;
// <NodeRole,rank_id>-><ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
// rank_id->tcpclient
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
// timestamp-><expected responses, actual responses>
// request_id-><expected responses, actual responses>
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
std::mutex message_mutex_;
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
std::mutex wait_finish_mutex_;
std::condition_variable wait_finish_cond_;
std::mutex wait_start_mutex_;
std::condition_variable wait_start_cond_;
std::mutex finish_mutex_;
std::mutex client_mutex_;
// request_id -> <rank_id, CommMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
std::mutex receive_messages_mutex_;
// request_id -> MessageCallback
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_;
};
} // namespace core
} // namespace ps

View File

@ -39,6 +39,10 @@ message MessageMeta {
NodeCommand cmd = 1;
// the request id of this message
uint64 request_id = 2;
// the role of the current node: worker,server,scheduler
NodeRole role = 3;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32 rank_id = 4;
}
message RegisterMessage {

View File

@ -249,7 +249,7 @@ void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb
void TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
uint32_t buf_size = message.ByteSizeLong();
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {

View File

@ -23,7 +23,6 @@
namespace mindspore {
namespace ps {
namespace core {
void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; }
void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
@ -32,11 +31,11 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
while (num > 0) {
if (remaining_length_ == 0) {
for (int i = 0; i < 4 && num > 0; ++i) {
for (int i = 0; i < kHeaderLen && num > 0; ++i) {
header_[++header_index_] = *(buffer_data + i);
--num;
if (header_index_ == 3) {
message_length_ = *reinterpret_cast<const uint32_t *>(header_);
if (header_index_ == kHeaderLen - 1) {
message_length_ = *reinterpret_cast<const size_t *>(header_);
remaining_length_ = message_length_;
message_buffer_.reset(new unsigned char[remaining_length_]);
buffer_data += (i + 1);
@ -46,7 +45,7 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
}
if (remaining_length_ > 0 && num > 0) {
uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
size_t copy_len = remaining_length_ <= num ? remaining_length_ : num;
remaining_length_ -= copy_len;
num -= copy_len;
@ -71,7 +70,6 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
}
}
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -31,6 +31,7 @@ namespace mindspore {
namespace ps {
namespace core {
using messageReceive = std::function<void(const CommMessage &message)>;
constexpr int kHeaderLen = 8;
class TcpMessageHandler {
public:
@ -51,10 +52,10 @@ class TcpMessageHandler {
bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_;
size_t message_length_;
uint32_t remaining_length_;
char header_[4];
size_t remaining_length_;
char header_[8];
int header_index_;
uint32_t last_copy_len_;
size_t last_copy_len_;
};
} // namespace core
} // namespace ps

View File

@ -55,7 +55,7 @@ const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
void TcpConnection::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
uint32_t buf_size = message.ByteSizeLong();
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,

View File

@ -0,0 +1,187 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/worker_node.h"
namespace mindspore {
namespace ps {
namespace core {
WorkerNode::~WorkerNode() {
MS_LOG(INFO) << "Stop worker node!";
if (!is_already_stopped_.load()) {
is_ready_ = true;
is_timeout_ = true;
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_->StopEventBase();
if (worker_thread_->joinable()) {
worker_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
}
}
bool WorkerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Starting worker node!";
Initialize();
Register();
Heartbeat(client_to_scheduler_);
if (!WaitForStart(timeout)) {
MS_LOG(ERROR) << "Start Worker node timeout!";
return false;
}
MS_LOG(INFO) << "The node is ready to fetch servers!";
if (!is_timeout_.load()) {
FetchServers(client_to_scheduler_);
MS_LOG(INFO) << "Fetch servers successful!";
}
MS_LOG(INFO) << "The Worker node has successfully started.";
return true;
}
void WorkerNode::Register() {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::REGISTER);
RegisterMessage register_message;
register_message.set_node_id(node_info_.node_id_);
register_message.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
if (!SendMessageSync(client_to_scheduler_, comm_message)) {
MS_LOG(EXCEPTION) << "Worker node register timeout!";
}
MS_LOG(INFO) << "The worker node id:" << node_info_.node_id_
<< "is registering to scheduler, the request id is:" << message_meta.request_id();
}
void WorkerNode::ProcessRegisterResp(const CommMessage &message) {
RegisterRespMessage register_resp_message;
register_resp_message.ParseFromString(message.data());
if (register_resp_message.node_id() != node_info_.node_id_) {
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
<< " is not match the current node id:" << node_info_.node_id_;
}
node_info_.rank_id_ = register_resp_message.rank_id();
MS_LOG(INFO) << "The client node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
}
void WorkerNode::Initialize() {
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = NodeRole::WORKER;
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_;
InitClientToScheduler();
}
void WorkerNode::InitClientToScheduler() {
std::string scheduler_host = ClusterConfig::scheduler_host();
uint16_t scheduler_port = ClusterConfig::scheduler_port();
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::HEARTBEAT:
ProcessHeartbeatResp(message);
break;
case NodeCommand::REGISTER:
ProcessRegisterResp(message);
break;
case NodeCommand::FETCH_SERVER:
ProcessFetchServersResp(message);
break;
case NodeCommand::FINISH:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
NotifyMessageArrival(message);
});
client_to_scheduler_->Init();
worker_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The worker node start a tcp client!";
client_to_scheduler_->Start();
});
worker_thread_->detach();
}
bool WorkerNode::Stop() {
MS_LOG(INFO) << "Stop worker node!";
if (!is_already_stopped_.load()) {
is_ready_ = true;
is_timeout_ = true;
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_->StopEventBase();
if (worker_thread_->joinable()) {
worker_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
}
return true;
}
bool WorkerNode::Finish(const uint32_t &timeout) {
std::lock_guard<std::mutex> lock(finish_mutex_);
if (is_already_finished_) {
MS_LOG(INFO) << "Worker node already finish!";
return true;
}
MS_LOG(INFO) << "Finish worker node!";
is_already_finished_ = true;
return Disconnect(client_to_scheduler_, timeout);
}
bool WorkerNode::BroadcastToServers(const std::string &message) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message);
}
return Wait(request_id);
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_
#include <atomic>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <thread>
#include <unordered_map>
#include <utility>
#include <condition_variable>
#include <algorithm>
#include <tuple>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "ps/core/node.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace core {
class WorkerNode : public Node {
public:
WorkerNode() : client_to_scheduler_(nullptr), worker_thread_(nullptr) {}
~WorkerNode() override;
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
bool BroadcastToServers(const std::string &message);
private:
void Register();
void ProcessRegisterResp(const CommMessage &message);
void Initialize();
void InitClientToScheduler();
std::shared_ptr<TcpClient> client_to_scheduler_;
std::unique_ptr<std::thread> worker_thread_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_

View File

@ -39,6 +39,14 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
EXPECT_TRUE(!interface.empty());
EXPECT_TRUE(!ip.empty());
}
TEST_F(TestCommUtil, ValidateRankId) {
ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999);
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2));
}
} // namespace comm
} // namespace ps
} // namespace mindspore

View File

@ -33,117 +33,118 @@ class TestTcpMessageHandler : public UT::Common {
void TearDown() override {}
};
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[1007];
int ret = memcpy_s(result, 4, &buf_size, 4);
size_t buf_size = message.ByteSizeLong();
char result[1011];
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
handler.ReceiveMessage(result, buf_size + 4);
memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
handler.ReceiveMessage(result, buf_size + kHeaderLen);
}
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[2014];
int ret = memcpy_s(result, 4, &buf_size, 4);
size_t buf_size = message.ByteSizeLong();
char result[2022] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size);
ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 2 * buf_size + 4 * 2);
handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2);
}
TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) {
TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); });
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4081); });
std::string data(4087, 'a');
std::string data(4081, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[4096];
int ret = memcpy_s(result, 4, &buf_size, 4);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2);
ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4096);
ret = memcpy_s(result, 2, &buf_size + 2, 2);
auto temp = reinterpret_cast<char *>(&buf_size);
ret = memcpy_s(result, 4, temp + 4, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4092);
}
TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); });
std::string data(4085, 'a');
CommMessage message;
message.set_data(data);
uint32_t buf_size = message.ByteSizeLong();
char result[4096];
int ret = memcpy_s(result, 4, &buf_size, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
handler.ReceiveMessage(result, 4088);
}
TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
TcpMessageHandler handler;
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4077); });
std::string data(4077, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
@ -155,9 +156,8 @@ TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
handler.ReceiveMessage(result, 4088);
handler.ReceiveMessage(result, 4080);
}
} // namespace comm
} // namespace core
} // namespace ps
} // namespace mindspore