forked from mindspore-Ecosystem/mindspore
!31325 Training supports failover
Merge pull request !31325 from zyli2020/worker_failover_bp
This commit is contained in:
commit
c3d15a079e
|
@ -37,6 +37,7 @@
|
|||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "runtime/pynative/run_op_helper.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "include/common/utils/callbacks_ge.h"
|
||||
|
@ -998,17 +999,22 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
bool need_contruct_output = !(runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
runtime::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
if (need_contruct_output) {
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
|
||||
// Fetch outputs.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||
auto &output_tensors = actor_set->output_actor_->outputs();
|
||||
if (output_tensors.size() > 0) {
|
||||
size_t output_position = 0;
|
||||
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
|
||||
// Fetch outputs.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||
auto &output_tensors = actor_set->output_actor_->outputs();
|
||||
if (output_tensors.size() > 0) {
|
||||
size_t output_position = 0;
|
||||
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
|
||||
}
|
||||
}
|
||||
|
||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||
// Close abstract_lock for dynamic_shape
|
||||
AnfUtils::CloseAbstractLock();
|
||||
|
|
|
@ -30,6 +30,7 @@ namespace cluster {
|
|||
ClusterContext::ClusterContext()
|
||||
: inited_(false),
|
||||
finalized_(true),
|
||||
cluster_ready_(false),
|
||||
node_num_each_role_({}),
|
||||
scheduler_host_(kLocalHost),
|
||||
scheduler_port_(kDefaultSchedPort),
|
||||
|
@ -151,11 +152,11 @@ void ClusterContext::InitClusterConfig() {
|
|||
bool ClusterContext::BuildCluster() {
|
||||
// Create node according to different role.
|
||||
if (node_role_ == kEnvRoleOfWorker) {
|
||||
node_ = std::make_shared<ps::core::WorkerNode>();
|
||||
node_ = std::make_shared<ps::core::PSWorkerNode>();
|
||||
} else if (node_role_ == kEnvRoleOfServer) {
|
||||
node_ = std::make_shared<ps::core::ServerNode>();
|
||||
node_ = std::make_shared<ps::core::PSServerNode>();
|
||||
} else if (node_role_ == kEnvRoleOfScheduler) {
|
||||
node_ = std::make_shared<ps::core::SchedulerNode>();
|
||||
node_ = std::make_shared<ps::core::PSSchedulerNode>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The role " << node_role_ << " is invalid.";
|
||||
return false;
|
||||
|
@ -258,8 +259,20 @@ void ClusterContext::RegisterEventCallback() {
|
|||
MsException::Instance().SetException();
|
||||
}
|
||||
});
|
||||
|
||||
abstract_node->RegisterEventCallback(ps::core::ClusterEvent::ON_SEND_META_DATA,
|
||||
[this]() { cluster_ready_ = true; });
|
||||
}
|
||||
}
|
||||
|
||||
void ClusterContext::WaitForClusterReady() {
|
||||
while (!cluster_ready_) {
|
||||
const int kWaitDuration = 200;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitDuration));
|
||||
}
|
||||
|
||||
cluster_ready_ = false;
|
||||
}
|
||||
} // namespace cluster
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,9 @@
|
|||
#include "ps/core/worker_node.h"
|
||||
#include "ps/core/server_node.h"
|
||||
#include "ps/core/scheduler_node.h"
|
||||
#include "ps/core/ps_worker_node.h"
|
||||
#include "ps/core/ps_server_node.h"
|
||||
#include "ps/core/ps_scheduler_node.h"
|
||||
#include "distributed/cluster/actor_route_table_proxy.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -77,6 +80,10 @@ class ClusterContext {
|
|||
// Return actor route proxy for AbstractNode.
|
||||
const ActorRouteTableProxyPtr &actor_route_table_proxy() const;
|
||||
|
||||
// Wait cluster networking or re-networking successly, using in disaster recovery to prevent collective communication
|
||||
// ops flapping.
|
||||
void WaitForClusterReady();
|
||||
|
||||
private:
|
||||
ClusterContext();
|
||||
|
||||
|
@ -101,6 +108,9 @@ class ClusterContext {
|
|||
// The flag that whether this cluster context instance is already finalized.
|
||||
std::atomic_bool finalized_;
|
||||
|
||||
// The cluster networking or re-networking successly.
|
||||
std::atomic_bool cluster_ready_;
|
||||
|
||||
// The mutex about exiting status of this node.
|
||||
std::mutex finish_mutex_;
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
|
@ -60,7 +61,7 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
|
|||
}
|
||||
|
||||
bool CollectiveManager::Initialize() {
|
||||
if (inited_) {
|
||||
if (inited_ && !runtime::recovery::RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -127,15 +128,51 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
|
|||
if (!host_comm_lib_instance_->Broadcast(root_info, root_info, root_info_size, TypeId::kNumberTypeInt8, 0,
|
||||
group_name)) {
|
||||
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
|
||||
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
|
||||
runtime::recovery::RecoveryErrCode::kBroadcastUniqueIDFailed);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 5: Initialize communication group on the device side.
|
||||
if (!group->Initialize(root_info)) {
|
||||
return InitDeviceCommGroup(group, root_info);
|
||||
}
|
||||
|
||||
bool CollectiveManager::InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info) {
|
||||
bool init_group_success = false;
|
||||
bool init_group_fail = false;
|
||||
std::condition_variable thread_blocker;
|
||||
init_group_thread_ = std::make_unique<std::thread>([&] {
|
||||
device_ctx_->Initialize();
|
||||
if (!group->Initialize(root_info)) {
|
||||
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
||||
std::unique_lock<std::mutex> lock(init_group_mutex_);
|
||||
init_group_fail = true;
|
||||
thread_blocker.notify_one();
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(init_group_mutex_);
|
||||
init_group_success = true;
|
||||
thread_blocker.notify_one();
|
||||
}
|
||||
});
|
||||
init_group_thread_->detach();
|
||||
|
||||
// Timeout limit 180 seconds to wait finishing init device communication group.
|
||||
const int64_t kTimeToWait = 180;
|
||||
std::unique_lock<std::mutex> locker(init_group_mutex_);
|
||||
(void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait),
|
||||
[&] { return init_group_success || init_group_fail; });
|
||||
|
||||
if (!init_group_success && runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
|
||||
runtime::recovery::RecoveryErrCode::kInitNcclFailed);
|
||||
MS_LOG(ERROR) << "Initialize group on the device side failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return init_group_success;
|
||||
}
|
||||
|
||||
bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
|
||||
|
@ -208,6 +245,10 @@ bool CollectiveManager::InitHostCommlib() {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!global_group_ranks_.empty()) {
|
||||
global_group_ranks_.clear();
|
||||
}
|
||||
|
||||
// Reassign 'global_rank_id_' and 'global_rank_size_'. Generate global communication group ranks.
|
||||
global_rank_id_ = host_comm_lib_instance_->global_rank_id();
|
||||
global_rank_size_ = host_comm_lib_instance_->global_rank_size();
|
||||
|
@ -276,11 +317,18 @@ bool CollectiveManager::AssignLocalRank() {
|
|||
// AllGather host names across the global communication group.
|
||||
if (!host_comm_lib_instance_->AllGather(&host_hash, all_host_hashs, 1, TypeId::kNumberTypeUInt64,
|
||||
host_global_group_name_)) {
|
||||
if (runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
|
||||
runtime::recovery::RecoveryErrCode::kAllGatherHostNameFailed);
|
||||
}
|
||||
MS_LOG(ERROR) << "AllGather for host names failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Accumulate rank id.
|
||||
// In disaster recovery scenario, this function will enter multiple times when the network is reconfigured, so old
|
||||
// local rank id need to be cleaned.
|
||||
local_rank_id_ = 0;
|
||||
for (uint32_t rank = 0; rank < global_rank_size_; rank++) {
|
||||
if (rank == global_rank_id_) {
|
||||
break;
|
||||
|
|
|
@ -80,6 +80,13 @@ class CollectiveManager {
|
|||
// Assign the local rank id for this process.
|
||||
bool AssignLocalRank();
|
||||
|
||||
// Initialize communication group on the device side.
|
||||
bool InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info);
|
||||
|
||||
// Initialize communication group on the device side in thread with timeout limit.
|
||||
std::unique_ptr<std::thread> init_group_thread_;
|
||||
std::mutex init_group_mutex_;
|
||||
|
||||
std::atomic_bool inited_;
|
||||
std::atomic_bool finalized_;
|
||||
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
#include "distributed/init.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
using runtime::recovery::RecoveryContext;
|
||||
|
||||
bool Initialize() {
|
||||
if (!InitializeCluster()) {
|
||||
MS_LOG(ERROR) << "Failed to initialize cluster.";
|
||||
|
@ -39,10 +42,20 @@ bool Initialize() {
|
|||
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
|
||||
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
cluster::ClusterContext::instance()->WaitForClusterReady();
|
||||
}
|
||||
|
||||
if (!InitializeCollective()) {
|
||||
MS_LOG(ERROR) << "Failed to initialize collective communication.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery()) {
|
||||
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
|
||||
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
|
||||
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -155,9 +155,13 @@ void FileIOUtils::CreateDir(const std::string &dir_path, mode_t mode) {
|
|||
}
|
||||
|
||||
void FileIOUtils::CreateDirRecursive(const std::string &dir_path, mode_t mode) {
|
||||
if (dir_path.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The directory path need to be create is empty";
|
||||
}
|
||||
size_t dir_path_len = dir_path.length();
|
||||
if (dir_path_len > PATH_MAX) {
|
||||
MS_LOG(EXCEPTION) << "Directory path is too long: " << dir_path;
|
||||
MS_LOG(EXCEPTION) << "Directory path is too long to exceed max length limit: " << PATH_MAX
|
||||
<< ", the path: " << dir_path;
|
||||
}
|
||||
|
||||
char tmp_dir_path[PATH_MAX] = {0};
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "fl/server/collective_ops_impl.h"
|
||||
#include "fl/server/local_meta_store.h"
|
||||
#include "fl/server/iteration.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -323,6 +324,12 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
|
|||
return false;
|
||||
}
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
// If enable recovery, set timeout 300s to prevent networking flapping.
|
||||
uint32_t collective_comm_timeout =
|
||||
context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY) ? kCollectiveCommMaxTimeout : kCollectiveCommTimeout;
|
||||
|
||||
// Ring AllGather.
|
||||
for (size_t i = 0; i < rank_size_ - 1; i++) {
|
||||
size_t send_chunk_index = (rank_id_ - i + rank_size_) % rank_size_;
|
||||
|
@ -337,7 +344,7 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
|
|||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = node_->CollectiveReceiveAsync(node_role_, recv_from_rank, &recv_str);
|
||||
if (!node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
||||
if (!node_->CollectiveWait(recv_req_id, collective_comm_timeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -348,7 +355,7 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
|
|||
<< recv_str->size();
|
||||
return false;
|
||||
}
|
||||
if (!node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
||||
if (!node_->Wait(send_req_id, collective_comm_timeout)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -32,6 +32,8 @@ namespace fl {
|
|||
namespace server {
|
||||
// The timeout for server collective communication in case of network jitter.
|
||||
constexpr uint32_t kCollectiveCommTimeout = 30;
|
||||
// The max timeout for server collective communication, used in disaster recovery to prevent networking flapping.
|
||||
constexpr uint32_t kCollectiveCommMaxTimeout = 300;
|
||||
|
||||
// The collective communication groups which are composed of multiple processes. Refer to MPI_Group.
|
||||
struct CommunicationGroupInfo {
|
||||
|
|
|
@ -39,6 +39,14 @@ namespace gpu {
|
|||
} \
|
||||
}
|
||||
|
||||
#define CHECK_OP_RET_WITH_EXCEPT_TRANCE(node, expression, message) \
|
||||
{ \
|
||||
bool success = (expression); \
|
||||
if (!success) { \
|
||||
MS_LOG(EXCEPTION) << "Op Error: " << message << " | " << trace::DumpSourceLines(node.lock()); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CHECK_OP_RET_WITH_ERROR(expression, message) \
|
||||
{ \
|
||||
bool success = (expression); \
|
||||
|
|
|
@ -48,7 +48,6 @@ bool NvidiaCommunicationGroup::Finalize() {
|
|||
// because 'ncclCommAbort' will abort any uncompleted operations before destroying the communicator, e.g.,
|
||||
// ncclAllReduce.
|
||||
CHECK_RET(ncclCommAbort(comm_), ncclSuccess, "Failed to abort NCCL communicator.");
|
||||
CHECK_RET(ncclCommDestroy(comm_), ncclSuccess, "Failed to destroy NCCL communicator.");
|
||||
initialized_ = false;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -724,9 +724,8 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
|
|||
}
|
||||
|
||||
bool is_worker = heartbeat_resp_message.is_worker();
|
||||
bool is_fl_mode = PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
||||
PSContext::instance()->server_mode() == ps::kServerModeHybrid;
|
||||
bool not_enable_recover_node_timeout = (is_worker && !is_fl_mode);
|
||||
bool is_ps_mode = PSContext::instance()->server_mode() == ps::kServerModePS;
|
||||
bool not_enable_recover_node_timeout = (is_worker && is_ps_mode);
|
||||
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
if (node_recovery_ == nullptr || not_enable_recover_node_timeout) {
|
||||
|
@ -853,6 +852,8 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
|
|||
|
||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||
connected_nodes_.clear();
|
||||
|
||||
OnEventCallback(ClusterEvent::ON_SEND_META_DATA);
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
|
|
|
@ -29,6 +29,11 @@ void AbstractPSNode::StartHeartbeatTimer() {
|
|||
if (!DoHeartbeat()) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Heartbeat timeout, the tcp connection to scheduler is lost, please check the status of scheduler.";
|
||||
|
||||
if (CheckSchedulerTimeout()) {
|
||||
MS_LOG(WARNING) << "Scheduler is Timeout, please recovery.";
|
||||
}
|
||||
|
||||
HandleHeartbeatTimeout();
|
||||
} else {
|
||||
UpdateSchedulerTime();
|
||||
|
@ -115,8 +120,11 @@ bool AbstractPSNode::InitClientToScheduler() {
|
|||
});
|
||||
client_to_scheduler_thread_->detach();
|
||||
|
||||
// Timeout for waiting for the tcp connection to the scheduler.
|
||||
uint32_t timeout = 10;
|
||||
// Timeout for waiting for the tcp connection to the scheduler, 10 seconds in recovery mode, or 900 seconds for first
|
||||
// build connection to scheduler.
|
||||
const uint32_t timeout_for_reinit_in_recovery = 10;
|
||||
uint32_t timeout = heartbeat_stopped_ ? timeout_for_reinit_in_recovery
|
||||
: PSContext::instance()->cluster_config().cluster_available_timeout;
|
||||
bool wait_res = client_to_scheduler_->WaitConnected(timeout);
|
||||
if (!wait_res) {
|
||||
is_ready_ = true;
|
||||
|
|
|
@ -35,6 +35,7 @@ enum class ClusterEvent {
|
|||
CLUSTER_SCALE_IN_DONE = 6,
|
||||
ON_PREPARE_PERSIST = 7,
|
||||
ON_BEGIN_PERSIST = 8,
|
||||
ON_SEND_META_DATA = 9,
|
||||
};
|
||||
|
||||
struct NodeInfo {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "ps/core/node_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -45,23 +46,29 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message
|
|||
return rank_id;
|
||||
}
|
||||
// This is for scheduler recovery
|
||||
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
|
||||
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
|
||||
return rank_id;
|
||||
}
|
||||
|
||||
void NodeManager::ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port) {
|
||||
bool NodeManager::ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port,
|
||||
uint32_t *rank_id) {
|
||||
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
||||
std::unordered_map<std::string, NodeInfo> recovery_node_infos = clusterConfig.initial_registered_nodes_infos;
|
||||
|
||||
if (registered_nodes_info_.find(node_id) == registered_nodes_info_.end() &&
|
||||
recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
|
||||
if (rank_id != nullptr) {
|
||||
*rank_id = recovery_node_infos[node_id].rank_id_;
|
||||
}
|
||||
recovery_node_infos[node_id].is_alive = true;
|
||||
recovery_node_infos[node_id].ip_ = ip;
|
||||
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(port);
|
||||
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
|
||||
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
|
||||
<< ", ip: " << ip << ", port: " << port;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr<MessageMeta> &meta) {
|
||||
|
@ -219,9 +226,14 @@ void NodeManager::UpdateCluster() {
|
|||
|
||||
if (!timeout_nodes_info_.empty()) {
|
||||
UpdateClusterState(ClusterState::NODE_TIMEOUT);
|
||||
for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
|
||||
(void)heartbeats_.erase(iter->first);
|
||||
finish_nodes_id_.insert(iter->first);
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) {
|
||||
for (auto iter = timeout_nodes_info_.begin(); iter != timeout_nodes_info_.end(); ++iter) {
|
||||
(void)heartbeats_.erase(iter->first);
|
||||
finish_nodes_id_.insert(iter->first);
|
||||
}
|
||||
}
|
||||
if (onPersist_) {
|
||||
onPersist_();
|
||||
|
|
|
@ -59,7 +59,8 @@ class NodeManager {
|
|||
uint32_t checkIfRankIdExist(const RegisterMessage ®ister_message);
|
||||
|
||||
// Re-Add the server or worker node into the registered node list if the node do not existed in the scheduler.
|
||||
void ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port);
|
||||
bool ReAddNodeIfNotExists(const std::string &node_id, const std::string &ip, uint32_t port,
|
||||
uint32_t *rank_id = nullptr);
|
||||
|
||||
void UpdateHeartbeat(const std::string &node_id);
|
||||
std::vector<ServersMeta> FetchServersMeta();
|
||||
|
|
|
@ -172,14 +172,29 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
|
|||
MS_LOG(WARNING) << "Send heart beat failed.";
|
||||
}
|
||||
|
||||
if (node_manager_.IsAllNodesRegistered()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Re-Add the missing node into node manager.
|
||||
if (heartbeat_message.has_address()) {
|
||||
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port());
|
||||
if (heartbeat_message.has_address() &&
|
||||
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port())) {
|
||||
SetRegisterConnectionFd(conn, node_id);
|
||||
|
||||
if (node_manager_.IsAllNodesRegistered()) {
|
||||
is_ready_ = true;
|
||||
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
|
||||
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
|
||||
node_manager_.UpdateNodesInfo();
|
||||
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
SendMetadata(client, kvs.second.rank_id_);
|
||||
node_manager_.UpdateHeartbeat(kvs.first);
|
||||
}
|
||||
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||
PersistMetaData();
|
||||
wait_start_cond_.notify_all();
|
||||
|
|
|
@ -2,7 +2,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*
|
|||
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
||||
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
|
||||
"ms_device_shape_transfer.cc" "context_extends.cc"
|
||||
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc"
|
||||
)
|
||||
|
||||
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/stream_synchronizer.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
std::mutex StreamSynchronizer::instance_lock_;
|
||||
std::shared_ptr<StreamSynchronizer> StreamSynchronizer::instance_ = nullptr;
|
||||
|
||||
StreamSynchronizer::~StreamSynchronizer() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
stop_ = true;
|
||||
}
|
||||
do_sync_stream_cv_.notify_all();
|
||||
worker_thread_.join();
|
||||
}
|
||||
|
||||
bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t timeout) {
|
||||
std::unique_lock<std::mutex> reentrant_lock(reentrant_mutex_);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
|
||||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
// If disable recovery or timeout==0, sync stream directly to improve performance.
|
||||
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
|
||||
device_context->Initialize();
|
||||
return device_context->SyncStream();
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
if (stop_) {
|
||||
MS_LOG(EXCEPTION) << "The synchronization stream task has stopped";
|
||||
}
|
||||
device_context_ = device_context;
|
||||
do_sync_stream_cv_.notify_one();
|
||||
|
||||
if (sync_stream_time_out_) {
|
||||
// If sync stream timeout has happened, increase the timeout by 4 times.
|
||||
const uint32_t kTimeOutScaleFactor = 4;
|
||||
timeout *= kTimeOutScaleFactor;
|
||||
}
|
||||
|
||||
if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) {
|
||||
if (!sync_stream_ret_) {
|
||||
MS_LOG(ERROR) << "Synchronize stream failed.";
|
||||
}
|
||||
return sync_stream_ret_;
|
||||
} else {
|
||||
sync_stream_time_out_ = true;
|
||||
runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true);
|
||||
distributed::collective::CollectiveManager::instance()->Finalize();
|
||||
time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; });
|
||||
MS_LOG(WARNING) << "Synchronize stream time out.";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void StreamSynchronizer::DoSyncStreamTask() {
|
||||
for (;;) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
do_sync_stream_cv_.wait(lock, [this]() { return stop_ || device_context_ != nullptr; });
|
||||
if (stop_) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
device_context_->Initialize();
|
||||
// Really sync stream.
|
||||
sync_stream_ret_ = device_context_->SyncStream();
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
device_context_ = nullptr;
|
||||
}
|
||||
time_out_cv_.notify_one();
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
constexpr uint32_t kTimeoutInSeconds = 30;
|
||||
|
||||
// Execute synchronization stream with timeout mechanism. Typical application scenarios: it is used to monitor
|
||||
// distributed data parallel training scenarios and whether a process exits unexpectedly.
|
||||
class StreamSynchronizer {
|
||||
public:
|
||||
static std::shared_ptr<StreamSynchronizer> &GetInstance() {
|
||||
std::lock_guard<std::mutex> lock(instance_lock_);
|
||||
if (instance_ == nullptr) {
|
||||
instance_.reset(new (std::nothrow) StreamSynchronizer());
|
||||
}
|
||||
return instance_;
|
||||
}
|
||||
|
||||
~StreamSynchronizer();
|
||||
|
||||
// Execute synchronization stream with timeout mechanism.
|
||||
bool SyncStream(const std::string &device_name, uint32_t timeout = kTimeoutInSeconds);
|
||||
|
||||
private:
|
||||
// Create a thread to actually execute the synchronization stream task.
|
||||
StreamSynchronizer() { worker_thread_ = std::thread(&StreamSynchronizer::DoSyncStreamTask, this); }
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(StreamSynchronizer);
|
||||
|
||||
// Monitor whether there are synchronization stream tasks, and actually execute the synchronization stream
|
||||
// tasks.
|
||||
void DoSyncStreamTask();
|
||||
|
||||
// Used for multi-thread safety of singleton creation.
|
||||
static std::mutex instance_lock_;
|
||||
|
||||
// The singleton pointer.
|
||||
static std::shared_ptr<StreamSynchronizer> instance_;
|
||||
|
||||
// Record whether the synchronization stream task has timed out.
|
||||
bool sync_stream_time_out_{false};
|
||||
|
||||
// Return value of synchronization stream.
|
||||
bool sync_stream_ret_{false};
|
||||
|
||||
// Whether synchronization stream thread need to stop.
|
||||
bool stop_{false};
|
||||
|
||||
DeviceContext *device_context_{nullptr};
|
||||
|
||||
// The method SyncStream is not multiple threads safe, so use this lock to prevent simultaneous access by
|
||||
// multiple threads.
|
||||
std::mutex reentrant_mutex_;
|
||||
|
||||
// Use this lock to ensure the safety of external calls to SyncStream and the execution of DoSyncStreamTask
|
||||
// in worker_thread_;
|
||||
std::mutex task_mutex_;
|
||||
|
||||
// The thread to actually execute the synchronization stream task.
|
||||
std::thread worker_thread_;
|
||||
|
||||
std::condition_variable time_out_cv_;
|
||||
std::condition_variable do_sync_stream_cv_;
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_SYNC_STREAM_H_
|
|
@ -25,9 +25,11 @@
|
|||
#include "mindrt/include/async/async.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
namespace {
|
||||
void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
|
||||
const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
|
||||
|
@ -354,6 +356,10 @@ void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::ve
|
|||
}
|
||||
}
|
||||
}
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
|
||||
RecoveryContext::GetInstance()->set_need_sync_weight_to_device(false);
|
||||
}
|
||||
|
||||
PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
|
||||
}
|
||||
|
@ -722,6 +728,11 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
|
|||
host_tensor_address->SetNodeIndex(backend_node, 0);
|
||||
DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
|
||||
is_need_sync = true;
|
||||
}
|
||||
|
||||
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
|
||||
MS_EXCEPTION_IF_NULL(host_tensor_address);
|
||||
if (is_need_sync || (host_tensor_address->GetPtr() == nullptr)) {
|
||||
|
|
|
@ -21,9 +21,12 @@
|
|||
#include "runtime/graph_scheduler/actor/debug_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
|
||||
void KernelActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
|
||||
|
@ -240,11 +243,18 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
|||
PreLaunchKernel(context);
|
||||
|
||||
try {
|
||||
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_, is_dynamic_shape_);
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
|
||||
// especially the collective communication operators.
|
||||
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
|
||||
<< kernel_->fullname_with_scope();
|
||||
} else {
|
||||
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_, is_dynamic_shape_);
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/actor/loop_count_actor.h"
|
||||
#include <set>
|
||||
#include "runtime/graph_scheduler/actor/data_prepare_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/output_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/memory_manager_actor.h"
|
||||
|
@ -23,9 +24,13 @@
|
|||
#include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "runtime/device/stream_synchronizer.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
|
||||
void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
|
||||
|
@ -52,6 +57,26 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
|
|||
return;
|
||||
}
|
||||
|
||||
// Sync device stream.
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
std::set<const DeviceContext *> sync_stream_device_contexts;
|
||||
for (auto &device_context : device_contexts_) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if ((sync_stream_device_contexts.count(device_context) == 0) &&
|
||||
(!device::StreamSynchronizer::GetInstance()->SyncStream(device_context->device_context_key().device_name_))) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
|
||||
("Sync stream failed:" + device_context->device_context_key().ToString()));
|
||||
}
|
||||
(void)sync_stream_device_contexts.insert(device_context);
|
||||
|
||||
// Trigger disaster recovery and exit loop early.
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
current_count_ = loop_count_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PostRun(context);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -35,11 +36,17 @@ namespace runtime {
|
|||
class LoopCountActor : public DebugAwareActor {
|
||||
public:
|
||||
LoopCountActor(const std::string &name, size_t loop_count, const AID &memory_manager_aid, const AID *debug_aid,
|
||||
const AID *recorder_aid)
|
||||
const AID *recorder_aid, GraphExecutionStrategy strategy,
|
||||
const std::vector<DeviceContext *> &device_contexts)
|
||||
: DebugAwareActor(name, KernelTransformType::kLoopCountActor, recorder_aid, memory_manager_aid, debug_aid),
|
||||
loop_count_(loop_count),
|
||||
current_count_(0),
|
||||
total_running_count_(0) {}
|
||||
total_running_count_(0),
|
||||
strategy_(strategy) {
|
||||
(void)std::transform(
|
||||
device_contexts.begin(), device_contexts.end(), std::back_inserter(device_contexts_),
|
||||
[](DeviceContext *device_context) { return static_cast<const DeviceContext *>(device_context); });
|
||||
}
|
||||
|
||||
~LoopCountActor() override = default;
|
||||
|
||||
|
@ -73,6 +80,10 @@ class LoopCountActor : public DebugAwareActor {
|
|||
// The actors which need be handled separately by loop count actor.
|
||||
AID data_prepare_aid_;
|
||||
std::vector<AID> entrance_aids_;
|
||||
|
||||
// The execution strategy for executing actor.
|
||||
// In pipeline mode, sync stream for every step.
|
||||
GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
|
||||
};
|
||||
|
||||
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
#include "runtime/graph_scheduler/actor/output_actor.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
|
||||
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
||||
|
@ -70,10 +73,44 @@ void OutputActor::Init() {
|
|||
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
|
||||
}
|
||||
|
||||
void OutputActor::FreeOutputNodeMem() {
|
||||
for (size_t i = 0; i < output_nodes_.size(); ++i) {
|
||||
auto &output_node = output_nodes_[i].first;
|
||||
auto &output_device_tensor = output_device_tensors_[i];
|
||||
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
|
||||
return;
|
||||
}
|
||||
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
|
||||
FreeMemory(output_device_tensor, device_contexts_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OutputActor::ClearOutputCache() {
|
||||
output_node_to_tensor_device_address_.clear();
|
||||
outputs_.clear();
|
||||
outputs_.resize(outputs_num_);
|
||||
output_nodes_.clear();
|
||||
output_nodes_.resize(outputs_num_);
|
||||
output_device_tensors_.clear();
|
||||
output_device_tensors_.resize(outputs_num_);
|
||||
|
||||
current_outputs_num_ = 0;
|
||||
current_count_ = 0;
|
||||
}
|
||||
|
||||
void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
++current_count_;
|
||||
|
||||
// Trigger disaster recovery and return empty output.
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
FreeOutputNodeMem();
|
||||
ClearOutputCache();
|
||||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||
}
|
||||
|
||||
// The last loop.
|
||||
if (loop_count_ == current_count_) {
|
||||
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
|
||||
|
@ -106,19 +143,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
|
|||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||
}
|
||||
|
||||
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
|
||||
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
|
||||
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
|
||||
for (size_t i = 0; i < output_nodes_.size(); ++i) {
|
||||
auto &output_node = output_nodes_[i].first;
|
||||
auto &output_device_tensor = output_device_tensors_[i];
|
||||
if ((output_node == nullptr) || (output_device_tensor == nullptr)) {
|
||||
return;
|
||||
}
|
||||
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
|
||||
FreeMemory(output_device_tensor, device_contexts_[i]);
|
||||
}
|
||||
}
|
||||
FreeOutputNodeMem();
|
||||
|
||||
// Send control arrow to trigger next step running.
|
||||
auto from_aid = const_cast<AID *>(&GetAID());
|
||||
|
|
|
@ -79,6 +79,14 @@ class OutputActor : public AbstractActor {
|
|||
|
||||
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position);
|
||||
|
||||
// The output device memory will be taken over by tensor in the last loop, otherwise needs to free the memory.
|
||||
// 1.Avoid the memory leak when memory used by dynamic ref count in the control flow scene.
|
||||
// 2.Alloc the new memory in the next step using the new shape size in the dynamic shape scene.
|
||||
void FreeOutputNodeMem();
|
||||
|
||||
// Clear output nodes and tensors in cache.
|
||||
void ClearOutputCache();
|
||||
|
||||
// The loop count is constant, the current count is increased after each step running finished.
|
||||
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
|
||||
size_t loop_count_;
|
||||
|
|
|
@ -46,9 +46,11 @@
|
|||
#endif
|
||||
#include "profiler/device/profiling.h"
|
||||
#include "debug/common.h"
|
||||
#include "runtime/recovery/recovery_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using recovery::RecoveryContext;
|
||||
namespace {
|
||||
bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
|
||||
MS_EXCEPTION_IF_NULL(from_device_context);
|
||||
|
@ -464,21 +466,19 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
|
|||
MS_LOG(EXCEPTION) << op_context.error_info_;
|
||||
}
|
||||
|
||||
// Sync device stream.
|
||||
if (strategy == GraphExecutionStrategy::kPipeline) {
|
||||
std::set<DeviceContext *> sync_stream_device_contexts;
|
||||
for (auto &device_context : device_contexts) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if ((sync_stream_device_contexts.count(device_context) == 0) && (!device_context->SyncStream())) {
|
||||
MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
|
||||
}
|
||||
(void)sync_stream_device_contexts.insert(device_context);
|
||||
}
|
||||
}
|
||||
|
||||
double end_time = GetTime();
|
||||
const size_t kSecondsToMilliseconds = 1000;
|
||||
SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
|
||||
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
|
||||
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
|
||||
if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
|
||||
MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
|
||||
|
||||
RecoveryContext::GetInstance()->set_need_reinit_collective(false);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
|
||||
|
@ -843,7 +843,8 @@ LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &g
|
|||
|
||||
auto actor_name = graph_compiler_info.name_ + kLoopCountActorNameSuffix;
|
||||
auto loop_count_actor =
|
||||
std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
|
||||
std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_,
|
||||
graph_compiler_info.strategy_, graph_compiler_info.device_contexts_);
|
||||
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "distributed/init.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -82,6 +83,10 @@ void RecoveryContext::Initialize() {
|
|||
return;
|
||||
}
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_RECOVERY, true);
|
||||
|
||||
recovery_path_ = common::GetEnv(kEnvRecoveryPath);
|
||||
if (recovery_path_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly.";
|
||||
|
@ -140,20 +145,20 @@ void RecoveryContext::Initialize() {
|
|||
bool RecoveryContext::ReInitializeCollective() {
|
||||
auto ret = distributed::Initialize();
|
||||
if (ret) {
|
||||
recovery_status_ = RecoveryStatus::kUnKnownError;
|
||||
recovery_status_ = RecoveryErrCode::kUnKnownError;
|
||||
set_need_reset(true);
|
||||
set_need_sync_weight_to_device(true);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (recovery_status_ == RecoveryStatus::kBroadcastUniqueIDFailed ||
|
||||
recovery_status_ == RecoveryStatus::kAllGatherHostNameFailed) {
|
||||
if (recovery_status_ == RecoveryErrCode::kBroadcastUniqueIDFailed ||
|
||||
recovery_status_ == RecoveryErrCode::kAllGatherHostNameFailed) {
|
||||
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
|
||||
// Retry duration: 30s.
|
||||
const int kRetryDuration = 30;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
|
||||
return ReInitializeCollective();
|
||||
} else if (recovery_status_ == RecoveryStatus::kInitNcclFailed) {
|
||||
} else if (recovery_status_ == RecoveryErrCode::kInitNcclFailed) {
|
||||
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace recovery {
|
|||
using distributed::storage::FileIOUtils;
|
||||
using distributed::storage::JsonUtils;
|
||||
|
||||
enum class RecoveryStatus { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
|
||||
enum class RecoveryErrCode { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };
|
||||
|
||||
// Used to save disaster recovery-related state quantities and provide disaster recovery-related
|
||||
// functions, such as reinitializing collective communication, etc.
|
||||
|
@ -65,9 +65,9 @@ class RecoveryContext {
|
|||
int recovery_interval() const { return recovery_interval_; }
|
||||
|
||||
// Get the error status of recovery.
|
||||
RecoveryStatus recovery_status() const { return recovery_status_; }
|
||||
RecoveryErrCode recovery_status() const { return recovery_status_; }
|
||||
// Set the error status of recovery.
|
||||
void set_recovery_status(RecoveryStatus recovery_status) { recovery_status_ = recovery_status; }
|
||||
void set_recovery_status(RecoveryErrCode recovery_status) { recovery_status_ = recovery_status; }
|
||||
|
||||
// Set the path used to save checkpoint.
|
||||
void SetCkptPath(const std::string &path);
|
||||
|
@ -161,7 +161,7 @@ class RecoveryContext {
|
|||
bool initialized_{false};
|
||||
|
||||
// The error status of recovery.
|
||||
RecoveryStatus recovery_status_{RecoveryStatus::kUnKnownError};
|
||||
RecoveryErrCode recovery_status_{RecoveryErrCode::kUnKnownError};
|
||||
|
||||
// The persitent json file util, used to persist recovery config.
|
||||
std::unique_ptr<JsonUtils> persistent_json_;
|
||||
|
|
|
@ -100,6 +100,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
|
||||
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
|
||||
|
||||
uint32_t kDefaultRuntimeNumThreads = 30;
|
||||
set_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS, kDefaultRuntimeNumThreads);
|
||||
|
|
|
@ -92,6 +92,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
|
||||
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
|
||||
MS_CTX_ENABLE_MEM_SCHEDULER,
|
||||
MS_CTX_ENABLE_RECOVERY,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// parameter of type int
|
||||
|
|
|
@ -181,7 +181,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/fl/*.cc"
|
||||
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_service.cc"
|
||||
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_proxy.cc"
|
||||
"../../../mindspore/ccsrc/distributed/cluster/cluster_context.cc"
|
||||
"../../../mindspore/ccsrc/distributed/persistent/*.cc"
|
||||
"../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc"
|
||||
"../../../mindspore/ccsrc/distributed/cluster/topology/*.cc"
|
||||
|
|
Loading…
Reference in New Issue