Optimize launch kernel in kernel actor
This commit is contained in:
parent
3170c0a90b
commit
0fb999ae82
|
@ -230,7 +230,26 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
|||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
PreLaunchKernel(context);
|
||||
|
||||
LaunchKernel(context);
|
||||
try {
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
|
||||
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
|
||||
// especially the collective communication operators.
|
||||
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
|
||||
<< kernel_->fullname_with_scope();
|
||||
} else {
|
||||
auto ret = LaunchKernel();
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
MsException::Instance().SetException();
|
||||
}
|
||||
std::string error_info = "Launch kernel exception: " + kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
}
|
||||
|
||||
// Debug actor is blocked, must wait debug actor callback message to process continue.
|
||||
if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
|
@ -430,28 +449,10 @@ void KernelActor::PreLaunchKernel(OpContext<DeviceTensor> *) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::LaunchKernel(OpContext<DeviceTensor> *const context) {
|
||||
try {
|
||||
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
|
||||
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
|
||||
// especially the collective communication operators.
|
||||
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "
|
||||
<< kernel_->fullname_with_scope();
|
||||
} else {
|
||||
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
MsException::Instance().SetException();
|
||||
}
|
||||
std::string error_info = "Launch kernel exception: " + kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
}
|
||||
bool KernelActor::LaunchKernel() {
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
return device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
}
|
||||
|
||||
void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
|
||||
|
|
|
@ -92,7 +92,7 @@ class KernelActor : public DebugAwareActor {
|
|||
void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override;
|
||||
|
||||
// Do kernel launching in this method after 'PreLaunchKernel' and 'PostLaunchKernel'.
|
||||
virtual void LaunchKernel(OpContext<DeviceTensor> *const context);
|
||||
virtual bool LaunchKernel();
|
||||
|
||||
// The info of kernel.
|
||||
CNodePtr kernel_;
|
||||
|
|
|
@ -54,23 +54,26 @@ bool SendActor::ConnectServer() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void SendActor::LaunchKernel(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
KernelActor::LaunchKernel(context);
|
||||
bool SendActor::LaunchKernel() {
|
||||
if (!KernelActor::LaunchKernel()) {
|
||||
MS_LOG(ERROR) << "Launching kernel for send actor failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Send input data(inter-process data is the input of the Send kernel) to peers.
|
||||
if (launch_info_.inputs_.empty()) {
|
||||
MS_LOG(ERROR) << "Send kernel has no output tensor.";
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
auto send_output = launch_info_.inputs_;
|
||||
for (const auto &peer : peer_actor_urls_) {
|
||||
std::string peer_server_url = peer.second;
|
||||
auto message = BuildRpcMessage(send_output, peer_server_url);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(message, false);
|
||||
MS_LOG(INFO) << "Rpc actor send message for inter-process edge: " << peer.first;
|
||||
client_->SendAsync(std::move(message));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SendActor::EraseInput(const OpContext<DeviceTensor> *context) {
|
||||
|
|
|
@ -47,7 +47,7 @@ class SendActor : public RpcActor {
|
|||
|
||||
protected:
|
||||
// Do real send operation in this method.
|
||||
void LaunchKernel(OpContext<DeviceTensor> *const context) override;
|
||||
bool LaunchKernel() override;
|
||||
|
||||
// Erase inter-process inputs for this sequential number.
|
||||
void EraseInput(const OpContext<DeviceTensor> *context) override;
|
||||
|
|
Loading…
Reference in New Issue