forked from mindspore-Ecosystem/mindspore
Fix comm helper method
This commit is contained in:
parent
bd1c1772ea
commit
78a79a9b5e
|
@ -84,12 +84,12 @@ bool ClusterContext::Initialize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ClusterContext::Finalize() {
|
||||
bool ClusterContext::Finalize(uint32_t timeout) {
|
||||
if (finalized_) {
|
||||
return true;
|
||||
}
|
||||
// In some cases, one node calls the Finish function while other nodes don't. So timeout is acceptable.
|
||||
if (!node_->Finish()) {
|
||||
if (!node_->Finish(timeout)) {
|
||||
MS_LOG(WARNING) << "Finishing node " << node_role_ << " timeout.";
|
||||
}
|
||||
if (!node_->Stop()) {
|
||||
|
|
|
@ -47,8 +47,8 @@ class ClusterContext {
|
|||
// Initialize the cluster configuration and build network.
|
||||
bool Initialize();
|
||||
|
||||
// Finalize the cluster and process exits.
|
||||
bool Finalize();
|
||||
// Finalize the cluster and process exits. If timeout is set to UINT32_MAX, this method will block without timeout.
|
||||
bool Finalize(uint32_t timeout = kDefaultFinishTimeout);
|
||||
|
||||
// Return node object of this process.
|
||||
const std::shared_ptr<ps::core::Node> &node() const;
|
||||
|
|
|
@ -31,7 +31,7 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
|
|||
|
||||
bool ClusterContext::Initialize() const { return true; }
|
||||
|
||||
bool ClusterContext::Finalize() const { return true; }
|
||||
bool ClusterContext::Finalize(uint32_t) const { return true; }
|
||||
|
||||
bool ClusterContext::initialized() const { return false; }
|
||||
} // namespace cluster
|
||||
|
|
|
@ -38,7 +38,7 @@ class ClusterContext {
|
|||
static std::shared_ptr<ClusterContext> instance();
|
||||
|
||||
bool Initialize() const;
|
||||
bool Finalize() const;
|
||||
bool Finalize(uint32_t timeout = kDefaultFinishTimeout) const;
|
||||
bool initialized() const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -38,6 +38,7 @@ constexpr char kLocalHost[] = "127.0.0.1";
|
|||
constexpr int MAX_HOSTNAME_LEN = 1024;
|
||||
const uint16_t kDefaultSchedPort = 6667;
|
||||
const uint16_t kMaxPort = 65535;
|
||||
constexpr uint32_t kDefaultFinishTimeout = 30;
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_
|
||||
|
|
|
@ -853,7 +853,7 @@ bool StartFLWorkerAction(const ResourcePtr &) {
|
|||
bool StartPSServerAction(const ResourcePtr &res) {
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
MS_LOG(INFO) << "This node is server. Start wait for finalizing.";
|
||||
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
|
||||
if (!distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX)) {
|
||||
MS_LOG(ERROR) << "Failed to finalize server.";
|
||||
return false;
|
||||
}
|
||||
|
@ -955,7 +955,7 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
bool StartPSSchedulerAction(const ResourcePtr &) {
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
MS_LOG(INFO) << "This node is scheduler. Start wait for finalizing.";
|
||||
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
|
||||
if (!distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX)) {
|
||||
MS_LOG(ERROR) << "Failed to finalize server.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -351,11 +351,19 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Init gpu collective communication mode.");
|
||||
(void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective,
|
||||
"Finalize gpu collective communication mode.");
|
||||
(void)m.def("get_rank_id", &mindspore::device::gpu::CollectiveInitializer::GetRankID,
|
||||
"Finalize gpu collective communication mode.");
|
||||
(void)m.def("get_rank_size", &mindspore::device::gpu::CollectiveInitializer::GetRankSize,
|
||||
"Finalize gpu collective communication mode.");
|
||||
#else
|
||||
(void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective,
|
||||
"Init gpu collective communication mode.");
|
||||
(void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective,
|
||||
"Finalize gpu collective communication mode.");
|
||||
(void)m.def("get_rank_id", &mindspore::device::gpu::CollectiveFakeInitializer::GetRankID,
|
||||
"Finalize gpu collective communication mode.");
|
||||
(void)m.def("get_rank_size", &mindspore::device::gpu::CollectiveFakeInitializer::GetRankSize,
|
||||
"Finalize gpu collective communication mode.");
|
||||
#endif
|
||||
|
||||
(void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
|
||||
|
|
|
@ -1795,9 +1795,9 @@ void ClearResAtexit() {
|
|||
g_args_cache.clear();
|
||||
// clean static variable to prevent from crash. As static variable is released after
|
||||
// Python threads is released.
|
||||
MS_LOG(INFO) << "Start clear data_converter...";
|
||||
MS_LOG(INFO) << "Start clear ClearObjectCache...";
|
||||
parse::data_converter::ClearObjectCache();
|
||||
MS_LOG(INFO) << "End clear data_converter...";
|
||||
MS_LOG(INFO) << "End clear ClearObjectCache...";
|
||||
|
||||
MS_LOG(INFO) << "Start clear Parser...";
|
||||
parse::Parser::CleanParserResource();
|
||||
|
|
|
@ -883,20 +883,29 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui
|
|||
return WaitForDisconnect(timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::WaitForDisconnect(const uint32_t &) {
|
||||
bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
|
||||
// If the cluster state is NODE_TIMEOUT, this node is already disconnected.
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
return true;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||
// Caller should use this method to help block the thread.
|
||||
wait_finish_cond_.wait(lock, [&] {
|
||||
auto condition_func = [&] {
|
||||
if (is_finish_.load()) {
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
|
||||
}
|
||||
return is_finish_.load();
|
||||
});
|
||||
return true;
|
||||
};
|
||||
|
||||
bool res;
|
||||
if (timeout == UINT32_MAX) {
|
||||
// Caller should use this method to help block the thread.
|
||||
wait_finish_cond_.wait(lock, condition_func);
|
||||
res = true;
|
||||
} else {
|
||||
res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), condition_func);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void AbstractNode::InitClientToServer() {
|
||||
|
|
|
@ -29,6 +29,16 @@ void CollectiveFakeInitializer::FinalizeCollective() {
|
|||
MS_LOG(EXCEPTION) << "You are trying to call 'init('nccl')', Please check "
|
||||
"this MindSpore package is GPU version and built with NCCL.";
|
||||
}
|
||||
|
||||
uint32_t CollectiveFakeInitializer::GetRankID(const std::string &group_name) {
|
||||
MS_LOG(EXCEPTION) << "You are trying to call 'GetRankID', Please check "
|
||||
"this MindSpore package is GPU version and built with NCCL.";
|
||||
}
|
||||
|
||||
uint32_t CollectiveFakeInitializer::GetRankSize(const std::string &group_name) {
|
||||
MS_LOG(EXCEPTION) << "You are trying to call 'GetRankSize', Please check "
|
||||
"this MindSpore package is GPU version and built with NCCL.";
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_
|
||||
|
||||
#include "string"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
|
@ -28,6 +30,8 @@ class CollectiveFakeInitializer {
|
|||
CollectiveFakeInitializer &operator=(const CollectiveFakeInitializer &) = delete;
|
||||
static void InitCollective();
|
||||
static void FinalizeCollective();
|
||||
static uint32_t GetRankID(const std::string &group_name);
|
||||
static uint32_t GetRankSize(const std::string &group_name);
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
|
|
|
@ -68,6 +68,14 @@ void CollectiveInitializer::FinalizeCollective() {
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::GetRankID(const std::string &group_name) {
|
||||
return CollectiveInitializer::instance().GetRankIDByGroup(group_name);
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::GetRankSize(const std::string &group_name) {
|
||||
return CollectiveInitializer::instance().GetGroupSize(group_name);
|
||||
}
|
||||
|
||||
uint32_t CollectiveInitializer::local_rank_id() {
|
||||
uint32_t local_rank_id;
|
||||
if (common::CheckUseMPI()) {
|
||||
|
|
|
@ -41,6 +41,8 @@ class CollectiveInitializer {
|
|||
const void *collective_handle();
|
||||
static void InitCollective();
|
||||
static void FinalizeCollective();
|
||||
static uint32_t GetRankID(const std::string &group_name);
|
||||
static uint32_t GetRankSize(const std::string &group_name);
|
||||
|
||||
// The capsulation of the collective communication APIs for compatibility.
|
||||
uint32_t local_rank_id();
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from mindspore import log as logger
|
||||
from ._hccl_management import load_lib as hccl_load_lib
|
||||
from .._c_expression import get_rank_id, get_rank_size
|
||||
|
||||
_HCCL_AVAILABLE = False
|
||||
_HCCL_TEST_AVAILABLE = False
|
||||
|
@ -210,7 +211,7 @@ def _get_rank_helper(group, backend):
|
|||
else:
|
||||
rank_id = hccl.get_rank_id(group)
|
||||
elif backend == Backend.NCCL:
|
||||
rank_id = mpi.get_rank_id(group)
|
||||
rank_id = get_rank_id(group)
|
||||
else:
|
||||
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
|
||||
"please use hccl_mpi, hccl or nccl.".format(backend))
|
||||
|
@ -275,7 +276,7 @@ def _get_size_helper(group, backend):
|
|||
else:
|
||||
size = hccl.get_rank_size(group)
|
||||
elif backend == Backend.NCCL:
|
||||
size = mpi.get_rank_size(group)
|
||||
size = get_rank_size(group)
|
||||
else:
|
||||
raise ValueError("For '_get_size_helper', the argument 'backend' {} is not supported, "
|
||||
"please use hccl or nccl.".format(backend))
|
||||
|
|
Loading…
Reference in New Issue