!49050 [MS][LITE]Fix control model parallel run

Merge pull request !49050 from gongdaguo1/fix_control
This commit is contained in:
i-robot 2023-02-20 12:45:56 +00:00 committed by Gitee
commit f984afe7e1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 9 additions and 4 deletions

View File

@ -31,9 +31,14 @@ std::shared_ptr<LiteOpActor> CreateActor(kernel::KernelExec *kernel, lite::Inner
actor = std::make_shared<LiteEntranceOpActor>(kernel, ctx);
} else if (kernel->subgraph_type() == kernel::kExitSubGraph) {
actor = std::make_shared<LiteExitOpActor>(kernel, ctx);
} else if (ctx->inter_op_parallel_num_ > 1 && (kernel->subgraph_type() == kernel::kCpuFP32SubGraph ||
kernel->subgraph_type() == kernel::kCpuFP16SubGraph)) {
actor = std::make_shared<ParallelLiteActor>(kernel, ctx);
} else if (kernel->subgraph_type() != kernel::kNotSubGraph) {
auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
if (subgraph_kernel->nodes().size() > 1 && ctx->inter_op_parallel_num_ > 1 &&
(kernel->subgraph_type() == kernel::kCpuFP32SubGraph || kernel->subgraph_type() == kernel::kCpuFP16SubGraph)) {
actor = std::make_shared<ParallelLiteActor>(kernel, ctx);
} else {
actor = std::make_shared<LiteOpActor>(kernel, ctx);
}
} else {
actor = std::make_shared<LiteOpActor>(kernel, ctx);
}

View File

@ -369,7 +369,7 @@ int SubGraphKernel::SubGraphSplitByOperator(KernelsArray *kernels_array) {
if (kernel == nullptr) {
continue;
}
MS_ASSERT(kernel->subgraph_type() != kernel::kNotSubGraph);
MS_CHECK_TRUE_MSG(kernel->subgraph_type() == kernel::kNotSubGraph, RET_ERROR, "node cannot be a subgraph.");
kernels_array->units.push_back({});
size_t now_index = kernels_array->units.size() - 1;
kernels_array->units.at(now_index).kernels.push_back(kernel);