forked from OSSInnovation/mindspore
!5999 [MSLITE] Fix bug of input kernel count.
Merge pull request !5999 from wangshaocong/lite_bugfix
This commit is contained in:
commit
217628a9b9
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "src/common/utils.h"
|
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
void LiteKernel::InitOutTensorRefCount() {
|
void LiteKernel::InitOutTensorRefCount() {
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_
|
#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "src/common/utils.h"
|
||||||
#ifdef ENABLE_ARM
|
#ifdef ENABLE_ARM
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
#endif
|
#endif
|
||||||
|
@ -113,9 +114,17 @@ class LiteKernel {
|
||||||
|
|
||||||
std::vector<lite::Tensor *> &out_tensors() { return this->out_tensors_; }
|
std::vector<lite::Tensor *> &out_tensors() { return this->out_tensors_; }
|
||||||
|
|
||||||
void AddInKernel(LiteKernel *kernel) { this->in_kernels_.emplace_back(kernel); }
|
void AddInKernel(LiteKernel *kernel) {
|
||||||
|
if (!lite::IsContain(this->in_kernels_, kernel)) {
|
||||||
|
this->in_kernels_.emplace_back(kernel);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void AddOutKernel(LiteKernel *kernel) { this->out_kernels_.emplace_back(kernel); }
|
void AddOutKernel(LiteKernel *kernel) {
|
||||||
|
if (!lite::IsContain(this->out_kernels_, kernel)) {
|
||||||
|
this->out_kernels_.emplace_back(kernel);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SetInKernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }
|
void SetInKernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }
|
||||||
|
|
||||||
|
|
|
@ -104,14 +104,13 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
||||||
return parameter;
|
return parameter;
|
||||||
}
|
}
|
||||||
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs, OpParameter *parameter,
|
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs, OpParameter *parameter,
|
||||||
mindspore::lite::PrimitiveC *primitive) {
|
lite::Context *context, mindspore::lite::PrimitiveC *primitive) {
|
||||||
MS_ASSERT(nullptr != lite_primitive);
|
MS_ASSERT(nullptr != lite_primitive);
|
||||||
auto data_type = inputs.front()->data_type();
|
auto data_type = inputs.front()->data_type();
|
||||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()};
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()};
|
||||||
lite::Context context;
|
|
||||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||||
if (creator != nullptr) {
|
if (creator != nullptr) {
|
||||||
auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive);
|
auto lite_kernel = creator(inputs, outputs, parameter, context, desc, primitive);
|
||||||
return lite_kernel;
|
return lite_kernel;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -235,7 +234,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
||||||
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
|
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
|
lite::Context context;
|
||||||
|
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get());
|
||||||
if (lite_kernel == nullptr) {
|
if (lite_kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||||
FreeTensors(&input_tensors, &output_tensors);
|
FreeTensors(&input_tensors, &output_tensors);
|
||||||
|
|
Loading…
Reference in New Issue