diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index 7d626beef65..0ecdc94f7d5 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_INCLUDE_CONTEXT_H_ #define MINDSPORE_LITE_INCLUDE_CONTEXT_H_ - +#include #include "include/ms_tensor.h" #include "include/lite_utils.h" #include "include/lite_types.h" @@ -57,6 +57,8 @@ union DeviceInfo { struct DeviceContext { DeviceType device_type_ = DT_CPU; DeviceInfo device_info_; + std::string provider_{}; + std::string provider_device_{}; }; /// \brief Context defined for holding environment variables during runtime. diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index b0a3bccd4cc..9472c83ded6 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -134,7 +134,7 @@ set(LITE_SRC ${LITE_DIR}/src/common/prim_util.cc ${LITE_DIR}/src/common/tensor_util.cc ${LITE_DIR}/src/runtime/infer_manager.cc - ${LITE_DIR}/src/kernel_interface_registry.cc + ${LITE_DIR}/src/registry/kernel_interface_registry.cc ${LITE_DIR}/src/lite_model.cc ${LITE_DIR}/src/tensorlist.cc ${LITE_DIR}/src/tensor.cc diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 58c0afbf881..cc6f6706e54 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -64,9 +64,9 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc - ${CMAKE_CURRENT_SOURCE_DIR}/register_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel_interface_registry.cc + ${CMAKE_CURRENT_SOURCE_DIR}/registry/register_kernel.cc + ${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface.cc + ${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface_registry.cc ${CMAKE_CURRENT_SOURCE_DIR}/inner_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index 4e2c036dde0..a5d82dd611b 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -171,6 +171,12 @@ bool InnerContext::IsNpuEnabled() const { #endif } +bool InnerContext::IsProviderEnabled() const { + return this->device_list_.end() != + std::find_if(this->device_list_.begin(), this->device_list_.end(), + [](const DeviceContext &device) { return !device.provider_.empty(); }); +} + bool InnerContext::IsUserSetCpu() const { return this->device_list_.end() != std::find_if(this->device_list_.begin(), this->device_list_.end(), @@ -189,6 +195,16 @@ bool InnerContext::IsUserSetNpu() const { [](const DeviceContext &device) { return device.device_type_ == DT_NPU; }); } +std::set InnerContext::GetProviders() const { + std::set providers; + for (auto &&device : device_list_) { + if (!device.provider_.empty()) { + providers.insert(device.provider_); + } + } + return providers; +} + CpuDeviceInfo InnerContext::GetCpuInfo() const { auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); diff --git a/mindspore/lite/src/inner_context.h b/mindspore/lite/src/inner_context.h index 8a84ae6fccb..94fa0216f0d 100644 --- a/mindspore/lite/src/inner_context.h +++ b/mindspore/lite/src/inner_context.h @@ -16,7 +16,8 @@ #ifndef MINDSPORE_LITE_SRC_INNER_CONTEXT_H #define MINDSPORE_LITE_SRC_INNER_CONTEXT_H - +#include +#include #include "include/context.h" #include "src/runtime/runtime_api.h" #include "src/runtime/allocator.h" @@ -48,6 +49,10 @@ struct InnerContext : public Context { bool IsNpuEnabled() const; + bool IsProviderEnabled() const; + + std::set GetProviders() const; + CpuDeviceInfo GetCpuInfo() const; GpuDeviceInfo GetGpuInfo() const; diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 19230ce0b79..9b9b2f1471d 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -55,17 +55,6 @@ KernelRegistry *KernelRegistry::GetInstance() { return &instance; } -std::set KernelRegistry::AllProviders() { - std::set providers; - for (auto &&item : kernel_creators_) { - providers.insert(item.first); - } - for (auto &&item : custom_kernel_creators_) { - providers.insert(item.first); - } - return providers; -} - int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) { if (desc.data_type >= kNumberTypeEnd) { return -1; @@ -166,12 +155,20 @@ kernel::CreateKernel KernelRegistry::GetProviderCreator(const kernel::KernelKey MS_ASSERT(param != nullptr); auto custom_type = param->type()->str(); auto archs = custom_kernel_creators_[desc.provider]; - auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) { - return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr; - }); - if (archs_iter != archs.end()) { - return archs_iter->second[custom_type][data_type_index]; + if (desc.kernel_arch.empty()) { + auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) { + return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr; + }); + if (archs_iter != archs.end()) { + return archs_iter->second[custom_type][data_type_index]; + } + } else { + auto find_arch_it = archs.find(desc.kernel_arch); + if (find_arch_it != archs.end()) { + return find_arch_it->second[custom_type][data_type_index]; + } } + return nullptr; } auto index = GetFuncIndex(desc); diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h index 66f94d52dfa..b485c1c0b9d 100644 --- a/mindspore/lite/src/kernel_registry.h +++ b/mindspore/lite/src/kernel_registry.h @@ -23,7 +23,7 @@ #include #include #include "src/lite_kernel.h" -#include "src/register_kernel.h" +#include "src/registry/register_kernel.h" #include "schema/model_generated.h" using mindspore::kernel::kKernelArch_MAX; @@ -43,7 +43,6 @@ class KernelRegistry { virtual kernel::CreateKernel GetProviderCreator(const kernel::KernelKey &desc, const schema::Primitive *prim); int GetCreatorFuncIndex(kernel::KernelKey desc); int GetFuncIndex(const kernel::KernelKey &desc); - std::set AllProviders(); int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type, kernel::CreateKernel creator); void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator); diff --git a/mindspore/lite/src/kernel_interface.cc b/mindspore/lite/src/registry/kernel_interface.cc similarity index 93% rename from mindspore/lite/src/kernel_interface.cc rename to mindspore/lite/src/registry/kernel_interface.cc index 9c92cc771b5..6a964440309 100644 --- a/mindspore/lite/src/kernel_interface.cc +++ b/mindspore/lite/src/registry/kernel_interface.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/kernel_interface.h" -#include "src/kernel_interface_registry.h" +#include "src/registry/kernel_interface.h" +#include "src/registry/kernel_interface_registry.h" namespace mindspore { namespace kernel { diff --git a/mindspore/lite/src/kernel_interface.h b/mindspore/lite/src/registry/kernel_interface.h similarity index 100% rename from mindspore/lite/src/kernel_interface.h rename to mindspore/lite/src/registry/kernel_interface.h diff --git a/mindspore/lite/src/kernel_interface_registry.cc b/mindspore/lite/src/registry/kernel_interface_registry.cc similarity index 93% rename from mindspore/lite/src/kernel_interface_registry.cc rename to mindspore/lite/src/registry/kernel_interface_registry.cc index 7aac3261355..db7eebfa959 100644 --- a/mindspore/lite/src/kernel_interface_registry.cc +++ b/mindspore/lite/src/registry/kernel_interface_registry.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/kernel_interface_registry.h" -#include "src/kernel_interface.h" +#include "src/registry/kernel_interface_registry.h" +#include "src/registry/kernel_interface.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/common/version_manager.h" @@ -34,7 +34,7 @@ std::string GetCustomType(const schema::Primitive *primitive) { } } // namespace -bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) { +bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node, std::set &&providers) { if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) { return false; } @@ -46,7 +46,10 @@ bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) { auto op_type = primitive->value_type(); if (op_type == schema::PrimitiveType_Custom) { auto &&custom_type = GetCustomType(primitive); - return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type](auto &&item) { + return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type, &providers](auto &&item) { + if (providers.find(item.first) == providers.end()) { + return false; + } if (item.second[custom_type] != nullptr) { return true; } @@ -156,17 +159,6 @@ int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Kerne return RET_OK; } -std::set KernelInterfaceRegistry::AllProviders() { - std::set providers; - for (auto &&item : kernel_creators_) { - providers.insert(item.first); - } - for (auto &&item : custom_creators_) { - providers.insert(item.first); - } - return providers; -} - KernelInterfaceRegistry::~KernelInterfaceRegistry() { for (auto &&item : kernel_creators_) { free(item.second); diff --git a/mindspore/lite/src/kernel_interface_registry.h b/mindspore/lite/src/registry/kernel_interface_registry.h similarity index 94% rename from mindspore/lite/src/kernel_interface_registry.h rename to mindspore/lite/src/registry/kernel_interface_registry.h index 2191f448e8c..bc9489b7850 100644 --- a/mindspore/lite/src/kernel_interface_registry.h +++ b/mindspore/lite/src/registry/kernel_interface_registry.h @@ -21,7 +21,7 @@ #include #include #include -#include "src/kernel_interface.h" +#include "src/registry/kernel_interface.h" #include "include/model.h" namespace mindspore { @@ -32,11 +32,10 @@ class KernelInterfaceRegistry { static KernelInterfaceRegistry instance; return &instance; } - bool CheckReg(const lite::Model::Node *node); + bool CheckReg(const lite::Model::Node *node, std::set &&providers); kernel::KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive); int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator); int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator); - std::set AllProviders(); virtual ~KernelInterfaceRegistry(); private: diff --git a/mindspore/lite/src/register_kernel.cc b/mindspore/lite/src/registry/register_kernel.cc similarity index 97% rename from mindspore/lite/src/register_kernel.cc rename to mindspore/lite/src/registry/register_kernel.cc index 97f331ae48f..855f264402c 100644 --- a/mindspore/lite/src/register_kernel.cc +++ b/mindspore/lite/src/registry/register_kernel.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/register_kernel.h" +#include "src/registry/register_kernel.h" #include "src/kernel_registry.h" namespace mindspore { diff --git a/mindspore/lite/src/register_kernel.h b/mindspore/lite/src/registry/register_kernel.h similarity index 100% rename from mindspore/lite/src/register_kernel.h rename to mindspore/lite/src/registry/register_kernel.h diff --git a/mindspore/lite/src/runtime/infer_manager.cc b/mindspore/lite/src/runtime/infer_manager.cc index d6ee66503c0..825fb2b2165 100644 --- a/mindspore/lite/src/runtime/infer_manager.cc +++ b/mindspore/lite/src/runtime/infer_manager.cc @@ -15,24 +15,26 @@ */ #include "src/runtime/infer_manager.h" #include +#include +#include #include "src/common/prim_util.h" #include "src/common/tensor_util.h" #include "schema/model_generated.h" #include "include/errorcode.h" #include "nnacl/errorcode.h" #include "src/tensorlist.h" -#include "src/kernel_interface_registry.h" +#include "src/registry/kernel_interface_registry.h" #include "src/kernel_registry.h" namespace mindspore { namespace lite { int KernelInferShape(const std::vector &inputs, const std::vector &outputs, - const void *primitive) { + const void *primitive, std::set &&providers) { std::vector in_tensors; std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors)); std::vector out_tensors; std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors)); - for (auto &&provider : KernelInterfaceRegistry::Instance()->AllProviders()) { + for (auto &&provider : providers) { auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface( provider, static_cast(primitive)); if (kernel_interface == nullptr) { diff --git a/mindspore/lite/src/runtime/infer_manager.h b/mindspore/lite/src/runtime/infer_manager.h index c132aecc722..cc21fbd9b82 100644 --- a/mindspore/lite/src/runtime/infer_manager.h +++ b/mindspore/lite/src/runtime/infer_manager.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include "src/common/prim_util.h" #include "src/common/common.h" #include "nnacl/tensor_c.h" @@ -28,7 +30,7 @@ namespace mindspore::lite { int KernelInferShape(const std::vector &tensors_in, const std::vector &outputs, OpParameter *parameter); int KernelInferShape(const std::vector &inputs, const std::vector &outputs, - const void *primitive); + const void *primitive, std::set &&providers); class InferManager { public: static InferManager *GetInstance() { diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index bf9afe84fec..63778cc9821 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -48,7 +48,7 @@ #if defined(ENABLE_ARM) && defined(ENABLE_FP16) #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" #endif -#include "src/kernel_interface_registry.h" +#include "src/registry/kernel_interface_registry.h" namespace mindspore::lite { using kernel::KERNEL_ARCH::kCPU; @@ -131,8 +131,8 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) { std::vector inputs; std::vector outputs; FindNodeInoutTensors(*node, &inputs, &outputs); - if (KernelInterfaceRegistry::Instance()->CheckReg(node)) { - return KernelInferShape(inputs, outputs, node->primitive_); + if (KernelInterfaceRegistry::Instance()->CheckReg(node, context_->GetProviders())) { + return KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders()); } int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); @@ -442,20 +442,26 @@ int Scheduler::FindNpuKernel(const std::vector &in_tensors, const std: int Scheduler::FindProviderKernel(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) { MS_ASSERT(kernel != nullptr); - int ret = RET_ERROR; - auto &&providers = KernelRegistry::GetInstance()->AllProviders(); + int ret = RET_NOT_SUPPORT; + if (!context_->IsProviderEnabled()) { + return ret; + } if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) { return ret; } - for (auto &&provider : providers) { - kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), "", provider}; - ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel, - node->primitive_); - if (ret == RET_OK && *kernel != nullptr) { - return ret; + for (auto &&device : context_->device_list_) { + if (!device.provider_.empty()) { + kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), device.provider_device_, + device.provider_}; + ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel, + node->primitive_); + if (ret == RET_OK && *kernel != nullptr) { + return ret; + } } } - return ret; + + return RET_NOT_SUPPORT; } kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in_tensors, diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 0a5ac862769..19214ed588a 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -141,10 +141,10 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/tensorlist.cc ${LITE_DIR}/src/executor.cc ${LITE_DIR}/src/inner_context.cc - ${LITE_DIR}/src/kernel_interface.cc - ${LITE_DIR}/src/kernel_interface_registry.cc + ${LITE_DIR}/src/registry/kernel_interface.cc + ${LITE_DIR}/src/registry/kernel_interface_registry.cc ${LITE_DIR}/src/kernel_registry.cc - ${LITE_DIR}/src/register_kernel.cc + ${LITE_DIR}/src/registry/register_kernel.cc ${LITE_DIR}/src/inner_kernel.cc ${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_kernel_util.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index a84371bae7c..9e0af408c8f 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -124,9 +124,9 @@ set(LITE_SRC ${SRC_DIR}/tensor.cc ${SRC_DIR}/ms_tensor.cc ${SRC_DIR}/tensorlist.cc - ${SRC_DIR}/kernel_interface_registry.cc + ${SRC_DIR}/registry/kernel_interface_registry.cc ${SRC_DIR}/kernel_registry.cc - ${SRC_DIR}/register_kernel.cc + ${SRC_DIR}/registry/register_kernel.cc ${SRC_DIR}/inner_kernel.cc ${SRC_DIR}/lite_kernel.cc ${SRC_DIR}/lite_kernel_util.cc