!9971 added scheduler node

From: @anancds
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-16 16:40:50 +08:00 committed by Gitee
commit 41bd077e09
14 changed files with 609 additions and 288 deletions

View File

@ -19,6 +19,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc")
endif ()
if (NOT ENABLE_D)

View File

@ -74,30 +74,161 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
on_node_event_message_ = on_node_event_message;
}
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageSync(client, comm_message);
}
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
if (rank_ids.size() != data.size()) {
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
}
for (size_t it = 0; it < rank_ids.size(); ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
*comm_message_resp = res[rank_id];
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message);
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) {
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
}
size_t len = rank_ids.size();
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
comm_message_resp->at(it) = &res[rank_ids.at(it)];
}
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
for (size_t it = 0; it < len; ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
return Wait(request_id, timeout);
}
bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
return ret;
});
message_tracker_.erase(request_id);
return res;
}
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
<< " begin send heartbeat to the scheduler!";
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
while (!is_finish_.load()) {
Heartbeat(client);
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT);
HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(heartbeat_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
}
}
});
heart_beat_thread_->detach();
}
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT);
HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_);
heartbeat_message.set_is_node_finish(is_node_finish);
CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(heartbeat_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
}
}
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromString(message.data());
@ -106,8 +237,9 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
wait_start_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
}
is_finish_ = heartbeat_resp_message.is_cluster_finish();
if (is_finish_.load()) {
if (heartbeat_resp_message.is_cluster_finish()) {
Heartbeat(client_to_scheduler_, true);
is_finish_ = true;
wait_finish_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
}
@ -115,6 +247,10 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
if (is_timeout_ && on_node_event_message_) {
is_ready_ = true;
wait_start_cond_.notify_all();
on_node_event_message_(NodeEvent::CLUSTER_TIMEOUT);
}
if (heartbeat_resp_message.is_node_timeout() && on_node_event_message_) {
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
}
}
@ -207,6 +343,101 @@ bool AbstractNode::InitClientToScheduler() {
});
return client_to_scheduler_->WaitConnected();
}
const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &rank_id) {
std::lock_guard<std::mutex> lock(client_mutex_);
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
return connected_nodes_[rank_id];
} else {
if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!";
}
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
auto client = std::make_shared<TcpClient>(ip, port);
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendDataResp(message);
RunMessageCallback(message.pb_meta().request_id());
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
NotifyMessageArrival(message);
});
client->Init();
connected_nodes_[rank_id] = client;
return connected_nodes_[rank_id];
}
}
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
return Wait(request_id, timeout);
}
void AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
uint64_t request_id = ++next_request_id_;
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
}
void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
const MessageMeta &message_meta = message.pb_meta();
const uint32_t &rank_id = message_meta.rank_id();
const uint64_t request_id = message_meta.request_id();
auto it = receive_messages_.find(request_id);
if (it != receive_messages_.end()) {
it->second.insert(std::make_pair(rank_id, message));
} else {
std::unordered_map<uint32_t, CommMessage> res;
res.insert(std::make_pair(rank_id, message));
receive_messages_[request_id] = res;
}
}
void AbstractNode::RunMessageCallback(const uint64_t &request_id) {
message_callbacks_mutex_.lock();
// When receiving a message's response, Then compare with the desired number of responses,
// If they are equal, then call the callback function
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
auto it = message_callbacks_.find(request_id);
if (it != message_callbacks_.end()) {
message_callbacks_mutex_.unlock();
if (it->second) {
it->second();
}
message_callbacks_mutex_.lock();
message_callbacks_.erase(it);
}
}
message_callbacks_mutex_.unlock();
}
void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) {
if (!message_callback) {
return;
}
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
message_callbacks_[request_id] = message_callback;
}
void AbstractNode::NotifyMessageArrival(const CommMessage &message) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
const MessageMeta &message_meta = message.pb_meta();
uint64_t request_id = message_meta.request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -20,6 +20,9 @@
#include <utility>
#include <string>
#include <memory>
#include <map>
#include <vector>
#include <unordered_map>
#include "ps/core/node.h"
@ -34,21 +37,60 @@ class AbstractNode : public Node {
bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
void Heartbeat(const std::shared_ptr<TcpClient> &client);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
void ProcessSendDataResp(const CommMessage &message);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
void NotifyMessageArrival(const CommMessage &message);
std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
std::shared_ptr<TcpClient> client_to_scheduler_;
OnNodeEventMessage on_node_event_message_;
// the map's key is: <node_role,rank_id>, the map's value is: <ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
std::mutex client_mutex_;
// the map's key is: rank_id
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
// the map's key is: request_id, the map's value is: <expected responses, actual responses>
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
// the map's key is: request_id, the map's value is:<rank_id, CommMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
std::mutex receive_messages_mutex_;
// the map's key is: request_id
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_;
};
} // namespace core
} // namespace ps

View File

@ -25,131 +25,6 @@ uint32_t Node::rank_id() const { return node_info_.rank_id_; }
NodeRole Node::role() const { return node_info_.node_role_; }
bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
return ret;
});
message_tracker_.erase(request_id);
return res;
}
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageSync(client, comm_message);
}
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
if (rank_ids.size() != data.size()) {
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
}
for (size_t it = 0; it < rank_ids.size(); ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
return Wait(request_id, timeout);
}
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
*comm_message_resp = res[rank_id];
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message);
return Wait(request_id, timeout);
}
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) {
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
}
size_t len = rank_ids.size();
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
comm_message_resp->at(it) = &res[rank_ids.at(it)];
}
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
for (size_t it = 0; it < len; ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
}
return Wait(request_id, timeout);
}
bool Node::WaitForStart(const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(wait_start_mutex_);
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
@ -161,101 +36,6 @@ bool Node::WaitForStart(const uint32_t &timeout) {
});
return res;
}
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
return Wait(request_id, timeout);
}
void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
uint64_t request_id = ++next_request_id_;
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
}
const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) {
std::lock_guard<std::mutex> lock(client_mutex_);
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
return connected_nodes_[rank_id];
} else {
if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!";
}
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
auto client = std::make_shared<TcpClient>(ip, port);
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendDataResp(message);
RunMessageCallback(message.pb_meta().request_id());
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
NotifyMessageArrival(message);
});
client->Init();
connected_nodes_[rank_id] = client;
return connected_nodes_[rank_id];
}
}
void Node::ProcessSendDataResp(const CommMessage &message) {
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
const MessageMeta &message_meta = message.pb_meta();
const uint32_t &rank_id = message_meta.rank_id();
const uint64_t request_id = message_meta.request_id();
auto it = receive_messages_.find(request_id);
if (it != receive_messages_.end()) {
it->second.insert(std::make_pair(rank_id, message));
} else {
std::unordered_map<uint32_t, CommMessage> res;
res.insert(std::make_pair(rank_id, message));
receive_messages_[request_id] = res;
}
}
void Node::RunMessageCallback(const uint64_t &request_id) {
message_callbacks_mutex_.lock();
// When receiving a message's response, Then compare with the desired number of responses,
// If they are equal, then call the callback function
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
auto it = message_callbacks_.find(request_id);
if (it != message_callbacks_.end()) {
message_callbacks_mutex_.unlock();
if (it->second) {
it->second();
}
message_callbacks_mutex_.lock();
message_callbacks_.erase(it);
}
}
message_callbacks_mutex_.unlock();
}
void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) {
if (!message_callback) {
return;
}
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
message_callbacks_[request_id] = message_callback;
}
void Node::NotifyMessageArrival(const CommMessage &message) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
const MessageMeta &message_meta = message.pb_meta();
uint64_t request_id = message_meta.request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -29,7 +29,6 @@
#include <condition_variable>
#include <utility>
#include <tuple>
#include <map>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@ -66,28 +65,8 @@ class Node {
uint32_t rank_id() const;
NodeRole role() const;
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
bool WaitForStart(const uint32_t &timeout);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
void ProcessSendDataResp(const CommMessage &message);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
void NotifyMessageArrival(const CommMessage &message);
NodeInfo node_info_;
std::atomic<bool> is_ready_;
@ -97,28 +76,11 @@ class Node {
std::atomic<bool> is_already_finished_;
std::atomic_uint64_t next_request_id_;
// <NodeRole,rank_id>-><ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
// rank_id->tcpclient
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
// request_id-><expected responses, actual responses>
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
std::mutex wait_finish_mutex_;
std::condition_variable wait_finish_cond_;
std::mutex wait_start_mutex_;
std::condition_variable wait_start_cond_;
std::mutex wait_finish_mutex_;
std::condition_variable wait_finish_cond_;
std::mutex finish_mutex_;
std::mutex client_mutex_;
// request_id -> <rank_id, CommMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
std::mutex receive_messages_mutex_;
// request_id -> MessageCallback
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_;
};
} // namespace core
} // namespace ps

View File

@ -26,7 +26,7 @@ namespace mindspore {
namespace ps {
namespace core {
enum NodeEvent { NODE_TIMEOUT = 0 };
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 };
struct NodeInfo {
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}

View File

@ -69,6 +69,10 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
}
void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); }
bool NodeManager::CheckNodesFinishState() { return heartbeats_finish_nodes_.size() == nodes_info_.size(); }
std::vector<ServersMeta> NodeManager::FetchServersMeta() {
std::vector<ServersMeta> servers_meta_list;
for (auto it = nodes_info_.begin(); it != nodes_info_.end(); ++it) {
@ -131,7 +135,11 @@ bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); }
bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); }
bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_; }
bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_.load(); }
bool NodeManager::is_node_timeout() { return is_node_timeout_.load(); }
void NodeManager::set_cluster_timeout(bool is_cluster_timeout) { is_cluster_timeout_ = is_cluster_timeout; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef RPC_CLUSTER_MANAGER_H
#define RPC_CLUSTER_MANAGER_H
#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_
#define MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_
#include <atomic>
#include <cstdlib>
@ -45,6 +45,7 @@ class NodeManager {
: is_cluster_ready_(false),
is_cluster_finish_(false),
is_cluster_timeout_(false),
is_node_timeout_(false),
total_node_num_(0),
next_worker_rank_id_(-1),
next_server_rank_id_(-1) {}
@ -55,6 +56,8 @@ class NodeManager {
void InitNodeNum();
int NextRankId(const RegisterMessage &register_message);
void UpdateHeartbeat(const std::string &node_id);
void UpdateNodeFinishState(const std::string &node_id);
bool CheckNodesFinishState();
std::vector<ServersMeta> FetchServersMeta();
void UpdateClusterState();
void CheckClusterTimeout();
@ -63,11 +66,14 @@ class NodeManager {
bool is_cluster_ready();
bool is_cluster_finish();
bool is_cluster_timeout();
bool is_node_timeout();
void set_cluster_timeout(bool is_cluster_timeout);
private:
std::atomic<bool> is_cluster_ready_;
std::atomic<bool> is_cluster_finish_;
std::atomic<bool> is_cluster_timeout_;
std::atomic<bool> is_node_timeout_;
uint32_t total_node_num_;
std::atomic<int> next_worker_rank_id_;
std::atomic<int> next_server_rank_id_;
@ -76,6 +82,7 @@ class NodeManager {
std::mutex assign_rank_id_mutex_;
std::mutex heartbeat_mutex_;
std::unordered_map<std::string, timeval> heartbeats_;
std::unordered_set<std::string> heartbeats_finish_nodes_;
// timeout nodes
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
std::unordered_set<std::string> finish_nodes_id_;
@ -83,4 +90,4 @@ class NodeManager {
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // RPC_CLUSTER_MANAGER_H
#endif // MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_

View File

@ -64,6 +64,7 @@ message RegisterRespMessage {
message HeartbeatMessage {
// the current Node unique id:0,1,2...
string node_id = 1;
bool is_node_finish = 2;
}
message HeartbeatRespMessage {
@ -71,6 +72,7 @@ message HeartbeatRespMessage {
bool is_cluster_ready = 1;
bool is_cluster_finish = 2;
bool is_cluster_timeout = 3;
bool is_node_timeout = 4;
}
message FetchServersRespMessage {

View File

@ -0,0 +1,222 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/scheduler_node.h"
namespace mindspore {
namespace ps {
namespace core {
SchedulerNode::~SchedulerNode() {
MS_LOG(INFO) << "Stop scheduler node!";
if (!is_already_stopped_) {
is_already_stopped_ = true;
server_->Stop();
if (scheduler_thread_->joinable()) {
scheduler_thread_->join();
}
if (update_state_thread_->joinable()) {
update_state_thread_->join();
}
is_ready_ = true;
}
}
bool SchedulerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Start scheduler node!";
Initialize();
StartUpdateClusterStateTimer();
if (!WaitForStart(timeout)) {
MS_LOG(ERROR) << "Start Scheduler node timeout!";
return false;
}
MS_LOG(INFO) << "Start the scheduler node is successful!";
return true;
}
void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
HeartbeatMessage heartbeat_message;
heartbeat_message.ParseFromString(message.data());
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
if (heartbeat_message.is_node_finish()) {
node_manager_.UpdateNodeFinishState(heartbeat_message.node_id());
}
if (heartbeat_message.is_node_finish() && node_manager_.CheckNodesFinishState()) {
MS_LOG(INFO) << "The scheduler node receive all the finish cmd!";
is_finish_ = true;
wait_finish_cond_.notify_all();
}
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.set_is_cluster_ready(node_manager_.is_cluster_ready());
heartbeat_resp_message.set_is_cluster_finish(node_manager_.is_cluster_finish());
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
comm_message.set_data(heartbeat_resp_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
}
void SchedulerNode::Initialize() {
CreateTcpServer();
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = NodeRole::SCHEDULER;
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_;
}
void SchedulerNode::CreateTcpServer() {
node_manager_.InitNodeNum();
std::string scheduler_host = ClusterConfig::scheduler_host();
uint32_t scheduler_port = ClusterConfig::scheduler_port();
server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::HEARTBEAT:
ProcessHeartbeat(server, conn, message);
break;
case NodeCommand::REGISTER:
ProcessRegister(server, conn, message);
break;
case NodeCommand::FINISH:
ProcessFinish(server, conn, message);
break;
case NodeCommand::FETCH_SERVER:
ProcessFetchServers(server, conn, message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
});
server_->Init();
scheduler_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The scheduler node start a tcp server!";
server_->Start();
});
scheduler_thread_->detach();
}
void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message;
register_message.ParseFromString(message.data());
// assign worker node and server node rank id
int rank_id = node_manager_.NextRankId(register_message);
if (rank_id < 0) {
MS_LOG(EXCEPTION) << "The rank id is wrong!";
}
const std::string &node_id = register_message.node_id();
node_manager_.UpdateHeartbeat(node_id);
RegisterRespMessage register_resp_message;
register_resp_message.set_node_id(node_id);
register_resp_message.set_rank_id(rank_id);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
comm_message.set_data(register_resp_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
}
void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
FinishMessage finish_message;
finish_message.ParseFromString(message.data());
node_manager_.AddFinishNode(finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
const_cast<TcpServer &>(server).SendMessage(conn, message);
}
void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn,
const CommMessage &message) {
FetchServersRespMessage fetch_servers_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
comm_message.set_data(fetch_servers_message.SerializeAsString());
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
}
void SchedulerNode::StartUpdateClusterStateTimer() {
MS_LOG(WARNING) << "The scheduler start a heartbeat timer!";
update_state_thread_ = std::make_unique<std::thread>([&]() {
auto start_time = std::chrono::steady_clock::now();
while (!is_finish_.load()) {
// 1. update cluster timeout
if (!node_manager_.is_cluster_ready() && (std::chrono::steady_clock::now() - start_time >
std::chrono::seconds(ClusterConfig::cluster_available_timeout()))) {
node_manager_.CheckClusterTimeout();
}
// 2. update cluster state
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
node_manager_.UpdateClusterState();
if (node_manager_.is_cluster_ready()) {
is_ready_ = true;
wait_start_cond_.notify_all();
}
if (node_manager_.is_cluster_finish()) {
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval() * 2));
is_finish_ = true;
wait_finish_cond_.notify_all();
}
}
});
update_state_thread_->detach();
}
bool SchedulerNode::Stop() {
MS_LOG(INFO) << "Stop scheduler node!";
if (!is_already_stopped_) {
is_already_stopped_ = true;
server_->Stop();
if (scheduler_thread_->joinable()) {
scheduler_thread_->join();
}
if (update_state_thread_->joinable()) {
update_state_thread_->join();
}
is_ready_ = true;
}
return true;
}
bool SchedulerNode::Finish(const uint32_t &timeout) {
MS_LOG(INFO) << "Finish scheduler node!";
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
wait_finish_cond_.wait(lock, [&] {
if (is_finish_.load()) {
MS_LOG(INFO) << "The scheduler finish success!";
}
return is_finish_.load();
});
return true;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_
#include <atomic>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <thread>
#include <mutex>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "ps/core/node_manager.h"
#include "ps/core/node.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace core {
class SchedulerNode : public Node {
public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
~SchedulerNode() override;
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
private:
void Initialize();
void CreateTcpServer();
void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void StartUpdateClusterStateTimer();
void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
std::unique_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_;
std::unique_ptr<std::thread> update_state_thread_;
NodeManager node_manager_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_

View File

@ -38,7 +38,7 @@ bool ServerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Start server node!";
Initialize();
Register(client_to_scheduler_);
Heartbeat(client_to_scheduler_);
StartHeartbeatTimer(client_to_scheduler_);
if (!WaitForStart(timeout)) {
MS_LOG(EXCEPTION) << "Start Worker node timeout!";

View File

@ -146,11 +146,7 @@ void TcpClient::StopEventBase() {
MS_LOG(INFO) << "Stop tcp client event base!";
int ret = event_base_loopbreak(event_base_);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Event base loop break failed!";
}
if (event_base_) {
event_base_free(event_base_);
event_base_ = nullptr;
MS_LOG(ERROR) << "Event base loop break failed!";
}
}

View File

@ -44,7 +44,7 @@ bool WorkerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Starting worker node!";
Initialize();
Register(client_to_scheduler_);
Heartbeat(client_to_scheduler_);
StartHeartbeatTimer(client_to_scheduler_);
if (!WaitForStart(timeout)) {
MS_LOG(ERROR) << "Start Worker node timeout!";