From a4b0c29a6fd79ee17543618946bb260c48cf9a2c Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Thu, 8 Jul 2021 10:57:34 +0800 Subject: [PATCH] Reuse rank id --- mindspore/ccsrc/ps/core/abstract_node.cc | 1 + mindspore/ccsrc/ps/core/node_info.h | 2 +- mindspore/ccsrc/ps/core/node_manager.cc | 31 ++++++++++++++++++----- mindspore/ccsrc/ps/core/node_manager.h | 4 +-- mindspore/ccsrc/ps/core/scheduler_node.cc | 6 ++--- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index a254507011e..4749ca46dc7 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -24,6 +24,7 @@ void AbstractNode::Register(const std::shared_ptr &client) { MS_EXCEPTION_IF_NULL(client); auto message_meta = std::make_shared(); message_meta->set_cmd(NodeCommand::REGISTER); + message_meta->set_rank_id(node_info_.rank_id_); RegisterMessage register_message; register_message.set_node_id(node_info_.node_id_); diff --git a/mindspore/ccsrc/ps/core/node_info.h b/mindspore/ccsrc/ps/core/node_info.h index 22c73525f22..ceafb120b04 100644 --- a/mindspore/ccsrc/ps/core/node_info.h +++ b/mindspore/ccsrc/ps/core/node_info.h @@ -36,7 +36,7 @@ enum class ClusterEvent { }; struct NodeInfo { - NodeInfo() : ip_(""), port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0), is_alive(false) {} + NodeInfo() : ip_(""), port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(UINT32_MAX), is_alive(false) {} // ip std::string ip_; // the port of this node diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index 48d4b5722fd..d13a5f873f9 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -27,7 +27,7 @@ void NodeManager::InitNode() { total_node_num_ = initial_total_node_num_; } -uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) { +uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr &meta) { std::lock_guard lock(assign_rank_id_mutex_); uint32_t rank_id = UINT_MAX; @@ -51,7 +51,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) { return res; }); if (rank_it == registered_nodes_info_.end()) { - rank_id = ++next_server_rank_id_; + if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_server_rank_id_) { + rank_id = meta->rank_id(); + MS_LOG(INFO) << "Use the old rank id:" << rank_id; + } else { + rank_id = ++next_server_rank_id_; + } } else { registered_nodes_info_.erase((*rank_it).first); } @@ -85,7 +90,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) { return res; }); if (worker_rank_it == registered_nodes_info_.end()) { - rank_id = ++next_worker_rank_id_; + if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_worker_rank_id_) { + rank_id = meta->rank_id(); + MS_LOG(INFO) << "Use the old rank id:" << rank_id; + } else { + rank_id = ++next_worker_rank_id_; + } } else { registered_nodes_info_.erase((*worker_rank_it).first); } @@ -235,12 +245,21 @@ ClusterState NodeManager::GetClusterState() { return cluster_state_; } -void NodeManager::ResetMetadata() { +void NodeManager::ResetMetadata(const std::vector &scale_in_nodes) { MS_LOG(WARNING) << "Reset metadata."; + std::vector server_rank_ids; + if (GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { + for (const auto &item : scale_in_nodes) { + if (registered_nodes_info_.count(item)) { + server_rank_ids.push_back(registered_nodes_info_[item].rank_id_); + } + } + auto min_rank_id = std::min_element(server_rank_ids.begin(), server_rank_ids.end()); + next_server_rank_id_ = *min_rank_id - 1; + MS_LOG(INFO) << "The next server rank id:" << next_server_rank_id_; + } registered_nodes_info_.clear(); heartbeats_.clear(); - next_worker_rank_id_ = -1; - next_server_rank_id_ = -1; } bool NodeManager::IsWorkerOrServer0() { diff --git a/mindspore/ccsrc/ps/core/node_manager.h b/mindspore/ccsrc/ps/core/node_manager.h index c8f8474b506..9c5cce3ae65 100644 --- a/mindspore/ccsrc/ps/core/node_manager.h +++ b/mindspore/ccsrc/ps/core/node_manager.h @@ -56,7 +56,7 @@ class NodeManager { // When initializing nodes, the initial number of nodes will be assigned to the total number of nodes. void InitNode(); - uint32_t NextRankId(const RegisterMessage ®ister_message); + uint32_t NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr &meta); void UpdateHeartbeat(const std::string &node_id); @@ -106,7 +106,7 @@ class NodeManager { // When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes // will re-register. - void ResetMetadata(); + void ResetMetadata(const std::vector &scale_in_nodes = {}); // Recovery currently does not support worker or server0 node downtime. bool IsWorkerOrServer0(); diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index 337c197d2b8..4285a9a65c1 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -131,7 +131,7 @@ void SchedulerNode::ProcessRegister(std::shared_ptr server, std::shar } // assign worker node and server node rank id - uint32_t rank_id = node_manager_.NextRankId(register_message); + uint32_t rank_id = node_manager_.NextRankId(register_message, meta); if (rank_id == UINT32_MAX) { MS_LOG(WARNING) << "The rank id is wrong!"; } @@ -559,7 +559,8 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr resp) { int32_t scale_worker_num = 0; int32_t scale_server_num = 0; auto node_infos = node_manager_.nodes_info(); - node_manager_.ResetMetadata(); + node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN); + node_manager_.ResetMetadata(scale_in_node_ids_); for (auto const &val : scale_in_node_ids_) { if (node_infos.count(val)) { scale_in_nodes[val] = true; @@ -580,7 +581,6 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr resp) { node_manager_.set_worker_num(total_worker_num); node_manager_.set_server_num(total_server_num); node_manager_.set_total_node_num(total_worker_num + total_server_num); - node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN); for (const auto &kvs : node_infos) { auto client = GetOrCreateClient(kvs.second); bool is_node_scale_in = false;