forked from mindspore-Ecosystem/mindspore
!9186 added node manger and node info
From: @anancds Reviewed-by: Signed-off-by:
This commit is contained in:
commit
7dce9f5f4e
|
@ -12,7 +12,9 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_message_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_server.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc")
|
||||
endif()
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
|
||||
endif ()
|
||||
|
||||
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})
|
||||
|
|
|
@ -109,6 +109,19 @@ std::string CommUtil::GenerateUUID() {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
std::string CommUtil::NodeRoleToString(const NodeRole &role) {
|
||||
switch (role) {
|
||||
case NodeRole::SCHEDULER:
|
||||
return "SCHEDULER";
|
||||
case NodeRole::SERVER:
|
||||
return "SERVER";
|
||||
case NodeRole::WORKER:
|
||||
return "WORKER";
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,11 +41,13 @@
|
|||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -63,7 +65,9 @@ class CommUtil {
|
|||
static bool CheckIp(const std::string &ip);
|
||||
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
|
||||
static std::string GenerateUUID();
|
||||
static std::string NodeRoleToString(const NodeRole &role);
|
||||
|
||||
private:
|
||||
static std::random_device rd;
|
||||
static std::mt19937_64 gen;
|
||||
static std::uniform_int_distribution<> dis;
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
void Node::Heartbeat(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
|
||||
<< " begin send heartbeat to the scheduler!";
|
||||
heart_beat_thread_ = std::make_unique<std::thread>([&]() {
|
||||
while (!is_finish_.load()) {
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::HEARTBEAT);
|
||||
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(heartbeat_message.SerializeAsString());
|
||||
SendMessageAsync(client, message);
|
||||
}
|
||||
});
|
||||
heart_beat_thread_->detach();
|
||||
}
|
||||
|
||||
void Node::ProcessHeartbeatResp(const CommMessage &message) {
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.ParseFromString(message.data());
|
||||
is_ready_ = heartbeat_resp_message.is_cluster_ready();
|
||||
if (is_ready_.load()) {
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
is_finish_ = heartbeat_resp_message.is_cluster_finish();
|
||||
if (is_finish_.load()) {
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
|
||||
if (is_timeout_ && on_node_event_message_) {
|
||||
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
||||
void Node::FetchServers(const std::shared_ptr<TcpClient> &client) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::FETCH_SERVER);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
SendMessageSync(client, message);
|
||||
}
|
||||
|
||||
void Node::ProcessFetchServersResp(const CommMessage &message) {
|
||||
FetchServersRespMessage fetch_servers_resp_message;
|
||||
fetch_servers_resp_message.ParseFromString(message.data());
|
||||
|
||||
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
|
||||
server_rank_ids_[it.rank_id()] = std::make_pair(it.ip(), it.port());
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "The all server host size is:" << server_rank_ids_.size();
|
||||
}
|
||||
|
||||
std::string Node::node_id() const { return node_info_.node_id_; }
|
||||
|
||||
uint32_t Node::rank_id() const { return node_info_.rank_id_; }
|
||||
|
||||
void Node::set_callback(const OnNodeEventMessage &on_node_event_message) {
|
||||
on_node_event_message_ = on_node_event_message;
|
||||
}
|
||||
|
||||
void Node::Wait(uint64_t request_id) {
|
||||
std::unique_lock<std::mutex> lock(message_mutex_);
|
||||
message_tracker_cond_.wait(lock, [&] {
|
||||
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
|
||||
if (ret) {
|
||||
MS_LOG(DEBUG) << "Message tracker remove request id:" << request_id;
|
||||
message_tracker_.erase(request_id);
|
||||
}
|
||||
return ret;
|
||||
});
|
||||
}
|
||||
|
||||
void Node::Disconnect(const std::shared_ptr<TcpClient> &client) {
|
||||
MessageMeta meta;
|
||||
meta.set_cmd(NodeCommand::FINISH);
|
||||
|
||||
FinishMessage finish_message;
|
||||
finish_message.set_node_id(node_info_.node_id_);
|
||||
|
||||
CommMessage message;
|
||||
*message.mutable_pb_meta() = {meta};
|
||||
message.set_data(finish_message.SerializeAsString());
|
||||
SendMessageSync(client, message);
|
||||
WaitForDisconnect();
|
||||
}
|
||||
|
||||
void Node::WaitForStart() {
|
||||
std::unique_lock<std::mutex> lock(wait_start_mutex_);
|
||||
wait_start_cond_.wait(lock, [&] {
|
||||
if (is_ready_.load()) {
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success start!";
|
||||
}
|
||||
return is_ready_.load();
|
||||
});
|
||||
}
|
||||
|
||||
void Node::WaitForDisconnect() {
|
||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||
wait_finish_cond_.wait(lock, [&] {
|
||||
if (is_finish_.load()) {
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is success finish!";
|
||||
}
|
||||
return is_finish_.load();
|
||||
});
|
||||
}
|
||||
|
||||
void Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
Wait(request_id);
|
||||
}
|
||||
|
||||
void Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
}
|
||||
|
||||
void Node::NotifyMessageArrival(const CommMessage &message) {
|
||||
const MessageMeta &message_meta = message.pb_meta();
|
||||
uint64_t request_id = message_meta.request_id();
|
||||
|
||||
message_tracker_[request_id].second++;
|
||||
message_tracker_cond_.notify_all();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_NODE_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <condition_variable>
|
||||
#include <utility>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/node_info.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class Node {
|
||||
public:
|
||||
Node()
|
||||
: is_ready_(false),
|
||||
is_finish_(false),
|
||||
is_timeout_(false),
|
||||
is_already_stopped_(true),
|
||||
next_request_id_(0),
|
||||
heart_beat_thread_(nullptr) {}
|
||||
virtual ~Node() = default;
|
||||
|
||||
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
|
||||
void set_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
|
||||
std::string node_id() const;
|
||||
uint32_t rank_id() const;
|
||||
|
||||
void Wait(uint64_t request_id);
|
||||
|
||||
protected:
|
||||
void Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessHeartbeatResp(const CommMessage &message);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
void ProcessFetchServersResp(const CommMessage &message);
|
||||
void Disconnect(const std::shared_ptr<TcpClient> &client);
|
||||
void WaitForStart();
|
||||
void WaitForDisconnect();
|
||||
void SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
|
||||
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
|
||||
void NotifyMessageArrival(const CommMessage &message);
|
||||
|
||||
NodeInfo node_info_;
|
||||
std::atomic<bool> is_ready_;
|
||||
std::atomic<bool> is_finish_;
|
||||
std::atomic<bool> is_timeout_;
|
||||
std::atomic<bool> is_already_stopped_;
|
||||
std::atomic_uint64_t next_request_id_;
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
|
||||
OnNodeEventMessage on_node_event_message_;
|
||||
|
||||
// rank_id-><ip, port>
|
||||
std::unordered_map<int, std::pair<std::string, uint16_t>> server_rank_ids_;
|
||||
|
||||
// timestamp-><expected responses, actual responses>
|
||||
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
|
||||
std::mutex message_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
std::mutex wait_finish_mutex_;
|
||||
std::condition_variable wait_finish_cond_;
|
||||
std::mutex wait_start_mutex_;
|
||||
std::condition_variable wait_start_cond_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
enum NodeEvent { NODE_TIMEOUT = 0 };
|
||||
|
||||
struct NodeInfo {
|
||||
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}
|
||||
// ip
|
||||
std::string ip_;
|
||||
// the port of this node
|
||||
uint16_t port_;
|
||||
// the current Node unique id:0,1,2...
|
||||
std::string node_id_;
|
||||
// the role of the node: worker,server,scheduler
|
||||
NodeRole node_role_;
|
||||
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
||||
uint32_t rank_id_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_NODE_INFO_H_
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/node_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
void NodeManager::InitNodeNum() { total_node_num_ = ClusterConfig::server_num() + ClusterConfig::worker_num(); }
|
||||
|
||||
int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||
std::lock_guard<std::mutex> lock(assign_rank_id_mutex_);
|
||||
int rank_id = -1;
|
||||
|
||||
const std::string &node_id = register_message.node_id();
|
||||
if (nodes_info_.find(node_id) != nodes_info_.end()) {
|
||||
rank_id = nodes_info_[node_id].rank_id_;
|
||||
MS_LOG(INFO) << "The node id: " << node_id << " is already assigned!";
|
||||
return rank_id;
|
||||
}
|
||||
|
||||
if (register_message.role() == NodeRole::SERVER) {
|
||||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
|
||||
rank_id = ++next_server_rank_id_;
|
||||
NodeInfo node_info;
|
||||
node_info.node_role_ = NodeRole::SERVER;
|
||||
node_info.node_id_ = node_id;
|
||||
node_info.rank_id_ = rank_id;
|
||||
node_info.ip_ = ip;
|
||||
node_info.port_ = port;
|
||||
nodes_info_[node_id] = node_info;
|
||||
MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port
|
||||
<< " assign rank id:" << rank_id;
|
||||
|
||||
} else if (register_message.role() == NodeRole::WORKER) {
|
||||
rank_id = ++next_worker_rank_id_;
|
||||
NodeInfo node_info;
|
||||
node_info.node_role_ = NodeRole::WORKER;
|
||||
node_info.node_id_ = node_id;
|
||||
node_info.rank_id_ = rank_id;
|
||||
nodes_info_[node_id] = node_info;
|
||||
MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id;
|
||||
}
|
||||
return rank_id;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateHeartbeat(const std::string &node_id) {
|
||||
std::lock_guard<std::mutex> lock(heartbeat_mutex_);
|
||||
NodeInfo node_info = nodes_info_[node_id];
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
heartbeats_[node_id] = current_time;
|
||||
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
|
||||
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
|
||||
}
|
||||
|
||||
std::vector<ServersMeta> NodeManager::FetchServersMeta() {
|
||||
std::vector<ServersMeta> servers_meta_list;
|
||||
for (auto it = nodes_info_.begin(); it != nodes_info_.end(); ++it) {
|
||||
if (it->second.node_role_ == NodeRole::SERVER) {
|
||||
ServersMeta servers_meta;
|
||||
servers_meta.set_rank_id(it->second.rank_id_);
|
||||
servers_meta.set_ip(it->second.ip_);
|
||||
servers_meta.set_port(it->second.port_);
|
||||
servers_meta_list.push_back(servers_meta);
|
||||
}
|
||||
}
|
||||
return servers_meta_list;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateClusterState() {
|
||||
// 1. update cluster timeout state
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
timeout_nodes_info_.clear();
|
||||
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
|
||||
if (it->second.tv_sec + ClusterConfig::heartbeat_timeout() < current_time.tv_sec) {
|
||||
MS_LOG(ERROR) << "The node id:" << it->first << " is timeout!";
|
||||
timeout_nodes_info_[it->first] = nodes_info_[it->first];
|
||||
}
|
||||
}
|
||||
if (!timeout_nodes_info_.empty()) {
|
||||
is_cluster_timeout_ = true;
|
||||
for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) {
|
||||
finish_nodes_id_.insert(it->first);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. update cluster finish state
|
||||
if (finish_nodes_id_.size() == total_node_num_) {
|
||||
is_cluster_finish_ = true;
|
||||
is_cluster_ready_ = true;
|
||||
}
|
||||
|
||||
// 3. update cluster ready state
|
||||
if (nodes_info_.size() == total_node_num_) {
|
||||
is_cluster_ready_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::CheckClusterTimeout() {
|
||||
if (total_node_num_ != nodes_info_.size()) {
|
||||
MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout()
|
||||
<< " seconds,so finish the cluster";
|
||||
is_cluster_timeout_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::AddFinishNode(const FinishMessage &finish_message) {
|
||||
finish_nodes_id_.insert(finish_message.node_id());
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> NodeManager::nodes_info() { return nodes_info_; }
|
||||
|
||||
bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); }
|
||||
|
||||
bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); }
|
||||
|
||||
bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef RPC_CLUSTER_MANAGER_H
|
||||
#define RPC_CLUSTER_MANAGER_H
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <condition_variable>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class NodeManager {
|
||||
public:
|
||||
NodeManager()
|
||||
: is_cluster_ready_(false),
|
||||
is_cluster_finish_(false),
|
||||
is_cluster_timeout_(false),
|
||||
total_node_num_(0),
|
||||
next_worker_rank_id_(-1),
|
||||
next_server_rank_id_(-1) {}
|
||||
virtual ~NodeManager() = default;
|
||||
|
||||
enum ClusterState { STARTING, STARTED, FAILED, STOPPING, STOPPED };
|
||||
|
||||
void InitNodeNum();
|
||||
int NextRankId(const RegisterMessage ®ister_message);
|
||||
void UpdateHeartbeat(const std::string &node_id);
|
||||
std::vector<ServersMeta> FetchServersMeta();
|
||||
void UpdateClusterState();
|
||||
void CheckClusterTimeout();
|
||||
void AddFinishNode(const FinishMessage &finish_message);
|
||||
std::unordered_map<std::string, NodeInfo> nodes_info();
|
||||
bool is_cluster_ready();
|
||||
bool is_cluster_finish();
|
||||
bool is_cluster_timeout();
|
||||
|
||||
private:
|
||||
std::atomic<bool> is_cluster_ready_;
|
||||
std::atomic<bool> is_cluster_finish_;
|
||||
std::atomic<bool> is_cluster_timeout_;
|
||||
uint32_t total_node_num_;
|
||||
std::atomic<int> next_worker_rank_id_;
|
||||
std::atomic<int> next_server_rank_id_;
|
||||
// worker nodes and server nodes
|
||||
std::unordered_map<std::string, NodeInfo> nodes_info_;
|
||||
std::mutex assign_rank_id_mutex_;
|
||||
std::mutex heartbeat_mutex_;
|
||||
std::unordered_map<std::string, timeval> heartbeats_;
|
||||
// timeout nodes
|
||||
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
|
||||
std::unordered_set<std::string> finish_nodes_id_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // RPC_CLUSTER_MANAGER_H
|
|
@ -25,6 +25,7 @@ enum NodeCommand {
|
|||
HEARTBEAT = 2;
|
||||
SEND_DATA = 3;
|
||||
FETCH_SERVER = 4;
|
||||
FINISH = 5;
|
||||
}
|
||||
|
||||
enum NodeRole {
|
||||
|
@ -65,6 +66,7 @@ message HeartbeatRespMessage {
|
|||
// Is the entire system ready to use.
|
||||
bool is_cluster_ready = 1;
|
||||
bool is_cluster_finish = 2;
|
||||
bool is_cluster_timeout = 3;
|
||||
}
|
||||
|
||||
message FetchServersRespMessage {
|
||||
|
@ -78,6 +80,11 @@ message ServersMeta {
|
|||
|
||||
}
|
||||
|
||||
message FinishMessage {
|
||||
// the current Node unique id:0,1,2...
|
||||
string node_id = 1;
|
||||
}
|
||||
|
||||
message CommMessage {
|
||||
MessageMeta pb_meta = 1;
|
||||
bytes data = 2;
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include <atomic>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -85,6 +85,8 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
|
|||
this->client_accept_ = client_accept;
|
||||
}
|
||||
|
||||
void TcpServer::set_timer_once_callback(const OnTimerOnce &timer) { on_timer_once_callback_ = timer; }
|
||||
|
||||
void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
|
||||
|
||||
void TcpServer::Init() {
|
||||
|
@ -165,7 +167,21 @@ void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
|
|||
struct timeval timeout {};
|
||||
timeout.tv_sec = time;
|
||||
timeout.tv_usec = 0;
|
||||
ev = evtimer_new(base_, TimerCallback, this);
|
||||
ev = evtimer_new(base_, TimerOnceCallback, this);
|
||||
MS_EXCEPTION_IF_NULL(ev);
|
||||
evtimer_add(ev, &timeout);
|
||||
}
|
||||
|
||||
void TcpServer::StartTimer(const uint32_t &time) {
|
||||
MS_EXCEPTION_IF_NULL(base_);
|
||||
struct event *ev = nullptr;
|
||||
if (time == 0) {
|
||||
MS_LOG(EXCEPTION) << "The time should not be 0!";
|
||||
}
|
||||
struct timeval timeout {};
|
||||
timeout.tv_sec = time;
|
||||
timeout.tv_usec = 0;
|
||||
ev = event_new(base_, -1, EV_PERSIST, TimerCallback, this);
|
||||
MS_EXCEPTION_IF_NULL(ev);
|
||||
evtimer_add(ev, &timeout);
|
||||
}
|
||||
|
@ -321,7 +337,15 @@ void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) {
|
|||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto tcp_server = reinterpret_cast<TcpServer *>(arg);
|
||||
if (tcp_server->on_timer_callback_) {
|
||||
tcp_server->on_timer_callback_(*tcp_server);
|
||||
tcp_server->on_timer_callback_();
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto tcp_server = reinterpret_cast<TcpServer *>(arg);
|
||||
if (tcp_server->on_timer_once_callback_) {
|
||||
tcp_server->on_timer_once_callback_(*tcp_server);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -337,6 +361,8 @@ void TcpServer::SendMessage(const CommMessage &message) {
|
|||
|
||||
uint16_t TcpServer::BoundPort() const { return server_port_; }
|
||||
|
||||
std::string TcpServer::BoundIp() const { return server_address_; }
|
||||
|
||||
int TcpServer::ConnectionNum() const { return connections_.size(); }
|
||||
|
||||
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include <atomic>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/tcp_message_handler.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -71,18 +72,21 @@ class TcpServer {
|
|||
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
|
||||
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
|
||||
using OnTimer = std::function<void(const TcpServer &)>;
|
||||
using OnTimerOnce = std::function<void(const TcpServer &)>;
|
||||
using OnTimer = std::function<void()>;
|
||||
|
||||
explicit TcpServer(const std::string &address, std::uint16_t port);
|
||||
virtual ~TcpServer();
|
||||
|
||||
void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
|
||||
const OnAccepted &client_accept);
|
||||
void set_timer_once_callback(const OnTimerOnce &timer);
|
||||
void set_timer_callback(const OnTimer &timer);
|
||||
void Init();
|
||||
void Start();
|
||||
void StartWithNoBlock();
|
||||
void StartTimerOnlyOnce(const uint32_t &time);
|
||||
void StartTimer(const uint32_t &time);
|
||||
void Stop();
|
||||
void SendToAllClients(const char *data, size_t len);
|
||||
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
|
||||
|
@ -92,6 +96,7 @@ class TcpServer {
|
|||
void SendMessage(const TcpConnection &conn, const CommMessage &message);
|
||||
void SendMessage(const CommMessage &message);
|
||||
uint16_t BoundPort() const;
|
||||
std::string BoundIp() const;
|
||||
int ConnectionNum() const;
|
||||
const std::map<evutil_socket_t, const TcpConnection *> &Connections() const;
|
||||
|
||||
|
@ -102,6 +107,7 @@ class TcpServer {
|
|||
static void ReadCallback(struct bufferevent *, void *connection);
|
||||
static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
|
||||
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
|
||||
static void TimerOnceCallback(evutil_socket_t fd, int16_t event, void *arg);
|
||||
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);
|
||||
|
||||
struct event_base *base_;
|
||||
|
@ -117,6 +123,7 @@ class TcpServer {
|
|||
OnAccepted client_accept_;
|
||||
std::recursive_mutex connection_mutex_;
|
||||
OnServerReceiveMessage message_callback_;
|
||||
OnTimerOnce on_timer_once_callback_;
|
||||
OnTimer on_timer_callback_;
|
||||
};
|
||||
} // namespace core
|
||||
|
|
Loading…
Reference in New Issue