forked from OSSInnovation/mindspore
!5999 [MSLITE] Fix bug of input kernel count.
Merge pull request !5999 from wangshaocong/lite_bugfix
This commit is contained in:
commit
217628a9b9
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "src/lite_kernel.h"
|
||||
#include <algorithm>
|
||||
#include "src/common/utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void LiteKernel::InitOutTensorRefCount() {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "src/common/utils.h"
|
||||
#ifdef ENABLE_ARM
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
@ -113,9 +114,17 @@ class LiteKernel {
|
|||
|
||||
std::vector<lite::Tensor *> &out_tensors() { return this->out_tensors_; }
|
||||
|
||||
void AddInKernel(LiteKernel *kernel) { this->in_kernels_.emplace_back(kernel); }
|
||||
void AddInKernel(LiteKernel *kernel) {
|
||||
if (!lite::IsContain(this->in_kernels_, kernel)) {
|
||||
this->in_kernels_.emplace_back(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
void AddOutKernel(LiteKernel *kernel) { this->out_kernels_.emplace_back(kernel); }
|
||||
void AddOutKernel(LiteKernel *kernel) {
|
||||
if (!lite::IsContain(this->out_kernels_, kernel)) {
|
||||
this->out_kernels_.emplace_back(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
void SetInKernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }
|
||||
|
||||
|
|
|
@ -104,14 +104,13 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
|||
return parameter;
|
||||
}
|
||||
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs, OpParameter *parameter,
|
||||
mindspore::lite::PrimitiveC *primitive) {
|
||||
lite::Context *context, mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(nullptr != lite_primitive);
|
||||
auto data_type = inputs.front()->data_type();
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()};
|
||||
lite::Context context;
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
if (creator != nullptr) {
|
||||
auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive);
|
||||
auto lite_kernel = creator(inputs, outputs, parameter, context, desc, primitive);
|
||||
return lite_kernel;
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -235,7 +234,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
|
||||
return nullptr;
|
||||
}
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
|
||||
lite::Context context;
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get());
|
||||
if (lite_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||
FreeTensors(&input_tensors, &output_tensors);
|
||||
|
|
Loading…
Reference in New Issue