forked from mindspore-Ecosystem/mindspore
!9061 Fix bug caused by function null pointer
From: @yeyunpeng2020 Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiang
This commit is contained in:
commit
d6024f8e96
|
@ -102,8 +102,13 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te
|
|||
const InnerContext *ctx, const kernel::KernelKey &key) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != ctx);
|
||||
auto parameter =
|
||||
PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive);
|
||||
auto func_pointer = PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type()));
|
||||
if (func_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: "
|
||||
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type());
|
||||
return nullptr;
|
||||
}
|
||||
auto parameter = func_pointer(primitive);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
||||
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type());
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include "src/tensor.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
using mindspore::lite::RET_ERROR;
|
||||
|
@ -196,7 +197,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
|
|||
if (outer_in_kernels.empty()) {
|
||||
for (auto &in_kernel_in_tensor : in_kernel_in_tensors) {
|
||||
if (!in_kernel_in_tensor->IsConst()) {
|
||||
input_tensors.push_back(in_kernel_in_tensor);
|
||||
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
|
||||
input_tensors.push_back(in_kernel_in_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
|
@ -211,7 +214,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
|
|||
auto outer_in_kernel_out_tensors_iter =
|
||||
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor);
|
||||
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
|
||||
input_tensors.emplace_back(in_kernel_in_tensor);
|
||||
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
|
||||
input_tensors.emplace_back(in_kernel_in_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -226,7 +231,11 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
|
|||
auto &outer_out_kernels = output_kernel->out_kernels();
|
||||
auto &out_kernel_out_tensors = output_kernel->out_tensors();
|
||||
if (outer_out_kernels.empty()) {
|
||||
output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end());
|
||||
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
|
||||
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
|
||||
output_tensors.push_back(out_kernel_out_tensor);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
for (auto outer_out_kernel : outer_out_kernels) {
|
||||
|
@ -239,7 +248,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
|
|||
auto outer_out_kernel_in_tensors_iter =
|
||||
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
|
||||
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
|
||||
output_tensors.emplace_back(out_kernel_out_tensor);
|
||||
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
|
||||
output_tensors.emplace_back(out_kernel_out_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,7 +50,14 @@ bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA
|
|||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||
#else
|
||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
|
||||
auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type());
|
||||
if (func_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: "
|
||||
<< schema::EnumNamePrimitiveType(primitive->value_type());
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_ = func_pointer(primitive);
|
||||
#endif
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||
|
|
|
@ -28,10 +28,10 @@ class OpsRegistry {
|
|||
return ®istry;
|
||||
}
|
||||
|
||||
void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) {
|
||||
void InsertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) {
|
||||
primitive_creators[type] = creator;
|
||||
}
|
||||
PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) {
|
||||
PrimitiveCCreator GetPrimitiveCreator(schema::PrimitiveType type) {
|
||||
if (primitive_creators.find(type) != primitive_creators.end()) {
|
||||
return primitive_creators[type];
|
||||
} else {
|
||||
|
@ -47,7 +47,7 @@ class OpsRegistry {
|
|||
class Registry {
|
||||
public:
|
||||
Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) {
|
||||
OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator);
|
||||
OpsRegistry::GetInstance()->InsertPrimitiveCMap(primitive_type, creator);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define LITE_MINDSPORE_LITE_C_OPS_OP_POPULATE_REGISTER_H
|
||||
|
||||
#include <map>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -29,9 +30,9 @@ class PopulateRegistry {
|
|||
return ®istry;
|
||||
}
|
||||
|
||||
void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; }
|
||||
void InsertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; }
|
||||
|
||||
ParameterCreator getParameterCreator(schema::PrimitiveType type) {
|
||||
ParameterCreator GetParameterCreator(schema::PrimitiveType type) {
|
||||
if (parameter_creators.find(type) != parameter_creators.end()) {
|
||||
return parameter_creators[type];
|
||||
} else {
|
||||
|
@ -47,7 +48,7 @@ class PopulateRegistry {
|
|||
class Registry {
|
||||
public:
|
||||
Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) {
|
||||
PopulateRegistry::GetInstance()->insertParameterMap(primitive_type, creator);
|
||||
PopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator);
|
||||
}
|
||||
~Registry() = default;
|
||||
};
|
||||
|
|
|
@ -244,8 +244,14 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
auto primitive = lite_primitive.get();
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
MS_ASSERT(primitive->Type() != nullptr);
|
||||
auto parameter =
|
||||
lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive);
|
||||
auto func_pointer =
|
||||
lite::PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type()));
|
||||
if (func_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: "
|
||||
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type());
|
||||
return nullptr;
|
||||
}
|
||||
auto parameter = func_pointer(primitive);
|
||||
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
||||
|
|
Loading…
Reference in New Issue