diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index ebdc27a739c..5e536c9b00b 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -102,8 +102,13 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector &in_te const InnerContext *ctx, const kernel::KernelKey &key) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != ctx); - auto parameter = - PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive); + auto func_pointer = PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type())); + if (func_pointer == nullptr) { + MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: " + << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); + return nullptr; + } + auto parameter = func_pointer(primitive); if (parameter == nullptr) { MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index 26648adc0cc..f0ea27d8bf9 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -18,6 +18,7 @@ #include #include #include "src/tensor.h" +#include "src/common/utils.h" namespace mindspore::kernel { using mindspore::lite::RET_ERROR; @@ -196,7 +197,9 @@ std::vector LiteKernelUtil::SubgraphInputTensors(const std::vect if (outer_in_kernels.empty()) { for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { if (!in_kernel_in_tensor->IsConst()) { - input_tensors.push_back(in_kernel_in_tensor); + if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { + input_tensors.push_back(in_kernel_in_tensor); + } } } continue; @@ -211,7 +214,9 @@ std::vector LiteKernelUtil::SubgraphInputTensors(const std::vect auto outer_in_kernel_out_tensors_iter = std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { - input_tensors.emplace_back(in_kernel_in_tensor); + if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { + input_tensors.emplace_back(in_kernel_in_tensor); + } } } } @@ -226,7 +231,11 @@ std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vec auto &outer_out_kernels = output_kernel->out_kernels(); auto &out_kernel_out_tensors = output_kernel->out_tensors(); if (outer_out_kernels.empty()) { - output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end()); + for (auto out_kernel_out_tensor : out_kernel_out_tensors) { + if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { + output_tensors.push_back(out_kernel_out_tensor); + } + } continue; } for (auto outer_out_kernel : outer_out_kernels) { @@ -239,7 +248,9 @@ std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vec auto outer_out_kernel_in_tensors_iter = std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { - output_tensors.emplace_back(out_kernel_out_tensor); + if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { + output_tensors.emplace_back(out_kernel_out_tensor); + } } } } diff --git a/mindspore/lite/src/model_common.h b/mindspore/lite/src/model_common.h index 041ec0b7a6d..e2a68a49304 100644 --- a/mindspore/lite/src/model_common.h +++ b/mindspore/lite/src/model_common.h @@ -50,7 +50,14 @@ bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA node->primitive_ = PrimitiveC::Create(const_cast(src_prim)); #else auto primitive = const_cast(src_prim); - node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive); + auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type()); + if (func_pointer == nullptr) { + MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: " + << schema::EnumNamePrimitiveType(primitive->value_type()); + delete node; + return false; + } + node->primitive_ = func_pointer(primitive); #endif if (node->primitive_ == nullptr) { MS_LOG(ERROR) << "unpack primitive == nullptr!"; diff --git a/mindspore/lite/src/ops/ops_register.h b/mindspore/lite/src/ops/ops_register.h index e86f00dde46..969f925f008 100644 --- a/mindspore/lite/src/ops/ops_register.h +++ b/mindspore/lite/src/ops/ops_register.h @@ -28,10 +28,10 @@ class OpsRegistry { return ®istry; } - void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { + void InsertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { primitive_creators[type] = creator; } - PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) { + PrimitiveCCreator GetPrimitiveCreator(schema::PrimitiveType type) { if (primitive_creators.find(type) != primitive_creators.end()) { return primitive_creators[type]; } else { @@ -47,7 +47,7 @@ class OpsRegistry { class Registry { public: Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) { - OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator); + OpsRegistry::GetInstance()->InsertPrimitiveCMap(primitive_type, creator); } }; diff --git a/mindspore/lite/src/ops/populate/populate_register.h b/mindspore/lite/src/ops/populate/populate_register.h index 704be5c4e3c..9e80d30c417 100644 --- a/mindspore/lite/src/ops/populate/populate_register.h +++ b/mindspore/lite/src/ops/populate/populate_register.h @@ -18,6 +18,7 @@ #define LITE_MINDSPORE_LITE_C_OPS_OP_POPULATE_REGISTER_H #include +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { @@ -29,9 +30,9 @@ class PopulateRegistry { return ®istry; } - void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } + void InsertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } - ParameterCreator getParameterCreator(schema::PrimitiveType type) { + ParameterCreator GetParameterCreator(schema::PrimitiveType type) { if (parameter_creators.find(type) != parameter_creators.end()) { return parameter_creators[type]; } else { @@ -47,7 +48,7 @@ class PopulateRegistry { class Registry { public: Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) { - PopulateRegistry::GetInstance()->insertParameterMap(primitive_type, creator); + PopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator); } ~Registry() = default; }; diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 566ddb0d936..1c6d1beada3 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -244,8 +244,14 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An auto primitive = lite_primitive.get(); MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive->Type() != nullptr); - auto parameter = - lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive); + auto func_pointer = + lite::PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type())); + if (func_pointer == nullptr) { + MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: " + << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); + return nullptr; + } + auto parameter = func_pointer(primitive); if (parameter == nullptr) { MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "