failover support exit process when timeout occur
This commit is contained in:
parent
201a89ddb8
commit
275f81c47b
|
@ -310,7 +310,6 @@ set(BACKEND_SUB_COMP
|
|||
runtime/graph_scheduler
|
||||
runtime/hardware
|
||||
runtime/pynative
|
||||
runtime/recovery
|
||||
plugin/device/ascend/hal/device
|
||||
plugin/device/ascend/hal/hardware
|
||||
plugin/device/ascend/hal/hccl_adapter
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "runtime/pynative/run_op_helper.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "include/common/utils/callbacks_ge.h"
|
||||
|
@ -953,8 +953,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||
|
||||
bool need_contruct_output = !(runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
runtime::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
if (need_contruct_output) {
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
|
|
|
@ -15,18 +15,24 @@
|
|||
*/
|
||||
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <csignal>
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace collective {
|
||||
using recovery::RecoveryContext;
|
||||
|
||||
CollectiveManager::CollectiveManager()
|
||||
: inited_(false),
|
||||
finalized_(true),
|
||||
need_reinit_(false),
|
||||
host_ctx_(nullptr),
|
||||
device_ctx_(nullptr),
|
||||
host_comm_lib_instance_(nullptr),
|
||||
|
@ -60,8 +66,75 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
|
|||
return instance;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// The wrapper to provide a timeout mechanism for executing functions.
|
||||
bool ExecuteFuncInThread(const std::function<bool()> &func, const int64_t timeout) {
|
||||
bool execute_success = false;
|
||||
bool execute_fail = false;
|
||||
std::mutex exec_ret_mutex;
|
||||
std::condition_variable thread_blocker;
|
||||
|
||||
std::unique_ptr<std::thread> executive_thread = std::make_unique<std::thread>([&] {
|
||||
if (!func()) {
|
||||
MS_LOG(ERROR) << "Failed to execute function asynchronously";
|
||||
std::unique_lock<std::mutex> lock(exec_ret_mutex);
|
||||
execute_fail = true;
|
||||
thread_blocker.notify_one();
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(exec_ret_mutex);
|
||||
execute_success = true;
|
||||
thread_blocker.notify_one();
|
||||
}
|
||||
});
|
||||
executive_thread->detach();
|
||||
|
||||
std::unique_lock<std::mutex> locker(exec_ret_mutex);
|
||||
(void)thread_blocker.wait_for(locker, std::chrono::seconds(timeout), [&] { return execute_success || execute_fail; });
|
||||
|
||||
if (!execute_success && !execute_fail) {
|
||||
std::string node_id = common::GetEnv("MS_NODE_ID");
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
MS_LOG(ERROR) << "Execute function asynchronously timeout, node id: " << node_id << " exit process";
|
||||
(void)kill(getpid(), SIGTERM);
|
||||
#endif
|
||||
}
|
||||
return execute_success;
|
||||
}
|
||||
|
||||
// In a disaster recovery scenario, the comparison between the current unique id and the last generated unique id
|
||||
// ensures that the acquired unique id is newly generated, and the latest unique id will be persisted.
|
||||
bool CheckUniqueIDLatest(const std::string &group_name, size_t root_info_size, const void *root_info) {
|
||||
MS_EXCEPTION_IF_NULL(root_info);
|
||||
auto persistent_json = RecoveryContext::GetInstance()->persistent_json();
|
||||
MS_EXCEPTION_IF_NULL(persistent_json);
|
||||
|
||||
std::string new_unique_id(static_cast<const char *>(root_info), root_info_size);
|
||||
std::vector<int> new_unique_id_integer_seq;
|
||||
(void)std::transform(new_unique_id.begin(), new_unique_id.end(), std::back_inserter(new_unique_id_integer_seq),
|
||||
[](char c) { return static_cast<int>(c); });
|
||||
|
||||
const char unique_id_str[] = "_unique_id";
|
||||
std::string unique_id_key = group_name + unique_id_str;
|
||||
if (!persistent_json->Exists(unique_id_key)) {
|
||||
persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int> old_unique_id_integer_seq = persistent_json->Get<std::vector<int>>(unique_id_key);
|
||||
if (new_unique_id_integer_seq == old_unique_id_integer_seq) {
|
||||
return false;
|
||||
}
|
||||
|
||||
persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CollectiveManager::Initialize() {
|
||||
if (inited_ && !runtime::recovery::RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
if (inited_ && !need_reinit_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -98,6 +171,8 @@ bool CollectiveManager::Initialize() {
|
|||
MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
|
||||
inited_ = true;
|
||||
finalized_ = false;
|
||||
need_reinit_ = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -125,50 +200,36 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
|
|||
void *root_info = group->GenerateRootInfo(&root_info_size);
|
||||
MS_EXCEPTION_IF_NULL(root_info);
|
||||
|
||||
bool ret = false;
|
||||
// Step 4: Broadcast the device root information to all nodes on host side.
|
||||
if (!host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info)) {
|
||||
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
|
||||
return false;
|
||||
while (!ret) {
|
||||
ret = host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// In disaster recovery scenarios, it is necessary to ensure that the unique id obtained from the Scheduler is a
|
||||
// newly generated one.
|
||||
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
ret = CheckUniqueIDLatest(group_name, root_info_size, root_info);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Initialize communication group on the device side.
|
||||
return InitDeviceCommGroup(group, root_info);
|
||||
}
|
||||
|
||||
bool CollectiveManager::InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info) {
|
||||
bool init_group_success = false;
|
||||
bool init_group_fail = false;
|
||||
std::condition_variable thread_blocker;
|
||||
init_group_thread_ = std::make_unique<std::thread>([&, this] {
|
||||
std::function<bool()> init_device_comm_group_func = [&, this]() {
|
||||
device_ctx_->Initialize();
|
||||
if (!group->Initialize(root_info)) {
|
||||
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
||||
std::unique_lock<std::mutex> lock(init_group_mutex_);
|
||||
init_group_fail = true;
|
||||
thread_blocker.notify_one();
|
||||
return;
|
||||
}
|
||||
return group->Initialize(root_info);
|
||||
};
|
||||
MS_LOG(INFO) << "Begin initialize communication group on the device side.";
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(init_group_mutex_);
|
||||
init_group_success = true;
|
||||
thread_blocker.notify_one();
|
||||
}
|
||||
});
|
||||
init_group_thread_->detach();
|
||||
|
||||
// Timeout limit 180 seconds to wait finishing init device communication group.
|
||||
// Timeout limit 180 seconds to wait finish initializing device communication group.
|
||||
const int64_t kTimeToWait = 180;
|
||||
std::unique_lock<std::mutex> locker(init_group_mutex_);
|
||||
(void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait),
|
||||
[&] { return init_group_success || init_group_fail; });
|
||||
// Initialize communication group on the device side in thread with timeout limit.
|
||||
ret = ExecuteFuncInThread(init_device_comm_group_func, kTimeToWait);
|
||||
|
||||
if (!init_group_success && runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
|
||||
runtime::recovery::RecoveryErrCode::kInitNcclFailed);
|
||||
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
||||
}
|
||||
return init_group_success;
|
||||
MS_LOG(INFO) << "End initialize communication group on the device side.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
|
||||
|
@ -197,22 +258,34 @@ uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
|
|||
}
|
||||
|
||||
bool CollectiveManager::Finalize() {
|
||||
if (finalized_) {
|
||||
if (!inited_.load() || finalized_.load()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
if (!host_comm_lib_instance_->Finalize()) {
|
||||
MS_LOG(WARNING) << "Failed to finalize host communication library.";
|
||||
}
|
||||
std::function<bool()> finalize_func = [&, this]() {
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
if (!host_comm_lib_instance_->Finalize()) {
|
||||
MS_LOG(WARNING) << "Failed to finalize host communication library.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
if (!device_comm_lib_instance_->Finalize()) {
|
||||
MS_LOG(WARNING) << "Failed to finalize device communication library.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
if (!device_comm_lib_instance_->Finalize()) {
|
||||
MS_LOG(WARNING) << "Failed to finalize device communication library.";
|
||||
}
|
||||
|
||||
finalized_ = true;
|
||||
return true;
|
||||
finalized_ = true;
|
||||
return true;
|
||||
};
|
||||
|
||||
MS_LOG(INFO) << "Begin finalize collective manager.";
|
||||
|
||||
// Timeout limit 5 seconds to wait to finish finalizing device communication group.
|
||||
const int64_t kTimeToWait = 5;
|
||||
// Finalize collective manager in thread with timeout limit.
|
||||
bool ret = ExecuteFuncInThread(finalize_func, kTimeToWait);
|
||||
|
||||
MS_LOG(INFO) << "End finalize collective manager.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
|
||||
|
|
|
@ -69,6 +69,11 @@ class BACKEND_EXPORT CollectiveManager {
|
|||
|
||||
uint32_t local_rank_id() const;
|
||||
|
||||
// Set whether need reinitialize collective communication.
|
||||
void set_need_reinit(bool need_reinit) { need_reinit_ = need_reinit; }
|
||||
// Get whether need reinitialize collective communication.
|
||||
bool need_reinit() const { return need_reinit_.load(); }
|
||||
|
||||
private:
|
||||
CollectiveManager();
|
||||
|
||||
|
@ -81,16 +86,13 @@ class BACKEND_EXPORT CollectiveManager {
|
|||
// Assign the local rank id for this process.
|
||||
bool AssignLocalRank();
|
||||
|
||||
// Initialize communication group on the device side.
|
||||
bool InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info);
|
||||
|
||||
// Initialize communication group on the device side in thread with timeout limit.
|
||||
std::unique_ptr<std::thread> init_group_thread_;
|
||||
std::mutex init_group_mutex_;
|
||||
|
||||
std::atomic_bool inited_;
|
||||
std::atomic_bool finalized_;
|
||||
|
||||
// Whether need reinitialize collective communication, this value should be set to true once a training process
|
||||
// exits unexpectedly is detected.
|
||||
std::atomic_bool need_reinit_;
|
||||
|
||||
// The device type read from MindSpore context.
|
||||
std::string device_type_;
|
||||
|
||||
|
@ -119,8 +121,8 @@ class BACKEND_EXPORT CollectiveManager {
|
|||
// Global group ranks.
|
||||
std::vector<uint32_t> global_group_ranks_;
|
||||
|
||||
// The global group name on the host side. This is used for Creating global group on host side for AllGather operation
|
||||
// of host name while assigning local rank.
|
||||
// The global group name on the host side. This is used for Creating global group on host side for AllGather
|
||||
// operation of host name while assigning local rank.
|
||||
std::string host_global_group_name_;
|
||||
};
|
||||
} // namespace collective
|
||||
|
|
|
@ -17,11 +17,11 @@
|
|||
#include "distributed/init.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
using runtime::recovery::RecoveryContext;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
|
||||
bool Initialize() {
|
||||
if (!InitializeCluster()) {
|
||||
|
@ -43,7 +43,8 @@ bool Initialize() {
|
|||
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
cluster::ClusterContext::instance()->WaitForClusterReady();
|
||||
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
|
||||
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
|
||||
}
|
||||
|
||||
if (!InitializeCollective()) {
|
||||
|
@ -52,8 +53,6 @@ bool Initialize() {
|
|||
}
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
|
||||
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
|
||||
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
|
||||
#include <dirent.h>
|
||||
#include <algorithm>
|
||||
|
@ -26,13 +26,12 @@
|
|||
#include "utils/file_utils.h"
|
||||
#include "distributed/constants.h"
|
||||
#include "distributed/cluster/topology/common.h"
|
||||
#include "distributed/init.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
namespace distributed {
|
||||
namespace recovery {
|
||||
constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY";
|
||||
constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH";
|
||||
|
@ -142,30 +141,6 @@ void RecoveryContext::Initialize() {
|
|||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool RecoveryContext::ReInitializeCollective() {
|
||||
auto ret = distributed::Initialize();
|
||||
if (ret) {
|
||||
recovery_status_ = RecoveryErrCode::kUnKnownError;
|
||||
set_need_reset(true);
|
||||
set_need_sync_weight_to_device(true);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (recovery_status_ == RecoveryErrCode::kBroadcastUniqueIDFailed ||
|
||||
recovery_status_ == RecoveryErrCode::kAllGatherHostNameFailed) {
|
||||
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
|
||||
// Retry duration: 30s.
|
||||
const int kRetryDuration = 30;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
|
||||
return ReInitializeCollective();
|
||||
} else if (recovery_status_ == RecoveryErrCode::kInitNcclFailed) {
|
||||
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "ReInitialize collective failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
void RecoveryContext::ObtainGlobalLatestCkptInfo() {
|
||||
// 1. Obtain the step corresponding to the local latest checkpoint.
|
||||
ObtainLocalLatestCkptInfo();
|
||||
|
@ -326,6 +301,7 @@ void RecoveryContext::ParseLatestCkptInfo(const int *recv_buffer, const uint32_t
|
|||
}
|
||||
|
||||
void RecoveryContext::CreatePersistentFile() {
|
||||
std::unique_lock<std::mutex> lock(create_persist_json_mtx_);
|
||||
if (node_role_ == distributed::kEnvRoleOfScheduler) {
|
||||
return;
|
||||
}
|
||||
|
@ -344,7 +320,7 @@ void RecoveryContext::CreatePersistentFile() {
|
|||
// The directory used to save ckpt is persisted to json file.
|
||||
std::string persistent_file_path =
|
||||
recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json";
|
||||
persistent_json_ = std::make_unique<JsonUtils>(persistent_file_path);
|
||||
persistent_json_ = std::make_shared<JsonUtils>(persistent_file_path);
|
||||
if (!persistent_json_->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path;
|
||||
}
|
||||
|
@ -388,6 +364,24 @@ std::string RecoveryContext::GetCkptPath() {
|
|||
|
||||
return persistent_json_->Get<std::string>(kCkptPath);
|
||||
}
|
||||
|
||||
const std::shared_ptr<JsonUtils> &RecoveryContext::persistent_json() {
|
||||
if (persistent_json_ == nullptr) {
|
||||
CreatePersistentFile();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(persistent_json_);
|
||||
return persistent_json_;
|
||||
}
|
||||
|
||||
std::string RecoveryContext::latest_ckpt_file() {
|
||||
// For standalone training.
|
||||
if (enable_recovery_ && global_rank_size_ == 0 && latest_ckpt_file_.empty()) {
|
||||
ObtainLocalLatestCkptInfo();
|
||||
}
|
||||
|
||||
return latest_ckpt_file_;
|
||||
}
|
||||
} // namespace recovery
|
||||
} // namespace runtime
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
|
||||
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
|
||||
#define MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -27,13 +27,11 @@
|
|||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
namespace distributed {
|
||||
namespace recovery {
|
||||
using distributed::storage::FileIOUtils;
|
||||
using distributed::storage::JsonUtils;
|
||||
|
||||
enum class RecoveryErrCode { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
|
||||
|
||||
// Used to save disaster recovery-related state quantities and provide disaster recovery-related
|
||||
// functions, such as reinitializing collective communication, etc.
|
||||
class BACKEND_EXPORT RecoveryContext {
|
||||
|
@ -47,14 +45,6 @@ class BACKEND_EXPORT RecoveryContext {
|
|||
}
|
||||
~RecoveryContext() = default;
|
||||
|
||||
// Reinitializing collective communication.
|
||||
bool ReInitializeCollective();
|
||||
|
||||
// Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
|
||||
// some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
|
||||
// successful checkpoint in each training process, and then take the minimum value as the final reset position.
|
||||
void ObtainGlobalLatestCkptInfo();
|
||||
|
||||
// Get whether enable recovery or not.
|
||||
bool enable_recovery() const { return enable_recovery_; }
|
||||
|
||||
|
@ -64,18 +54,14 @@ class BACKEND_EXPORT RecoveryContext {
|
|||
// Get interval to persist model.
|
||||
int recovery_interval() const { return recovery_interval_; }
|
||||
|
||||
// Get the error status of recovery.
|
||||
RecoveryErrCode recovery_status() const { return recovery_status_; }
|
||||
// Set the error status of recovery.
|
||||
void set_recovery_status(RecoveryErrCode recovery_status) { recovery_status_ = recovery_status; }
|
||||
|
||||
// Set the path used to save checkpoint.
|
||||
void SetCkptPath(const std::string &path);
|
||||
// Get the path used to save checkpoint.
|
||||
std::string GetCkptPath();
|
||||
|
||||
// Get the latest checkpoint in this node.
|
||||
std::string latest_ckpt_file() const { return latest_ckpt_file_; }
|
||||
std::string latest_ckpt_file();
|
||||
|
||||
// Get the epoch of latest checkpoint in this node.
|
||||
int latest_ckpt_epoch() const { return latest_ckpt_epoch_; }
|
||||
// Get the step of latest checkpoint in this node.
|
||||
|
@ -99,10 +85,13 @@ class BACKEND_EXPORT RecoveryContext {
|
|||
// Set global rank size.
|
||||
void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }
|
||||
|
||||
// Set whether need reinitialize collective communication.
|
||||
void set_need_reinit_collective(bool need_reinit_collective) { need_reinit_collective_ = need_reinit_collective; }
|
||||
// Get whether need reinitialize collective communication.
|
||||
bool need_reinit_collective() const { return need_reinit_collective_.load(); }
|
||||
// Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
|
||||
// some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
|
||||
// successful checkpoint in each training process, and then take the minimum value as the final reset position.
|
||||
void ObtainGlobalLatestCkptInfo();
|
||||
|
||||
// Get the persistent json file pointer.
|
||||
const std::shared_ptr<JsonUtils> &persistent_json();
|
||||
|
||||
private:
|
||||
inline static std::shared_ptr<RecoveryContext> instance_{};
|
||||
|
@ -155,20 +144,14 @@ class BACKEND_EXPORT RecoveryContext {
|
|||
// performs load checkpoint.
|
||||
bool need_sync_weight_to_device_{false};
|
||||
|
||||
// Whether need reinitialize collective communication, this value should be set to true once a training process
|
||||
// exits unexpectedly is detected.
|
||||
std::atomic_bool need_reinit_collective_{false};
|
||||
|
||||
// Whether the recovery context is already initialized.
|
||||
bool initialized_{false};
|
||||
|
||||
// The error status of recovery.
|
||||
RecoveryErrCode recovery_status_{RecoveryErrCode::kUnKnownError};
|
||||
|
||||
std::mutex create_persist_json_mtx_;
|
||||
// The persitent json file util, used to persist recovery config.
|
||||
std::unique_ptr<JsonUtils> persistent_json_;
|
||||
std::shared_ptr<JsonUtils> persistent_json_;
|
||||
};
|
||||
} // namespace recovery
|
||||
} // namespace runtime
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
|
||||
#endif // MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
|
|
@ -40,7 +40,7 @@
|
|||
#include "ps/util.h"
|
||||
#endif
|
||||
#include "ps/ps_context.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
|
||||
#include "pybind_api/gil_scoped_long_running.h"
|
||||
|
||||
|
@ -58,7 +58,7 @@ using ParallelContext = mindspore::parallel::ParallelContext;
|
|||
using CostModelContext = mindspore::parallel::CostModelContext;
|
||||
using mindspore::MsCtxParam;
|
||||
using PSContext = mindspore::ps::PSContext;
|
||||
using RecoveryContext = mindspore::runtime::recovery::RecoveryContext;
|
||||
using RecoveryContext = mindspore::distributed::recovery::RecoveryContext;
|
||||
|
||||
// Interface with python
|
||||
PYBIND11_MODULE(_c_expression, m) {
|
||||
|
|
|
@ -61,6 +61,7 @@
|
|||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "runtime/pynative/op_executor.h"
|
||||
#include "runtime/device/stream_synchronizer.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
#ifdef ENABLE_D
|
||||
|
@ -1608,6 +1609,8 @@ void ClearResAtexit() {
|
|||
device::StreamSynchronizer::GetInstance()->Finalize();
|
||||
MS_LOG(INFO) << "End Finalize StreamSynchronizer...";
|
||||
|
||||
(void)distributed::collective::CollectiveManager::instance()->Finalize();
|
||||
|
||||
PrimitivePy::ClearHookRes();
|
||||
ad::g_k_prims.clear();
|
||||
ad::ClearKPynativeCellStaticRes();
|
||||
|
|
|
@ -35,7 +35,7 @@ using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
|
|||
using ps::core::NodeCommand;
|
||||
|
||||
// The time interval for send info or query info between worker and scheduler.
|
||||
constexpr uint32_t kWaitDuration = 3;
|
||||
constexpr uint32_t kWaitDuration = 5;
|
||||
|
||||
// The collective communication library for MindSpore self developed communication framework.
|
||||
class MsCollectiveCommLib : public CollectiveCommunicationLib {
|
||||
|
|
|
@ -1301,6 +1301,7 @@ void AbstractNode::InitCommandHandler() {
|
|||
handlers_[NodeCommand::SEND_EVENT] = nullptr;
|
||||
RegisterActorRouteTableRspHandler();
|
||||
RegisterInitCollectCommResphandler();
|
||||
RegisterRecoveryRespHandler();
|
||||
}
|
||||
|
||||
void AbstractNode::RegisterActorRouteTableRspHandler() {
|
||||
|
|
|
@ -228,6 +228,9 @@ class BACKEND_EXPORT AbstractNode : public Node {
|
|||
// Register collective communication initialization response methods.
|
||||
virtual void RegisterInitCollectCommResphandler() {}
|
||||
|
||||
// Register recovery response methods.
|
||||
virtual void RegisterRecoveryRespHandler() {}
|
||||
|
||||
// when initializing the node, should initializing the node info.
|
||||
void InitNodeInfo(const NodeRole &role);
|
||||
// Initialize worker num and server num by cluster config.
|
||||
|
|
|
@ -139,6 +139,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() {
|
|||
if (!stop_heartbeat_.load()) {
|
||||
stop_heartbeat_ = true;
|
||||
while (!heartbeat_stopped_.load()) {
|
||||
if (is_finish_.load()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Waiting for heartbeat to stop...";
|
||||
|
||||
// Time interval for waiting the heartbeat to stop.
|
||||
|
@ -152,6 +155,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() {
|
|||
|
||||
bool success = false;
|
||||
while (!success) {
|
||||
if (is_finish_.load()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(WARNING) << "Trying to reconnect to the scheduler...";
|
||||
success = InitClientToScheduler();
|
||||
if (success) {
|
||||
|
@ -176,6 +182,11 @@ void AbstractPSNode::RegisterInitCollectCommResphandler() {
|
|||
handlers_[NodeCommand::SEND_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
|
||||
handlers_[NodeCommand::QUERY_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
|
||||
}
|
||||
|
||||
void AbstractPSNode::RegisterRecoveryRespHandler() {
|
||||
handlers_[NodeCommand::SEND_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp;
|
||||
handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,9 @@ class AbstractPSNode : public AbstractNode {
|
|||
// Register collective communication initialization response methods.
|
||||
void RegisterInitCollectCommResphandler() override;
|
||||
|
||||
// Register recovery response methods.
|
||||
void RegisterRecoveryRespHandler() override;
|
||||
|
||||
// Indicate whether the heartbeat thread should be stopped.
|
||||
std::atomic<bool> stop_heartbeat_{false};
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message
|
|||
return rank_id;
|
||||
}
|
||||
} else {
|
||||
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
|
||||
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
|
||||
}
|
||||
return rank_id;
|
||||
}
|
||||
|
@ -206,11 +206,16 @@ std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
|
|||
return servers_meta_list;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, NodeInfo> &NodeManager::QueryTimeOutNodesInfo() const {
|
||||
return timeout_nodes_info_;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateCluster() {
|
||||
// 1. update cluster timeout state
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
timeout_nodes_info_.clear();
|
||||
std::lock_guard<std::mutex> lock(heartbeat_mutex_);
|
||||
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
|
||||
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
|
||||
if (registered_nodes_info_.count(it->first)) {
|
||||
|
|
|
@ -139,6 +139,8 @@ class NodeManager {
|
|||
|
||||
bool IsAllNodesAlive() const;
|
||||
|
||||
const std::unordered_map<std::string, NodeInfo> &QueryTimeOutNodesInfo() const;
|
||||
|
||||
private:
|
||||
std::mutex node_mutex_;
|
||||
std::mutex cluster_mutex_;
|
||||
|
|
|
@ -57,6 +57,12 @@ enum NodeCommand {
|
|||
SEND_UNIQUE_ID = 20;
|
||||
// Query unique id used to initialize collective communication.
|
||||
QUERY_UNIQUE_ID = 21;
|
||||
// Send the ready status to finish transform graph of computed node,
|
||||
// used in disaster recovery mode to prevent timeout of waiting for graph transformation.
|
||||
SEND_FINISH_TRANSFORM = 22;
|
||||
// Query the ready status to finish transform graph of computed node,
|
||||
// used in disaster recovery mode to prevent timeout of waiting for graph transformation.
|
||||
QUERY_FINISH_TRANSFORM = 23;
|
||||
}
|
||||
|
||||
enum NodeRole {
|
||||
|
@ -298,3 +304,19 @@ message QueryUniqueIDRespMessage {
|
|||
// The unique id used to initialize collective communication.
|
||||
bytes unique_id = 2;
|
||||
}
|
||||
|
||||
message SendFinishTransformMessage {
|
||||
// the current Node unique id:0,1,2...
|
||||
string node_id = 1;
|
||||
// The rank id of the node in the cluster.
|
||||
uint32 rank_id = 2;
|
||||
// Whether finish transform graph.
|
||||
bool is_ready = 3;
|
||||
}
|
||||
|
||||
message QueryFinishTransformRespMessage {
|
||||
// Whether all computed nodes are ready to run dag.
|
||||
bool is_ready = 1;
|
||||
// Whether there is any worker timeout.
|
||||
bool is_worker_timeout = 2;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include "ps/core/ps_scheduler_node.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -58,6 +59,13 @@ void PSSchedulerNode::RegisterInitCollectCommServiceHandler() {
|
|||
handlers_[NodeCommand::QUERY_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryUniqueID);
|
||||
}
|
||||
|
||||
void PSSchedulerNode::RegisterRecoveryServiceHandler() {
|
||||
handlers_[NodeCommand::SEND_FINISH_TRANSFORM] =
|
||||
static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendFinishTransform);
|
||||
handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] =
|
||||
static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryFinishTransform);
|
||||
}
|
||||
|
||||
void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
|
@ -71,7 +79,7 @@ void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &serv
|
|||
std::string node_id = send_host_name_msg.node_id();
|
||||
uint32_t rank_id = send_host_name_msg.rank_id();
|
||||
size_t host_hash_name = send_host_name_msg.host_hash_name();
|
||||
MS_LOG(INFO) << "Received send host name request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
MS_LOG(INFO) << "Receive send host name request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
|
||||
bool ret = false;
|
||||
std::string error = "";
|
||||
|
@ -100,7 +108,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se
|
|||
query_msg.ParseFromArray(data, SizeToInt(size));
|
||||
std::string node_id = query_msg.node_id();
|
||||
uint32_t rank_id = query_msg.rank_id();
|
||||
MS_LOG(INFO) << "Received query host name request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
MS_LOG(INFO) << "Receive query host name request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
|
||||
bool is_success = recv_rank_id_send_host_name_.size() == host_hash_names_.size();
|
||||
QueryHostHashNameRespMessage resp_msg;
|
||||
|
@ -121,6 +129,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se
|
|||
if (recv_rank_id_query_host_name_.size() == recv_rank_id_send_host_name_.size()) {
|
||||
recv_rank_id_send_host_name_.clear();
|
||||
recv_rank_id_query_host_name_.clear();
|
||||
node_timeout_ = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -138,7 +147,7 @@ void PSSchedulerNode::ProcessSendUniqueID(const std::shared_ptr<TcpServer> &serv
|
|||
std::string node_id = send_unique_id_msg.node_id();
|
||||
uint32_t rank_id = send_unique_id_msg.rank_id();
|
||||
std::string group_name = send_unique_id_msg.group_name();
|
||||
MS_LOG(INFO) << "Received send unique id request, group name: " << group_name << ", node id: " << node_id
|
||||
MS_LOG(INFO) << "Receive send unique id request, group name: " << group_name << ", node id: " << node_id
|
||||
<< ", rank id: " << rank_id;
|
||||
|
||||
bool ret = false;
|
||||
|
@ -169,7 +178,7 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser
|
|||
std::string node_id = query_msg.node_id();
|
||||
uint32_t rank_id = query_msg.rank_id();
|
||||
std::string group_name = query_msg.group_name();
|
||||
MS_LOG(INFO) << "Received query unique id request, group name: " << group_name << ", node id: " << node_id
|
||||
MS_LOG(INFO) << "Receive query unique id request, group name: " << group_name << ", node id: " << node_id
|
||||
<< ", rank id: " << rank_id;
|
||||
|
||||
auto iter = unique_id_group_.find(group_name);
|
||||
|
@ -190,6 +199,89 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser
|
|||
MS_LOG(INFO) << "Respond query unique id request, group name: " << group_name << ", node id: " << node_id
|
||||
<< ", rank id: " << rank_id;
|
||||
}
|
||||
|
||||
void PSSchedulerNode::ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data,
|
||||
size_t size) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(server);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(data);
|
||||
|
||||
SendFinishTransformMessage send_ready_to_run_msg;
|
||||
send_ready_to_run_msg.ParseFromArray(data, SizeToInt(size));
|
||||
std::string node_id = send_ready_to_run_msg.node_id();
|
||||
uint32_t rank_id = send_ready_to_run_msg.rank_id();
|
||||
MS_LOG(INFO) << "Receive send finish transform request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
bool is_ready = send_ready_to_run_msg.is_ready();
|
||||
if (is_ready) {
|
||||
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
|
||||
(void)nodes_finish_trans_.insert(rank_id);
|
||||
}
|
||||
|
||||
GeneralResponse(server, conn, meta, true, "");
|
||||
MS_LOG(INFO) << "Respond send finish transform request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
}
|
||||
|
||||
void PSSchedulerNode::ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data,
|
||||
size_t size) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(server);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(data);
|
||||
|
||||
GeneralQueryMessage query_msg;
|
||||
query_msg.ParseFromArray(data, SizeToInt(size));
|
||||
std::string node_id = query_msg.node_id();
|
||||
uint32_t rank_id = query_msg.rank_id();
|
||||
MS_LOG(INFO) << "Receive query finish transform request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
|
||||
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
|
||||
bool is_ready = nodes_finish_trans_.size() == worker_num_;
|
||||
|
||||
QueryFinishTransformRespMessage resp_msg;
|
||||
resp_msg.set_is_ready(is_ready);
|
||||
|
||||
if (node_timeout_) {
|
||||
(void)resp_msg.set_is_worker_timeout(true);
|
||||
} else {
|
||||
resp_msg.set_is_worker_timeout(false);
|
||||
}
|
||||
|
||||
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(),
|
||||
resp_msg.ByteSizeLong())) {
|
||||
MS_LOG(ERROR) << "Scheduler failed to respond message.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Respond query finish transform request, node id: " << node_id << ", rank id: " << rank_id;
|
||||
}
|
||||
|
||||
void PSSchedulerNode::HandleNodeTimeoutForRecovery(
|
||||
const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (timeout_nodes_infos.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
|
||||
node_timeout_ = true;
|
||||
for (const auto &item : timeout_nodes_infos) {
|
||||
(void)nodes_finish_trans_.erase(item.second.rank_id_);
|
||||
}
|
||||
}
|
||||
|
||||
void PSSchedulerNode::HandleNodeRecoverByHeartBeat(uint32_t rank_id) {
|
||||
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
|
||||
(void)nodes_finish_trans_.insert(rank_id);
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
@ -47,9 +48,16 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
|
|||
// alive node should be rejected.
|
||||
bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; }
|
||||
|
||||
bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) override {
|
||||
return true;
|
||||
};
|
||||
|
||||
// Register collective communication initialization service.
|
||||
void RegisterInitCollectCommServiceHandler() override;
|
||||
|
||||
// Register recovery service.
|
||||
void RegisterRecoveryServiceHandler() override;
|
||||
|
||||
// Process message for sending node's host name.
|
||||
void ProcessSendHostName(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
@ -66,6 +74,20 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
|
|||
void ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
// Process message for sending the ready status to finish transform graph of computed node,
|
||||
void ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
// Process message for querying the ready status to finish transform graph of computed node,
|
||||
void ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
// Handle node timeout info and update nodes which finish transform graph.
|
||||
void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) override;
|
||||
|
||||
// Recover finish transform nodes info when nodes recover heartbeat.
|
||||
void HandleNodeRecoverByHeartBeat(uint32_t rank_id) override;
|
||||
|
||||
// Record received host hash name from workers.
|
||||
std::vector<size_t> host_hash_names_;
|
||||
// Record rank id of the nodes which sended host name.
|
||||
|
@ -77,6 +99,11 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
|
|||
std::map<std::string, std::string> unique_id_group_;
|
||||
|
||||
uint32_t worker_num_;
|
||||
|
||||
std::mutex nodes_finish_trans_mutex_;
|
||||
// Record the rank ids of nodes who finish transform graph.
|
||||
std::set<uint32_t> nodes_finish_trans_;
|
||||
std::atomic_bool node_timeout_{false};
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -39,16 +39,13 @@ bool PSServerNode::Start(const uint32_t &timeout) {
|
|||
}
|
||||
|
||||
void PSServerNode::Initialize() {
|
||||
InitNodeNum();
|
||||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(INFO) << "The config file is empty, then init node by context.";
|
||||
InitNodeNum();
|
||||
} else {
|
||||
if (!Recover()) {
|
||||
MS_LOG(WARNING) << "Recover the server node is failed.";
|
||||
}
|
||||
if (config_->Initialize() && !Recover()) {
|
||||
MS_LOG(INFO) << "Recover the server node is failed.";
|
||||
}
|
||||
|
||||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
InitNodeInfo(NodeRole::SERVER);
|
||||
|
@ -119,8 +116,9 @@ void PSServerNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
|
||||
|
||||
const int kCommTimeoutInSeconds = 20;
|
||||
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
|
||||
register_message.ByteSizeLong())) {
|
||||
register_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " register timeout!";
|
||||
} else {
|
||||
|
|
|
@ -79,16 +79,13 @@ bool PSWorkerNode::Finish(const uint32_t &timeout) {
|
|||
}
|
||||
|
||||
void PSWorkerNode::Initialize() {
|
||||
InitNodeNum();
|
||||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(INFO) << "The config file is empty, then init node by context.";
|
||||
InitNodeNum();
|
||||
} else {
|
||||
if (!Recover()) {
|
||||
MS_LOG(WARNING) << "Recover the worker node is failed.";
|
||||
}
|
||||
if (config_->Initialize() && !Recover()) {
|
||||
MS_LOG(INFO) << "Recover the worker node is failed.";
|
||||
}
|
||||
|
||||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
InitNodeInfo(NodeRole::WORKER);
|
||||
|
@ -120,8 +117,9 @@ void PSWorkerNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";
|
||||
|
||||
const int kCommTimeoutInSeconds = 20;
|
||||
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
|
||||
register_message.ByteSizeLong())) {
|
||||
register_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " register timeout!";
|
||||
} else {
|
||||
|
|
|
@ -176,10 +176,12 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
|
|||
return;
|
||||
}
|
||||
|
||||
uint32_t rank_id = UINT32_MAX;
|
||||
// Re-Add the missing node into node manager.
|
||||
if (heartbeat_message.has_address() &&
|
||||
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port())) {
|
||||
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port(), &rank_id)) {
|
||||
SetRegisterConnectionFd(conn, node_id);
|
||||
HandleNodeRecoverByHeartBeat(rank_id);
|
||||
|
||||
if (node_manager_.IsAllNodesRegistered()) {
|
||||
is_ready_ = true;
|
||||
|
@ -234,6 +236,7 @@ void SchedulerNode::InitCommandHandler() {
|
|||
handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
|
||||
RegisterActorRouteTableServiceHandler();
|
||||
RegisterInitCollectCommServiceHandler();
|
||||
RegisterRecoveryServiceHandler();
|
||||
}
|
||||
|
||||
void SchedulerNode::RegisterActorRouteTableServiceHandler() {
|
||||
|
@ -699,6 +702,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
|
|||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
||||
node_manager_.UpdateCluster();
|
||||
HandleNodeTimeoutForRecovery(node_manager_.QueryTimeOutNodesInfo());
|
||||
|
||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
|
||||
std::this_thread::sleep_for(
|
||||
|
|
|
@ -92,6 +92,15 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
// Register collective communication initialization service.
|
||||
virtual void RegisterInitCollectCommServiceHandler() {}
|
||||
|
||||
// Register recovery service.
|
||||
virtual void RegisterRecoveryServiceHandler() {}
|
||||
|
||||
// Handle node timeout info and update nodes which finish transform graph.
|
||||
virtual void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) {}
|
||||
|
||||
// Recover finish transform nodes info when nodes recover heartbeat.
|
||||
virtual void HandleNodeRecoverByHeartBeat(uint32_t rank_id) {}
|
||||
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);
|
||||
|
||||
void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
|
@ -197,7 +206,7 @@ class BACKEND_EXPORT SchedulerNode : public Node {
|
|||
|
||||
void SetRegisterConnectionFd(const std::shared_ptr<TcpConnection> &conn, const std::string &node_id);
|
||||
|
||||
bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos);
|
||||
virtual bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos);
|
||||
|
||||
// Responding peer with the general response message.
|
||||
void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
|
|
|
@ -17,16 +17,19 @@
|
|||
#include "runtime/device/stream_synchronizer.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
using distributed::collective::CollectiveManager;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
|
||||
std::mutex StreamSynchronizer::instance_lock_;
|
||||
std::shared_ptr<StreamSynchronizer> StreamSynchronizer::instance_ = nullptr;
|
||||
|
||||
void StreamSynchronizer::Initialize() {
|
||||
// Non disaster recovery mode does not need to start thread and timeout mechanisms.
|
||||
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
if (!RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -56,7 +59,7 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim
|
|||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
// If disable recovery or timeout==0, sync stream directly to improve performance.
|
||||
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
|
||||
if (!RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
|
||||
device_context->Initialize();
|
||||
return device_context->SyncStream();
|
||||
}
|
||||
|
@ -68,26 +71,19 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim
|
|||
device_context_ = device_context;
|
||||
do_sync_stream_cv_.notify_one();
|
||||
|
||||
if (sync_stream_time_out_) {
|
||||
// If sync stream timeout has happened, increase the timeout by 4 times.
|
||||
const uint32_t kTimeOutScaleFactor = 4;
|
||||
timeout *= kTimeOutScaleFactor;
|
||||
}
|
||||
|
||||
if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) {
|
||||
if (!sync_stream_ret_) {
|
||||
MS_LOG(ERROR) << "Synchronize stream failed.";
|
||||
}
|
||||
return sync_stream_ret_;
|
||||
} else {
|
||||
sync_stream_time_out_ = true;
|
||||
runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true);
|
||||
if (!distributed::collective::CollectiveManager::instance()->Finalize()) {
|
||||
CollectiveManager::instance()->set_need_reinit(true);
|
||||
if (!CollectiveManager::instance()->Finalize()) {
|
||||
MS_LOG(ERROR) << "Finalize collective manager failed.";
|
||||
return false;
|
||||
}
|
||||
time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; });
|
||||
MS_LOG(WARNING) << "Synchronize stream time out.";
|
||||
MS_LOG(WARNING) << "Synchronize stream timeout.";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,9 +71,6 @@ class BACKEND_EXPORT StreamSynchronizer {
|
|||
// The singleton pointer.
|
||||
static std::shared_ptr<StreamSynchronizer> instance_;
|
||||
|
||||
// Record whether the synchronization stream task has timed out.
|
||||
bool sync_stream_time_out_{false};
|
||||
|
||||
// Return value of synchronization stream.
|
||||
bool sync_stream_ret_{false};
|
||||
|
||||
|
|
|
@ -25,11 +25,11 @@
|
|||
#include "mindrt/include/async/async.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
namespace {
|
||||
void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
|
||||
const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
|
||||
|
|
|
@ -21,11 +21,13 @@
|
|||
#include "runtime/graph_scheduler/actor/debug_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
using distributed::collective::CollectiveManager;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
|
||||
void KernelActor::Init() {
|
||||
// Check device contexts number.
|
||||
|
@ -243,7 +245,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
|||
PreLaunchKernel(context);
|
||||
|
||||
try {
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
|
||||
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
|
||||
// especially the collective communication operators.
|
||||
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
|
||||
|
|
|
@ -25,11 +25,13 @@
|
|||
#include "mindrt/include/async/async.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "runtime/device/stream_synchronizer.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
using distributed::collective::CollectiveManager;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
|
||||
void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
@ -70,8 +72,7 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
|
|||
(void)sync_stream_device_contexts.insert(device_context);
|
||||
|
||||
// Trigger disaster recovery and exit loop early.
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
|
||||
current_count_ = loop_count_;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,11 +17,13 @@
|
|||
#include "runtime/graph_scheduler/actor/output_actor.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
using distributed::collective::CollectiveManager;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
|
||||
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
|
@ -105,7 +107,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
|
|||
++current_count_;
|
||||
|
||||
// Trigger disaster recovery and return empty output.
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
|
||||
FreeOutputNodeMem();
|
||||
ClearOutputCache();
|
||||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||
|
|
|
@ -45,11 +45,19 @@
|
|||
#endif
|
||||
#include "profiler/device/profiling.h"
|
||||
#include "include/common/debug/common.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#else
|
||||
#include "distributed/cluster/dummy_cluster_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
using distributed::cluster::ClusterContext;
|
||||
using distributed::collective::CollectiveManager;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
namespace {
|
||||
bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
|
||||
MS_EXCEPTION_IF_NULL(from_device_context);
|
||||
|
@ -196,6 +204,108 @@ void IntHandler(int, siginfo_t *, void *) {
|
|||
(void)kill(this_pid, SIGTERM);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
|
||||
bool SendFinishTransform() {
|
||||
auto node = ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->role() != ps::core::NodeRole::WORKER) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node());
|
||||
MS_EXCEPTION_IF_NULL(abstract_node);
|
||||
|
||||
ps::core::SendFinishTransformMessage send_ready_to_run_msg;
|
||||
send_ready_to_run_msg.set_node_id(abstract_node->node_id());
|
||||
send_ready_to_run_msg.set_rank_id(abstract_node->rank_id());
|
||||
send_ready_to_run_msg.set_is_ready(true);
|
||||
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
|
||||
if (!abstract_node->SendToScheduler(send_ready_to_run_msg.SerializeAsString().data(),
|
||||
send_ready_to_run_msg.SerializeAsString().size(),
|
||||
ps::core::NodeCommand::SEND_FINISH_TRANSFORM, &output)) {
|
||||
MS_LOG(WARNING) << "Failed to send finish transform request to scheduler.";
|
||||
return false;
|
||||
}
|
||||
|
||||
ps::core::GeneralResponseMsg resp_msg;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
|
||||
if (!resp_msg.is_success()) {
|
||||
MS_LOG(ERROR) << "Send finish transform to scheduler failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool QueryFinishTransform() {
|
||||
auto node = ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->role() != ps::core::NodeRole::WORKER) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node());
|
||||
MS_EXCEPTION_IF_NULL(abstract_node);
|
||||
|
||||
ps::core::GeneralQueryMessage general_query_msg;
|
||||
general_query_msg.set_node_id(abstract_node->node_id());
|
||||
general_query_msg.set_rank_id(abstract_node->rank_id());
|
||||
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
|
||||
bool ret = false;
|
||||
while (!ret) {
|
||||
if (!abstract_node->SendToScheduler(general_query_msg.SerializeAsString().data(),
|
||||
general_query_msg.SerializeAsString().size(),
|
||||
ps::core::NodeCommand::QUERY_FINISH_TRANSFORM, &output)) {
|
||||
MS_LOG(WARNING) << "Failed to send query finish transform request to scheduler.";
|
||||
ret = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
ps::core::QueryFinishTransformRespMessage resp_msg;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
|
||||
ret = resp_msg.is_ready();
|
||||
if (!ret) {
|
||||
MS_LOG(INFO) << "There is worker which has not finished transform graph";
|
||||
}
|
||||
|
||||
if (resp_msg.is_worker_timeout()) {
|
||||
MS_LOG(WARNING) << "There is worker timeout";
|
||||
return false;
|
||||
}
|
||||
// The time interval for querying the all worker finish transform graphs status to scheduler: 10 seconds.
|
||||
const uint32_t kWaitDuration = 10;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void DoDisasterRecovery() {
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
|
||||
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
|
||||
bool ret = false;
|
||||
while (!ret) {
|
||||
while (!CollectiveManager::instance()->Initialize()) {
|
||||
MS_LOG(WARNING) << "ReInitialize collective communication failed, retrying...";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
|
||||
|
||||
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
|
||||
|
||||
ret = QueryFinishTransform();
|
||||
if (!ret) {
|
||||
CollectiveManager::instance()->set_need_reinit(true);
|
||||
(void)CollectiveManager::instance()->Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
RecoveryContext::GetInstance()->set_need_reset(true);
|
||||
RecoveryContext::GetInstance()->set_need_sync_weight_to_device(true);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
GraphScheduler &GraphScheduler::GetInstance() noexcept {
|
||||
|
@ -407,6 +517,17 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
|
|||
}
|
||||
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
|
||||
if (ClusterContext::instance()->initialized() && RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
while (!SendFinishTransform()) {
|
||||
MS_LOG(WARNING) << "Send finish transform graph failed.";
|
||||
// The time interval for sending finish transform graph to scheduler.
|
||||
constexpr uint32_t kWaitDuration = 10;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return actor_set.get();
|
||||
}
|
||||
|
||||
|
@ -489,15 +610,9 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
|
|||
const size_t kSecondsToMilliseconds = 1000;
|
||||
SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
|
||||
if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
|
||||
MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
|
||||
|
||||
RecoveryContext::GetInstance()->set_need_reinit_collective(false);
|
||||
}
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
|
||||
DoDisasterRecovery();
|
||||
#endif
|
||||
}
|
||||
|
||||
void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
file(GLOB_RECURSE RECOVERY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"recovery_context.cc")
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-abstract-non-virtual-dtor")
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${RECOVERY_SRC_LIST} PROPERTY SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(_mindspore_runtime_recovery_obj OBJECT ${RECOVERY_SRC_LIST})
|
|
@ -306,8 +306,7 @@ add_library(backend_static STATIC
|
|||
$<TARGET_OBJECTS:_mindspore_runtime_device_obj>
|
||||
$<TARGET_OBJECTS:_mindspore_runtime_graph_scheduler_obj>
|
||||
$<TARGET_OBJECTS:_mindspore_runtime_hardware_obj>
|
||||
$<TARGET_OBJECTS:_mindspore_runtime_pynative_obj>
|
||||
$<TARGET_OBJECTS:_mindspore_runtime_recovery_obj>)
|
||||
$<TARGET_OBJECTS:_mindspore_runtime_pynative_obj>)
|
||||
target_link_libraries(ut_tests PRIVATE mindspore securec -Wl,--start-group proto_input mindspore::protobuf
|
||||
backend_static -Wl,--end-group)
|
||||
target_link_libraries(ut_tests PRIVATE mindspore::grpc++)
|
Loading…
Reference in New Issue