commit
41bd077e09
|
@ -19,6 +19,7 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc")
|
||||
endif ()
|
||||
|
||||
if (NOT ENABLE_D)
|
||||
|
|
|
@ -74,30 +74,161 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
|
|||
on_node_event_message_ = on_node_event_message;
|
||||
}
|
||||
|
||||
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
const uint32_t &timeout) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
return SendMessageSync(client, comm_message);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
|
||||
if (rank_ids.size() != data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
|
||||
}
|
||||
for (size_t it = 0; it < rank_ids.size(); ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
CommMessage *comm_message_resp, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(comm_message_resp);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
*comm_message_resp = res[rank_id];
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
client->SendMessage(comm_message);
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
|
||||
const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(comm_message_resp);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
|
||||
if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
|
||||
}
|
||||
|
||||
size_t len = rank_ids.size();
|
||||
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
comm_message_resp->at(it) = &res[rank_ids.at(it)];
|
||||
}
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
|
||||
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
|
||||
return ret;
|
||||
});
|
||||
message_tracker_.erase(request_id);
|
||||
return res;
|
||||
}
|
||||
|
||||
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
|
||||
<< " begin send heartbeat to the scheduler!";
|
||||
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
|
||||
while (!is_finish_.load()) {
|
||||
Heartbeat(client);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::HEARTBEAT);
|
||||
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(heartbeat_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||
}
|
||||
}
|
||||
});
|
||||
heart_beat_thread_->detach();
|
||||
}
|
||||
|
||||
void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::HEARTBEAT);
|
||||
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.set_node_id(node_info_.node_id_);
|
||||
heartbeat_message.set_is_node_finish(is_node_finish);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(heartbeat_message.SerializeAsString());
|
||||
if (!SendMessageSync(client, message)) {
|
||||
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.ParseFromString(message.data());
|
||||
|
@ -106,8 +237,9 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
|
|||
wait_start_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
|
||||
}
|
||||
is_finish_ = heartbeat_resp_message.is_cluster_finish();
|
||||
if (is_finish_.load()) {
|
||||
if (heartbeat_resp_message.is_cluster_finish()) {
|
||||
Heartbeat(client_to_scheduler_, true);
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
|
||||
}
|
||||
|
@ -115,6 +247,10 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
|
|||
if (is_timeout_ && on_node_event_message_) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
on_node_event_message_(NodeEvent::CLUSTER_TIMEOUT);
|
||||
}
|
||||
|
||||
if (heartbeat_resp_message.is_node_timeout() && on_node_event_message_) {
|
||||
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
@ -207,6 +343,101 @@ bool AbstractNode::InitClientToScheduler() {
|
|||
});
|
||||
return client_to_scheduler_->WaitConnected();
|
||||
}
|
||||
|
||||
const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &rank_id) {
|
||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
|
||||
return connected_nodes_[rank_id];
|
||||
} else {
|
||||
if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!";
|
||||
}
|
||||
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
|
||||
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
|
||||
auto client = std::make_shared<TcpClient>(ip, port);
|
||||
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendDataResp(message);
|
||||
RunMessageCallback(message.pb_meta().request_id());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
});
|
||||
client->Init();
|
||||
connected_nodes_[rank_id] = client;
|
||||
return connected_nodes_[rank_id];
|
||||
}
|
||||
}
|
||||
|
||||
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
void AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
|
||||
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
const uint32_t &rank_id = message_meta.rank_id();
|
||||
const uint64_t request_id = message_meta.request_id();
|
||||
auto it = receive_messages_.find(request_id);
|
||||
if (it != receive_messages_.end()) {
|
||||
it->second.insert(std::make_pair(rank_id, message));
|
||||
} else {
|
||||
std::unordered_map<uint32_t, CommMessage> res;
|
||||
res.insert(std::make_pair(rank_id, message));
|
||||
receive_messages_[request_id] = res;
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::RunMessageCallback(const uint64_t &request_id) {
|
||||
message_callbacks_mutex_.lock();
|
||||
// When receiving a message's response, Then compare with the desired number of responses,
|
||||
// If they are equal, then call the callback function
|
||||
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
|
||||
auto it = message_callbacks_.find(request_id);
|
||||
if (it != message_callbacks_.end()) {
|
||||
message_callbacks_mutex_.unlock();
|
||||
|
||||
if (it->second) {
|
||||
it->second();
|
||||
}
|
||||
|
||||
message_callbacks_mutex_.lock();
|
||||
message_callbacks_.erase(it);
|
||||
}
|
||||
}
|
||||
message_callbacks_mutex_.unlock();
|
||||
}
|
||||
|
||||
void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) {
|
||||
if (!message_callback) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
|
||||
message_callbacks_[request_id] = message_callback;
|
||||
}
|
||||
|
||||
void AbstractNode::NotifyMessageArrival(const CommMessage &message) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
uint64_t request_id = message_meta.request_id();
|
||||
|
||||
message_tracker_[request_id].second++;
|
||||
message_tracker_cond_.notify_all();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ps/core/node.h"
|
||||
|
||||
|
@ -34,21 +37,60 @@ class AbstractNode : public Node {
|
|||
bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
|
||||
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
void Register(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessRegisterResp(const CommMessage &message);
|
||||
void Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
void Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
|
||||
void ProcessHeartbeatResp(const CommMessage &message);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessFetchServersResp(const CommMessage &message);
|
||||
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
|
||||
bool WaitForDisconnect(const uint32_t &timeout);
|
||||
bool InitClientToScheduler();
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
|
||||
void ProcessSendDataResp(const CommMessage &message);
|
||||
void RunMessageCallback(const uint64_t &request_id);
|
||||
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
|
||||
void NotifyMessageArrival(const CommMessage &message);
|
||||
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||
std::shared_ptr<TcpClient> client_to_scheduler_;
|
||||
|
||||
OnNodeEventMessage on_node_event_message_;
|
||||
// the map's key is: <node_role,rank_id>, the map's value is: <ip, port>
|
||||
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
|
||||
std::mutex client_mutex_;
|
||||
// the map's key is: rank_id
|
||||
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
|
||||
|
||||
// the map's key is: request_id, the map's value is: <expected responses, actual responses>
|
||||
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
|
||||
std::mutex message_tracker_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
|
||||
// the map's key is: request_id, the map's value is:<rank_id, CommMessage>
|
||||
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
|
||||
std::mutex receive_messages_mutex_;
|
||||
// the map's key is: request_id
|
||||
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
|
||||
std::mutex message_callbacks_mutex_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -25,131 +25,6 @@ uint32_t Node::rank_id() const { return node_info_.rank_id_; }
|
|||
|
||||
NodeRole Node::role() const { return node_info_.node_role_; }
|
||||
|
||||
bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
|
||||
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
|
||||
return ret;
|
||||
});
|
||||
message_tracker_.erase(request_id);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
const uint32_t &timeout) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
return SendMessageSync(client, comm_message);
|
||||
}
|
||||
|
||||
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
|
||||
if (rank_ids.size() != data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
|
||||
}
|
||||
for (size_t it = 0; it < rank_ids.size(); ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool Node::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
CommMessage *comm_message_resp, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(comm_message_resp);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
*comm_message_resp = res[rank_id];
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
client->SendMessage(comm_message);
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool Node::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
std::vector<CommMessage *> *comm_message_resp, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(comm_message_resp);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
|
||||
if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
|
||||
}
|
||||
|
||||
size_t len = rank_ids.size();
|
||||
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
comm_message_resp->at(it) = &res[rank_ids.at(it)];
|
||||
}
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_request_id(request_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool Node::WaitForStart(const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(wait_start_mutex_);
|
||||
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
|
@ -161,101 +36,6 @@ bool Node::WaitForStart(const uint32_t &timeout) {
|
|||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
}
|
||||
|
||||
const std::shared_ptr<TcpClient> &Node::GetOrCreateTcpClient(const int &rank_id) {
|
||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
|
||||
return connected_nodes_[rank_id];
|
||||
} else {
|
||||
if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Worker node Fetch servers failed!";
|
||||
}
|
||||
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
|
||||
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
|
||||
auto client = std::make_shared<TcpClient>(ip, port);
|
||||
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendDataResp(message);
|
||||
RunMessageCallback(message.pb_meta().request_id());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
});
|
||||
client->Init();
|
||||
connected_nodes_[rank_id] = client;
|
||||
return connected_nodes_[rank_id];
|
||||
}
|
||||
}
|
||||
|
||||
void Node::ProcessSendDataResp(const CommMessage &message) {
|
||||
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
const uint32_t &rank_id = message_meta.rank_id();
|
||||
const uint64_t request_id = message_meta.request_id();
|
||||
auto it = receive_messages_.find(request_id);
|
||||
if (it != receive_messages_.end()) {
|
||||
it->second.insert(std::make_pair(rank_id, message));
|
||||
} else {
|
||||
std::unordered_map<uint32_t, CommMessage> res;
|
||||
res.insert(std::make_pair(rank_id, message));
|
||||
receive_messages_[request_id] = res;
|
||||
}
|
||||
}
|
||||
|
||||
void Node::RunMessageCallback(const uint64_t &request_id) {
|
||||
message_callbacks_mutex_.lock();
|
||||
// When receiving a message's response, Then compare with the desired number of responses,
|
||||
// If they are equal, then call the callback function
|
||||
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
|
||||
auto it = message_callbacks_.find(request_id);
|
||||
if (it != message_callbacks_.end()) {
|
||||
message_callbacks_mutex_.unlock();
|
||||
|
||||
if (it->second) {
|
||||
it->second();
|
||||
}
|
||||
|
||||
message_callbacks_mutex_.lock();
|
||||
message_callbacks_.erase(it);
|
||||
}
|
||||
}
|
||||
message_callbacks_mutex_.unlock();
|
||||
}
|
||||
|
||||
void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) {
|
||||
if (!message_callback) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
|
||||
message_callbacks_[request_id] = message_callback;
|
||||
}
|
||||
|
||||
void Node::NotifyMessageArrival(const CommMessage &message) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
uint64_t request_id = message_meta.request_id();
|
||||
|
||||
message_tracker_[request_id].second++;
|
||||
message_tracker_cond_.notify_all();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
#include <condition_variable>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <map>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
|
@ -66,28 +65,8 @@ class Node {
|
|||
uint32_t rank_id() const;
|
||||
NodeRole role() const;
|
||||
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
bool WaitForStart(const uint32_t &timeout);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
|
||||
void ProcessSendDataResp(const CommMessage &message);
|
||||
void RunMessageCallback(const uint64_t &request_id);
|
||||
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
|
||||
void NotifyMessageArrival(const CommMessage &message);
|
||||
|
||||
NodeInfo node_info_;
|
||||
std::atomic<bool> is_ready_;
|
||||
|
@ -97,28 +76,11 @@ class Node {
|
|||
std::atomic<bool> is_already_finished_;
|
||||
std::atomic_uint64_t next_request_id_;
|
||||
|
||||
// <NodeRole,rank_id>-><ip, port>
|
||||
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
|
||||
// rank_id->tcpclient
|
||||
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
|
||||
|
||||
// request_id-><expected responses, actual responses>
|
||||
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
|
||||
std::mutex message_tracker_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
std::mutex wait_finish_mutex_;
|
||||
std::condition_variable wait_finish_cond_;
|
||||
std::mutex wait_start_mutex_;
|
||||
std::condition_variable wait_start_cond_;
|
||||
std::mutex wait_finish_mutex_;
|
||||
std::condition_variable wait_finish_cond_;
|
||||
std::mutex finish_mutex_;
|
||||
std::mutex client_mutex_;
|
||||
|
||||
// request_id -> <rank_id, CommMessage>
|
||||
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
|
||||
std::mutex receive_messages_mutex_;
|
||||
// request_id -> MessageCallback
|
||||
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
|
||||
std::mutex message_callbacks_mutex_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
enum NodeEvent { NODE_TIMEOUT = 0 };
|
||||
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 };
|
||||
|
||||
struct NodeInfo {
|
||||
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}
|
||||
|
|
|
@ -69,6 +69,10 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
|
|||
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); }
|
||||
|
||||
bool NodeManager::CheckNodesFinishState() { return heartbeats_finish_nodes_.size() == nodes_info_.size(); }
|
||||
|
||||
std::vector<ServersMeta> NodeManager::FetchServersMeta() {
|
||||
std::vector<ServersMeta> servers_meta_list;
|
||||
for (auto it = nodes_info_.begin(); it != nodes_info_.end(); ++it) {
|
||||
|
@ -131,7 +135,11 @@ bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); }
|
|||
|
||||
bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); }
|
||||
|
||||
bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_; }
|
||||
bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_.load(); }
|
||||
|
||||
bool NodeManager::is_node_timeout() { return is_node_timeout_.load(); }
|
||||
|
||||
void NodeManager::set_cluster_timeout(bool is_cluster_timeout) { is_cluster_timeout_ = is_cluster_timeout; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef RPC_CLUSTER_MANAGER_H
|
||||
#define RPC_CLUSTER_MANAGER_H
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
|
@ -45,6 +45,7 @@ class NodeManager {
|
|||
: is_cluster_ready_(false),
|
||||
is_cluster_finish_(false),
|
||||
is_cluster_timeout_(false),
|
||||
is_node_timeout_(false),
|
||||
total_node_num_(0),
|
||||
next_worker_rank_id_(-1),
|
||||
next_server_rank_id_(-1) {}
|
||||
|
@ -55,6 +56,8 @@ class NodeManager {
|
|||
void InitNodeNum();
|
||||
int NextRankId(const RegisterMessage ®ister_message);
|
||||
void UpdateHeartbeat(const std::string &node_id);
|
||||
void UpdateNodeFinishState(const std::string &node_id);
|
||||
bool CheckNodesFinishState();
|
||||
std::vector<ServersMeta> FetchServersMeta();
|
||||
void UpdateClusterState();
|
||||
void CheckClusterTimeout();
|
||||
|
@ -63,11 +66,14 @@ class NodeManager {
|
|||
bool is_cluster_ready();
|
||||
bool is_cluster_finish();
|
||||
bool is_cluster_timeout();
|
||||
bool is_node_timeout();
|
||||
void set_cluster_timeout(bool is_cluster_timeout);
|
||||
|
||||
private:
|
||||
std::atomic<bool> is_cluster_ready_;
|
||||
std::atomic<bool> is_cluster_finish_;
|
||||
std::atomic<bool> is_cluster_timeout_;
|
||||
std::atomic<bool> is_node_timeout_;
|
||||
uint32_t total_node_num_;
|
||||
std::atomic<int> next_worker_rank_id_;
|
||||
std::atomic<int> next_server_rank_id_;
|
||||
|
@ -76,6 +82,7 @@ class NodeManager {
|
|||
std::mutex assign_rank_id_mutex_;
|
||||
std::mutex heartbeat_mutex_;
|
||||
std::unordered_map<std::string, timeval> heartbeats_;
|
||||
std::unordered_set<std::string> heartbeats_finish_nodes_;
|
||||
// timeout nodes
|
||||
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
|
||||
std::unordered_set<std::string> finish_nodes_id_;
|
||||
|
@ -83,4 +90,4 @@ class NodeManager {
|
|||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // RPC_CLUSTER_MANAGER_H
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_NODE_MANAGER_H_
|
||||
|
|
|
@ -64,6 +64,7 @@ message RegisterRespMessage {
|
|||
message HeartbeatMessage {
|
||||
// the current Node unique id:0,1,2...
|
||||
string node_id = 1;
|
||||
bool is_node_finish = 2;
|
||||
}
|
||||
|
||||
message HeartbeatRespMessage {
|
||||
|
@ -71,6 +72,7 @@ message HeartbeatRespMessage {
|
|||
bool is_cluster_ready = 1;
|
||||
bool is_cluster_finish = 2;
|
||||
bool is_cluster_timeout = 3;
|
||||
bool is_node_timeout = 4;
|
||||
}
|
||||
|
||||
message FetchServersRespMessage {
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/scheduler_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
SchedulerNode::~SchedulerNode() {
|
||||
MS_LOG(INFO) << "Stop scheduler node!";
|
||||
if (!is_already_stopped_) {
|
||||
is_already_stopped_ = true;
|
||||
server_->Stop();
|
||||
if (scheduler_thread_->joinable()) {
|
||||
scheduler_thread_->join();
|
||||
}
|
||||
if (update_state_thread_->joinable()) {
|
||||
update_state_thread_->join();
|
||||
}
|
||||
is_ready_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool SchedulerNode::Start(const uint32_t &timeout) {
|
||||
MS_LOG(INFO) << "Start scheduler node!";
|
||||
Initialize();
|
||||
StartUpdateClusterStateTimer();
|
||||
if (!WaitForStart(timeout)) {
|
||||
MS_LOG(ERROR) << "Start Scheduler node timeout!";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Start the scheduler node is successful!";
|
||||
return true;
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.ParseFromString(message.data());
|
||||
|
||||
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
|
||||
|
||||
if (heartbeat_message.is_node_finish()) {
|
||||
node_manager_.UpdateNodeFinishState(heartbeat_message.node_id());
|
||||
}
|
||||
|
||||
if (heartbeat_message.is_node_finish() && node_manager_.CheckNodesFinishState()) {
|
||||
MS_LOG(INFO) << "The scheduler node receive all the finish cmd!";
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.set_is_cluster_ready(node_manager_.is_cluster_ready());
|
||||
heartbeat_resp_message.set_is_cluster_finish(node_manager_.is_cluster_finish());
|
||||
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
|
||||
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message.pb_meta()};
|
||||
comm_message.set_data(heartbeat_resp_message.SerializeAsString());
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
void SchedulerNode::Initialize() {
|
||||
CreateTcpServer();
|
||||
is_already_stopped_ = false;
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = NodeRole::SCHEDULER;
|
||||
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_;
|
||||
}
|
||||
|
||||
void SchedulerNode::CreateTcpServer() {
|
||||
node_manager_.InitNodeNum();
|
||||
|
||||
std::string scheduler_host = ClusterConfig::scheduler_host();
|
||||
uint32_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
server_ = std::make_unique<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::HEARTBEAT:
|
||||
ProcessHeartbeat(server, conn, message);
|
||||
break;
|
||||
case NodeCommand::REGISTER:
|
||||
ProcessRegister(server, conn, message);
|
||||
break;
|
||||
case NodeCommand::FINISH:
|
||||
ProcessFinish(server, conn, message);
|
||||
break;
|
||||
case NodeCommand::FETCH_SERVER:
|
||||
ProcessFetchServers(server, conn, message);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
});
|
||||
|
||||
server_->Init();
|
||||
|
||||
scheduler_thread_ = std::make_unique<std::thread>([&]() {
|
||||
MS_LOG(INFO) << "The scheduler node start a tcp server!";
|
||||
server_->Start();
|
||||
});
|
||||
scheduler_thread_->detach();
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
MS_LOG(INFO) << "The scheduler process a register message!";
|
||||
RegisterMessage register_message;
|
||||
register_message.ParseFromString(message.data());
|
||||
|
||||
// assign worker node and server node rank id
|
||||
int rank_id = node_manager_.NextRankId(register_message);
|
||||
if (rank_id < 0) {
|
||||
MS_LOG(EXCEPTION) << "The rank id is wrong!";
|
||||
}
|
||||
const std::string &node_id = register_message.node_id();
|
||||
node_manager_.UpdateHeartbeat(node_id);
|
||||
|
||||
RegisterRespMessage register_resp_message;
|
||||
register_resp_message.set_node_id(node_id);
|
||||
register_resp_message.set_rank_id(rank_id);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message.pb_meta()};
|
||||
comm_message.set_data(register_resp_message.SerializeAsString());
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
|
||||
FinishMessage finish_message;
|
||||
finish_message.ParseFromString(message.data());
|
||||
node_manager_.AddFinishNode(finish_message);
|
||||
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, message);
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFetchServers(const TcpServer &server, const TcpConnection &conn,
|
||||
const CommMessage &message) {
|
||||
FetchServersRespMessage fetch_servers_message;
|
||||
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
|
||||
|
||||
*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message.pb_meta()};
|
||||
comm_message.set_data(fetch_servers_message.SerializeAsString());
|
||||
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
|
||||
}
|
||||
|
||||
void SchedulerNode::StartUpdateClusterStateTimer() {
|
||||
MS_LOG(WARNING) << "The scheduler start a heartbeat timer!";
|
||||
update_state_thread_ = std::make_unique<std::thread>([&]() {
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
while (!is_finish_.load()) {
|
||||
// 1. update cluster timeout
|
||||
if (!node_manager_.is_cluster_ready() && (std::chrono::steady_clock::now() - start_time >
|
||||
std::chrono::seconds(ClusterConfig::cluster_available_timeout()))) {
|
||||
node_manager_.CheckClusterTimeout();
|
||||
}
|
||||
|
||||
// 2. update cluster state
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
|
||||
node_manager_.UpdateClusterState();
|
||||
if (node_manager_.is_cluster_ready()) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
if (node_manager_.is_cluster_finish()) {
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval() * 2));
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
});
|
||||
update_state_thread_->detach();
|
||||
}
|
||||
|
||||
bool SchedulerNode::Stop() {
|
||||
MS_LOG(INFO) << "Stop scheduler node!";
|
||||
if (!is_already_stopped_) {
|
||||
is_already_stopped_ = true;
|
||||
server_->Stop();
|
||||
if (scheduler_thread_->joinable()) {
|
||||
scheduler_thread_->join();
|
||||
}
|
||||
if (update_state_thread_->joinable()) {
|
||||
update_state_thread_->join();
|
||||
}
|
||||
is_ready_ = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SchedulerNode::Finish(const uint32_t &timeout) {
|
||||
MS_LOG(INFO) << "Finish scheduler node!";
|
||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||
wait_finish_cond_.wait(lock, [&] {
|
||||
if (is_finish_.load()) {
|
||||
MS_LOG(INFO) << "The scheduler finish success!";
|
||||
}
|
||||
return is_finish_.load();
|
||||
});
|
||||
return true;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
#include "ps/core/node_manager.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
class SchedulerNode : public Node {
|
||||
public:
|
||||
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
|
||||
~SchedulerNode() override;
|
||||
|
||||
bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
private:
|
||||
void Initialize();
|
||||
void CreateTcpServer();
|
||||
void ProcessHeartbeat(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void StartUpdateClusterStateTimer();
|
||||
void ProcessFinish(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
void ProcessFetchServers(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
|
||||
|
||||
std::unique_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> scheduler_thread_;
|
||||
std::unique_ptr<std::thread> update_state_thread_;
|
||||
|
||||
NodeManager node_manager_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_SCHEDULER_NODE_H_
|
|
@ -38,7 +38,7 @@ bool ServerNode::Start(const uint32_t &timeout) {
|
|||
MS_LOG(INFO) << "Start server node!";
|
||||
Initialize();
|
||||
Register(client_to_scheduler_);
|
||||
Heartbeat(client_to_scheduler_);
|
||||
StartHeartbeatTimer(client_to_scheduler_);
|
||||
|
||||
if (!WaitForStart(timeout)) {
|
||||
MS_LOG(EXCEPTION) << "Start Worker node timeout!";
|
||||
|
|
|
@ -146,11 +146,7 @@ void TcpClient::StopEventBase() {
|
|||
MS_LOG(INFO) << "Stop tcp client event base!";
|
||||
int ret = event_base_loopbreak(event_base_);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Event base loop break failed!";
|
||||
}
|
||||
if (event_base_) {
|
||||
event_base_free(event_base_);
|
||||
event_base_ = nullptr;
|
||||
MS_LOG(ERROR) << "Event base loop break failed!";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ bool WorkerNode::Start(const uint32_t &timeout) {
|
|||
MS_LOG(INFO) << "Starting worker node!";
|
||||
Initialize();
|
||||
Register(client_to_scheduler_);
|
||||
Heartbeat(client_to_scheduler_);
|
||||
StartHeartbeatTimer(client_to_scheduler_);
|
||||
|
||||
if (!WaitForStart(timeout)) {
|
||||
MS_LOG(ERROR) << "Start Worker node timeout!";
|
||||
|
|
Loading…
Reference in New Issue