From 0fb999ae82f0ca28137c1931359e410be1e5d151 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Mon, 13 Jun 2022 17:14:14 +0800 Subject: [PATCH] Optimize launch kernel in kernel actor --- .../graph_scheduler/actor/kernel_actor.cc | 47 ++++++++++--------- .../graph_scheduler/actor/kernel_actor.h | 2 +- .../graph_scheduler/actor/rpc/send_actor.cc | 13 +++-- .../graph_scheduler/actor/rpc/send_actor.h | 2 +- 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc index 6b92701ce21..85abfc44f3d 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc @@ -230,7 +230,26 @@ void KernelActor::OnMemoryAllocFinish(OpContext *const context) { MS_EXCEPTION_IF_NULL(device_contexts_[0]); PreLaunchKernel(context); - LaunchKernel(context); + try { + if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) { + // In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch, + // especially the collective communication operators. + MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: " + << kernel_->fullname_with_scope(); + } else { + auto ret = LaunchKernel(); + if (!ret) { + std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope(); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info); + } + } + } catch (const std::exception &e) { + if (strategy_ == GraphExecutionStrategy::kPipeline) { + MsException::Instance().SetException(); + } + std::string error_info = "Launch kernel exception: " + kernel_->fullname_with_scope(); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info); + } // Debug actor is blocked, must wait debug actor callback message to process continue. if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) { @@ -430,28 +449,10 @@ void KernelActor::PreLaunchKernel(OpContext *) { } } -void KernelActor::LaunchKernel(OpContext *const context) { - try { - if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) { - // In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch, - // especially the collective communication operators. - MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: " - << kernel_->fullname_with_scope(); - } else { - auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_, - launch_info_.outputs_); - if (!ret) { - std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope(); - SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info); - } - } - } catch (const std::exception &e) { - if (strategy_ == GraphExecutionStrategy::kPipeline) { - MsException::Instance().SetException(); - } - std::string error_info = "Launch kernel exception: " + kernel_->fullname_with_scope(); - SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info); - } +bool KernelActor::LaunchKernel() { + MS_EXCEPTION_IF_NULL(device_contexts_[0]); + return device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_, + launch_info_.outputs_); } void KernelActor::PostLaunchKernel(OpContext *const context) { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h index 14b459e4958..831acf372c7 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h @@ -92,7 +92,7 @@ class KernelActor : public DebugAwareActor { void SendRecorderInfo(OpContext *const context) const override; // Do kernel launching in this method after 'PreLaunchKernel' and 'PostLaunchKernel'. - virtual void LaunchKernel(OpContext *const context); + virtual bool LaunchKernel(); // The info of kernel. CNodePtr kernel_; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc index 35798d8edbb..fa7ea8be1b9 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc @@ -54,23 +54,26 @@ bool SendActor::ConnectServer() { return true; } -void SendActor::LaunchKernel(OpContext *const context) { - MS_EXCEPTION_IF_NULL(context); - KernelActor::LaunchKernel(context); +bool SendActor::LaunchKernel() { + if (!KernelActor::LaunchKernel()) { + MS_LOG(ERROR) << "Launching kernel for send actor failed."; + return false; + } // Send input data(inter-process data is the input of the Send kernel) to peers. if (launch_info_.inputs_.empty()) { MS_LOG(ERROR) << "Send kernel has no output tensor."; - return; + return false; } auto send_output = launch_info_.inputs_; for (const auto &peer : peer_actor_urls_) { std::string peer_server_url = peer.second; auto message = BuildRpcMessage(send_output, peer_server_url); - MS_ERROR_IF_NULL_WO_RET_VAL(message); + MS_ERROR_IF_NULL_W_RET_VAL(message, false); MS_LOG(INFO) << "Rpc actor send message for inter-process edge: " << peer.first; client_->SendAsync(std::move(message)); } + return true; } void SendActor::EraseInput(const OpContext *context) { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h index bf13a103904..e8a91435ca8 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h @@ -47,7 +47,7 @@ class SendActor : public RpcActor { protected: // Do real send operation in this method. - void LaunchKernel(OpContext *const context) override; + bool LaunchKernel() override; // Erase inter-process inputs for this sequential number. void EraseInput(const OpContext *context) override;