diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index efa70a61c7..7fa5420e96 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -16,7 +16,6 @@ #include "src/lite_kernel.h" #include -#include "src/common/utils.h" namespace mindspore::kernel { void LiteKernel::InitOutTensorRefCount() { diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 35bb4615bb..e1196b8261 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_SRC_LITE_KERNEL_H_ #include #include +#include "src/common/utils.h" #ifdef ENABLE_ARM #include #endif @@ -113,9 +114,17 @@ class LiteKernel { std::vector &out_tensors() { return this->out_tensors_; } - void AddInKernel(LiteKernel *kernel) { this->in_kernels_.emplace_back(kernel); } + void AddInKernel(LiteKernel *kernel) { + if (!lite::IsContain(this->in_kernels_, kernel)) { + this->in_kernels_.emplace_back(kernel); + } + } - void AddOutKernel(LiteKernel *kernel) { this->out_kernels_.emplace_back(kernel); } + void AddOutKernel(LiteKernel *kernel) { + if (!lite::IsContain(this->out_kernels_, kernel)) { + this->out_kernels_.emplace_back(kernel); + } + } void SetInKernel(const std::vector &kernel) { this->in_kernels_ = kernel; } diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 3f37eb7da3..adb2a7f2e7 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -104,14 +104,13 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { return parameter; } kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vector outputs, OpParameter *parameter, - mindspore::lite::PrimitiveC *primitive) { + lite::Context *context, mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(nullptr != lite_primitive); auto data_type = inputs.front()->data_type(); kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()}; - lite::Context context; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); if (creator != nullptr) { - auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive); + auto lite_kernel = creator(inputs, outputs, parameter, context, desc, primitive); return lite_kernel; } return nullptr; @@ -235,7 +234,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); return nullptr; } - auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get()); + lite::Context context; + auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get()); if (lite_kernel == nullptr) { MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; FreeTensors(&input_tensors, &output_tensors);