!33606 Replace the API of worker_num() & server_num()

Merge pull request !33606 from chengang/replace_node_num_api
This commit is contained in:
i-robot 2022-04-27 09:52:05 +00:00 committed by Gitee
commit 6412e38990
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 15 additions and 3 deletions

View File

@ -140,6 +140,16 @@ uint32_t ClusterContext::node_num(const std::string &node_role) {
return node_num_each_role_[node_role];
}
uint32_t ClusterContext::node_num() const {
uint32_t node_num = 0;
for (auto iter = node_num_each_role_.begin(); iter != node_num_each_role_.end(); iter++) {
if (iter->first != kEnvRoleOfScheduler) {
node_num += iter->second;
}
}
return node_num;
}
bool ClusterContext::initialized() const { return inited_; }
const ActorRouteTableProxyPtr &ClusterContext::actor_route_table_proxy() const { return actor_route_table_proxy_; }

View File

@ -79,6 +79,9 @@ class BACKEND_EXPORT ClusterContext {
// Returns total number of the specified node role. This is used as the group size of this node role.
uint32_t node_num(const std::string &node_role);
// Returns the total number of various role nodes.
uint32_t node_num() const;
// Return cluster is initialized.
bool initialized() const;

View File

@ -42,8 +42,7 @@ bool Initialize() {
std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
auto global_rank_size =
(cluster_ctx->node_role() == kEnvRoleOfWorker) ? abstract_node->worker_num() : abstract_node->server_num();
auto global_rank_size = cluster::ClusterContext::instance()->node_num();
collective::CollectiveManager::instance()->set_global_rank_size(global_rank_size);
if (RecoveryContext::GetInstance()->enable_recovery()) {

View File

@ -34,7 +34,7 @@ AllReduceLauncher::AllReduceLauncher() {
MS_LOG(EXCEPTION) << "The abstract node is nullptr when init AllReduceLauncher.";
}
rank_id_ = abs_node_->rank_id();
rank_size_ = IntToSize(abs_node_->worker_num());
rank_size_ = IntToSize(distributed::cluster::ClusterContext::instance()->node_num());
const auto &cluster_ctx = distributed::cluster::ClusterContext::instance();
MS_EXCEPTION_IF_NULL(cluster_ctx);