Optimize launch kernel in kernel actor

This commit is contained in:
ZPaC 2022-06-13 17:14:14 +08:00
parent 3170c0a90b
commit 0fb999ae82
4 changed files with 34 additions and 30 deletions

View File

@ -230,7 +230,26 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
PreLaunchKernel(context);
LaunchKernel(context);
try {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
// especially the collective communication operators.
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
<< kernel_->fullname_with_scope();
} else {
auto ret = LaunchKernel();
if (!ret) {
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
}
} catch (const std::exception &e) {
if (strategy_ == GraphExecutionStrategy::kPipeline) {
MsException::Instance().SetException();
}
std::string error_info = "Launch kernel exception: " + kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
// Debug actor is blocked, must wait debug actor callback message to process continue.
if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
@ -430,28 +449,10 @@ void KernelActor::PreLaunchKernel(OpContext<DeviceTensor> *) {
}
}
void KernelActor::LaunchKernel(OpContext<DeviceTensor> *const context) {
try {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
// especially the collective communication operators.
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
<< kernel_->fullname_with_scope();
} else {
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_);
if (!ret) {
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
}
} catch (const std::exception &e) {
if (strategy_ == GraphExecutionStrategy::kPipeline) {
MsException::Instance().SetException();
}
std::string error_info = "Launch kernel exception: " + kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
bool KernelActor::LaunchKernel() {
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
return device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_);
}
void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {

View File

@ -92,7 +92,7 @@ class KernelActor : public DebugAwareActor {
void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override;
// Do kernel launching in this method after 'PreLaunchKernel' and 'PostLaunchKernel'.
virtual void LaunchKernel(OpContext<DeviceTensor> *const context);
virtual bool LaunchKernel();
// The info of kernel.
CNodePtr kernel_;

View File

@ -54,23 +54,26 @@ bool SendActor::ConnectServer() {
return true;
}
void SendActor::LaunchKernel(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
KernelActor::LaunchKernel(context);
bool SendActor::LaunchKernel() {
if (!KernelActor::LaunchKernel()) {
MS_LOG(ERROR) << "Launching kernel for send actor failed.";
return false;
}
// Send input data(inter-process data is the input of the Send kernel) to peers.
if (launch_info_.inputs_.empty()) {
MS_LOG(ERROR) << "Send kernel has no output tensor.";
return;
return false;
}
auto send_output = launch_info_.inputs_;
for (const auto &peer : peer_actor_urls_) {
std::string peer_server_url = peer.second;
auto message = BuildRpcMessage(send_output, peer_server_url);
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_W_RET_VAL(message, false);
MS_LOG(INFO) << "Rpc actor send message for inter-process edge: " << peer.first;
client_->SendAsync(std::move(message));
}
return true;
}
void SendActor::EraseInput(const OpContext<DeviceTensor> *context) {

View File

@ -47,7 +47,7 @@ class SendActor : public RpcActor {
protected:
// Do real send operation in this method.
void LaunchKernel(OpContext<DeviceTensor> *const context) override;
bool LaunchKernel() override;
// Erase inter-process inputs for this sequential number.
void EraseInput(const OpContext<DeviceTensor> *context) override;