From 3af80892beade720e95c4fe9bcb674bc37e79935 Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Thu, 30 Dec 2021 17:36:22 +0800 Subject: [PATCH] fix register bug --- .../lite/include/registry/register_kernel.h | 4 ++- mindspore/lite/src/lite_session.cc | 28 +++++++++++++++++++ mindspore/lite/src/mindrt_executor.cc | 3 ++ .../lite/src/registry/register_kernel.cc | 4 ++- .../lite/src/registry/register_kernel_impl.h | 5 ++++ .../src/registry/registry_custom_op_test.cc | 9 ++++-- mindspore/lite/tools/converter/CMakeLists.txt | 1 + 7 files changed, 49 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/include/registry/register_kernel.h b/mindspore/lite/include/registry/register_kernel.h index 38084cbc61a..6c73317c80c 100644 --- a/mindspore/lite/include/registry/register_kernel.h +++ b/mindspore/lite/include/registry/register_kernel.h @@ -147,7 +147,9 @@ CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, Kern return nullptr; } KernelDescHelper kernel_desc = {desc->data_type, desc->type, StringToChar(desc->arch), StringToChar(desc->provider)}; - return GetCreator(primitive, &kernel_desc); + auto ret = GetCreator(primitive, &kernel_desc); + desc->arch = CharToString(kernel_desc.arch); + return ret; } /// \brief Defined registering macro to register ordinary op kernel, which called by user directly. diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index d3497d8d592..e28e9773d5f 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -37,6 +37,9 @@ #include "src/weight_decoder.h" #include "src/runtime/runtime_allocator.h" #include "src/lite_kernel_util.h" +#ifndef CUSTOM_KERNEL_REGISTRY_CLIP +#include "src/registry/register_kernel_impl.h" +#endif #ifdef ENABLE_MINDRT #include "src/mindrt_executor.h" #endif @@ -61,6 +64,9 @@ extern void common_log_init(); #endif namespace lite { namespace { +#ifndef CUSTOM_KERNEL_REGISTRY_CLIP +const char *const kArchCPU = "CPU"; +#endif bool NeedBitUppackCheck(const SchemaTensorWrapper &src_tensor) { MS_ASSERT(src_tensor.handler() != nullptr); MS_ASSERT(src_tensor.data() != nullptr); @@ -106,6 +112,23 @@ int DecompressTensor(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor) #endif } } +#ifndef CUSTOM_KERNEL_REGISTRY_CLIP +bool ExistCustomCpuKernel() { + auto custom_kernel_creators = registry::RegistryKernelImpl::GetInstance()->GetCustomKernelCreators(); + for (const auto &custom_kernel_creator : custom_kernel_creators) { // >> + if (custom_kernel_creator.second.empty()) { + continue; + } + if (std::any_of(custom_kernel_creator.second.begin(), custom_kernel_creator.second.end(), + [](const std::pair> &pair) { + return pair.first == kArchCPU && !pair.second.empty(); + })) { + return true; + } + } + return false; +} +#endif } // namespace LiteSession::LiteSession() { @@ -1351,6 +1374,11 @@ int LiteSession::RuntimeAllocatorInit() { if (RuntimeAllocatorValid() != RET_OK) { return RET_OK; } +#ifndef CUSTOM_KERNEL_REGISTRY_CLIP + if (ExistCustomCpuKernel()) { + return RET_OK; + } +#endif if (runtime_allocator_ == nullptr) { runtime_allocator_ = std::shared_ptr(new (std::nothrow) RuntimeAllocator()); } else { diff --git a/mindspore/lite/src/mindrt_executor.cc b/mindspore/lite/src/mindrt_executor.cc index d7df288f89f..7c463386057 100644 --- a/mindspore/lite/src/mindrt_executor.cc +++ b/mindspore/lite/src/mindrt_executor.cc @@ -208,6 +208,9 @@ void MindrtExecutor::TransferGraphOutput() { reinterpret_cast(dst_tensor->data()), dst_tensor->ElementsNum()); } else { #endif + if (dst_tensor->allocator() != src_tensor->allocator()) { + dst_tensor->set_allocator(src_tensor->allocator()); + } dst_tensor->set_data(src_tensor->data()); if (IS_RUNTIME_ALLOCATOR(src_tensor->allocator()) == false) { src_tensor->set_data(nullptr); diff --git a/mindspore/lite/src/registry/register_kernel.cc b/mindspore/lite/src/registry/register_kernel.cc index bc5f9d883b7..8d018f0bc4f 100644 --- a/mindspore/lite/src/registry/register_kernel.cc +++ b/mindspore/lite/src/registry/register_kernel.cc @@ -50,7 +50,9 @@ CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, Kern return nullptr; } KernelDesc kernel_desc = {desc->data_type, desc->type, CharToString(desc->arch), CharToString(desc->provider)}; - return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, &kernel_desc); + auto ret = RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, &kernel_desc); + desc->arch = StringToChar(kernel_desc.arch); + return ret; #else MS_LOG(ERROR) << unsupport_custom_kernel_register_log; return nullptr; diff --git a/mindspore/lite/src/registry/register_kernel_impl.h b/mindspore/lite/src/registry/register_kernel_impl.h index 666df603c1d..4041b85ac50 100644 --- a/mindspore/lite/src/registry/register_kernel_impl.h +++ b/mindspore/lite/src/registry/register_kernel_impl.h @@ -48,6 +48,11 @@ class RegistryKernelImpl { return kernel_creators_; } + const std::map>> + &GetCustomKernelCreators() const { + return custom_kernel_creators_; + } + protected: // keys:provider, arch std::map> kernel_creators_; diff --git a/mindspore/lite/test/ut/src/registry/registry_custom_op_test.cc b/mindspore/lite/test/ut/src/registry/registry_custom_op_test.cc index fe6a45f3610..c207dc90ecd 100644 --- a/mindspore/lite/test/ut/src/registry/registry_custom_op_test.cc +++ b/mindspore/lite/test/ut/src/registry/registry_custom_op_test.cc @@ -129,12 +129,15 @@ std::shared_ptr TestCustomAddCreator(const std::vector &inputs std::shared_ptr CustomAddInferCreator() { return std::make_shared(); } } // namespace -REGISTER_CUSTOM_KERNEL(CPU, BuiltInTest, kFloat32, Add, TestCustomAddCreator) -REGISTER_CUSTOM_KERNEL_INTERFACE(BuiltInTest, Add, CustomAddInferCreator) - class TestRegistryCustomOp : public mindspore::CommonTest { public: TestRegistryCustomOp() = default; + void SetUp() override { + static mindspore::registry::KernelReg g_CPUBuiltInTestkFloat32AddkernelReg("CPU", "BuiltInTest", kFloat32, "Add", + TestCustomAddCreator); + static mindspore::registry::KernelInterfaceReg g_BuiltInTestAdd_custom_inter_reg("BuiltInTest", "Add", + CustomAddInferCreator); + } }; TEST_F(TestRegistryCustomOp, TestCustomAdd) { diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 434c9a04a77..0505ed94d56 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -103,6 +103,7 @@ set(LITE_SRC ${API_SRC} ${SRC_DIR}/ms_tensor.cc ${SRC_DIR}/tensorlist.cc ${SRC_DIR}/kernel_registry.cc + ${SRC_DIR}/registry/register_kernel_impl.cc ${SRC_DIR}/inner_kernel.cc ${SRC_DIR}/lite_kernel.cc ${SRC_DIR}/lite_kernel_util.cc