failover support exit process when timeout occur

This commit is contained in:
lizhenyu 2022-03-18 11:04:53 +08:00
parent 201a89ddb8
commit 275f81c47b
32 changed files with 537 additions and 206 deletions

View File

@ -310,7 +310,6 @@ set(BACKEND_SUB_COMP
runtime/graph_scheduler
runtime/hardware
runtime/pynative
runtime/recovery
plugin/device/ascend/hal/device
plugin/device/ascend/hal/hardware
plugin/device/ascend/hal/hccl_adapter

View File

@ -37,7 +37,7 @@
#include "runtime/hardware/device_context_manager.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/pynative/run_op_helper.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "include/common/utils/scoped_long_running.h"
#ifdef ENABLE_D
#include "include/common/utils/callbacks_ge.h"
@ -953,8 +953,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);
bool need_contruct_output = !(runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
runtime::recovery::RecoveryContext::GetInstance()->need_reset());
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.

View File

@ -15,18 +15,24 @@
*/
#include "distributed/collective/collective_manager.h"
#include <algorithm>
#include <string>
#include <vector>
#include <functional>
#include <csignal>
#include <memory>
#include "utils/ms_context.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
namespace mindspore {
namespace distributed {
namespace collective {
using recovery::RecoveryContext;
CollectiveManager::CollectiveManager()
: inited_(false),
finalized_(true),
need_reinit_(false),
host_ctx_(nullptr),
device_ctx_(nullptr),
host_comm_lib_instance_(nullptr),
@ -60,8 +66,75 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
return instance;
}
namespace {
// The wrapper to provide a timeout mechanism for executing functions.
bool ExecuteFuncInThread(const std::function<bool()> &func, const int64_t timeout) {
bool execute_success = false;
bool execute_fail = false;
std::mutex exec_ret_mutex;
std::condition_variable thread_blocker;
std::unique_ptr<std::thread> executive_thread = std::make_unique<std::thread>([&] {
if (!func()) {
MS_LOG(ERROR) << "Failed to execute function asynchronously";
std::unique_lock<std::mutex> lock(exec_ret_mutex);
execute_fail = true;
thread_blocker.notify_one();
return;
}
{
std::unique_lock<std::mutex> lock(exec_ret_mutex);
execute_success = true;
thread_blocker.notify_one();
}
});
executive_thread->detach();
std::unique_lock<std::mutex> locker(exec_ret_mutex);
(void)thread_blocker.wait_for(locker, std::chrono::seconds(timeout), [&] { return execute_success || execute_fail; });
if (!execute_success && !execute_fail) {
std::string node_id = common::GetEnv("MS_NODE_ID");
#if !defined(_WIN32) && !defined(_WIN64)
MS_LOG(ERROR) << "Execute function asynchronously timeout, node id: " << node_id << " exit process";
(void)kill(getpid(), SIGTERM);
#endif
}
return execute_success;
}
// In a disaster recovery scenario, the comparison between the current unique id and the last generated unique id
// ensures that the acquired unique id is newly generated, and the latest unique id will be persisted.
bool CheckUniqueIDLatest(const std::string &group_name, size_t root_info_size, const void *root_info) {
MS_EXCEPTION_IF_NULL(root_info);
auto persistent_json = RecoveryContext::GetInstance()->persistent_json();
MS_EXCEPTION_IF_NULL(persistent_json);
std::string new_unique_id(static_cast<const char *>(root_info), root_info_size);
std::vector<int> new_unique_id_integer_seq;
(void)std::transform(new_unique_id.begin(), new_unique_id.end(), std::back_inserter(new_unique_id_integer_seq),
[](char c) { return static_cast<int>(c); });
const char unique_id_str[] = "_unique_id";
std::string unique_id_key = group_name + unique_id_str;
if (!persistent_json->Exists(unique_id_key)) {
persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
return true;
}
std::vector<int> old_unique_id_integer_seq = persistent_json->Get<std::vector<int>>(unique_id_key);
if (new_unique_id_integer_seq == old_unique_id_integer_seq) {
return false;
}
persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
return true;
}
} // namespace
bool CollectiveManager::Initialize() {
if (inited_ && !runtime::recovery::RecoveryContext::GetInstance()->need_reinit_collective()) {
if (inited_ && !need_reinit_) {
return true;
}
@ -98,6 +171,8 @@ bool CollectiveManager::Initialize() {
MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
inited_ = true;
finalized_ = false;
need_reinit_ = false;
return true;
}
@ -125,50 +200,36 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
void *root_info = group->GenerateRootInfo(&root_info_size);
MS_EXCEPTION_IF_NULL(root_info);
bool ret = false;
// Step 4: Broadcast the device root information to all nodes on host side.
if (!host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info)) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
return false;
while (!ret) {
ret = host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info);
if (!ret) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
return false;
}
// In disaster recovery scenarios, it is necessary to ensure that the unique id obtained from the Scheduler is a
// newly generated one.
if (RecoveryContext::GetInstance()->enable_recovery()) {
ret = CheckUniqueIDLatest(group_name, root_info_size, root_info);
}
}
// Step 5: Initialize communication group on the device side.
return InitDeviceCommGroup(group, root_info);
}
bool CollectiveManager::InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info) {
bool init_group_success = false;
bool init_group_fail = false;
std::condition_variable thread_blocker;
init_group_thread_ = std::make_unique<std::thread>([&, this] {
std::function<bool()> init_device_comm_group_func = [&, this]() {
device_ctx_->Initialize();
if (!group->Initialize(root_info)) {
MS_LOG(ERROR) << "Initialize group on the device side failed.";
std::unique_lock<std::mutex> lock(init_group_mutex_);
init_group_fail = true;
thread_blocker.notify_one();
return;
}
return group->Initialize(root_info);
};
MS_LOG(INFO) << "Begin initialize communication group on the device side.";
{
std::unique_lock<std::mutex> lock(init_group_mutex_);
init_group_success = true;
thread_blocker.notify_one();
}
});
init_group_thread_->detach();
// Timeout limit 180 seconds to wait finishing init device communication group.
// Timeout limit 180 seconds to wait finish initializing device communication group.
const int64_t kTimeToWait = 180;
std::unique_lock<std::mutex> locker(init_group_mutex_);
(void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait),
[&] { return init_group_success || init_group_fail; });
// Initialize communication group on the device side in thread with timeout limit.
ret = ExecuteFuncInThread(init_device_comm_group_func, kTimeToWait);
if (!init_group_success && runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
runtime::recovery::RecoveryErrCode::kInitNcclFailed);
MS_LOG(ERROR) << "Initialize group on the device side failed.";
}
return init_group_success;
MS_LOG(INFO) << "End initialize communication group on the device side.";
return ret;
}
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
@ -197,22 +258,34 @@ uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
}
bool CollectiveManager::Finalize() {
if (finalized_) {
if (!inited_.load() || finalized_.load()) {
return true;
}
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize host communication library.";
}
std::function<bool()> finalize_func = [&, this]() {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize host communication library.";
}
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize device communication library.";
}
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize device communication library.";
}
finalized_ = true;
return true;
finalized_ = true;
return true;
};
MS_LOG(INFO) << "Begin finalize collective manager.";
// Timeout limit 5 seconds to wait to finish finalizing device communication group.
const int64_t kTimeToWait = 5;
// Finalize collective manager in thread with timeout limit.
bool ret = ExecuteFuncInThread(finalize_func, kTimeToWait);
MS_LOG(INFO) << "End finalize collective manager.";
return ret;
}
void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }

View File

@ -69,6 +69,11 @@ class BACKEND_EXPORT CollectiveManager {
uint32_t local_rank_id() const;
// Set whether need reinitialize collective communication.
void set_need_reinit(bool need_reinit) { need_reinit_ = need_reinit; }
// Get whether need reinitialize collective communication.
bool need_reinit() const { return need_reinit_.load(); }
private:
CollectiveManager();
@ -81,16 +86,13 @@ class BACKEND_EXPORT CollectiveManager {
// Assign the local rank id for this process.
bool AssignLocalRank();
// Initialize communication group on the device side.
bool InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info);
// Initialize communication group on the device side in thread with timeout limit.
std::unique_ptr<std::thread> init_group_thread_;
std::mutex init_group_mutex_;
std::atomic_bool inited_;
std::atomic_bool finalized_;
// Whether need reinitialize collective communication, this value should be set to true once a training process
// exits unexpectedly is detected.
std::atomic_bool need_reinit_;
// The device type read from MindSpore context.
std::string device_type_;
@ -119,8 +121,8 @@ class BACKEND_EXPORT CollectiveManager {
// Global group ranks.
std::vector<uint32_t> global_group_ranks_;
// The global group name on the host side. This is used for Creating global group on host side for AllGather operation
// of host name while assigning local rank.
// The global group name on the host side. This is used for Creating global group on host side for AllGather
// operation of host name while assigning local rank.
std::string host_global_group_name_;
};
} // namespace collective

View File

@ -17,11 +17,11 @@
#include "distributed/init.h"
#include <vector>
#include <string>
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
namespace mindspore {
namespace distributed {
using runtime::recovery::RecoveryContext;
using distributed::recovery::RecoveryContext;
bool Initialize() {
if (!InitializeCluster()) {
@ -43,7 +43,8 @@ bool Initialize() {
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
if (RecoveryContext::GetInstance()->enable_recovery()) {
cluster::ClusterContext::instance()->WaitForClusterReady();
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
}
if (!InitializeCollective()) {
@ -52,8 +53,6 @@ bool Initialize() {
}
if (RecoveryContext::GetInstance()->enable_recovery()) {
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
}
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include <dirent.h>
#include <algorithm>
@ -26,13 +26,12 @@
#include "utils/file_utils.h"
#include "distributed/constants.h"
#include "distributed/cluster/topology/common.h"
#include "distributed/init.h"
#include "runtime/hardware/device_context.h"
#include "runtime/hardware/device_context_manager.h"
#include "utils/convert_utils_base.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace runtime {
namespace distributed {
namespace recovery {
constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY";
constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH";
@ -142,30 +141,6 @@ void RecoveryContext::Initialize() {
initialized_ = true;
}
bool RecoveryContext::ReInitializeCollective() {
auto ret = distributed::Initialize();
if (ret) {
recovery_status_ = RecoveryErrCode::kUnKnownError;
set_need_reset(true);
set_need_sync_weight_to_device(true);
return true;
}
if (recovery_status_ == RecoveryErrCode::kBroadcastUniqueIDFailed ||
recovery_status_ == RecoveryErrCode::kAllGatherHostNameFailed) {
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
// Retry duration: 30s.
const int kRetryDuration = 30;
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
return ReInitializeCollective();
} else if (recovery_status_ == RecoveryErrCode::kInitNcclFailed) {
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
}
MS_LOG(EXCEPTION) << "ReInitialize collective failed.";
return false;
}
void RecoveryContext::ObtainGlobalLatestCkptInfo() {
// 1. Obtain the step corresponding to the local latest checkpoint.
ObtainLocalLatestCkptInfo();
@ -326,6 +301,7 @@ void RecoveryContext::ParseLatestCkptInfo(const int *recv_buffer, const uint32_t
}
void RecoveryContext::CreatePersistentFile() {
std::unique_lock<std::mutex> lock(create_persist_json_mtx_);
if (node_role_ == distributed::kEnvRoleOfScheduler) {
return;
}
@ -344,7 +320,7 @@ void RecoveryContext::CreatePersistentFile() {
// The directory used to save ckpt is persisted to json file.
std::string persistent_file_path =
recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json";
persistent_json_ = std::make_unique<JsonUtils>(persistent_file_path);
persistent_json_ = std::make_shared<JsonUtils>(persistent_file_path);
if (!persistent_json_->Initialize()) {
MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path;
}
@ -388,6 +364,24 @@ std::string RecoveryContext::GetCkptPath() {
return persistent_json_->Get<std::string>(kCkptPath);
}
const std::shared_ptr<JsonUtils> &RecoveryContext::persistent_json() {
if (persistent_json_ == nullptr) {
CreatePersistentFile();
}
MS_EXCEPTION_IF_NULL(persistent_json_);
return persistent_json_;
}
std::string RecoveryContext::latest_ckpt_file() {
// For standalone training.
if (enable_recovery_ && global_rank_size_ == 0 && latest_ckpt_file_.empty()) {
ObtainLocalLatestCkptInfo();
}
return latest_ckpt_file_;
}
} // namespace recovery
} // namespace runtime
} // namespace distributed
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
#define MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
#include <vector>
#include <string>
@ -27,13 +27,11 @@
#include "include/backend/visible.h"
namespace mindspore {
namespace runtime {
namespace distributed {
namespace recovery {
using distributed::storage::FileIOUtils;
using distributed::storage::JsonUtils;
enum class RecoveryErrCode { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
// Used to save disaster recovery-related state quantities and provide disaster recovery-related
// functions, such as reinitializing collective communication, etc.
class BACKEND_EXPORT RecoveryContext {
@ -47,14 +45,6 @@ class BACKEND_EXPORT RecoveryContext {
}
~RecoveryContext() = default;
// Reinitializing collective communication.
bool ReInitializeCollective();
// Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
// some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
// successful checkpoint in each training process, and then take the minimum value as the final reset position.
void ObtainGlobalLatestCkptInfo();
// Get whether enable recovery or not.
bool enable_recovery() const { return enable_recovery_; }
@ -64,18 +54,14 @@ class BACKEND_EXPORT RecoveryContext {
// Get interval to persist model.
int recovery_interval() const { return recovery_interval_; }
// Get the error status of recovery.
RecoveryErrCode recovery_status() const { return recovery_status_; }
// Set the error status of recovery.
void set_recovery_status(RecoveryErrCode recovery_status) { recovery_status_ = recovery_status; }
// Set the path used to save checkpoint.
void SetCkptPath(const std::string &path);
// Get the path used to save checkpoint.
std::string GetCkptPath();
// Get the latest checkpoint in this node.
std::string latest_ckpt_file() const { return latest_ckpt_file_; }
std::string latest_ckpt_file();
// Get the epoch of latest checkpoint in this node.
int latest_ckpt_epoch() const { return latest_ckpt_epoch_; }
// Get the step of latest checkpoint in this node.
@ -99,10 +85,13 @@ class BACKEND_EXPORT RecoveryContext {
// Set global rank size.
void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
// Set whether need reinitialize collective communication.
void set_need_reinit_collective(bool need_reinit_collective) { need_reinit_collective_ = need_reinit_collective; }
// Get whether need reinitialize collective communication.
bool need_reinit_collective() const { return need_reinit_collective_.load(); }
// Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
// some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
// successful checkpoint in each training process, and then take the minimum value as the final reset position.
void ObtainGlobalLatestCkptInfo();
// Get the persistent json file pointer.
const std::shared_ptr<JsonUtils> &persistent_json();
private:
inline static std::shared_ptr<RecoveryContext> instance_{};
@ -155,20 +144,14 @@ class BACKEND_EXPORT RecoveryContext {
// performs load checkpoint.
bool need_sync_weight_to_device_{false};
// Whether need reinitialize collective communication, this value should be set to true once a training process
// exits unexpectedly is detected.
std::atomic_bool need_reinit_collective_{false};
// Whether the recovery context is already initialized.
bool initialized_{false};
// The error status of recovery.
RecoveryErrCode recovery_status_{RecoveryErrCode::kUnKnownError};
std::mutex create_persist_json_mtx_;
// The persitent json file util, used to persist recovery config.
std::unique_ptr<JsonUtils> persistent_json_;
std::shared_ptr<JsonUtils> persistent_json_;
};
} // namespace recovery
} // namespace runtime
} // namespace distributed
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
#endif // MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_

View File

@ -40,7 +40,7 @@
#include "ps/util.h"
#endif
#include "ps/ps_context.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "pybind_api/gil_scoped_long_running.h"
@ -58,7 +58,7 @@ using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
using PSContext = mindspore::ps::PSContext;
using RecoveryContext = mindspore::runtime::recovery::RecoveryContext;
using RecoveryContext = mindspore::distributed::recovery::RecoveryContext;
// Interface with python
PYBIND11_MODULE(_c_expression, m) {

View File

@ -61,6 +61,7 @@
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/pynative/op_executor.h"
#include "runtime/device/stream_synchronizer.h"
#include "distributed/collective/collective_manager.h"
#ifndef ENABLE_SECURITY
#ifdef ENABLE_D
@ -1608,6 +1609,8 @@ void ClearResAtexit() {
device::StreamSynchronizer::GetInstance()->Finalize();
MS_LOG(INFO) << "End Finalize StreamSynchronizer...";
(void)distributed::collective::CollectiveManager::instance()->Finalize();
PrimitivePy::ClearHookRes();
ad::g_k_prims.clear();
ad::ClearKPynativeCellStaticRes();

View File

@ -35,7 +35,7 @@ using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
using ps::core::NodeCommand;
// The time interval for send info or query info between worker and scheduler.
constexpr uint32_t kWaitDuration = 3;
constexpr uint32_t kWaitDuration = 5;
// The collective communication library for MindSpore self developed communication framework.
class MsCollectiveCommLib : public CollectiveCommunicationLib {

View File

@ -1301,6 +1301,7 @@ void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::SEND_EVENT] = nullptr;
RegisterActorRouteTableRspHandler();
RegisterInitCollectCommResphandler();
RegisterRecoveryRespHandler();
}
void AbstractNode::RegisterActorRouteTableRspHandler() {

View File

@ -228,6 +228,9 @@ class BACKEND_EXPORT AbstractNode : public Node {
// Register collective communication initialization response methods.
virtual void RegisterInitCollectCommResphandler() {}
// Register recovery response methods.
virtual void RegisterRecoveryRespHandler() {}
// when initializing the node, should initializing the node info.
void InitNodeInfo(const NodeRole &role);
// Initialize worker num and server num by cluster config.

View File

@ -139,6 +139,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() {
if (!stop_heartbeat_.load()) {
stop_heartbeat_ = true;
while (!heartbeat_stopped_.load()) {
if (is_finish_.load()) {
return;
}
MS_LOG(INFO) << "Waiting for heartbeat to stop...";
// Time interval for waiting the heartbeat to stop.
@ -152,6 +155,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() {
bool success = false;
while (!success) {
if (is_finish_.load()) {
return;
}
MS_LOG(WARNING) << "Trying to reconnect to the scheduler...";
success = InitClientToScheduler();
if (success) {
@ -176,6 +182,11 @@ void AbstractPSNode::RegisterInitCollectCommResphandler() {
handlers_[NodeCommand::SEND_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::QUERY_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
}
void AbstractPSNode::RegisterRecoveryRespHandler() {
handlers_[NodeCommand::SEND_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -38,6 +38,9 @@ class AbstractPSNode : public AbstractNode {
// Register collective communication initialization response methods.
void RegisterInitCollectCommResphandler() override;
// Register recovery response methods.
void RegisterRecoveryRespHandler() override;
// Indicate whether the heartbeat thread should be stopped.
std::atomic<bool> stop_heartbeat_{false};

View File

@ -67,7 +67,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
return rank_id;
}
} else {
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
}
return rank_id;
}
@ -206,11 +206,16 @@ std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
return servers_meta_list;
}
const std::unordered_map<std::string, NodeInfo> &NodeManager::QueryTimeOutNodesInfo() const {
return timeout_nodes_info_;
}
void NodeManager::UpdateCluster() {
// 1. update cluster timeout state
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
timeout_nodes_info_.clear();
std::lock_guard<std::mutex> lock(heartbeat_mutex_);
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
if (registered_nodes_info_.count(it->first)) {

View File

@ -139,6 +139,8 @@ class NodeManager {
bool IsAllNodesAlive() const;
const std::unordered_map<std::string, NodeInfo> &QueryTimeOutNodesInfo() const;
private:
std::mutex node_mutex_;
std::mutex cluster_mutex_;

View File

@ -57,6 +57,12 @@ enum NodeCommand {
SEND_UNIQUE_ID = 20;
// Query unique id used to initialize collective communication.
QUERY_UNIQUE_ID = 21;
// Send the ready status to finish transform graph of computed node,
// used in disaster recovery mode to prevent timeout of waiting for graph transformation.
SEND_FINISH_TRANSFORM = 22;
// Query the ready status to finish transform graph of computed node,
// used in disaster recovery mode to prevent timeout of waiting for graph transformation.
QUERY_FINISH_TRANSFORM = 23;
}
enum NodeRole {
@ -298,3 +304,19 @@ message QueryUniqueIDRespMessage {
// The unique id used to initialize collective communication.
bytes unique_id = 2;
}
message SendFinishTransformMessage {
// the current Node unique id:0,1,2...
string node_id = 1;
// The rank id of the node in the cluster.
uint32 rank_id = 2;
// Whether finish transform graph.
bool is_ready = 3;
}
message QueryFinishTransformRespMessage {
// Whether all computed nodes are ready to run dag.
bool is_ready = 1;
// Whether there is any worker timeout.
bool is_worker_timeout = 2;
}

View File

@ -16,6 +16,7 @@
#include <memory>
#include "ps/core/ps_scheduler_node.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace ps {
@ -58,6 +59,13 @@ void PSSchedulerNode::RegisterInitCollectCommServiceHandler() {
handlers_[NodeCommand::QUERY_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryUniqueID);
}
void PSSchedulerNode::RegisterRecoveryServiceHandler() {
handlers_[NodeCommand::SEND_FINISH_TRANSFORM] =
static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendFinishTransform);
handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] =
static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryFinishTransform);
}
void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
@ -71,7 +79,7 @@ void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &serv
std::string node_id = send_host_name_msg.node_id();
uint32_t rank_id = send_host_name_msg.rank_id();
size_t host_hash_name = send_host_name_msg.host_hash_name();
MS_LOG(INFO) << "Received send host name request, node id: " << node_id << ", rank id: " << rank_id;
MS_LOG(INFO) << "Receive send host name request, node id: " << node_id << ", rank id: " << rank_id;
bool ret = false;
std::string error = "";
@ -100,7 +108,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se
query_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
MS_LOG(INFO) << "Received query host name request, node id: " << node_id << ", rank id: " << rank_id;
MS_LOG(INFO) << "Receive query host name request, node id: " << node_id << ", rank id: " << rank_id;
bool is_success = recv_rank_id_send_host_name_.size() == host_hash_names_.size();
QueryHostHashNameRespMessage resp_msg;
@ -121,6 +129,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se
if (recv_rank_id_query_host_name_.size() == recv_rank_id_send_host_name_.size()) {
recv_rank_id_send_host_name_.clear();
recv_rank_id_query_host_name_.clear();
node_timeout_ = false;
}
}
}
@ -138,7 +147,7 @@ void PSSchedulerNode::ProcessSendUniqueID(const std::shared_ptr<TcpServer> &serv
std::string node_id = send_unique_id_msg.node_id();
uint32_t rank_id = send_unique_id_msg.rank_id();
std::string group_name = send_unique_id_msg.group_name();
MS_LOG(INFO) << "Received send unique id request, group name: " << group_name << ", node id: " << node_id
MS_LOG(INFO) << "Receive send unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;
bool ret = false;
@ -169,7 +178,7 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
std::string group_name = query_msg.group_name();
MS_LOG(INFO) << "Received query unique id request, group name: " << group_name << ", node id: " << node_id
MS_LOG(INFO) << "Receive query unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;
auto iter = unique_id_group_.find(group_name);
@ -190,6 +199,89 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser
MS_LOG(INFO) << "Respond query unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;
}
void PSSchedulerNode::ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);
SendFinishTransformMessage send_ready_to_run_msg;
send_ready_to_run_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = send_ready_to_run_msg.node_id();
uint32_t rank_id = send_ready_to_run_msg.rank_id();
MS_LOG(INFO) << "Receive send finish transform request, node id: " << node_id << ", rank id: " << rank_id;
bool is_ready = send_ready_to_run_msg.is_ready();
if (is_ready) {
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
(void)nodes_finish_trans_.insert(rank_id);
}
GeneralResponse(server, conn, meta, true, "");
MS_LOG(INFO) << "Respond send finish transform request, node id: " << node_id << ", rank id: " << rank_id;
}
void PSSchedulerNode::ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);
GeneralQueryMessage query_msg;
query_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
MS_LOG(INFO) << "Receive query finish transform request, node id: " << node_id << ", rank id: " << rank_id;
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
bool is_ready = nodes_finish_trans_.size() == worker_num_;
QueryFinishTransformRespMessage resp_msg;
resp_msg.set_is_ready(is_ready);
if (node_timeout_) {
(void)resp_msg.set_is_worker_timeout(true);
} else {
resp_msg.set_is_worker_timeout(false);
}
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(),
resp_msg.ByteSizeLong())) {
MS_LOG(ERROR) << "Scheduler failed to respond message.";
return;
}
MS_LOG(INFO) << "Respond query finish transform request, node id: " << node_id << ", rank id: " << rank_id;
}
void PSSchedulerNode::HandleNodeTimeoutForRecovery(
const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) {
return;
}
if (timeout_nodes_infos.empty()) {
return;
}
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
node_timeout_ = true;
for (const auto &item : timeout_nodes_infos) {
(void)nodes_finish_trans_.erase(item.second.rank_id_);
}
}
void PSSchedulerNode::HandleNodeRecoverByHeartBeat(uint32_t rank_id) {
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
(void)nodes_finish_trans_.insert(rank_id);
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_
#include <map>
#include <unordered_map>
#include <memory>
#include <vector>
#include <set>
@ -47,9 +48,16 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
// alive node should be rejected.
bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; }
bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) override {
return true;
};
// Register collective communication initialization service.
void RegisterInitCollectCommServiceHandler() override;
// Register recovery service.
void RegisterRecoveryServiceHandler() override;
// Process message for sending node's host name.
void ProcessSendHostName(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
@ -66,6 +74,20 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
void ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Process message for sending the ready status to finish transform graph of computed node,
void ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Process message for querying the ready status to finish transform graph of computed node,
void ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Handle node timeout info and update nodes which finish transform graph.
void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) override;
// Recover finish transform nodes info when nodes recover heartbeat.
void HandleNodeRecoverByHeartBeat(uint32_t rank_id) override;
// Record received host hash name from workers.
std::vector<size_t> host_hash_names_;
// Record rank id of the nodes which sended host name.
@ -77,6 +99,11 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
std::map<std::string, std::string> unique_id_group_;
uint32_t worker_num_;
std::mutex nodes_finish_trans_mutex_;
// Record the rank ids of nodes who finish transform graph.
std::set<uint32_t> nodes_finish_trans_;
std::atomic_bool node_timeout_{false};
};
} // namespace core
} // namespace ps

View File

@ -39,16 +39,13 @@ bool PSServerNode::Start(const uint32_t &timeout) {
}
void PSServerNode::Initialize() {
InitNodeNum();
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the server node is failed.";
}
if (config_->Initialize() && !Recover()) {
MS_LOG(INFO) << "Recover the server node is failed.";
}
InitServerHandler();
CreateTcpServer();
InitNodeInfo(NodeRole::SERVER);
@ -119,8 +116,9 @@ void PSServerNode::Register(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
const int kCommTimeoutInSeconds = 20;
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
register_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
} else {

View File

@ -79,16 +79,13 @@ bool PSWorkerNode::Finish(const uint32_t &timeout) {
}
void PSWorkerNode::Initialize() {
InitNodeNum();
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the worker node is failed.";
}
if (config_->Initialize() && !Recover()) {
MS_LOG(INFO) << "Recover the worker node is failed.";
}
InitServerHandler();
CreateTcpServer();
InitNodeInfo(NodeRole::WORKER);
@ -120,8 +117,9 @@ void PSWorkerNode::Register(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
const int kCommTimeoutInSeconds = 20;
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
register_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
} else {

View File

@ -176,10 +176,12 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
return;
}
uint32_t rank_id = UINT32_MAX;
// Re-Add the missing node into node manager.
if (heartbeat_message.has_address() &&
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port())) {
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port(), &rank_id)) {
SetRegisterConnectionFd(conn, node_id);
HandleNodeRecoverByHeartBeat(rank_id);
if (node_manager_.IsAllNodesRegistered()) {
is_ready_ = true;
@ -234,6 +236,7 @@ void SchedulerNode::InitCommandHandler() {
handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
RegisterActorRouteTableServiceHandler();
RegisterInitCollectCommServiceHandler();
RegisterRecoveryServiceHandler();
}
void SchedulerNode::RegisterActorRouteTableServiceHandler() {
@ -699,6 +702,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
}
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
node_manager_.UpdateCluster();
HandleNodeTimeoutForRecovery(node_manager_.QueryTimeOutNodesInfo());
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
std::this_thread::sleep_for(

View File

@ -92,6 +92,15 @@ class BACKEND_EXPORT SchedulerNode : public Node {
// Register collective communication initialization service.
virtual void RegisterInitCollectCommServiceHandler() {}
// Register recovery service.
virtual void RegisterRecoveryServiceHandler() {}
// Handle node timeout info and update nodes which finish transform graph.
virtual void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) {}
// Recover finish transform nodes info when nodes recover heartbeat.
virtual void HandleNodeRecoverByHeartBeat(uint32_t rank_id) {}
const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);
void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
@ -197,7 +206,7 @@ class BACKEND_EXPORT SchedulerNode : public Node {
void SetRegisterConnectionFd(const std::shared_ptr<TcpConnection> &conn, const std::string &node_id);
bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos);
virtual bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos);
// Responding peer with the general response message.
void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,

View File

@ -17,16 +17,19 @@
#include "runtime/device/stream_synchronizer.h"
#include "utils/ms_context.h"
#include "distributed/collective/collective_manager.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
namespace mindspore {
namespace device {
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;
std::mutex StreamSynchronizer::instance_lock_;
std::shared_ptr<StreamSynchronizer> StreamSynchronizer::instance_ = nullptr;
void StreamSynchronizer::Initialize() {
// Non disaster recovery mode does not need to start thread and timeout mechanisms.
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
if (!RecoveryContext::GetInstance()->enable_recovery()) {
return;
}
@ -56,7 +59,7 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim
MS_EXCEPTION_IF_NULL(device_context);
// If disable recovery or timeout==0, sync stream directly to improve performance.
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
if (!RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
device_context->Initialize();
return device_context->SyncStream();
}
@ -68,26 +71,19 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim
device_context_ = device_context;
do_sync_stream_cv_.notify_one();
if (sync_stream_time_out_) {
// If sync stream timeout has happened, increase the timeout by 4 times.
const uint32_t kTimeOutScaleFactor = 4;
timeout *= kTimeOutScaleFactor;
}
if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) {
if (!sync_stream_ret_) {
MS_LOG(ERROR) << "Synchronize stream failed.";
}
return sync_stream_ret_;
} else {
sync_stream_time_out_ = true;
runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true);
if (!distributed::collective::CollectiveManager::instance()->Finalize()) {
CollectiveManager::instance()->set_need_reinit(true);
if (!CollectiveManager::instance()->Finalize()) {
MS_LOG(ERROR) << "Finalize collective manager failed.";
return false;
}
time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; });
MS_LOG(WARNING) << "Synchronize stream time out.";
MS_LOG(WARNING) << "Synchronize stream timeout.";
return true;
}
}

View File

@ -71,9 +71,6 @@ class BACKEND_EXPORT StreamSynchronizer {
// The singleton pointer.
static std::shared_ptr<StreamSynchronizer> instance_;
// Record whether the synchronization stream task has timed out.
bool sync_stream_time_out_{false};
// Return value of synchronization stream.
bool sync_stream_ret_{false};

View File

@ -25,11 +25,11 @@
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "include/common/utils/convert_utils.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::recovery::RecoveryContext;
namespace {
void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
const DeviceContext *device_context, OpContext<DeviceTensor> *const context,

View File

@ -21,11 +21,13 @@
#include "runtime/graph_scheduler/actor/debug_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;
void KernelActor::Init() {
// Check device contexts number.
@ -243,7 +245,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
PreLaunchKernel(context);
try {
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
// especially the collective communication operators.
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "

View File

@ -25,11 +25,13 @@
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "runtime/device/stream_synchronizer.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;
void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
@ -70,8 +72,7 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
(void)sync_stream_device_contexts.insert(device_context);
// Trigger disaster recovery and exit loop early.
if (RecoveryContext::GetInstance()->enable_recovery() &&
RecoveryContext::GetInstance()->need_reinit_collective()) {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
current_count_ = loop_count_;
}
}

View File

@ -17,11 +17,13 @@
#include "runtime/graph_scheduler/actor/output_actor.h"
#include "runtime/hardware/device_context_manager.h"
#include "utils/log_adapter.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
MS_EXCEPTION_IF_NULL(output_node);
@ -105,7 +107,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
++current_count_;
// Trigger disaster recovery and return empty output.
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
FreeOutputNodeMem();
ClearOutputCache();
SET_OPCONTEXT_SUCCESS_RET((*context));

View File

@ -45,11 +45,19 @@
#endif
#include "profiler/device/profiling.h"
#include "include/common/debug/common.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
#include "distributed/cluster/cluster_context.h"
#else
#include "distributed/cluster/dummy_cluster_context.h"
#endif
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::cluster::ClusterContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;
namespace {
bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
MS_EXCEPTION_IF_NULL(from_device_context);
@ -196,6 +204,108 @@ void IntHandler(int, siginfo_t *, void *) {
(void)kill(this_pid, SIGTERM);
}
#endif
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
bool SendFinishTransform() {
auto node = ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
if (node->role() != ps::core::NodeRole::WORKER) {
return true;
}
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);
ps::core::SendFinishTransformMessage send_ready_to_run_msg;
send_ready_to_run_msg.set_node_id(abstract_node->node_id());
send_ready_to_run_msg.set_rank_id(abstract_node->rank_id());
send_ready_to_run_msg.set_is_ready(true);
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!abstract_node->SendToScheduler(send_ready_to_run_msg.SerializeAsString().data(),
send_ready_to_run_msg.SerializeAsString().size(),
ps::core::NodeCommand::SEND_FINISH_TRANSFORM, &output)) {
MS_LOG(WARNING) << "Failed to send finish transform request to scheduler.";
return false;
}
ps::core::GeneralResponseMsg resp_msg;
MS_EXCEPTION_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!resp_msg.is_success()) {
MS_LOG(ERROR) << "Send finish transform to scheduler failed.";
return false;
}
return true;
}
bool QueryFinishTransform() {
auto node = ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
if (node->role() != ps::core::NodeRole::WORKER) {
return true;
}
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);
ps::core::GeneralQueryMessage general_query_msg;
general_query_msg.set_node_id(abstract_node->node_id());
general_query_msg.set_rank_id(abstract_node->rank_id());
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
bool ret = false;
while (!ret) {
if (!abstract_node->SendToScheduler(general_query_msg.SerializeAsString().data(),
general_query_msg.SerializeAsString().size(),
ps::core::NodeCommand::QUERY_FINISH_TRANSFORM, &output)) {
MS_LOG(WARNING) << "Failed to send query finish transform request to scheduler.";
ret = false;
continue;
}
ps::core::QueryFinishTransformRespMessage resp_msg;
MS_EXCEPTION_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
ret = resp_msg.is_ready();
if (!ret) {
MS_LOG(INFO) << "There is worker which has not finished transform graph";
}
if (resp_msg.is_worker_timeout()) {
MS_LOG(WARNING) << "There is worker timeout";
return false;
}
// The time interval for querying the all worker finish transform graphs status to scheduler: 10 seconds.
const uint32_t kWaitDuration = 10;
std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}
return ret;
}
void DoDisasterRecovery() {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
bool ret = false;
while (!ret) {
while (!CollectiveManager::instance()->Initialize()) {
MS_LOG(WARNING) << "ReInitialize collective communication failed, retrying...";
}
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
ret = QueryFinishTransform();
if (!ret) {
CollectiveManager::instance()->set_need_reinit(true);
(void)CollectiveManager::instance()->Finalize();
}
}
RecoveryContext::GetInstance()->set_need_reset(true);
RecoveryContext::GetInstance()->set_need_sync_weight_to_device(true);
}
}
#endif
} // namespace
GraphScheduler &GraphScheduler::GetInstance() noexcept {
@ -407,6 +517,17 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
}
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
if (ClusterContext::instance()->initialized() && RecoveryContext::GetInstance()->enable_recovery()) {
while (!SendFinishTransform()) {
MS_LOG(WARNING) << "Send finish transform graph failed.";
// The time interval for sending finish transform graph to scheduler.
constexpr uint32_t kWaitDuration = 10;
std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}
}
#endif
return actor_set.get();
}
@ -489,15 +610,9 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
const size_t kSecondsToMilliseconds = 1000;
SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
}
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
RecoveryContext::GetInstance()->set_need_reinit_collective(false);
}
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
DoDisasterRecovery();
#endif
}
void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,

View File

@ -1,9 +0,0 @@
file(GLOB_RECURSE RECOVERY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"recovery_context.cc")
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-abstract-non-virtual-dtor")
endif()
set_property(SOURCE ${RECOVERY_SRC_LIST} PROPERTY SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_runtime_recovery_obj OBJECT ${RECOVERY_SRC_LIST})

View File

@ -306,8 +306,7 @@ add_library(backend_static STATIC
$<TARGET_OBJECTS:_mindspore_runtime_device_obj>
$<TARGET_OBJECTS:_mindspore_runtime_graph_scheduler_obj>
$<TARGET_OBJECTS:_mindspore_runtime_hardware_obj>
$<TARGET_OBJECTS:_mindspore_runtime_pynative_obj>
$<TARGET_OBJECTS:_mindspore_runtime_recovery_obj>)
$<TARGET_OBJECTS:_mindspore_runtime_pynative_obj>)
target_link_libraries(ut_tests PRIVATE mindspore securec -Wl,--start-group proto_input mindspore::protobuf
backend_static -Wl,--end-group)
target_link_libraries(ut_tests PRIVATE mindspore::grpc++)