!5999 [MSLITE] Fix bug of input kernel count.

Merge pull request !5999 from wangshaocong/lite_bugfix
This commit is contained in:
mindspore-ci-bot 2020-09-11 09:48:00 +08:00 committed by Gitee
commit 217628a9b9
3 changed files with 15 additions and 7 deletions

View File

@ -16,7 +16,6 @@
#include "src/lite_kernel.h"
#include <algorithm>
#include "src/common/utils.h"
namespace mindspore::kernel {
void LiteKernel::InitOutTensorRefCount() {

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_
#include <vector>
#include <string>
#include "src/common/utils.h"
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
@ -113,9 +114,17 @@ class LiteKernel {
std::vector<lite::Tensor *> &out_tensors() { return this->out_tensors_; }
void AddInKernel(LiteKernel *kernel) { this->in_kernels_.emplace_back(kernel); }
void AddInKernel(LiteKernel *kernel) {
if (!lite::IsContain(this->in_kernels_, kernel)) {
this->in_kernels_.emplace_back(kernel);
}
}
void AddOutKernel(LiteKernel *kernel) { this->out_kernels_.emplace_back(kernel); }
void AddOutKernel(LiteKernel *kernel) {
if (!lite::IsContain(this->out_kernels_, kernel)) {
this->out_kernels_.emplace_back(kernel);
}
}
void SetInKernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }

View File

@ -104,14 +104,13 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
return parameter;
}
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs, OpParameter *parameter,
mindspore::lite::PrimitiveC *primitive) {
lite::Context *context, mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(nullptr != lite_primitive);
auto data_type = inputs.front()->data_type();
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()};
lite::Context context;
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
if (creator != nullptr) {
auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive);
auto lite_kernel = creator(inputs, outputs, parameter, context, desc, primitive);
return lite_kernel;
}
return nullptr;
@ -235,7 +234,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
return nullptr;
}
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
lite::Context context;
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get());
if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
FreeTensors(&input_tensors, &output_tensors);