!24029 [MSLITE] runtime pass opt

Merge pull request !24029 from ling/bug
This commit is contained in:
i-robot 2021-09-26 07:04:22 +00:00 committed by Gitee
commit 8dbdcb8cdf
3 changed files with 21 additions and 19 deletions

View File

@ -135,12 +135,14 @@ bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
return true;
}
bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels) {
if (context->IsGpuEnabled() || context->IsNpuEnabled()) {
bool RuntimePassValid(kernel::SubGraphKernel *subgraph) {
if (subgraph->desc().arch != kernel::KERNEL_ARCH::kCPU) {
return false;
}
for (auto kernel : *kernels) {
auto kernels = subgraph->nodes();
for (auto kernel : kernels) {
if (kernel->op_parameter() != nullptr) {
if (kernel->op_parameter()->quant_type_ == schema::QuantType_AwareTraining ||
kernel->op_parameter()->quant_type_ == schema::QuantType_PostTraining) {
@ -148,7 +150,6 @@ bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKerne
}
}
}
return true;
}
@ -226,15 +227,17 @@ void ConvNormC4PassAct(std::vector<kernel::LiteKernel *> *kernels) {
return;
}
void RuntimePass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
std::vector<Tensor *> *tensors) {
if (!RuntimePassValid(context, kernels)) {
return;
void RuntimePass(std::vector<kernel::LiteKernel *> *subgraphs, std::vector<Tensor *> *tensors) {
for (auto subgraph : *subgraphs) {
auto sub = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
if (RuntimePassValid(sub) == false) {
continue;
}
int i = 0;
auto &kernels = sub->nodes();
Nc4hw4PassAct(&kernels, tensors, i);
ConvNormC4PassAct(&kernels);
}
int i = 0;
Nc4hw4PassAct(kernels, tensors, i);
ConvNormC4PassAct(kernels);
}
} // namespace mindspore::lite

View File

@ -26,8 +26,7 @@
namespace mindspore::lite {
void RuntimePass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
std::vector<Tensor *> *tensors);
void RuntimePass(std::vector<kernel::LiteKernel *> *subgraphs, std::vector<Tensor *> *tensors);
/* Nc4hw4 PASS
* before : --(nhwc)-- CONV --(nhwc)-- TRANSPOSE --(nchw)-- IN --(nchw)-- TRANSPOSE --(nhwc)--

View File

@ -332,16 +332,16 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
FindAllInoutKernels(*dst_kernels);
#ifndef RUNTIME_PASS_CLIP
RuntimePass(context_, dst_kernels, src_tensors_);
#endif
ret = ConstructSubGraphs(dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConstructSubGraphs failed.";
return ret;
}
#ifndef RUNTIME_PASS_CLIP
RuntimePass(dst_kernels, src_tensors_);
#endif
ret = InitKernels(*dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitKernels failed.";