Fix comm helper method

This commit is contained in:
ZPaC 2021-12-07 15:14:10 +08:00
parent bd1c1772ea
commit 78a79a9b5e
14 changed files with 60 additions and 17 deletions

View File

@ -84,12 +84,12 @@ bool ClusterContext::Initialize() {
return true;
}
bool ClusterContext::Finalize() {
bool ClusterContext::Finalize(uint32_t timeout) {
if (finalized_) {
return true;
}
// In some cases, one node calls the Finish function while other nodes don't. So timeout is acceptable.
if (!node_->Finish()) {
if (!node_->Finish(timeout)) {
MS_LOG(WARNING) << "Finishing node " << node_role_ << " timeout.";
}
if (!node_->Stop()) {

View File

@ -47,8 +47,8 @@ class ClusterContext {
// Initialize the cluster configuration and build network.
bool Initialize();
// Finalize the cluster and process exits.
bool Finalize();
// Finalize the cluster and process exits. If timeout is set to UINT32_MAX, this method will block without timeout.
bool Finalize(uint32_t timeout = kDefaultFinishTimeout);
// Return node object of this process.
const std::shared_ptr<ps::core::Node> &node() const;

View File

@ -31,7 +31,7 @@ std::shared_ptr<ClusterContext> ClusterContext::instance() {
bool ClusterContext::Initialize() const { return true; }
bool ClusterContext::Finalize() const { return true; }
bool ClusterContext::Finalize(uint32_t) const { return true; }
bool ClusterContext::initialized() const { return false; }
} // namespace cluster

View File

@ -38,7 +38,7 @@ class ClusterContext {
static std::shared_ptr<ClusterContext> instance();
bool Initialize() const;
bool Finalize() const;
bool Finalize(uint32_t timeout = kDefaultFinishTimeout) const;
bool initialized() const;
private:

View File

@ -38,6 +38,7 @@ constexpr char kLocalHost[] = "127.0.0.1";
constexpr int MAX_HOSTNAME_LEN = 1024;
const uint16_t kDefaultSchedPort = 6667;
const uint16_t kMaxPort = 65535;
constexpr uint32_t kDefaultFinishTimeout = 30;
} // namespace distributed
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_

View File

@ -853,7 +853,7 @@ bool StartFLWorkerAction(const ResourcePtr &) {
bool StartPSServerAction(const ResourcePtr &res) {
if (distributed::cluster::ClusterContext::instance()->initialized()) {
MS_LOG(INFO) << "This node is server. Start wait for finalizing.";
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
if (!distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX)) {
MS_LOG(ERROR) << "Failed to finalize server.";
return false;
}
@ -955,7 +955,7 @@ bool StartServerAction(const ResourcePtr &res) {
bool StartPSSchedulerAction(const ResourcePtr &) {
if (distributed::cluster::ClusterContext::instance()->initialized()) {
MS_LOG(INFO) << "This node is scheduler. Start wait for finalizing.";
if (!distributed::cluster::ClusterContext::instance()->Finalize()) {
if (!distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX)) {
MS_LOG(ERROR) << "Failed to finalize server.";
return false;
}

View File

@ -351,11 +351,19 @@ PYBIND11_MODULE(_c_expression, m) {
"Init gpu collective communication mode.");
(void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective,
"Finalize gpu collective communication mode.");
(void)m.def("get_rank_id", &mindspore::device::gpu::CollectiveInitializer::GetRankID,
"Finalize gpu collective communication mode.");
(void)m.def("get_rank_size", &mindspore::device::gpu::CollectiveInitializer::GetRankSize,
"Finalize gpu collective communication mode.");
#else
(void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective,
"Init gpu collective communication mode.");
(void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective,
"Finalize gpu collective communication mode.");
(void)m.def("get_rank_id", &mindspore::device::gpu::CollectiveFakeInitializer::GetRankID,
"Finalize gpu collective communication mode.");
(void)m.def("get_rank_size", &mindspore::device::gpu::CollectiveFakeInitializer::GetRankSize,
"Finalize gpu collective communication mode.");
#endif
(void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")

View File

@ -1795,9 +1795,9 @@ void ClearResAtexit() {
g_args_cache.clear();
// clean static variable to prevent from crash. As static variable is released after
// Python threads is released.
MS_LOG(INFO) << "Start clear data_converter...";
MS_LOG(INFO) << "Start clear ClearObjectCache...";
parse::data_converter::ClearObjectCache();
MS_LOG(INFO) << "End clear data_converter...";
MS_LOG(INFO) << "End clear ClearObjectCache...";
MS_LOG(INFO) << "Start clear Parser...";
parse::Parser::CleanParserResource();

View File

@ -883,20 +883,29 @@ bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const ui
return WaitForDisconnect(timeout);
}
bool AbstractNode::WaitForDisconnect(const uint32_t &) {
bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
// If the cluster state is NODE_TIMEOUT, this node is already disconnected.
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
return true;
}
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
// Caller should use this method to help block the thread.
wait_finish_cond_.wait(lock, [&] {
auto condition_func = [&] {
if (is_finish_.load()) {
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success finish!";
}
return is_finish_.load();
});
return true;
};
bool res;
if (timeout == UINT32_MAX) {
// Caller should use this method to help block the thread.
wait_finish_cond_.wait(lock, condition_func);
res = true;
} else {
res = wait_finish_cond_.wait_for(lock, std::chrono::seconds(timeout), condition_func);
}
return res;
}
void AbstractNode::InitClientToServer() {

View File

@ -29,6 +29,16 @@ void CollectiveFakeInitializer::FinalizeCollective() {
MS_LOG(EXCEPTION) << "You are trying to call 'init('nccl')', Please check "
"this MindSpore package is GPU version and built with NCCL.";
}
uint32_t CollectiveFakeInitializer::GetRankID(const std::string &group_name) {
MS_LOG(EXCEPTION) << "You are trying to call 'GetRankID', Please check "
"this MindSpore package is GPU version and built with NCCL.";
}
uint32_t CollectiveFakeInitializer::GetRankSize(const std::string &group_name) {
MS_LOG(EXCEPTION) << "You are trying to call 'GetRankSize', Please check "
"this MindSpore package is GPU version and built with NCCL.";
}
} // namespace gpu
} // namespace device
} // namespace mindspore

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_FAKE_INIT_H_
#include "string"
namespace mindspore {
namespace device {
namespace gpu {
@ -28,6 +30,8 @@ class CollectiveFakeInitializer {
CollectiveFakeInitializer &operator=(const CollectiveFakeInitializer &) = delete;
static void InitCollective();
static void FinalizeCollective();
static uint32_t GetRankID(const std::string &group_name);
static uint32_t GetRankSize(const std::string &group_name);
};
} // namespace gpu
} // namespace device

View File

@ -68,6 +68,14 @@ void CollectiveInitializer::FinalizeCollective() {
}
}
uint32_t CollectiveInitializer::GetRankID(const std::string &group_name) {
return CollectiveInitializer::instance().GetRankIDByGroup(group_name);
}
uint32_t CollectiveInitializer::GetRankSize(const std::string &group_name) {
return CollectiveInitializer::instance().GetGroupSize(group_name);
}
uint32_t CollectiveInitializer::local_rank_id() {
uint32_t local_rank_id;
if (common::CheckUseMPI()) {

View File

@ -41,6 +41,8 @@ class CollectiveInitializer {
const void *collective_handle();
static void InitCollective();
static void FinalizeCollective();
static uint32_t GetRankID(const std::string &group_name);
static uint32_t GetRankSize(const std::string &group_name);
// The capsulation of the collective communication APIs for compatibility.
uint32_t local_rank_id();

View File

@ -17,6 +17,7 @@
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib
from .._c_expression import get_rank_id, get_rank_size
_HCCL_AVAILABLE = False
_HCCL_TEST_AVAILABLE = False
@ -210,7 +211,7 @@ def _get_rank_helper(group, backend):
else:
rank_id = hccl.get_rank_id(group)
elif backend == Backend.NCCL:
rank_id = mpi.get_rank_id(group)
rank_id = get_rank_id(group)
else:
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi, hccl or nccl.".format(backend))
@ -275,7 +276,7 @@ def _get_size_helper(group, backend):
else:
size = hccl.get_rank_size(group)
elif backend == Backend.NCCL:
size = mpi.get_rank_size(group)
size = get_rank_size(group)
else:
raise ValueError("For '_get_size_helper', the argument 'backend' {} is not supported, "
"please use hccl or nccl.".format(backend))