forked from mindspore-Ecosystem/mindspore
!33606 Replace the API of worker_num() & server_num()
Merge pull request !33606 from chengang/replace_node_num_api
This commit is contained in:
commit
6412e38990
|
@ -140,6 +140,16 @@ uint32_t ClusterContext::node_num(const std::string &node_role) {
|
|||
return node_num_each_role_[node_role];
|
||||
}
|
||||
|
||||
uint32_t ClusterContext::node_num() const {
|
||||
uint32_t node_num = 0;
|
||||
for (auto iter = node_num_each_role_.begin(); iter != node_num_each_role_.end(); iter++) {
|
||||
if (iter->first != kEnvRoleOfScheduler) {
|
||||
node_num += iter->second;
|
||||
}
|
||||
}
|
||||
return node_num;
|
||||
}
|
||||
|
||||
bool ClusterContext::initialized() const { return inited_; }
|
||||
|
||||
const ActorRouteTableProxyPtr &ClusterContext::actor_route_table_proxy() const { return actor_route_table_proxy_; }
|
||||
|
|
|
@ -79,6 +79,9 @@ class BACKEND_EXPORT ClusterContext {
|
|||
// Returns total number of the specified node role. This is used as the group size of this node role.
|
||||
uint32_t node_num(const std::string &node_role);
|
||||
|
||||
// Returns the total number of various role nodes.
|
||||
uint32_t node_num() const;
|
||||
|
||||
// Return cluster is initialized.
|
||||
bool initialized() const;
|
||||
|
||||
|
|
|
@ -42,8 +42,7 @@ bool Initialize() {
|
|||
std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node());
|
||||
MS_EXCEPTION_IF_NULL(abstract_node);
|
||||
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
|
||||
auto global_rank_size =
|
||||
(cluster_ctx->node_role() == kEnvRoleOfWorker) ? abstract_node->worker_num() : abstract_node->server_num();
|
||||
auto global_rank_size = cluster::ClusterContext::instance()->node_num();
|
||||
collective::CollectiveManager::instance()->set_global_rank_size(global_rank_size);
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
|
|
|
@ -34,7 +34,7 @@ AllReduceLauncher::AllReduceLauncher() {
|
|||
MS_LOG(EXCEPTION) << "The abstract node is nullptr when init AllReduceLauncher.";
|
||||
}
|
||||
rank_id_ = abs_node_->rank_id();
|
||||
rank_size_ = IntToSize(abs_node_->worker_num());
|
||||
rank_size_ = IntToSize(distributed::cluster::ClusterContext::instance()->node_num());
|
||||
|
||||
const auto &cluster_ctx = distributed::cluster::ClusterContext::instance();
|
||||
MS_EXCEPTION_IF_NULL(cluster_ctx);
|
||||
|
|
Loading…
Reference in New Issue