forked from mindspore-Ecosystem/mindspore
!24029 [MSLITE] runtime pass opt
Merge pull request !24029 from ling/bug
This commit is contained in:
commit
8dbdcb8cdf
|
@ -135,12 +135,14 @@ bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels) {
|
||||
if (context->IsGpuEnabled() || context->IsNpuEnabled()) {
|
||||
bool RuntimePassValid(kernel::SubGraphKernel *subgraph) {
|
||||
if (subgraph->desc().arch != kernel::KERNEL_ARCH::kCPU) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto kernel : *kernels) {
|
||||
auto kernels = subgraph->nodes();
|
||||
|
||||
for (auto kernel : kernels) {
|
||||
if (kernel->op_parameter() != nullptr) {
|
||||
if (kernel->op_parameter()->quant_type_ == schema::QuantType_AwareTraining ||
|
||||
kernel->op_parameter()->quant_type_ == schema::QuantType_PostTraining) {
|
||||
|
@ -148,7 +150,6 @@ bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKerne
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -226,15 +227,17 @@ void ConvNormC4PassAct(std::vector<kernel::LiteKernel *> *kernels) {
|
|||
return;
|
||||
}
|
||||
|
||||
void RuntimePass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
|
||||
std::vector<Tensor *> *tensors) {
|
||||
if (!RuntimePassValid(context, kernels)) {
|
||||
return;
|
||||
void RuntimePass(std::vector<kernel::LiteKernel *> *subgraphs, std::vector<Tensor *> *tensors) {
|
||||
for (auto subgraph : *subgraphs) {
|
||||
auto sub = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
|
||||
if (RuntimePassValid(sub) == false) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
auto &kernels = sub->nodes();
|
||||
Nc4hw4PassAct(&kernels, tensors, i);
|
||||
ConvNormC4PassAct(&kernels);
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
Nc4hw4PassAct(kernels, tensors, i);
|
||||
|
||||
ConvNormC4PassAct(kernels);
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -26,8 +26,7 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
|
||||
void RuntimePass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
|
||||
std::vector<Tensor *> *tensors);
|
||||
void RuntimePass(std::vector<kernel::LiteKernel *> *subgraphs, std::vector<Tensor *> *tensors);
|
||||
|
||||
/* Nc4hw4 PASS
|
||||
* before : --(nhwc)-- CONV --(nhwc)-- TRANSPOSE --(nchw)-- IN --(nchw)-- TRANSPOSE --(nhwc)--
|
||||
|
|
|
@ -332,16 +332,16 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
|||
|
||||
FindAllInoutKernels(*dst_kernels);
|
||||
|
||||
#ifndef RUNTIME_PASS_CLIP
|
||||
RuntimePass(context_, dst_kernels, src_tensors_);
|
||||
#endif
|
||||
|
||||
ret = ConstructSubGraphs(dst_kernels);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConstructSubGraphs failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifndef RUNTIME_PASS_CLIP
|
||||
RuntimePass(dst_kernels, src_tensors_);
|
||||
#endif
|
||||
|
||||
ret = InitKernels(*dst_kernels);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitKernels failed.";
|
||||
|
|
Loading…
Reference in New Issue