forked from mindspore-Ecosystem/mindspore
!31325 Training supports failover
Merge pull request !31325 from zyli2020/worker_failover_bp
This commit is contained in:
commit
c3d15a079e
|
@ -37,6 +37,7 @@
|
||||||
#include "runtime/hardware/device_context_manager.h"
|
#include "runtime/hardware/device_context_manager.h"
|
||||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||||
#include "runtime/pynative/run_op_helper.h"
|
#include "runtime/pynative/run_op_helper.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
#include "include/common/utils/scoped_long_running.h"
|
#include "include/common/utils/scoped_long_running.h"
|
||||||
#ifdef ENABLE_D
|
#ifdef ENABLE_D
|
||||||
#include "include/common/utils/callbacks_ge.h"
|
#include "include/common/utils/callbacks_ge.h"
|
||||||
|
@ -998,17 +999,22 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
||||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||||
|
|
||||||
// Update device address for output node of graph.
|
bool need_contruct_output = !(runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||||
// Summary processing will use the output device address, so must be after the summary processing.
|
runtime::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
if (need_contruct_output) {
|
||||||
|
// Update device address for output node of graph.
|
||||||
|
// Summary processing will use the output device address, so must be after the summary processing.
|
||||||
|
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||||
|
|
||||||
// Fetch outputs.
|
// Fetch outputs.
|
||||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||||
auto &output_tensors = actor_set->output_actor_->outputs();
|
auto &output_tensors = actor_set->output_actor_->outputs();
|
||||||
if (output_tensors.size() > 0) {
|
if (output_tensors.size() > 0) {
|
||||||
size_t output_position = 0;
|
size_t output_position = 0;
|
||||||
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
|
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||||
// Close abstract_lock for dynamic_shape
|
// Close abstract_lock for dynamic_shape
|
||||||
AnfUtils::CloseAbstractLock();
|
AnfUtils::CloseAbstractLock();
|
||||||
|
|
|
@ -30,6 +30,7 @@ namespace cluster {
|
||||||
ClusterContext::ClusterContext()
|
ClusterContext::ClusterContext()
|
||||||
: inited_(false),
|
: inited_(false),
|
||||||
finalized_(true),
|
finalized_(true),
|
||||||
|
cluster_ready_(false),
|
||||||
node_num_each_role_({}),
|
node_num_each_role_({}),
|
||||||
scheduler_host_(kLocalHost),
|
scheduler_host_(kLocalHost),
|
||||||
scheduler_port_(kDefaultSchedPort),
|
scheduler_port_(kDefaultSchedPort),
|
||||||
|
@ -151,11 +152,11 @@ void ClusterContext::InitClusterConfig() {
|
||||||
bool ClusterContext::BuildCluster() {
|
bool ClusterContext::BuildCluster() {
|
||||||
// Create node according to different role.
|
// Create node according to different role.
|
||||||
if (node_role_ == kEnvRoleOfWorker) {
|
if (node_role_ == kEnvRoleOfWorker) {
|
||||||
node_ = std::make_shared<ps::core::WorkerNode>();
|
node_ = std::make_shared<ps::core::PSWorkerNode>();
|
||||||
} else if (node_role_ == kEnvRoleOfServer) {
|
} else if (node_role_ == kEnvRoleOfServer) {
|
||||||
node_ = std::make_shared<ps::core::ServerNode>();
|
node_ = std::make_shared<ps::core::PSServerNode>();
|
||||||
} else if (node_role_ == kEnvRoleOfScheduler) {
|
} else if (node_role_ == kEnvRoleOfScheduler) {
|
||||||
node_ = std::make_shared<ps::core::SchedulerNode>();
|
node_ = std::make_shared<ps::core::PSSchedulerNode>();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "The role " << node_role_ << " is invalid.";
|
MS_LOG(EXCEPTION) << "The role " << node_role_ << " is invalid.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -258,8 +259,20 @@ void ClusterContext::RegisterEventCallback() {
|
||||||
MsException::Instance().SetException();
|
MsException::Instance().SetException();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
abstract_node->RegisterEventCallback(ps::core::ClusterEvent::ON_SEND_META_DATA,
|
||||||
|
[this]() { cluster_ready_ = true; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ClusterContext::WaitForClusterReady() {
|
||||||
|
while (!cluster_ready_) {
|
||||||
|
const int kWaitDuration = 200;
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitDuration));
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster_ready_ = false;
|
||||||
|
}
|
||||||
} // namespace cluster
|
} // namespace cluster
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -33,6 +33,9 @@
|
||||||
#include "ps/core/worker_node.h"
|
#include "ps/core/worker_node.h"
|
||||||
#include "ps/core/server_node.h"
|
#include "ps/core/server_node.h"
|
||||||
#include "ps/core/scheduler_node.h"
|
#include "ps/core/scheduler_node.h"
|
||||||
|
#include "ps/core/ps_worker_node.h"
|
||||||
|
#include "ps/core/ps_server_node.h"
|
||||||
|
#include "ps/core/ps_scheduler_node.h"
|
||||||
#include "distributed/cluster/actor_route_table_proxy.h"
|
#include "distributed/cluster/actor_route_table_proxy.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -77,6 +80,10 @@ class ClusterContext {
|
||||||
// Return actor route proxy for AbstractNode.
|
// Return actor route proxy for AbstractNode.
|
||||||
const ActorRouteTableProxyPtr &actor_route_table_proxy() const;
|
const ActorRouteTableProxyPtr &actor_route_table_proxy() const;
|
||||||
|
|
||||||
|
// Wait cluster networking or re-networking successly, using in disaster recovery to prevent collective communication
|
||||||
|
// ops flapping.
|
||||||
|
void WaitForClusterReady();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ClusterContext();
|
ClusterContext();
|
||||||
|
|
||||||
|
@ -101,6 +108,9 @@ class ClusterContext {
|
||||||
// The flag that whether this cluster context instance is already finalized.
|
// The flag that whether this cluster context instance is already finalized.
|
||||||
std::atomic_bool finalized_;
|
std::atomic_bool finalized_;
|
||||||
|
|
||||||
|
// The cluster networking or re-networking successly.
|
||||||
|
std::atomic_bool cluster_ready_;
|
||||||
|
|
||||||
// The mutex about exiting status of this node.
|
// The mutex about exiting status of this node.
|
||||||
std::mutex finish_mutex_;
|
std::mutex finish_mutex_;
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
|
@ -60,7 +61,7 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CollectiveManager::Initialize() {
|
bool CollectiveManager::Initialize() {
|
||||||
if (inited_) {
|
if (inited_ && !runtime::recovery::RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,15 +128,51 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
|
||||||
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt8, 0,
|
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt8, 0,
|
||||||
group_name)) {
|
group_name)) {
|
||||||
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
|
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
|
||||||
|
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||||
|
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
|
||||||
|
runtime::recovery::RecoveryErrCode::kBroadcastUniqueIDFailed);
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 5: Initialize communication group on the device side.
|
// Step 5: Initialize communication group on the device side.
|
||||||
if (!group->Initialize(root_info)) {
|
return InitDeviceCommGroup(group, root_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CollectiveManager::InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info) {
|
||||||
|
bool init_group_success = false;
|
||||||
|
bool init_group_fail = false;
|
||||||
|
std::condition_variable thread_blocker;
|
||||||
|
init_group_thread_ = std::make_unique<std::thread>([&] {
|
||||||
|
device_ctx_->Initialize();
|
||||||
|
if (!group->Initialize(root_info)) {
|
||||||
|
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
||||||
|
std::unique_lock<std::mutex> lock(init_group_mutex_);
|
||||||
|
init_group_fail = true;
|
||||||
|
thread_blocker.notify_one();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(init_group_mutex_);
|
||||||
|
init_group_success = true;
|
||||||
|
thread_blocker.notify_one();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
init_group_thread_->detach();
|
||||||
|
|
||||||
|
// Timeout limit 180 seconds to wait finishing init device communication group.
|
||||||
|
const int64_t kTimeToWait = 180;
|
||||||
|
std::unique_lock<std::mutex> locker(init_group_mutex_);
|
||||||
|
(void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait),
|
||||||
|
[&] { return init_group_success || init_group_fail; });
|
||||||
|
|
||||||
|
if (!init_group_success && runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||||
|
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
|
||||||
|
runtime::recovery::RecoveryErrCode::kInitNcclFailed);
|
||||||
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
return true;
|
return init_group_success;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
|
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
|
||||||
|
@ -208,6 +245,10 @@ bool CollectiveManager::InitHostCommlib() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!global_group_ranks_.empty()) {
|
||||||
|
global_group_ranks_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
// Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
|
// Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
|
||||||
global_rank_id_ = host_comm_lib_instance_->global_rank_id();
|
global_rank_id_ = host_comm_lib_instance_->global_rank_id();
|
||||||
global_rank_size_ = host_comm_lib_instance_->global_rank_size();
|
global_rank_size_ = host_comm_lib_instance_->global_rank_size();
|
||||||
|
@ -276,11 +317,18 @@ bool CollectiveManager::AssignLocalRank() {
|
||||||
// AllGather host names across the global communication group.
|
// AllGather host names across the global communication group.
|
||||||
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeUInt64,
|
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeUInt64,
|
||||||
host_global_group_name_)) {
|
host_global_group_name_)) {
|
||||||
|
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||||
|
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
|
||||||
|
runtime::recovery::RecoveryErrCode::kAllGatherHostNameFailed);
|
||||||
|
}
|
||||||
MS_LOG(ERROR) << "AllGather for host names failed.";
|
MS_LOG(ERROR) << "AllGather for host names failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate rank id.
|
// Accumulate rank id.
|
||||||
|
// In disaster recovery scenario, this function will enter multiple times when the network is reconfigured, so old
|
||||||
|
// local rank id need to be cleaned.
|
||||||
|
local_rank_id_ = 0;
|
||||||
for (uint32_t rank = 0; rank < global_rank_size_; rank++) {
|
for (uint32_t rank = 0; rank < global_rank_size_; rank++) {
|
||||||
if (rank == global_rank_id_) {
|
if (rank == global_rank_id_) {
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -80,6 +80,13 @@ class CollectiveManager {
|
||||||
// Assign the local rank id for this process.
|
// Assign the local rank id for this process.
|
||||||
bool AssignLocalRank();
|
bool AssignLocalRank();
|
||||||
|
|
||||||
|
// Initialize communication group on the device side.
|
||||||
|
bool InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info);
|
||||||
|
|
||||||
|
// Initialize communication group on the device side in thread with timeout limit.
|
||||||
|
std::unique_ptr<std::thread> init_group_thread_;
|
||||||
|
std::mutex init_group_mutex_;
|
||||||
|
|
||||||
std::atomic_bool inited_;
|
std::atomic_bool inited_;
|
||||||
std::atomic_bool finalized_;
|
std::atomic_bool finalized_;
|
||||||
|
|
||||||
|
|
|
@ -17,9 +17,12 @@
|
||||||
#include "distributed/init.h"
|
#include "distributed/init.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
|
using runtime::recovery::RecoveryContext;
|
||||||
|
|
||||||
bool Initialize() {
|
bool Initialize() {
|
||||||
if (!InitializeCluster()) {
|
if (!InitializeCluster()) {
|
||||||
MS_LOG(ERROR) << "Failed to initialize cluster.";
|
MS_LOG(ERROR) << "Failed to initialize cluster.";
|
||||||
|
@ -39,10 +42,20 @@ bool Initialize() {
|
||||||
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
|
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
|
||||||
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
|
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
|
||||||
|
|
||||||
|
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||||
|
cluster::ClusterContext::instance()->WaitForClusterReady();
|
||||||
|
}
|
||||||
|
|
||||||
if (!InitializeCollective()) {
|
if (!InitializeCollective()) {
|
||||||
MS_LOG(ERROR) << "Failed to initialize collective communication.";
|
MS_LOG(ERROR) << "Failed to initialize collective communication.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||||
|
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
|
||||||
|
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
|
||||||
|
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -155,9 +155,13 @@ void FileIOUtils::CreateDir(const std::string &dir_path, mode_t mode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void FileIOUtils::CreateDirRecursive(const std::string &dir_path, mode_t mode) {
|
void FileIOUtils::CreateDirRecursive(const std::string &dir_path, mode_t mode) {
|
||||||
|
if (dir_path.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The directory path need to be create is empty";
|
||||||
|
}
|
||||||
size_t dir_path_len = dir_path.length();
|
size_t dir_path_len = dir_path.length();
|
||||||
if (dir_path_len > PATH_MAX) {
|
if (dir_path_len > PATH_MAX) {
|
||||||
MS_LOG(EXCEPTION) << "Directory path is too long: " << dir_path;
|
MS_LOG(EXCEPTION) << "Directory path is too long to exceed max length limit: " << PATH_MAX
|
||||||
|
<< ", the path: " << dir_path;
|
||||||
}
|
}
|
||||||
|
|
||||||
char tmp_dir_path[PATH_MAX] = {0};
|
char tmp_dir_path[PATH_MAX] = {0};
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "fl/server/collective_ops_impl.h"
|
#include "fl/server/collective_ops_impl.h"
|
||||||
#include "fl/server/local_meta_store.h"
|
#include "fl/server/local_meta_store.h"
|
||||||
#include "fl/server/iteration.h"
|
#include "fl/server/iteration.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -323,6 +324,12 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto context_ptr = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
// If enable recovery, set timeout 300s to prevent networking flapping.
|
||||||
|
uint32_t collective_comm_timeout =
|
||||||
|
context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY) ? kCollectiveCommMaxTimeout : kCollectiveCommTimeout;
|
||||||
|
|
||||||
// Ring AllGather.
|
// Ring AllGather.
|
||||||
for (size_t i = 0; i < rank_size_ - 1; i++) {
|
for (size_t i = 0; i < rank_size_ - 1; i++) {
|
||||||
size_t send_chunk_index = (rank_id_ - i + rank_size_) % rank_size_;
|
size_t send_chunk_index = (rank_id_ - i + rank_size_) % rank_size_;
|
||||||
|
@ -337,7 +344,7 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
auto recv_req_id = node_->CollectiveReceiveAsync(node_role_, recv_from_rank, &recv_str);
|
auto recv_req_id = node_->CollectiveReceiveAsync(node_role_, recv_from_rank, &recv_str);
|
||||||
if (!node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
if (!node_->CollectiveWait(recv_req_id, collective_comm_timeout)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -348,7 +355,7 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
|
||||||
<< recv_str->size();
|
<< recv_str->size();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
if (!node_->Wait(send_req_id, collective_comm_timeout)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,8 @@ namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// The timeout for server collective communication in case of network jitter.
|
// The timeout for server collective communication in case of network jitter.
|
||||||
constexpr uint32_t kCollectiveCommTimeout = 30;
|
constexpr uint32_t kCollectiveCommTimeout = 30;
|
||||||
|
// The max timeout for server collective communication, used in disaster recovery to prevent networking flapping.
|
||||||
|
constexpr uint32_t kCollectiveCommMaxTimeout = 300;
|
||||||
|
|
||||||
// The collective communication groups which are composed of multiple processes. Refer to MPI_Group.
|
// The collective communication groups which are composed of multiple processes. Refer to MPI_Group.
|
||||||
struct CommunicationGroupInfo {
|
struct CommunicationGroupInfo {
|
||||||
|
|
|
@ -39,6 +39,14 @@ namespace gpu {
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define CHECK_OP_RET_WITH_EXCEPT_TRANCE(node, expression, message) \
|
||||||
|
{ \
|
||||||
|
bool success = (expression); \
|
||||||
|
if (!success) { \
|
||||||
|
MS_LOG(EXCEPTION) << "Op Error: " << message << " | " << trace::DumpSourceLines(node.lock()); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
#define CHECK_OP_RET_WITH_ERROR(expression, message) \
|
#define CHECK_OP_RET_WITH_ERROR(expression, message) \
|
||||||
{ \
|
{ \
|
||||||
bool success = (expression); \
|
bool success = (expression); \
|
||||||
|
|
|
@ -48,7 +48,6 @@ bool NvidiaCommunicationGroup::Finalize() {
|
||||||
// because 'ncclCommAbort' will abort any uncompleted operations before destroying the communicator, e.g.,
|
// because 'ncclCommAbort' will abort any uncompleted operations before destroying the communicator, e.g.,
|
||||||
// ncclAllReduce.
|
// ncclAllReduce.
|
||||||
CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator.");
|
CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator.");
|
||||||
CHECK_RET(ncclCommDestroy(comm_), ncclSuccess, "Failed to destroy NCCL communicator.");
|
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -724,9 +724,8 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_worker = heartbeat_resp_message.is_worker();
|
bool is_worker = heartbeat_resp_message.is_worker();
|
||||||
bool is_fl_mode = PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
bool is_ps_mode = PSContext::instance()->server_mode() == ps::kServerModePS;
|
||||||
PSContext::instance()->server_mode() == ps::kServerModeHybrid;
|
bool not_enable_recover_node_timeout = (is_worker && is_ps_mode);
|
||||||
bool not_enable_recover_node_timeout = (is_worker && !is_fl_mode);
|
|
||||||
|
|
||||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||||
if (node_recovery_ == nullptr || not_enable_recover_node_timeout) {
|
if (node_recovery_ == nullptr || not_enable_recover_node_timeout) {
|
||||||
|
@ -853,6 +852,8 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||||
connected_nodes_.clear();
|
connected_nodes_.clear();
|
||||||
|
|
||||||
|
OnEventCallback(ClusterEvent::ON_SEND_META_DATA);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
|
|
|
@ -29,6 +29,11 @@ void AbstractPSNode::StartHeartbeatTimer() {
|
||||||
if (!DoHeartbeat()) {
|
if (!DoHeartbeat()) {
|
||||||
MS_LOG(WARNING)
|
MS_LOG(WARNING)
|
||||||
<< "Heartbeat timeout, the tcp connection to scheduler is lost, please check the status of scheduler.";
|
<< "Heartbeat timeout, the tcp connection to scheduler is lost, please check the status of scheduler.";
|
||||||
|
|
||||||
|
if (CheckSchedulerTimeout()) {
|
||||||
|
MS_LOG(WARNING) << "Scheduler is Timeout, please recovery.";
|
||||||
|
}
|
||||||
|
|
||||||
HandleHeartbeatTimeout();
|
HandleHeartbeatTimeout();
|
||||||
} else {
|
} else {
|
||||||
UpdateSchedulerTime();
|
UpdateSchedulerTime();
|
||||||
|
@ -115,8 +120,11 @@ bool AbstractPSNode::InitClientToScheduler() {
|
||||||
});
|
});
|
||||||
client_to_scheduler_thread_->detach();
|
client_to_scheduler_thread_->detach();
|
||||||
|
|
||||||
// Timeout for waiting for the tcp connection to the scheduler.
|
// Timeout for waiting for the tcp connection to the scheduler, 10 seconds in recovery mode, or 900 seconds for first
|
||||||
uint32_t timeout = 10;
|
// build connection to scheduler.
|
||||||
|
const uint32_t timeout_for_reinit_in_recovery = 10;
|
||||||
|
uint32_t timeout = heartbeat_stopped_ ? timeout_for_reinit_in_recovery
|
||||||
|
: PSContext::instance()->cluster_config().cluster_available_timeout;
|
||||||
bool wait_res = client_to_scheduler_->WaitConnected(timeout);
|
bool wait_res = client_to_scheduler_->WaitConnected(timeout);
|
||||||
if (!wait_res) {
|
if (!wait_res) {
|
||||||
is_ready_ = true;
|
is_ready_ = true;
|
||||||
|
|
|
@ -35,6 +35,7 @@ enum class ClusterEvent {
|
||||||
CLUSTER_SCALE_IN_DONE = 6,
|
CLUSTER_SCALE_IN_DONE = 6,
|
||||||
ON_PREPARE_PERSIST = 7,
|
ON_PREPARE_PERSIST = 7,
|
||||||
ON_BEGIN_PERSIST = 8,
|
ON_BEGIN_PERSIST = 8,
|
||||||
|
ON_SEND_META_DATA = 9,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NodeInfo {
|
struct NodeInfo {
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ps/core/node_manager.h"
|
#include "ps/core/node_manager.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
|
@ -45,23 +46,29 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message
|
||||||
return rank_id;
|
return rank_id;
|
||||||
}
|
}
|
||||||
// This is for scheduler recovery
|
// This is for scheduler recovery
|
||||||
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
|
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
|
||||||
return rank_id;
|
return rank_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeManager::ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port) {
|
bool NodeManager::ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port,
|
||||||
|
uint32_t *rank_id) {
|
||||||
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
||||||
std::unordered_map<std::string, NodeInfo> recovery_node_infos = clusterConfig.initial_registered_nodes_infos;
|
std::unordered_map<std::string, NodeInfo> recovery_node_infos = clusterConfig.initial_registered_nodes_infos;
|
||||||
|
|
||||||
if (registered_nodes_info_.find(node_id) == registered_nodes_info_.end() &&
|
if (registered_nodes_info_.find(node_id) == registered_nodes_info_.end() &&
|
||||||
recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
|
recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
|
||||||
|
if (rank_id != nullptr) {
|
||||||
|
*rank_id = recovery_node_infos[node_id].rank_id_;
|
||||||
|
}
|
||||||
recovery_node_infos[node_id].is_alive = true;
|
recovery_node_infos[node_id].is_alive = true;
|
||||||
recovery_node_infos[node_id].ip_ = ip;
|
recovery_node_infos[node_id].ip_ = ip;
|
||||||
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(port);
|
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(port);
|
||||||
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
|
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
|
||||||
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
|
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
|
||||||
<< ", ip: " << ip << ", port: " << port;
|
<< ", ip: " << ip << ", port: " << port;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr<MessageMeta> &meta) {
|
uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr<MessageMeta> &meta) {
|
||||||
|
@ -219,9 +226,14 @@ void NodeManager::UpdateCluster() {
|
||||||
|
|
||||||
if (!timeout_nodes_info_.empty()) {
|
if (!timeout_nodes_info_.empty()) {
|
||||||
UpdateClusterState(ClusterState::NODE_TIMEOUT);
|
UpdateClusterState(ClusterState::NODE_TIMEOUT);
|
||||||
for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
|
|
||||||
(void)heartbeats_.erase(iter->first);
|
auto context_ptr = MsContext::GetInstance();
|
||||||
finish_nodes_id_.insert(iter->first);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) {
|
||||||
|
for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
|
||||||
|
(void)heartbeats_.erase(iter->first);
|
||||||
|
finish_nodes_id_.insert(iter->first);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (onPersist_) {
|
if (onPersist_) {
|
||||||
onPersist_();
|
onPersist_();
|
||||||
|
|
|
@ -59,7 +59,8 @@ class NodeManager {
|
||||||
uint32_t checkIfRankIdExist(const RegisterMessage ®ister_message);
|
uint32_t checkIfRankIdExist(const RegisterMessage ®ister_message);
|
||||||
|
|
||||||
// Re-Add the server or worker node into the registered node list if the node do not existed in the scheduler.
|
// Re-Add the server or worker node into the registered node list if the node do not existed in the scheduler.
|
||||||
void ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port);
|
bool ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port,
|
||||||
|
uint32_t *rank_id = nullptr);
|
||||||
|
|
||||||
void UpdateHeartbeat(const std::string &node_id);
|
void UpdateHeartbeat(const std::string &node_id);
|
||||||
std::vector<ServersMeta> FetchServersMeta();
|
std::vector<ServersMeta> FetchServersMeta();
|
||||||
|
|
|
@ -172,14 +172,29 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
|
||||||
MS_LOG(WARNING) << "Send heart beat failed.";
|
MS_LOG(WARNING) << "Send heart beat failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node_manager_.IsAllNodesRegistered()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Re-Add the missing node into node manager.
|
// Re-Add the missing node into node manager.
|
||||||
if (heartbeat_message.has_address()) {
|
if (heartbeat_message.has_address() &&
|
||||||
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port());
|
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port())) {
|
||||||
|
SetRegisterConnectionFd(conn, node_id);
|
||||||
|
|
||||||
if (node_manager_.IsAllNodesRegistered()) {
|
if (node_manager_.IsAllNodesRegistered()) {
|
||||||
is_ready_ = true;
|
is_ready_ = true;
|
||||||
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
|
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
|
||||||
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
|
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
|
||||||
node_manager_.UpdateNodesInfo();
|
node_manager_.UpdateNodesInfo();
|
||||||
|
|
||||||
|
auto node_infos = node_manager_.nodes_info();
|
||||||
|
for (const auto &kvs : node_infos) {
|
||||||
|
auto client = GetOrCreateClient(kvs.second);
|
||||||
|
MS_EXCEPTION_IF_NULL(client);
|
||||||
|
SendMetadata(client, kvs.second.rank_id_);
|
||||||
|
node_manager_.UpdateHeartbeat(kvs.first);
|
||||||
|
}
|
||||||
|
|
||||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||||
PersistMetaData();
|
PersistMetaData();
|
||||||
wait_start_cond_.notify_all();
|
wait_start_cond_.notify_all();
|
||||||
|
|
|
@ -2,7 +2,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*
|
||||||
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
||||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
||||||
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
|
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
|
||||||
"ms_device_shape_transfer.cc" "context_extends.cc"
|
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/stream_synchronizer.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
#include "distributed/collective/collective_manager.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
std::mutex StreamSynchronizer::instance_lock_;
|
||||||
|
std::shared_ptr<StreamSynchronizer> StreamSynchronizer::instance_ = nullptr;
|
||||||
|
|
||||||
|
StreamSynchronizer::~StreamSynchronizer() {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||||
|
stop_ = true;
|
||||||
|
}
|
||||||
|
do_sync_stream_cv_.notify_all();
|
||||||
|
worker_thread_.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t timeout) {
|
||||||
|
std::unique_lock<std::mutex> reentrant_lock(reentrant_mutex_);
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||||
|
|
||||||
|
const auto &device_context =
|
||||||
|
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
|
||||||
|
// If disable recovery or timeout==0, sync stream directly to improve performance.
|
||||||
|
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
|
||||||
|
device_context->Initialize();
|
||||||
|
return device_context->SyncStream();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||||
|
if (stop_) {
|
||||||
|
MS_LOG(EXCEPTION) << "The synchronization stream task has stopped";
|
||||||
|
}
|
||||||
|
device_context_ = device_context;
|
||||||
|
do_sync_stream_cv_.notify_one();
|
||||||
|
|
||||||
|
if (sync_stream_time_out_) {
|
||||||
|
// If sync stream timeout has happened, increase the timeout by 4 times.
|
||||||
|
const uint32_t kTimeOutScaleFactor = 4;
|
||||||
|
timeout *= kTimeOutScaleFactor;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) {
|
||||||
|
if (!sync_stream_ret_) {
|
||||||
|
MS_LOG(ERROR) << "Synchronize stream failed.";
|
||||||
|
}
|
||||||
|
return sync_stream_ret_;
|
||||||
|
} else {
|
||||||
|
sync_stream_time_out_ = true;
|
||||||
|
runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true);
|
||||||
|
distributed::collective::CollectiveManager::instance()->Finalize();
|
||||||
|
time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; });
|
||||||
|
MS_LOG(WARNING) << "Synchronize stream time out.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void StreamSynchronizer::DoSyncStreamTask() {
|
||||||
|
for (;;) {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||||
|
do_sync_stream_cv_.wait(lock, [this]() { return stop_ || device_context_ != nullptr; });
|
||||||
|
if (stop_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
device_context_->Initialize();
|
||||||
|
// Really sync stream.
|
||||||
|
sync_stream_ret_ = device_context_->SyncStream();
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||||
|
device_context_ = nullptr;
|
||||||
|
}
|
||||||
|
time_out_cv_.notify_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,93 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
|
||||||
|
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <thread>
|
||||||
|
#include <mutex>
|
||||||
|
#include <condition_variable>
|
||||||
|
|
||||||
|
#include "runtime/hardware/device_context.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
constexpr uint32_t kTimeoutInSeconds = 30;
|
||||||
|
|
||||||
|
// Execute synchronization stream with timeout mechanism. Typical application scenarios: it is used to monitor
|
||||||
|
// distributed data parallel training scenarios and whether a process exits unexpectedly.
|
||||||
|
class StreamSynchronizer {
|
||||||
|
public:
|
||||||
|
static std::shared_ptr<StreamSynchronizer> &GetInstance() {
|
||||||
|
std::lock_guard<std::mutex> lock(instance_lock_);
|
||||||
|
if (instance_ == nullptr) {
|
||||||
|
instance_.reset(new (std::nothrow) StreamSynchronizer());
|
||||||
|
}
|
||||||
|
return instance_;
|
||||||
|
}
|
||||||
|
|
||||||
|
~StreamSynchronizer();
|
||||||
|
|
||||||
|
// Execute synchronization stream with timeout mechanism.
|
||||||
|
bool SyncStream(const std::string &device_name, uint32_t timeout = kTimeoutInSeconds);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Create a thread to actually execute the synchronization stream task.
|
||||||
|
StreamSynchronizer() { worker_thread_ = std::thread(&StreamSynchronizer::DoSyncStreamTask, this); }
|
||||||
|
|
||||||
|
DISABLE_COPY_AND_ASSIGN(StreamSynchronizer);
|
||||||
|
|
||||||
|
// Monitor whether there are synchronization stream tasks, and actually execute the synchronization stream
|
||||||
|
// tasks.
|
||||||
|
void DoSyncStreamTask();
|
||||||
|
|
||||||
|
// Used for multi-thread safety of singleton creation.
|
||||||
|
static std::mutex instance_lock_;
|
||||||
|
|
||||||
|
// The singleton pointer.
|
||||||
|
static std::shared_ptr<StreamSynchronizer> instance_;
|
||||||
|
|
||||||
|
// Record whether the synchronization stream task has timed out.
|
||||||
|
bool sync_stream_time_out_{false};
|
||||||
|
|
||||||
|
// Return value of synchronization stream.
|
||||||
|
bool sync_stream_ret_{false};
|
||||||
|
|
||||||
|
// Whether synchronization stream thread need to stop.
|
||||||
|
bool stop_{false};
|
||||||
|
|
||||||
|
DeviceContext *device_context_{nullptr};
|
||||||
|
|
||||||
|
// The method SyncStream is not multiple threads safe, so use this lock to prevent simultaneous access by
|
||||||
|
// multiple threads.
|
||||||
|
std::mutex reentrant_mutex_;
|
||||||
|
|
||||||
|
// Use this lock to ensure the safety of external calls to SyncStream and the execution of DoSyncStreamTask
|
||||||
|
// in worker_thread_;
|
||||||
|
std::mutex task_mutex_;
|
||||||
|
|
||||||
|
// The thread to actually execute the synchronization stream task.
|
||||||
|
std::thread worker_thread_;
|
||||||
|
|
||||||
|
std::condition_variable time_out_cv_;
|
||||||
|
std::condition_variable do_sync_stream_cv_;
|
||||||
|
};
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
|
|
@ -25,9 +25,11 @@
|
||||||
#include "mindrt/include/async/async.h"
|
#include "mindrt/include/async/async.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "include/common/utils/convert_utils.h"
|
#include "include/common/utils/convert_utils.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
using recovery::RecoveryContext;
|
||||||
namespace {
|
namespace {
|
||||||
void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
|
void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
|
||||||
const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
|
const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
|
||||||
|
@ -354,6 +356,10 @@ void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::ve
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (RecoveryContext::GetInstance()->enable_recovery() &&
|
||||||
|
RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
|
||||||
|
RecoveryContext::GetInstance()->set_need_sync_weight_to_device(false);
|
||||||
|
}
|
||||||
|
|
||||||
PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
|
PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
|
||||||
}
|
}
|
||||||
|
@ -722,6 +728,11 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
|
||||||
host_tensor_address->SetNodeIndex(backend_node, 0);
|
host_tensor_address->SetNodeIndex(backend_node, 0);
|
||||||
DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
|
DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
|
||||||
|
|
||||||
|
if (RecoveryContext::GetInstance()->enable_recovery() &&
|
||||||
|
RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
|
||||||
|
is_need_sync = true;
|
||||||
|
}
|
||||||
|
|
||||||
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
|
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
|
||||||
MS_EXCEPTION_IF_NULL(host_tensor_address);
|
MS_EXCEPTION_IF_NULL(host_tensor_address);
|
||||||
if (is_need_sync || (host_tensor_address->GetPtr() == nullptr)) {
|
if (is_need_sync || (host_tensor_address->GetPtr() == nullptr)) {
|
||||||
|
|
|
@ -21,9 +21,12 @@
|
||||||
#include "runtime/graph_scheduler/actor/debug_actor.h"
|
#include "runtime/graph_scheduler/actor/debug_actor.h"
|
||||||
#include "mindrt/include/async/async.h"
|
#include "mindrt/include/async/async.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
using recovery::RecoveryContext;
|
||||||
|
|
||||||
void KernelActor::Init() {
|
void KernelActor::Init() {
|
||||||
// Check device contexts number.
|
// Check device contexts number.
|
||||||
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
|
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
|
||||||
|
@ -240,11 +243,18 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||||
PreLaunchKernel(context);
|
PreLaunchKernel(context);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||||
launch_info_.outputs_, is_dynamic_shape_);
|
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
|
||||||
if (!ret) {
|
// especially the collective communication operators.
|
||||||
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
|
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
|
||||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
<< kernel_->fullname_with_scope();
|
||||||
|
} else {
|
||||||
|
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||||
|
launch_info_.outputs_, is_dynamic_shape_);
|
||||||
|
if (!ret) {
|
||||||
|
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "runtime/graph_scheduler/actor/loop_count_actor.h"
|
#include "runtime/graph_scheduler/actor/loop_count_actor.h"
|
||||||
|
#include <set>
|
||||||
#include "runtime/graph_scheduler/actor/data_prepare_actor.h"
|
#include "runtime/graph_scheduler/actor/data_prepare_actor.h"
|
||||||
#include "runtime/graph_scheduler/actor/output_actor.h"
|
#include "runtime/graph_scheduler/actor/output_actor.h"
|
||||||
#include "runtime/graph_scheduler/actor/memory_manager_actor.h"
|
#include "runtime/graph_scheduler/actor/memory_manager_actor.h"
|
||||||
|
@ -23,9 +24,13 @@
|
||||||
#include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
|
#include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
|
||||||
#include "mindrt/include/async/async.h"
|
#include "mindrt/include/async/async.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "runtime/device/stream_synchronizer.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
using recovery::RecoveryContext;
|
||||||
|
|
||||||
void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
|
void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
|
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
|
||||||
|
@ -52,6 +57,26 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sync device stream.
|
||||||
|
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||||
|
std::set<const DeviceContext *> sync_stream_device_contexts;
|
||||||
|
for (auto &device_context : device_contexts_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
if ((sync_stream_device_contexts.count(device_context) == 0) &&
|
||||||
|
(!device::StreamSynchronizer::GetInstance()->SyncStream(device_context->device_context_key().device_name_))) {
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
|
||||||
|
("Sync stream failed:" + device_context->device_context_key().ToString()));
|
||||||
|
}
|
||||||
|
(void)sync_stream_device_contexts.insert(device_context);
|
||||||
|
|
||||||
|
// Trigger disaster recovery and exit loop early.
|
||||||
|
if (RecoveryContext::GetInstance()->enable_recovery() &&
|
||||||
|
RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||||
|
current_count_ = loop_count_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
PostRun(context);
|
PostRun(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
||||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -35,11 +36,17 @@ namespace runtime {
|
||||||
class LoopCountActor : public DebugAwareActor {
|
class LoopCountActor : public DebugAwareActor {
|
||||||
public:
|
public:
|
||||||
LoopCountActor(const std::string &name, size_t loop_count, const AID &memory_manager_aid, const AID *debug_aid,
|
LoopCountActor(const std::string &name, size_t loop_count, const AID &memory_manager_aid, const AID *debug_aid,
|
||||||
const AID *recorder_aid)
|
const AID *recorder_aid, GraphExecutionStrategy strategy,
|
||||||
|
const std::vector<DeviceContext *> &device_contexts)
|
||||||
: DebugAwareActor(name, KernelTransformType::kLoopCountActor, recorder_aid, memory_manager_aid, debug_aid),
|
: DebugAwareActor(name, KernelTransformType::kLoopCountActor, recorder_aid, memory_manager_aid, debug_aid),
|
||||||
loop_count_(loop_count),
|
loop_count_(loop_count),
|
||||||
current_count_(0),
|
current_count_(0),
|
||||||
total_running_count_(0) {}
|
total_running_count_(0),
|
||||||
|
strategy_(strategy) {
|
||||||
|
(void)std::transform(
|
||||||
|
device_contexts.begin(), device_contexts.end(), std::back_inserter(device_contexts_),
|
||||||
|
[](DeviceContext *device_context) { return static_cast<const DeviceContext *>(device_context); });
|
||||||
|
}
|
||||||
|
|
||||||
~LoopCountActor() override = default;
|
~LoopCountActor() override = default;
|
||||||
|
|
||||||
|
@ -73,6 +80,10 @@ class LoopCountActor : public DebugAwareActor {
|
||||||
// The actors which need be handled separately by loop count actor.
|
// The actors which need be handled separately by loop count actor.
|
||||||
AID data_prepare_aid_;
|
AID data_prepare_aid_;
|
||||||
std::vector<AID> entrance_aids_;
|
std::vector<AID> entrance_aids_;
|
||||||
|
|
||||||
|
// The execution strategy for executing actor.
|
||||||
|
// In pipeline mode, sync stream for every step.
|
||||||
|
GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
|
||||||
};
|
};
|
||||||
|
|
||||||
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
|
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
|
||||||
|
|
|
@ -17,9 +17,12 @@
|
||||||
#include "runtime/graph_scheduler/actor/output_actor.h"
|
#include "runtime/graph_scheduler/actor/output_actor.h"
|
||||||
#include "runtime/hardware/device_context_manager.h"
|
#include "runtime/hardware/device_context_manager.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
using recovery::RecoveryContext;
|
||||||
|
|
||||||
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
|
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
|
||||||
MS_EXCEPTION_IF_NULL(output_node);
|
MS_EXCEPTION_IF_NULL(output_node);
|
||||||
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
||||||
|
@ -70,10 +73,44 @@ void OutputActor::Init() {
|
||||||
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
|
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OutputActor::FreeOutputNodeMem() {
|
||||||
|
for (size_t i = 0; i < output_nodes_.size(); ++i) {
|
||||||
|
auto &output_node = output_nodes_[i].first;
|
||||||
|
auto &output_device_tensor = output_device_tensors_[i];
|
||||||
|
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
|
||||||
|
FreeMemory(output_device_tensor, device_contexts_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OutputActor::ClearOutputCache() {
|
||||||
|
output_node_to_tensor_device_address_.clear();
|
||||||
|
outputs_.clear();
|
||||||
|
outputs_.resize(outputs_num_);
|
||||||
|
output_nodes_.clear();
|
||||||
|
output_nodes_.resize(outputs_num_);
|
||||||
|
output_device_tensors_.clear();
|
||||||
|
output_device_tensors_.resize(outputs_num_);
|
||||||
|
|
||||||
|
current_outputs_num_ = 0;
|
||||||
|
current_count_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const context) {
|
void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const context) {
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
|
||||||
++current_count_;
|
++current_count_;
|
||||||
|
|
||||||
|
// Trigger disaster recovery and return empty output.
|
||||||
|
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||||
|
FreeOutputNodeMem();
|
||||||
|
ClearOutputCache();
|
||||||
|
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||||
|
}
|
||||||
|
|
||||||
// The last loop.
|
// The last loop.
|
||||||
if (loop_count_ == current_count_) {
|
if (loop_count_ == current_count_) {
|
||||||
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
|
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
|
||||||
|
@ -106,19 +143,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
|
||||||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||||
}
|
}
|
||||||
|
|
||||||
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
|
FreeOutputNodeMem();
|
||||||
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
|
|
||||||
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
|
|
||||||
for (size_t i = 0; i < output_nodes_.size(); ++i) {
|
|
||||||
auto &output_node = output_nodes_[i].first;
|
|
||||||
auto &output_device_tensor = output_device_tensors_[i];
|
|
||||||
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
|
|
||||||
FreeMemory(output_device_tensor, device_contexts_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send control arrow to trigger next step running.
|
// Send control arrow to trigger next step running.
|
||||||
auto from_aid = const_cast<AID *>(&GetAID());
|
auto from_aid = const_cast<AID *>(&GetAID());
|
||||||
|
|
|
@ -79,6 +79,14 @@ class OutputActor : public AbstractActor {
|
||||||
|
|
||||||
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position);
|
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position);
|
||||||
|
|
||||||
|
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
|
||||||
|
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
|
||||||
|
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
|
||||||
|
void FreeOutputNodeMem();
|
||||||
|
|
||||||
|
// Clear output nodes and tensors in cache.
|
||||||
|
void ClearOutputCache();
|
||||||
|
|
||||||
// The loop count is constant, the current count is increased after each step running finished.
|
// The loop count is constant, the current count is increased after each step running finished.
|
||||||
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
|
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
|
||||||
size_t loop_count_;
|
size_t loop_count_;
|
||||||
|
|
|
@ -46,9 +46,11 @@
|
||||||
#endif
|
#endif
|
||||||
#include "profiler/device/profiling.h"
|
#include "profiler/device/profiling.h"
|
||||||
#include "debug/common.h"
|
#include "debug/common.h"
|
||||||
|
#include "runtime/recovery/recovery_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
using recovery::RecoveryContext;
|
||||||
namespace {
|
namespace {
|
||||||
bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
|
bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
|
||||||
MS_EXCEPTION_IF_NULL(from_device_context);
|
MS_EXCEPTION_IF_NULL(from_device_context);
|
||||||
|
@ -464,21 +466,19 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
|
||||||
MS_LOG(EXCEPTION) << op_context.error_info_;
|
MS_LOG(EXCEPTION) << op_context.error_info_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync device stream.
|
|
||||||
if (strategy == GraphExecutionStrategy::kPipeline) {
|
|
||||||
std::set<DeviceContext *> sync_stream_device_contexts;
|
|
||||||
for (auto &device_context : device_contexts) {
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
if ((sync_stream_device_contexts.count(device_context) == 0) && (!device_context->SyncStream())) {
|
|
||||||
MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
|
|
||||||
}
|
|
||||||
(void)sync_stream_device_contexts.insert(device_context);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
double end_time = GetTime();
|
double end_time = GetTime();
|
||||||
const size_t kSecondsToMilliseconds = 1000;
|
const size_t kSecondsToMilliseconds = 1000;
|
||||||
SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
|
SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
|
||||||
|
|
||||||
|
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||||
|
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
|
||||||
|
if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
|
||||||
|
|
||||||
|
RecoveryContext::GetInstance()->set_need_reinit_collective(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
|
void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
|
||||||
|
@ -843,7 +843,8 @@ LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &g
|
||||||
|
|
||||||
auto actor_name = graph_compiler_info.name_ + kLoopCountActorNameSuffix;
|
auto actor_name = graph_compiler_info.name_ + kLoopCountActorNameSuffix;
|
||||||
auto loop_count_actor =
|
auto loop_count_actor =
|
||||||
std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
|
std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_,
|
||||||
|
graph_compiler_info.strategy_, graph_compiler_info.device_contexts_);
|
||||||
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
|
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
|
||||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "distributed/init.h"
|
#include "distributed/init.h"
|
||||||
#include "runtime/hardware/device_context.h"
|
#include "runtime/hardware/device_context.h"
|
||||||
#include "utils/convert_utils_base.h"
|
#include "utils/convert_utils_base.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
@ -82,6 +83,10 @@ void RecoveryContext::Initialize() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto context_ptr = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
context_ptr->set_param<bool>(MS_CTX_ENABLE_RECOVERY, true);
|
||||||
|
|
||||||
recovery_path_ = common::GetEnv(kEnvRecoveryPath);
|
recovery_path_ = common::GetEnv(kEnvRecoveryPath);
|
||||||
if (recovery_path_.empty()) {
|
if (recovery_path_.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly.";
|
MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly.";
|
||||||
|
@ -140,20 +145,20 @@ void RecoveryContext::Initialize() {
|
||||||
bool RecoveryContext::ReInitializeCollective() {
|
bool RecoveryContext::ReInitializeCollective() {
|
||||||
auto ret = distributed::Initialize();
|
auto ret = distributed::Initialize();
|
||||||
if (ret) {
|
if (ret) {
|
||||||
recovery_status_ = RecoveryStatus::kUnKnownError;
|
recovery_status_ = RecoveryErrCode::kUnKnownError;
|
||||||
set_need_reset(true);
|
set_need_reset(true);
|
||||||
set_need_sync_weight_to_device(true);
|
set_need_sync_weight_to_device(true);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (recovery_status_ == RecoveryStatus::kBroadcastUniqueIDFailed ||
|
if (recovery_status_ == RecoveryErrCode::kBroadcastUniqueIDFailed ||
|
||||||
recovery_status_ == RecoveryStatus::kAllGatherHostNameFailed) {
|
recovery_status_ == RecoveryErrCode::kAllGatherHostNameFailed) {
|
||||||
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
|
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
|
||||||
// Retry duration: 30s.
|
// Retry duration: 30s.
|
||||||
const int kRetryDuration = 30;
|
const int kRetryDuration = 30;
|
||||||
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
|
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
|
||||||
return ReInitializeCollective();
|
return ReInitializeCollective();
|
||||||
} else if (recovery_status_ == RecoveryStatus::kInitNcclFailed) {
|
} else if (recovery_status_ == RecoveryErrCode::kInitNcclFailed) {
|
||||||
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
|
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace recovery {
|
||||||
using distributed::storage::FileIOUtils;
|
using distributed::storage::FileIOUtils;
|
||||||
using distributed::storage::JsonUtils;
|
using distributed::storage::JsonUtils;
|
||||||
|
|
||||||
enum class RecoveryStatus { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
|
enum class RecoveryErrCode { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
|
||||||
|
|
||||||
// Used to save disaster recovery-related state quantities and provide disaster recovery-related
|
// Used to save disaster recovery-related state quantities and provide disaster recovery-related
|
||||||
// functions, such as reinitializing collective communication, etc.
|
// functions, such as reinitializing collective communication, etc.
|
||||||
|
@ -65,9 +65,9 @@ class RecoveryContext {
|
||||||
int recovery_interval() const { return recovery_interval_; }
|
int recovery_interval() const { return recovery_interval_; }
|
||||||
|
|
||||||
// Get the error status of recovery.
|
// Get the error status of recovery.
|
||||||
RecoveryStatus recovery_status() const { return recovery_status_; }
|
RecoveryErrCode recovery_status() const { return recovery_status_; }
|
||||||
// Set the error status of recovery.
|
// Set the error status of recovery.
|
||||||
void set_recovery_status(RecoveryStatus recovery_status) { recovery_status_ = recovery_status; }
|
void set_recovery_status(RecoveryErrCode recovery_status) { recovery_status_ = recovery_status; }
|
||||||
|
|
||||||
// Set the path used to save checkpoint.
|
// Set the path used to save checkpoint.
|
||||||
void SetCkptPath(const std::string &path);
|
void SetCkptPath(const std::string &path);
|
||||||
|
@ -161,7 +161,7 @@ class RecoveryContext {
|
||||||
bool initialized_{false};
|
bool initialized_{false};
|
||||||
|
|
||||||
// The error status of recovery.
|
// The error status of recovery.
|
||||||
RecoveryStatus recovery_status_{RecoveryStatus::kUnKnownError};
|
RecoveryErrCode recovery_status_{RecoveryErrCode::kUnKnownError};
|
||||||
|
|
||||||
// The persitent json file util, used to persist recovery config.
|
// The persitent json file util, used to persist recovery config.
|
||||||
std::unique_ptr<JsonUtils> persistent_json_;
|
std::unique_ptr<JsonUtils> persistent_json_;
|
||||||
|
|
|
@ -100,6 +100,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
|
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
|
||||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
|
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
|
||||||
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
|
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
|
||||||
|
set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
|
||||||
|
|
||||||
uint32_t kDefaultRuntimeNumThreads = 30;
|
uint32_t kDefaultRuntimeNumThreads = 30;
|
||||||
set_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS, kDefaultRuntimeNumThreads);
|
set_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS, kDefaultRuntimeNumThreads);
|
||||||
|
|
|
@ -92,6 +92,7 @@ enum MsCtxParam : unsigned {
|
||||||
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
|
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
|
||||||
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
|
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
|
||||||
MS_CTX_ENABLE_MEM_SCHEDULER,
|
MS_CTX_ENABLE_MEM_SCHEDULER,
|
||||||
|
MS_CTX_ENABLE_RECOVERY,
|
||||||
MS_CTX_TYPE_BOOL_END,
|
MS_CTX_TYPE_BOOL_END,
|
||||||
|
|
||||||
// parameter of type int
|
// parameter of type int
|
||||||
|
|
|
@ -181,7 +181,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/fl/*.cc"
|
"../../../mindspore/ccsrc/fl/*.cc"
|
||||||
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_service.cc"
|
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_service.cc"
|
||||||
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_proxy.cc"
|
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_proxy.cc"
|
||||||
"../../../mindspore/ccsrc/distributed/cluster/cluster_context.cc"
|
|
||||||
"../../../mindspore/ccsrc/distributed/persistent/*.cc"
|
"../../../mindspore/ccsrc/distributed/persistent/*.cc"
|
||||||
"../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc"
|
"../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc"
|
||||||
"../../../mindspore/ccsrc/distributed/cluster/topology/*.cc"
|
"../../../mindspore/ccsrc/distributed/cluster/topology/*.cc"
|
||||||
|
|
Loading…
Reference in New Issue