!18639 fix bug of actor runtime of tests/st

Merge pull request !18639 from limingqi107/actor_runtime
This commit is contained in:
i-robot 2021-06-23 01:01:57 +00:00 committed by Gitee
commit 7a9b4d49f1
5 changed files with 48 additions and 25 deletions

View File

@ -339,6 +339,15 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
continue; continue;
} }
// Skip the empty value node.
if (output_with_index.first->isa<ValueNode>()) {
auto value = output_with_index.first->cast<ValueNodePtr>()->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>() && (value->cast<ValueTuplePtr>()->size() == 0)) {
continue;
}
}
// Ignore the output of front call node. // Ignore the output of front call node.
if (output_with_index.first->isa<CNode>()) { if (output_with_index.first->isa<CNode>()) {
auto cnode = output_with_index.first->cast<CNodePtr>(); auto cnode = output_with_index.first->cast<CNodePtr>();

View File

@ -140,17 +140,18 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
} }
// Copy data from device queue by data kernel launching. // Copy data from device queue by data kernel launching.
bool ret = true;
try { try {
ret = device_context_->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_, auto ret = device_context_->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_); launch_info_.outputs_);
} catch (const std::exception &e) {
MsException::Instance().SetException();
}
if (!ret) { if (!ret) {
std::string error_info = "Launch kernel failed: " + data_kernel_->ToString(); std::string error_info = "Launch kernel failed: " + data_kernel_->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
} }
} catch (const std::exception &e) {
MsException::Instance().SetException();
std::string error_info = "Launch kernel exception: " + data_kernel_->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Debug actor is blocked, must wait debug actor callback message to process continue. // Debug actor is blocked, must wait debug actor callback message to process continue.
if (debug_aid_ != nullptr) { if (debug_aid_ != nullptr) {

View File

@ -147,17 +147,18 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
PreLaunchKernel(context); PreLaunchKernel(context);
bool ret = true;
try { try {
ret = device_context_->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_, launch_info_.outputs_, auto ret = device_context_->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
is_dynamic_shape_); launch_info_.outputs_, is_dynamic_shape_);
} catch (const std::exception &e) {
MsException::Instance().SetException();
}
if (!ret) { if (!ret) {
std::string error_info = "Launch kernel failed: " + kernel_->ToString(); std::string error_info = "Launch kernel failed: " + kernel_->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
} }
} catch (const std::exception &e) {
MsException::Instance().SetException();
std::string error_info = "Launch kernel exception: " + kernel_->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Debug actor is blocked, must wait debug actor callback message to process continue. // Debug actor is blocked, must wait debug actor callback message to process continue.
if (debug_aid_ != nullptr) { if (debug_aid_ != nullptr) {

View File

@ -582,6 +582,13 @@ bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strat
} }
} }
// Trigger output actor running when there are no data source actor and kernel actor.
if ((actor_set->data_source_actors_.size() == 0) && (actor_set->kernel_actors_.size() == 0)) {
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
Async(actor_set->output_actor_->GetAID(), &OutputActor::CollectLoopCount, actor_set->output_actor_->loop_count_,
&op_context);
}
// Get the run result. // Get the run result.
auto result_future = result[0].GetFuture(); auto result_future = result[0].GetFuture();
result_future.Wait(); result_future.Wait();
@ -2157,15 +2164,6 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt
} }
} }
// Check the loop count actor.
const auto &loop_count_actor = actor_set->loop_count_actor_;
if (loop_count_actor != nullptr) {
if (loop_count_actor->branch_id_to_input_controls_num_[kMainBranchID] == 0) {
MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source.";
return false;
}
}
return true; return true;
} }

View File

@ -629,9 +629,21 @@ void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
return; return;
} }
// Judge the output whether tuple or not by the outputs number. // The empty value node return the empty VectorRef.
if (output_node->isa<ValueNode>()) {
auto value = output_node->cast<ValueNodePtr>()->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>() && (value->cast<ValueTuplePtr>()->size() == 0)) {
outputs->emplace_back(VectorRef());
return;
}
}
auto outputs_num = AnfAlgo::GetOutputTensorNum(output_node); auto outputs_num = AnfAlgo::GetOutputTensorNum(output_node);
if (outputs_num > 1) { auto &output_abstract = output_node->abstract();
MS_EXCEPTION_IF_NULL(output_abstract);
// Wrap output to VectorRef if the output is tuple.
if (output_abstract->isa<abstract::AbstractTuple>()) {
VectorRef output_tuple; VectorRef output_tuple;
for (size_t i = 0; i < outputs_num; ++i) { for (size_t i = 0; i < outputs_num; ++i) {
output_tuple.emplace_back(std::move(output_tensors[*output_position])); output_tuple.emplace_back(std::move(output_tensors[*output_position]));
@ -639,9 +651,11 @@ void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
} }
outputs->emplace_back(std::move(output_tuple)); outputs->emplace_back(std::move(output_tuple));
} else { } else {
for (size_t i = 0; i < outputs_num; ++i) {
outputs->emplace_back(std::move(output_tensors[*output_position])); outputs->emplace_back(std::move(output_tensors[*output_position]));
++(*output_position); ++(*output_position);
} }
}
} }
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) { std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {