!31325 Training supports failover

Merge pull request !31325 from zyli2020/worker_failover_bp
This commit is contained in:
i-robot 2022-03-17 01:37:30 +00:00 committed by Gitee
commit c3d15a079e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
32 changed files with 523 additions and 76 deletions

View File

@ -37,6 +37,7 @@
#include "runtime/hardware/device_context_manager.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/pynative/run_op_helper.h"
#include "runtime/recovery/recovery_context.h"
#include "include/common/utils/scoped_long_running.h"
#ifdef ENABLE_D
#include "include/common/utils/callbacks_ge.h"
@ -998,17 +999,22 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
actor_set->output_actor_->UpdateOutputDeviceAddress();
bool need_contruct_output = !(runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
runtime::recovery::RecoveryContext::GetInstance()->need_reset());
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
actor_set->output_actor_->UpdateOutputDeviceAddress();
// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
if (output_tensors.size() > 0) {
size_t output_position = 0;
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
if (output_tensors.size() > 0) {
size_t output_position = 0;
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
}
}
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
// Close abstract_lock for dynamic_shape
AnfUtils::CloseAbstractLock();

View File

@ -30,6 +30,7 @@ namespace cluster {
ClusterContext::ClusterContext()
: inited_(false),
finalized_(true),
cluster_ready_(false),
node_num_each_role_({}),
scheduler_host_(kLocalHost),
scheduler_port_(kDefaultSchedPort),
@ -151,11 +152,11 @@ void ClusterContext::InitClusterConfig() {
bool ClusterContext::BuildCluster() {
// Create node according to different role.
if (node_role_ == kEnvRoleOfWorker) {
node_ = std::make_shared<ps::core::WorkerNode>();
node_ = std::make_shared<ps::core::PSWorkerNode>();
} else if (node_role_ == kEnvRoleOfServer) {
node_ = std::make_shared<ps::core::ServerNode>();
node_ = std::make_shared<ps::core::PSServerNode>();
} else if (node_role_ == kEnvRoleOfScheduler) {
node_ = std::make_shared<ps::core::SchedulerNode>();
node_ = std::make_shared<ps::core::PSSchedulerNode>();
} else {
MS_LOG(EXCEPTION) << "The role " << node_role_ << " is invalid.";
return false;
@ -258,8 +259,20 @@ void ClusterContext::RegisterEventCallback() {
MsException::Instance().SetException();
}
});
abstract_node->RegisterEventCallback(ps::core::ClusterEvent::ON_SEND_META_DATA,
[this]() { cluster_ready_ = true; });
}
}
void ClusterContext::WaitForClusterReady() {
while (!cluster_ready_) {
const int kWaitDuration = 200;
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitDuration));
}
cluster_ready_ = false;
}
} // namespace cluster
} // namespace distributed
} // namespace mindspore

View File

@ -33,6 +33,9 @@
#include "ps/core/worker_node.h"
#include "ps/core/server_node.h"
#include "ps/core/scheduler_node.h"
#include "ps/core/ps_worker_node.h"
#include "ps/core/ps_server_node.h"
#include "ps/core/ps_scheduler_node.h"
#include "distributed/cluster/actor_route_table_proxy.h"
namespace mindspore {
@ -77,6 +80,10 @@ class ClusterContext {
// Return actor route proxy for AbstractNode.
const ActorRouteTableProxyPtr &actor_route_table_proxy() const;
// Wait cluster networking or re-networking successly, using in disaster recovery to prevent collective communication
// ops flapping.
void WaitForClusterReady();
private:
ClusterContext();
@ -101,6 +108,9 @@ class ClusterContext {
// The flag that whether this cluster context instance is already finalized.
std::atomic_bool finalized_;
// The cluster networking or re-networking successly.
std::atomic_bool cluster_ready_;
// The mutex about exiting status of this node.
std::mutex finish_mutex_;

View File

@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "utils/ms_context.h"
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace distributed {
@ -60,7 +61,7 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
}
bool CollectiveManager::Initialize() {
if (inited_) {
if (inited_ && !runtime::recovery::RecoveryContext::GetInstance()->need_reinit_collective()) {
return true;
}
@ -127,15 +128,51 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt8, 0,
group_name)) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
runtime::recovery::RecoveryErrCode::kBroadcastUniqueIDFailed);
}
return false;
}
// Step 5: Initialize communication group on the device side.
if (!group->Initialize(root_info)) {
return InitDeviceCommGroup(group, root_info);
}
bool CollectiveManager::InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info) {
bool init_group_success = false;
bool init_group_fail = false;
std::condition_variable thread_blocker;
init_group_thread_ = std::make_unique<std::thread>([&] {
device_ctx_->Initialize();
if (!group->Initialize(root_info)) {
MS_LOG(ERROR) << "Initialize group on the device side failed.";
std::unique_lock<std::mutex> lock(init_group_mutex_);
init_group_fail = true;
thread_blocker.notify_one();
return;
}
{
std::unique_lock<std::mutex> lock(init_group_mutex_);
init_group_success = true;
thread_blocker.notify_one();
}
});
init_group_thread_->detach();
// Timeout limit 180 seconds to wait finishing init device communication group.
const int64_t kTimeToWait = 180;
std::unique_lock<std::mutex> locker(init_group_mutex_);
(void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait),
[&] { return init_group_success || init_group_fail; });
if (!init_group_success && runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
runtime::recovery::RecoveryErrCode::kInitNcclFailed);
MS_LOG(ERROR) << "Initialize group on the device side failed.";
return false;
}
return true;
return init_group_success;
}
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
@ -208,6 +245,10 @@ bool CollectiveManager::InitHostCommlib() {
return false;
}
if (!global_group_ranks_.empty()) {
global_group_ranks_.clear();
}
// Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
global_rank_id_ = host_comm_lib_instance_->global_rank_id();
global_rank_size_ = host_comm_lib_instance_->global_rank_size();
@ -276,11 +317,18 @@ bool CollectiveManager::AssignLocalRank() {
// AllGather host names across the global communication group.
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeUInt64,
host_global_group_name_)) {
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
runtime::recovery::RecoveryErrCode::kAllGatherHostNameFailed);
}
MS_LOG(ERROR) << "AllGather for host names failed.";
return false;
}
// Accumulate rank id.
// In disaster recovery scenario, this function will enter multiple times when the network is reconfigured, so old
// local rank id need to be cleaned.
local_rank_id_ = 0;
for (uint32_t rank = 0; rank < global_rank_size_; rank++) {
if (rank == global_rank_id_) {
break;

View File

@ -80,6 +80,13 @@ class CollectiveManager {
// Assign the local rank id for this process.
bool AssignLocalRank();
// Initialize communication group on the device side.
bool InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info);
// Initialize communication group on the device side in thread with timeout limit.
std::unique_ptr<std::thread> init_group_thread_;
std::mutex init_group_mutex_;
std::atomic_bool inited_;
std::atomic_bool finalized_;

View File

@ -17,9 +17,12 @@
#include "distributed/init.h"
#include <vector>
#include <string>
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace distributed {
using runtime::recovery::RecoveryContext;
bool Initialize() {
if (!InitializeCluster()) {
MS_LOG(ERROR) << "Failed to initialize cluster.";
@ -39,10 +42,20 @@ bool Initialize() {
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
if (RecoveryContext::GetInstance()->enable_recovery()) {
cluster::ClusterContext::instance()->WaitForClusterReady();
}
if (!InitializeCollective()) {
MS_LOG(ERROR) << "Failed to initialize collective communication.";
return false;
}
if (RecoveryContext::GetInstance()->enable_recovery()) {
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
}
}
}
#endif

View File

@ -155,9 +155,13 @@ void FileIOUtils::CreateDir(const std::string &dir_path, mode_t mode) {
}
void FileIOUtils::CreateDirRecursive(const std::string &dir_path, mode_t mode) {
if (dir_path.empty()) {
MS_LOG(EXCEPTION) << "The directory path need to be create is empty";
}
size_t dir_path_len = dir_path.length();
if (dir_path_len > PATH_MAX) {
MS_LOG(EXCEPTION) << "Directory path is too long: " << dir_path;
MS_LOG(EXCEPTION) << "Directory path is too long to exceed max length limit: " << PATH_MAX
<< ", the path: " << dir_path;
}
char tmp_dir_path[PATH_MAX] = {0};

View File

@ -17,6 +17,7 @@
#include "fl/server/collective_ops_impl.h"
#include "fl/server/local_meta_store.h"
#include "fl/server/iteration.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace fl {
@ -323,6 +324,12 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
return false;
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// If enable recovery, set timeout 300s to prevent networking flapping.
uint32_t collective_comm_timeout =
context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY) ? kCollectiveCommMaxTimeout : kCollectiveCommTimeout;
// Ring AllGather.
for (size_t i = 0; i < rank_size_ - 1; i++) {
size_t send_chunk_index = (rank_id_ - i + rank_size_) % rank_size_;
@ -337,7 +344,7 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = node_->CollectiveReceiveAsync(node_role_, recv_from_rank, &recv_str);
if (!node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
if (!node_->CollectiveWait(recv_req_id, collective_comm_timeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
@ -348,7 +355,7 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
<< recv_str->size();
return false;
}
if (!node_->Wait(send_req_id, kCollectiveCommTimeout)) {
if (!node_->Wait(send_req_id, collective_comm_timeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}

View File

@ -32,6 +32,8 @@ namespace fl {
namespace server {
// The timeout for server collective communication in case of network jitter.
constexpr uint32_t kCollectiveCommTimeout = 30;
// The max timeout for server collective communication, used in disaster recovery to prevent networking flapping.
constexpr uint32_t kCollectiveCommMaxTimeout = 300;
// The collective communication groups which are composed of multiple processes. Refer to MPI_Group.
struct CommunicationGroupInfo {

View File

@ -39,6 +39,14 @@ namespace gpu {
} \
}
#define CHECK_OP_RET_WITH_EXCEPT_TRANCE(node, expression, message) \
{ \
bool success = (expression); \
if (!success) { \
MS_LOG(EXCEPTION) << "Op Error: " << message << " | " << trace::DumpSourceLines(node.lock()); \
} \
}
#define CHECK_OP_RET_WITH_ERROR(expression, message) \
{ \
bool success = (expression); \

View File

@ -48,7 +48,6 @@ bool NvidiaCommunicationGroup::Finalize() {
// because 'ncclCommAbort' will abort any uncompleted operations before destroying the communicator, e.g.,
// ncclAllReduce.
CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator.");
CHECK_RET(ncclCommDestroy(comm_), ncclSuccess, "Failed to destroy NCCL communicator.");
initialized_ = false;
return true;
}

View File

@ -724,9 +724,8 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
}
bool is_worker = heartbeat_resp_message.is_worker();
bool is_fl_mode = PSContext::instance()->server_mode() == ps::kServerModeFL ||
PSContext::instance()->server_mode() == ps::kServerModeHybrid;
bool not_enable_recover_node_timeout = (is_worker && !is_fl_mode);
bool is_ps_mode = PSContext::instance()->server_mode() == ps::kServerModePS;
bool not_enable_recover_node_timeout = (is_worker && is_ps_mode);
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
if (node_recovery_ == nullptr || not_enable_recover_node_timeout) {
@ -853,6 +852,8 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
std::lock_guard<std::mutex> lock(client_mutex_);
connected_nodes_.clear();
OnEventCallback(ClusterEvent::ON_SEND_META_DATA);
}
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,

View File

@ -29,6 +29,11 @@ void AbstractPSNode::StartHeartbeatTimer() {
if (!DoHeartbeat()) {
MS_LOG(WARNING)
<< "Heartbeat timeout, the tcp connection to scheduler is lost, please check the status of scheduler.";
if (CheckSchedulerTimeout()) {
MS_LOG(WARNING) << "Scheduler is Timeout, please recovery.";
}
HandleHeartbeatTimeout();
} else {
UpdateSchedulerTime();
@ -115,8 +120,11 @@ bool AbstractPSNode::InitClientToScheduler() {
});
client_to_scheduler_thread_->detach();
// Timeout for waiting for the tcp connection to the scheduler.
uint32_t timeout = 10;
// Timeout for waiting for the tcp connection to the scheduler, 10 seconds in recovery mode, or 900 seconds for first
// build connection to scheduler.
const uint32_t timeout_for_reinit_in_recovery = 10;
uint32_t timeout = heartbeat_stopped_ ? timeout_for_reinit_in_recovery
: PSContext::instance()->cluster_config().cluster_available_timeout;
bool wait_res = client_to_scheduler_->WaitConnected(timeout);
if (!wait_res) {
is_ready_ = true;

View File

@ -35,6 +35,7 @@ enum class ClusterEvent {
CLUSTER_SCALE_IN_DONE = 6,
ON_PREPARE_PERSIST = 7,
ON_BEGIN_PERSIST = 8,
ON_SEND_META_DATA = 9,
};
struct NodeInfo {

View File

@ -15,6 +15,7 @@
*/
#include "ps/core/node_manager.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace ps {
@ -45,23 +46,29 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
return rank_id;
}
// This is for scheduler recovery
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
return rank_id;
}
void NodeManager::ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port) {
bool NodeManager::ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port,
uint32_t *rank_id) {
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
std::unordered_map<std::string, NodeInfo> recovery_node_infos = clusterConfig.initial_registered_nodes_infos;
if (registered_nodes_info_.find(node_id) == registered_nodes_info_.end() &&
recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
if (rank_id != nullptr) {
*rank_id = recovery_node_infos[node_id].rank_id_;
}
recovery_node_infos[node_id].is_alive = true;
recovery_node_infos[node_id].ip_ = ip;
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(port);
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
<< ", ip: " << ip << ", port: " << port;
return true;
}
return false;
}
uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const std::shared_ptr<MessageMeta> &meta) {
@ -219,9 +226,14 @@ void NodeManager::UpdateCluster() {
if (!timeout_nodes_info_.empty()) {
UpdateClusterState(ClusterState::NODE_TIMEOUT);
for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
(void)heartbeats_.erase(iter->first);
finish_nodes_id_.insert(iter->first);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) {
for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
(void)heartbeats_.erase(iter->first);
finish_nodes_id_.insert(iter->first);
}
}
if (onPersist_) {
onPersist_();

View File

@ -59,7 +59,8 @@ class NodeManager {
uint32_t checkIfRankIdExist(const RegisterMessage &register_message);
// Re-Add the server or worker node into the registered node list if the node do not existed in the scheduler.
void ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port);
bool ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port,
uint32_t *rank_id = nullptr);
void UpdateHeartbeat(const std::string &node_id);
std::vector<ServersMeta> FetchServersMeta();

View File

@ -172,14 +172,29 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
MS_LOG(WARNING) << "Send heart beat failed.";
}
if (node_manager_.IsAllNodesRegistered()) {
return;
}
// Re-Add the missing node into node manager.
if (heartbeat_message.has_address()) {
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port());
if (heartbeat_message.has_address() &&
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port())) {
SetRegisterConnectionFd(conn, node_id);
if (node_manager_.IsAllNodesRegistered()) {
is_ready_ = true;
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
node_manager_.UpdateNodesInfo();
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
SendMetadata(client, kvs.second.rank_id_);
node_manager_.UpdateHeartbeat(kvs.first);
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
PersistMetaData();
wait_start_cond_.notify_all();

View File

@ -2,7 +2,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
"ms_device_shape_transfer.cc" "context_extends.cc"
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc"
)
if("${ENABLE_HIDDEN}" STREQUAL "OFF")

View File

@ -0,0 +1,102 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/stream_synchronizer.h"
#include "utils/ms_context.h"
#include "distributed/collective/collective_manager.h"
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace device {
std::mutex StreamSynchronizer::instance_lock_;
std::shared_ptr<StreamSynchronizer> StreamSynchronizer::instance_ = nullptr;
StreamSynchronizer::~StreamSynchronizer() {
{
std::unique_lock<std::mutex> lock(task_mutex_);
stop_ = true;
}
do_sync_stream_cv_.notify_all();
worker_thread_.join();
}
bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t timeout) {
std::unique_lock<std::mutex> reentrant_lock(reentrant_mutex_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
MS_EXCEPTION_IF_NULL(device_context);
// If disable recovery or timeout==0, sync stream directly to improve performance.
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
device_context->Initialize();
return device_context->SyncStream();
}
std::unique_lock<std::mutex> lock(task_mutex_);
if (stop_) {
MS_LOG(EXCEPTION) << "The synchronization stream task has stopped";
}
device_context_ = device_context;
do_sync_stream_cv_.notify_one();
if (sync_stream_time_out_) {
// If sync stream timeout has happened, increase the timeout by 4 times.
const uint32_t kTimeOutScaleFactor = 4;
timeout *= kTimeOutScaleFactor;
}
if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) {
if (!sync_stream_ret_) {
MS_LOG(ERROR) << "Synchronize stream failed.";
}
return sync_stream_ret_;
} else {
sync_stream_time_out_ = true;
runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true);
distributed::collective::CollectiveManager::instance()->Finalize();
time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; });
MS_LOG(WARNING) << "Synchronize stream time out.";
return true;
}
}
void StreamSynchronizer::DoSyncStreamTask() {
for (;;) {
{
std::unique_lock<std::mutex> lock(task_mutex_);
do_sync_stream_cv_.wait(lock, [this]() { return stop_ || device_context_ != nullptr; });
if (stop_) {
return;
}
}
device_context_->Initialize();
// Really sync stream.
sync_stream_ret_ = device_context_->SyncStream();
{
std::unique_lock<std::mutex> lock(task_mutex_);
device_context_ = nullptr;
}
time_out_cv_.notify_one();
}
}
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,93 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
#include <memory>
#include <string>
#include <vector>
#include <thread>
#include <mutex>
#include <condition_variable>
#include "runtime/hardware/device_context.h"
namespace mindspore {
namespace device {
constexpr uint32_t kTimeoutInSeconds = 30;
// Execute synchronization stream with timeout mechanism. Typical application scenarios: it is used to monitor
// distributed data parallel training scenarios and whether a process exits unexpectedly.
class StreamSynchronizer {
public:
static std::shared_ptr<StreamSynchronizer> &GetInstance() {
std::lock_guard<std::mutex> lock(instance_lock_);
if (instance_ == nullptr) {
instance_.reset(new (std::nothrow) StreamSynchronizer());
}
return instance_;
}
~StreamSynchronizer();
// Execute synchronization stream with timeout mechanism.
bool SyncStream(const std::string &device_name, uint32_t timeout = kTimeoutInSeconds);
private:
// Create a thread to actually execute the synchronization stream task.
StreamSynchronizer() { worker_thread_ = std::thread(&StreamSynchronizer::DoSyncStreamTask, this); }
DISABLE_COPY_AND_ASSIGN(StreamSynchronizer);
// Monitor whether there are synchronization stream tasks, and actually execute the synchronization stream
// tasks.
void DoSyncStreamTask();
// Used for multi-thread safety of singleton creation.
static std::mutex instance_lock_;
// The singleton pointer.
static std::shared_ptr<StreamSynchronizer> instance_;
// Record whether the synchronization stream task has timed out.
bool sync_stream_time_out_{false};
// Return value of synchronization stream.
bool sync_stream_ret_{false};
// Whether synchronization stream thread need to stop.
bool stop_{false};
DeviceContext *device_context_{nullptr};
// The method SyncStream is not multiple threads safe, so use this lock to prevent simultaneous access by
// multiple threads.
std::mutex reentrant_mutex_;
// Use this lock to ensure the safety of external calls to SyncStream and the execution of DoSyncStreamTask
// in worker_thread_;
std::mutex task_mutex_;
// The thread to actually execute the synchronization stream task.
std::thread worker_thread_;
std::condition_variable time_out_cv_;
std::condition_variable do_sync_stream_cv_;
};
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_

View File

@ -25,9 +25,11 @@
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "include/common/utils/convert_utils.h"
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
namespace {
void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
@ -354,6 +356,10 @@ void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::ve
}
}
}
if (RecoveryContext::GetInstance()->enable_recovery() &&
RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
RecoveryContext::GetInstance()->set_need_sync_weight_to_device(false);
}
PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
}
@ -722,6 +728,11 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
host_tensor_address->SetNodeIndex(backend_node, 0);
DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
if (RecoveryContext::GetInstance()->enable_recovery() &&
RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
is_need_sync = true;
}
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
MS_EXCEPTION_IF_NULL(host_tensor_address);
if (is_need_sync || (host_tensor_address->GetPtr() == nullptr)) {

View File

@ -21,9 +21,12 @@
#include "runtime/graph_scheduler/actor/debug_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
void KernelActor::Init() {
// Check device contexts number.
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
@ -240,11 +243,18 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
PreLaunchKernel(context);
try {
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_, is_dynamic_shape_);
if (!ret) {
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
// especially the collective communication operators.
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
<< kernel_->fullname_with_scope();
} else {
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_, is_dynamic_shape_);
if (!ret) {
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
}
} catch (const std::exception &e) {
if (strategy_ == GraphExecutionStrategy::kPipeline) {

View File

@ -15,6 +15,7 @@
*/
#include "runtime/graph_scheduler/actor/loop_count_actor.h"
#include <set>
#include "runtime/graph_scheduler/actor/data_prepare_actor.h"
#include "runtime/graph_scheduler/actor/output_actor.h"
#include "runtime/graph_scheduler/actor/memory_manager_actor.h"
@ -23,9 +24,13 @@
#include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "runtime/device/stream_synchronizer.h"
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
@ -52,6 +57,26 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
return;
}
// Sync device stream.
if (strategy_ == GraphExecutionStrategy::kPipeline) {
std::set<const DeviceContext *> sync_stream_device_contexts;
for (auto &device_context : device_contexts_) {
MS_EXCEPTION_IF_NULL(device_context);
if ((sync_stream_device_contexts.count(device_context) == 0) &&
(!device::StreamSynchronizer::GetInstance()->SyncStream(device_context->device_context_key().device_name_))) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
("Sync stream failed:" + device_context->device_context_key().ToString()));
}
(void)sync_stream_device_contexts.insert(device_context);
// Trigger disaster recovery and exit loop early.
if (RecoveryContext::GetInstance()->enable_recovery() &&
RecoveryContext::GetInstance()->need_reinit_collective()) {
current_count_ = loop_count_;
}
}
}
PostRun(context);
}

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
#include <algorithm>
#include <vector>
#include <string>
#include <memory>
@ -35,11 +36,17 @@ namespace runtime {
class LoopCountActor : public DebugAwareActor {
public:
LoopCountActor(const std::string &name, size_t loop_count, const AID &memory_manager_aid, const AID *debug_aid,
const AID *recorder_aid)
const AID *recorder_aid, GraphExecutionStrategy strategy,
const std::vector<DeviceContext *> &device_contexts)
: DebugAwareActor(name, KernelTransformType::kLoopCountActor, recorder_aid, memory_manager_aid, debug_aid),
loop_count_(loop_count),
current_count_(0),
total_running_count_(0) {}
total_running_count_(0),
strategy_(strategy) {
(void)std::transform(
device_contexts.begin(), device_contexts.end(), std::back_inserter(device_contexts_),
[](DeviceContext *device_context) { return static_cast<const DeviceContext *>(device_context); });
}
~LoopCountActor() override = default;
@ -73,6 +80,10 @@ class LoopCountActor : public DebugAwareActor {
// The actors which need be handled separately by loop count actor.
AID data_prepare_aid_;
std::vector<AID> entrance_aids_;
// The execution strategy for executing actor.
// In pipeline mode, sync stream for every step.
GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
};
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;

View File

@ -17,9 +17,12 @@
#include "runtime/graph_scheduler/actor/output_actor.h"
#include "runtime/hardware/device_context_manager.h"
#include "utils/log_adapter.h"
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(output_device_tensor);
@ -70,10 +73,44 @@ void OutputActor::Init() {
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
}
void OutputActor::FreeOutputNodeMem() {
for (size_t i = 0; i < output_nodes_.size(); ++i) {
auto &output_node = output_nodes_[i].first;
auto &output_device_tensor = output_device_tensors_[i];
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
return;
}
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
FreeMemory(output_device_tensor, device_contexts_[i]);
}
}
}
void OutputActor::ClearOutputCache() {
output_node_to_tensor_device_address_.clear();
outputs_.clear();
outputs_.resize(outputs_num_);
output_nodes_.clear();
output_nodes_.resize(outputs_num_);
output_device_tensors_.clear();
output_device_tensors_.resize(outputs_num_);
current_outputs_num_ = 0;
current_count_ = 0;
}
void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
++current_count_;
// Trigger disaster recovery and return empty output.
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
FreeOutputNodeMem();
ClearOutputCache();
SET_OPCONTEXT_SUCCESS_RET((*context));
}
// The last loop.
if (loop_count_ == current_count_) {
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
@ -106,19 +143,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
SET_OPCONTEXT_SUCCESS_RET((*context));
}
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
for (size_t i = 0; i < output_nodes_.size(); ++i) {
auto &output_node = output_nodes_[i].first;
auto &output_device_tensor = output_device_tensors_[i];
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
return;
}
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
FreeMemory(output_device_tensor, device_contexts_[i]);
}
}
FreeOutputNodeMem();
// Send control arrow to trigger next step running.
auto from_aid = const_cast<AID *>(&GetAID());

View File

@ -79,6 +79,14 @@ class OutputActor : public AbstractActor {
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position);
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
void FreeOutputNodeMem();
// Clear output nodes and tensors in cache.
void ClearOutputCache();
// The loop count is constant, the current count is increased after each step running finished.
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
size_t loop_count_;

View File

@ -46,9 +46,11 @@
#endif
#include "profiler/device/profiling.h"
#include "debug/common.h"
#include "runtime/recovery/recovery_context.h"
namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
namespace {
bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
MS_EXCEPTION_IF_NULL(from_device_context);
@ -464,21 +466,19 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
MS_LOG(EXCEPTION) << op_context.error_info_;
}
// Sync device stream.
if (strategy == GraphExecutionStrategy::kPipeline) {
std::set<DeviceContext *> sync_stream_device_contexts;
for (auto &device_context : device_contexts) {
MS_EXCEPTION_IF_NULL(device_context);
if ((sync_stream_device_contexts.count(device_context) == 0) && (!device_context->SyncStream())) {
MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
}
(void)sync_stream_device_contexts.insert(device_context);
}
}
double end_time = GetTime();
const size_t kSecondsToMilliseconds = 1000;
SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
}
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
RecoveryContext::GetInstance()->set_need_reinit_collective(false);
}
}
void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
@ -843,7 +843,8 @@ LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &g
auto actor_name = graph_compiler_info.name_ + kLoopCountActorNameSuffix;
auto loop_count_actor =
std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_,
graph_compiler_info.strategy_, graph_compiler_info.device_contexts_);
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
MS_EXCEPTION_IF_NULL(loop_count_actor);

View File

@ -29,6 +29,7 @@
#include "distributed/init.h"
#include "runtime/hardware/device_context.h"
#include "utils/convert_utils_base.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace runtime {
@ -82,6 +83,10 @@ void RecoveryContext::Initialize() {
return;
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
context_ptr->set_param<bool>(MS_CTX_ENABLE_RECOVERY, true);
recovery_path_ = common::GetEnv(kEnvRecoveryPath);
if (recovery_path_.empty()) {
MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly.";
@ -140,20 +145,20 @@ void RecoveryContext::Initialize() {
bool RecoveryContext::ReInitializeCollective() {
auto ret = distributed::Initialize();
if (ret) {
recovery_status_ = RecoveryStatus::kUnKnownError;
recovery_status_ = RecoveryErrCode::kUnKnownError;
set_need_reset(true);
set_need_sync_weight_to_device(true);
return true;
}
if (recovery_status_ == RecoveryStatus::kBroadcastUniqueIDFailed ||
recovery_status_ == RecoveryStatus::kAllGatherHostNameFailed) {
if (recovery_status_ == RecoveryErrCode::kBroadcastUniqueIDFailed ||
recovery_status_ == RecoveryErrCode::kAllGatherHostNameFailed) {
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
// Retry duration: 30s.
const int kRetryDuration = 30;
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
return ReInitializeCollective();
} else if (recovery_status_ == RecoveryStatus::kInitNcclFailed) {
} else if (recovery_status_ == RecoveryErrCode::kInitNcclFailed) {
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
}

View File

@ -31,7 +31,7 @@ namespace recovery {
using distributed::storage::FileIOUtils;
using distributed::storage::JsonUtils;
enum class RecoveryStatus { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
enum class RecoveryErrCode { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
// Used to save disaster recovery-related state quantities and provide disaster recovery-related
// functions, such as reinitializing collective communication, etc.
@ -65,9 +65,9 @@ class RecoveryContext {
int recovery_interval() const { return recovery_interval_; }
// Get the error status of recovery.
RecoveryStatus recovery_status() const { return recovery_status_; }
RecoveryErrCode recovery_status() const { return recovery_status_; }
// Set the error status of recovery.
void set_recovery_status(RecoveryStatus recovery_status) { recovery_status_ = recovery_status; }
void set_recovery_status(RecoveryErrCode recovery_status) { recovery_status_ = recovery_status; }
// Set the path used to save checkpoint.
void SetCkptPath(const std::string &path);
@ -161,7 +161,7 @@ class RecoveryContext {
bool initialized_{false};
// The error status of recovery.
RecoveryStatus recovery_status_{RecoveryStatus::kUnKnownError};
RecoveryErrCode recovery_status_{RecoveryErrCode::kUnKnownError};
// The persitent json file util, used to persist recovery config.
std::unique_ptr<JsonUtils> persistent_json_;

View File

@ -100,6 +100,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
uint32_t kDefaultRuntimeNumThreads = 30;
set_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS, kDefaultRuntimeNumThreads);

View File

@ -92,6 +92,7 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
MS_CTX_ENABLE_MEM_SCHEDULER,
MS_CTX_ENABLE_RECOVERY,
MS_CTX_TYPE_BOOL_END,
// parameter of type int

View File

@ -181,7 +181,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/fl/*.cc"
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_service.cc"
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_proxy.cc"
"../../../mindspore/ccsrc/distributed/cluster/cluster_context.cc"
"../../../mindspore/ccsrc/distributed/persistent/*.cc"
"../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc"
"../../../mindspore/ccsrc/distributed/cluster/topology/*.cc"