!9061 Fix bug caused by function null pointer

From: @yeyunpeng2020
Reviewed-by: @hangangqiang,@HilbertDavid
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2020-11-26 19:12:54 +08:00 committed by Gitee
commit d6024f8e96
6 changed files with 45 additions and 15 deletions

View File

@ -102,8 +102,13 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te
const InnerContext *ctx, const kernel::KernelKey &key) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != ctx);
auto parameter =
PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive);
auto func_pointer = PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type()));
if (func_pointer == nullptr) {
MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: "
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type());
return nullptr;
}
auto parameter = func_pointer(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type());

View File

@ -18,6 +18,7 @@
#include <algorithm>
#include <queue>
#include "src/tensor.h"
#include "src/common/utils.h"
namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
@ -196,9 +197,11 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
if (outer_in_kernels.empty()) {
for (auto &in_kernel_in_tensor : in_kernel_in_tensors) {
if (!in_kernel_in_tensor->IsConst()) {
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
input_tensors.push_back(in_kernel_in_tensor);
}
}
}
continue;
}
for (auto outer_in_kernel : outer_in_kernels) {
@ -211,11 +214,13 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
auto outer_in_kernel_out_tensors_iter =
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor);
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
input_tensors.emplace_back(in_kernel_in_tensor);
}
}
}
}
}
return input_tensors;
}
@ -226,7 +231,11 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
auto &outer_out_kernels = output_kernel->out_kernels();
auto &out_kernel_out_tensors = output_kernel->out_tensors();
if (outer_out_kernels.empty()) {
output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end());
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.push_back(out_kernel_out_tensor);
}
}
continue;
}
for (auto outer_out_kernel : outer_out_kernels) {
@ -239,11 +248,13 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
auto outer_out_kernel_in_tensors_iter =
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.emplace_back(out_kernel_out_tensor);
}
}
}
}
}
return output_tensors;
}

View File

@ -50,7 +50,14 @@ bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
#else
auto primitive = const_cast<schema::Primitive *>(src_prim);
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type());
if (func_pointer == nullptr) {
MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: "
<< schema::EnumNamePrimitiveType(primitive->value_type());
delete node;
return false;
}
node->primitive_ = func_pointer(primitive);
#endif
if (node->primitive_ == nullptr) {
MS_LOG(ERROR) << "unpack primitive == nullptr!";

View File

@ -28,10 +28,10 @@ class OpsRegistry {
return &registry;
}
void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) {
void InsertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) {
primitive_creators[type] = creator;
}
PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) {
PrimitiveCCreator GetPrimitiveCreator(schema::PrimitiveType type) {
if (primitive_creators.find(type) != primitive_creators.end()) {
return primitive_creators[type];
} else {
@ -47,7 +47,7 @@ class OpsRegistry {
class Registry {
public:
Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) {
OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator);
OpsRegistry::GetInstance()->InsertPrimitiveCMap(primitive_type, creator);
}
};

View File

@ -18,6 +18,7 @@
#define LITE_MINDSPORE_LITE_C_OPS_OP_POPULATE_REGISTER_H
#include <map>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
@ -29,9 +30,9 @@ class PopulateRegistry {
return &registry;
}
void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; }
void InsertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; }
ParameterCreator getParameterCreator(schema::PrimitiveType type) {
ParameterCreator GetParameterCreator(schema::PrimitiveType type) {
if (parameter_creators.find(type) != parameter_creators.end()) {
return parameter_creators[type];
} else {
@ -47,7 +48,7 @@ class PopulateRegistry {
class Registry {
public:
Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) {
PopulateRegistry::GetInstance()->insertParameterMap(primitive_type, creator);
PopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator);
}
~Registry() = default;
};

View File

@ -244,8 +244,14 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
auto primitive = lite_primitive.get();
MS_ASSERT(primitive != nullptr);
MS_ASSERT(primitive->Type() != nullptr);
auto parameter =
lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive);
auto func_pointer =
lite::PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type()));
if (func_pointer == nullptr) {
MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: "
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type());
return nullptr;
}
auto parameter = func_pointer(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "