adjust regist policies
This commit is contained in:
parent
85ba75565b
commit
a3b4ca694f
12
build.sh
12
build.sh
|
@ -726,6 +726,8 @@ build_lite_java_arm64() {
|
|||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
|
@ -735,6 +737,8 @@ build_lite_java_arm64() {
|
|||
else
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
fi
|
||||
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
|
||||
}
|
||||
|
@ -758,6 +762,8 @@ build_lite_java_arm32() {
|
|||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
|
@ -767,6 +773,8 @@ build_lite_java_arm32() {
|
|||
else
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
fi
|
||||
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
|
||||
}
|
||||
|
@ -791,6 +799,8 @@ build_lite_java_x86() {
|
|||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
|
@ -800,6 +810,8 @@ build_lite_java_x86() {
|
|||
else
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmslite_kernel_reg.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmslite_kernel_reg.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
fi
|
||||
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
|
||||
}
|
||||
|
|
|
@ -169,6 +169,8 @@ if(PLATFORM_ARM64)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/registry/libmslite_kernel_reg.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
|
@ -201,6 +203,8 @@ elseif(PLATFORM_ARM32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/registry/libmslite_kernel_reg.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
|
@ -231,6 +235,8 @@ elseif(WIN32)
|
|||
install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin_reg.dll
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/src/registry/libmslite_kernel_reg.dll
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${glog_LIBPATH}/../bin/libglog.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${NNACL_FILES} DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -279,6 +285,8 @@ elseif(WIN32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/src/${MINDSPORE_LITE_LIB_NAME}.dll DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/src/registry/libmslite_kernel_reg.dll DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
else()
|
||||
if(SUPPORT_TRAIN)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
|
@ -295,11 +303,15 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/registry/libmslite_kernel_reg.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(ENABLE_CONVERTER)
|
||||
install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin_reg.so
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/registry/libmslite_kernel_reg.so
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so.0
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
file(GLOB NNACL_FILES GLOB ${NNACL_DIR}/*.h)
|
||||
|
|
|
@ -291,6 +291,7 @@ if(BUILD_MINDDATA STREQUAL "lite_cv")
|
|||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/registry)
|
||||
add_subdirectory(${CCSRC_DIR}/backend/kernel_compiler/cpu/nnacl build)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/micro/coder)
|
||||
if(NOT APPLE AND ENABLE_TOOLS)
|
||||
|
|
|
@ -133,7 +133,10 @@ set(LITE_SRC
|
|||
${LITE_DIR}/src/common/prim_util.cc
|
||||
${LITE_DIR}/src/common/tensor_util.cc
|
||||
${LITE_DIR}/src/runtime/infer_manager.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/registry/register_kernel.cc
|
||||
${LITE_DIR}/src/registry/register_kernel_impl.cc
|
||||
${LITE_DIR}/src/lite_model.cc
|
||||
${LITE_DIR}/src/tensorlist.cc
|
||||
${LITE_DIR}/src/tensor.cc
|
||||
|
|
|
@ -47,7 +47,6 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/log_adapter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/string_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/prim_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/tensor_util.cc
|
||||
|
@ -64,9 +63,6 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/registry/register_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/registry/kernel_interface_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inner_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel_util.cc
|
||||
|
@ -192,10 +188,6 @@ if(SUPPORT_NPU)
|
|||
target_link_libraries(mindspore-lite npu_kernel_mid)
|
||||
target_link_libraries(mindspore-lite_static npu_kernel_mid)
|
||||
endif()
|
||||
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||
target_link_libraries(mindspore-lite log)
|
||||
target_link_libraries(mindspore-lite_static log)
|
||||
endif()
|
||||
if(BUILD_MINDDATA STREQUAL "lite")
|
||||
target_link_libraries(mindspore-lite minddata_eager_mid minddata-lite)
|
||||
target_link_libraries(mindspore-lite_static minddata_eager_mid)
|
||||
|
@ -217,7 +209,8 @@ if(NOT APPLE AND "${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND PLATFORM_ARM)
|
|||
add_custom_command(TARGET mindspore-lite POST_BUILD COMMAND ${NDK_STRIP}
|
||||
${CMAKE_BINARY_DIR}/src/libmindspore-lite*.so)
|
||||
endif()
|
||||
|
||||
target_link_libraries(mindspore-lite mslite_kernel_reg)
|
||||
target_link_libraries(mindspore-lite_static mslite_kernel_reg)
|
||||
if("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
|
||||
if(NOT APPLE AND PLATFORM_ARM)
|
||||
add_custom_command(TARGET mindspore-lite POST_BUILD COMMAND ${NDK_STRIP}
|
||||
|
|
|
@ -32,7 +32,6 @@ enum CHWK_SHAPE { CHWK_C = 0, CHWK_H = 1, CHWK_W = 2, CHWK_K = 3 };
|
|||
enum KHWC_SHAPE { KHWC_K = 0, KHWC_H = 1, KHWC_W = 2, KHWC_C = 3 };
|
||||
enum CHW_SHAPE { CHW_C = 0, CHW_H = 1, CHW_W = 2 };
|
||||
enum HWC_SHAPE { HWC_H = 0, HWC_W = 1, HWC_C = 2 };
|
||||
enum SCHEMA_VERSION : int { SCHEMA_INVALID = -1, SCHEMA_CUR = 0, SCHEMA_V0 = 1 };
|
||||
static constexpr int kNCHWDimNumber = 4;
|
||||
static constexpr int kNHWCDimNumber = 4;
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "schema/model_generated.h"
|
||||
#ifdef ENABLE_V0
|
||||
#include "schema/model_v0_generated.h"
|
||||
|
|
|
@ -234,9 +234,5 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors) {
|
||||
std::copy(tensors.begin(), tensors.end(), std::back_inserter(*out_tensors));
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,7 +40,6 @@ int GenerateOutTensorC(const OpParameter *const parameter, const std::vector<lit
|
|||
const std::vector<lite::Tensor *> &outputs, std::vector<TensorC *> *out_tensor_c);
|
||||
|
||||
int CheckTensorsInvalid(const std::vector<Tensor *> &tensors);
|
||||
void Tensor2MSTensor(const std::vector<Tensor *> &&tensors, std::vector<tensor::MSTensor *> *out_tensors);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -18,9 +18,10 @@
|
|||
#define MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_
|
||||
|
||||
#include <string>
|
||||
#include "src/common/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
enum SCHEMA_VERSION : int { SCHEMA_INVALID = -1, SCHEMA_CUR = 0, SCHEMA_V0 = 1 };
|
||||
class VersionManager {
|
||||
public:
|
||||
static VersionManager *GetInstance() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include <utility>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/registry/register_kernel.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
|
@ -38,7 +39,13 @@ using mindspore::kernel::KernelKey;
|
|||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin - 1) * (PrimitiveType_MAX - PrimitiveType_MIN);
|
||||
void KernelKeyToKernelDesc(const KernelKey &key, kernel::KernelDesc *desc) {
|
||||
MS_ASSERT(desc != nullptr);
|
||||
desc->data_type = key.data_type;
|
||||
desc->type = key.type;
|
||||
desc->arch = key.kernel_arch;
|
||||
desc->provider = key.provider;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
KernelRegistry *KernelRegistry::GetInstance() {
|
||||
|
@ -55,77 +62,6 @@ KernelRegistry *KernelRegistry::GetInstance() {
|
|||
return &instance;
|
||||
}
|
||||
|
||||
int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) {
|
||||
if (desc.data_type >= kNumberTypeEnd) {
|
||||
return -1;
|
||||
}
|
||||
int data_type_index = static_cast<int>(desc.data_type) - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0) {
|
||||
return -1;
|
||||
}
|
||||
return data_type_index * op_type_length_ + desc.type;
|
||||
}
|
||||
|
||||
int KernelRegistry::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
if (data_type >= kNumberTypeEnd) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
|
||||
custom_kernel_creators_[provider][arch][type] =
|
||||
reinterpret_cast<CreateKernel *>(malloc(data_type_length_ * sizeof(CreateKernel)));
|
||||
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(custom_kernel_creators_[provider][arch][type], 0, data_type_length_ * sizeof(CreateKernel));
|
||||
}
|
||||
|
||||
int data_type_index = data_type - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0 || data_type_index >= data_type_length_) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
}
|
||||
custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int KernelRegistry::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type,
|
||||
kernel::CreateKernel creator) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = kernel_creators_.find(provider);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (kernel_creators_[provider][arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(kernel_creators_[provider][arch], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
} else {
|
||||
auto iter_arch = iter->second.find(arch);
|
||||
if (iter_arch == iter->second.end()) {
|
||||
iter->second[arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (iter->second[arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(iter->second[arch], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
}
|
||||
}
|
||||
|
||||
KernelKey desc = {kCPU, data_type, type, arch, provider};
|
||||
int index = GetFuncIndex(desc);
|
||||
if (index >= kKernelMaxNum || index < 0) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
kernel_creators_[provider][arch][index] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int KernelRegistry::Init() { return RET_OK; }
|
||||
|
||||
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
|
||||
|
@ -142,53 +78,6 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
kernel::CreateKernel KernelRegistry::GetProviderCreator(const kernel::KernelKey &desc,
|
||||
const schema::Primitive *primitive) {
|
||||
kernel::CreateKernel creator = nullptr;
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (desc.type == schema::PrimitiveType_Custom) {
|
||||
int data_type_index = static_cast<int>(desc.data_type) - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto param = primitive->value_as_Custom();
|
||||
MS_ASSERT(param != nullptr);
|
||||
auto custom_type = param->type()->str();
|
||||
auto archs = custom_kernel_creators_[desc.provider];
|
||||
if (desc.kernel_arch.empty()) {
|
||||
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
|
||||
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
|
||||
});
|
||||
if (archs_iter != archs.end()) {
|
||||
return archs_iter->second[custom_type][data_type_index];
|
||||
}
|
||||
} else {
|
||||
auto find_arch_it = archs.find(desc.kernel_arch);
|
||||
if (find_arch_it != archs.end()) {
|
||||
return find_arch_it->second[custom_type][data_type_index];
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
auto index = GetFuncIndex(desc);
|
||||
if (index >= kKernelMaxNum || index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
for (auto &&item : kernel_creators_) {
|
||||
for (auto &&arch_item : item.second) {
|
||||
creator = arch_item.second[index];
|
||||
if (creator != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (creator != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return creator;
|
||||
}
|
||||
|
||||
int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
|
||||
int index;
|
||||
int device_index = static_cast<int>(desc.arch) - kKernelArch_MIN;
|
||||
|
@ -228,21 +117,6 @@ KernelRegistry::~KernelRegistry() {
|
|||
free(instance->creator_arrays_);
|
||||
instance->creator_arrays_ = nullptr;
|
||||
}
|
||||
|
||||
for (auto &&item : kernel_creators_) {
|
||||
for (auto &&creator : item.second) {
|
||||
free(creator.second);
|
||||
creator.second = nullptr;
|
||||
}
|
||||
}
|
||||
for (auto &&provider : custom_kernel_creators_) {
|
||||
for (auto &&arch : provider.second) {
|
||||
for (auto &&creator : arch.second) {
|
||||
free(creator.second);
|
||||
creator.second = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelRegistry::SupportKernel(const KernelKey &key) {
|
||||
|
@ -273,14 +147,15 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std
|
|||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto creator = GetProviderCreator(key, static_cast<const schema::Primitive *>(primitive));
|
||||
kernel::KernelDesc desc;
|
||||
KernelKeyToKernelDesc(key, &desc);
|
||||
auto creator =
|
||||
kernel::RegisterKernel::GetInstance()->GetCreator(desc, static_cast<const schema::Primitive *>(primitive));
|
||||
if (creator == nullptr) {
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
std::vector<tensor::MSTensor *> tensors_in;
|
||||
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
|
||||
std::vector<tensor::MSTensor *> tensors_out;
|
||||
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
|
||||
std::vector<tensor::MSTensor *> tensors_in(in_tensors.begin(), in_tensors.end());
|
||||
std::vector<tensor::MSTensor *> tensors_out(out_tensors.begin(), out_tensors.end());
|
||||
auto base_kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
|
||||
if (base_kernel != nullptr) {
|
||||
auto *lite_kernel = new (std::nothrow) kernel::LiteKernel(base_kernel);
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/registry/register_kernel.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
using mindspore::kernel::kKernelArch_MAX;
|
||||
|
@ -40,15 +39,9 @@ class KernelRegistry {
|
|||
static KernelRegistry *GetInstance();
|
||||
static int Init();
|
||||
virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc);
|
||||
virtual kernel::CreateKernel GetProviderCreator(const kernel::KernelKey &desc, const schema::Primitive *prim);
|
||||
int GetCreatorFuncIndex(kernel::KernelKey desc);
|
||||
int GetFuncIndex(const kernel::KernelKey &desc);
|
||||
int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type,
|
||||
kernel::CreateKernel creator);
|
||||
void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator);
|
||||
void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator);
|
||||
int RegKernel(const std::string &arch, const std::string &vendor, TypeId data_type, int type,
|
||||
kernel::CreateKernel creator);
|
||||
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
|
||||
bool SupportKernel(const kernel::KernelKey &key);
|
||||
int GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
|
@ -61,10 +54,6 @@ class KernelRegistry {
|
|||
static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
|
||||
static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_};
|
||||
kernel::KernelCreator *creator_arrays_ = nullptr;
|
||||
std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> kernel_creators_;
|
||||
// keys:provider, arch, type
|
||||
std::map<std::string, std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>>>
|
||||
custom_kernel_creators_;
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "include/model.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/version_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "nnacl/op_base.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/version_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
file(GLOB RUNTIME_REG_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
if(USE_GLOG)
|
||||
set(RUNTIME_REG_SRC ${RUNTIME_REG_SRC} ${CORE_DIR}/utils/log_adapter.cc)
|
||||
else()
|
||||
set(RUNTIME_REG_SRC ${RUNTIME_REG_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/../common/log_adapter.cc)
|
||||
endif()
|
||||
set_property(SOURCE ${RUNTIME_REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(mslite_kernel_reg SHARED ${RUNTIME_REG_SRC})
|
||||
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||
target_link_libraries(mslite_kernel_reg log)
|
||||
endif()
|
||||
add_dependencies(mslite_kernel_reg fbs_src)
|
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/registry/kernel_interface.h"
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -31,5 +33,14 @@ int RegisterKernelInterface::CustomReg(const std::string &provider, const std::s
|
|||
KernelInterfaceCreator creator) {
|
||||
return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
|
||||
}
|
||||
|
||||
bool RegisterKernelInterface::CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers) {
|
||||
return lite::KernelInterfaceRegistry::Instance()->CheckReg(node, std::forward<std::set<std::string>>(providers));
|
||||
}
|
||||
|
||||
KernelInterface *RegisterKernelInterface::GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive) {
|
||||
return lite::KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,19 +17,21 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
|
||||
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/model.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct CapabilityParam {
|
||||
struct MS_API CapabilityParam {
|
||||
float exec_time_;
|
||||
float power_usage_;
|
||||
};
|
||||
|
||||
class KernelInterface {
|
||||
class MS_API KernelInterface {
|
||||
public:
|
||||
virtual ~KernelInterface() = default;
|
||||
virtual int Infer(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs,
|
||||
|
@ -44,18 +46,20 @@ class KernelInterface {
|
|||
};
|
||||
typedef KernelInterface *(*KernelInterfaceCreator)();
|
||||
|
||||
class RegisterKernelInterface {
|
||||
class MS_API RegisterKernelInterface {
|
||||
public:
|
||||
static RegisterKernelInterface *Instance();
|
||||
int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
|
||||
int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
||||
bool CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers);
|
||||
KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive);
|
||||
virtual ~RegisterKernelInterface() = default;
|
||||
|
||||
private:
|
||||
RegisterKernelInterface() = default;
|
||||
};
|
||||
|
||||
class KernelInterfaceReg {
|
||||
class MS_API KernelInterfaceReg {
|
||||
public:
|
||||
KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
RegisterKernelInterface::Instance()->Reg(provider, op_type, creator);
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/registry/register_kernel.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include <set>
|
||||
#include "src/registry/register_kernel_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -25,12 +27,16 @@ RegisterKernel *RegisterKernel::GetInstance() {
|
|||
|
||||
int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
return lite::KernelRegistry::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
|
||||
return lite::RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
|
||||
}
|
||||
|
||||
int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int op_type,
|
||||
CreateKernel creator) {
|
||||
return lite::KernelRegistry::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
||||
return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
||||
}
|
||||
|
||||
CreateKernel RegisterKernel::GetCreator(const kernel::KernelDesc &desc, const schema::Primitive *primitive) {
|
||||
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(desc, primitive);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,28 +14,51 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_
|
||||
#define MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_H_
|
||||
#define MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "schema/ops_generated.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/context.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct MS_API KernelDesc {
|
||||
TypeId data_type;
|
||||
int type;
|
||||
std::string arch;
|
||||
std::string provider;
|
||||
|
||||
bool operator<(const KernelDesc &dst) const {
|
||||
if (provider != dst.provider) {
|
||||
return provider < dst.provider;
|
||||
} else if (arch != dst.arch) {
|
||||
return arch < dst.arch;
|
||||
} else if (data_type != dst.data_type) {
|
||||
return data_type < dst.data_type;
|
||||
} else {
|
||||
return type < dst.type;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
typedef kernel::Kernel *(*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs,
|
||||
const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive, const lite::Context *ctx);
|
||||
class RegisterKernel {
|
||||
class MS_API RegisterKernel {
|
||||
public:
|
||||
static RegisterKernel *GetInstance();
|
||||
int RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type, CreateKernel creator);
|
||||
int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &type,
|
||||
CreateKernel creator);
|
||||
CreateKernel GetCreator(const kernel::KernelDesc &desc, const schema::Primitive *primitive);
|
||||
};
|
||||
|
||||
class KernelReg {
|
||||
class MS_API KernelReg {
|
||||
public:
|
||||
~KernelReg() = default;
|
||||
|
||||
|
@ -57,4 +80,4 @@ class KernelReg {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTER_KERNEL_H_
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_H_
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/registry/register_kernel_impl.h"
|
||||
#include "src/registry/register_kernel.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
using mindspore::kernel::CreateKernel;
|
||||
using mindspore::kernel::KernelDesc;
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin - 1) * (PrimitiveType_MAX - PrimitiveType_MIN);
|
||||
} // namespace
|
||||
|
||||
int RegistryKernelImpl::GetFuncIndex(const kernel::KernelDesc &desc) {
|
||||
if (desc.data_type >= kNumberTypeEnd) {
|
||||
return -1;
|
||||
}
|
||||
int data_type_index = static_cast<int>(desc.data_type) - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0) {
|
||||
return -1;
|
||||
}
|
||||
return data_type_index * op_type_length_ + desc.type;
|
||||
}
|
||||
|
||||
int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
if (data_type >= kNumberTypeEnd) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
|
||||
custom_kernel_creators_[provider][arch][type] =
|
||||
reinterpret_cast<CreateKernel *>(malloc(data_type_length_ * sizeof(CreateKernel)));
|
||||
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(custom_kernel_creators_[provider][arch][type], 0, data_type_length_ * sizeof(CreateKernel));
|
||||
}
|
||||
|
||||
int data_type_index = data_type - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0 || data_type_index >= data_type_length_) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
}
|
||||
custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type,
|
||||
kernel::CreateKernel creator) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = kernel_creators_.find(provider);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (kernel_creators_[provider][arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(kernel_creators_[provider][arch], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
} else {
|
||||
auto iter_arch = iter->second.find(arch);
|
||||
if (iter_arch == iter->second.end()) {
|
||||
iter->second[arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (iter->second[arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(iter->second[arch], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
}
|
||||
}
|
||||
|
||||
KernelDesc desc = {data_type, type, arch, provider};
|
||||
int index = GetFuncIndex(desc);
|
||||
if (index >= kKernelMaxNum || index < 0) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
kernel_creators_[provider][arch][index] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::CreateKernel RegistryKernelImpl::GetProviderCreator(const KernelDesc &desc,
|
||||
const schema::Primitive *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
kernel::CreateKernel creator = nullptr;
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (desc.type == schema::PrimitiveType_Custom) {
|
||||
int data_type_index = static_cast<int>(desc.data_type) - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto param = primitive->value_as_Custom();
|
||||
MS_ASSERT(param != nullptr);
|
||||
auto custom_type = param->type()->str();
|
||||
auto archs = custom_kernel_creators_[desc.provider];
|
||||
if (desc.arch.empty()) {
|
||||
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
|
||||
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
|
||||
});
|
||||
if (archs_iter != archs.end()) {
|
||||
return archs_iter->second[custom_type][data_type_index];
|
||||
}
|
||||
} else {
|
||||
auto find_arch_it = archs.find(desc.arch);
|
||||
if (find_arch_it != archs.end()) {
|
||||
return find_arch_it->second[custom_type][data_type_index];
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
auto index = GetFuncIndex(desc);
|
||||
if (index >= kKernelMaxNum || index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
for (auto &&item : kernel_creators_) {
|
||||
for (auto &&arch_item : item.second) {
|
||||
creator = arch_item.second[index];
|
||||
if (creator != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (creator != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return creator;
|
||||
}
|
||||
|
||||
RegistryKernelImpl::~RegistryKernelImpl() {
|
||||
for (auto &&item : kernel_creators_) {
|
||||
for (auto &&creator : item.second) {
|
||||
free(creator.second);
|
||||
creator.second = nullptr;
|
||||
}
|
||||
}
|
||||
for (auto &&provider : custom_kernel_creators_) {
|
||||
for (auto &&arch : provider.second) {
|
||||
for (auto &&creator : arch.second) {
|
||||
free(creator.second);
|
||||
creator.second = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_IMPL_H_
|
||||
#define MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_IMPL_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "src/registry/register_kernel.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_MAX;
|
||||
using mindspore::schema::PrimitiveType_MIN;
|
||||
|
||||
namespace mindspore::lite {
|
||||
class RegistryKernelImpl {
|
||||
public:
|
||||
RegistryKernelImpl() = default;
|
||||
virtual ~RegistryKernelImpl();
|
||||
|
||||
static RegistryKernelImpl *GetInstance() {
|
||||
static RegistryKernelImpl instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
int GetFuncIndex(const kernel::KernelDesc &desc);
|
||||
|
||||
int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &type,
|
||||
kernel::CreateKernel creator);
|
||||
|
||||
int RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type,
|
||||
kernel::CreateKernel creator);
|
||||
|
||||
virtual kernel::CreateKernel GetProviderCreator(const kernel::KernelDesc &desc, const schema::Primitive *primitive);
|
||||
|
||||
const std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> &kernel_creators() {
|
||||
return kernel_creators_;
|
||||
}
|
||||
|
||||
protected:
|
||||
static const int data_type_length_{kNumberTypeEnd - kNumberTypeBegin + 1};
|
||||
static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
|
||||
std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> kernel_creators_;
|
||||
// keys:provider, arch, type
|
||||
std::map<std::string, std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>>>
|
||||
custom_kernel_creators_;
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_IMPL_H_
|
|
@ -23,7 +23,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
#include "src/registry/kernel_interface.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -35,7 +35,7 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
|
|||
std::vector<tensor::MSTensor *> out_tensors;
|
||||
std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors));
|
||||
for (auto &&provider : providers) {
|
||||
auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface(
|
||||
auto kernel_interface = kernel::RegisterKernelInterface::Instance()->GetKernelInterface(
|
||||
provider, static_cast<const schema::Primitive *>(primitive));
|
||||
if (kernel_interface == nullptr) {
|
||||
continue;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "src/common/graph_util.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/registry/register_kernel.h"
|
||||
#include "src/lite_kernel_util.h"
|
||||
#include "src/sub_graph_kernel.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
|
@ -48,7 +49,7 @@
|
|||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
|
||||
#endif
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
#include "src/registry/kernel_interface.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
using kernel::KERNEL_ARCH::kCPU;
|
||||
|
@ -132,7 +133,7 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
|
|||
std::vector<Tensor *> inputs;
|
||||
std::vector<Tensor *> outputs;
|
||||
FindNodeInoutTensors(*node, &inputs, &outputs);
|
||||
if (KernelInterfaceRegistry::Instance()->CheckReg(node, context_->GetProviders())) {
|
||||
if (kernel::RegisterKernelInterface::Instance()->CheckReg(node, context_->GetProviders())) {
|
||||
return KernelInferShape(inputs, outputs, node->primitive_, context_->GetProviders());
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ if(ENABLE_CONVERTER)
|
|||
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
|
||||
)
|
||||
else()
|
||||
set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc)
|
||||
add_compile_definitions(USE_ANDROID_LOG)
|
||||
endif()
|
||||
|
||||
|
@ -141,10 +140,7 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/src/tensorlist.cc
|
||||
${LITE_DIR}/src/executor.cc
|
||||
${LITE_DIR}/src/inner_context.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface.cc
|
||||
${LITE_DIR}/src/registry/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/kernel_registry.cc
|
||||
${LITE_DIR}/src/registry/register_kernel.cc
|
||||
${LITE_DIR}/src/inner_kernel.cc
|
||||
${LITE_DIR}/src/lite_kernel.cc
|
||||
${LITE_DIR}/src/lite_kernel_util.cc
|
||||
|
@ -455,3 +451,4 @@ if(ENABLE_CONVERTER)
|
|||
mindspore::glog
|
||||
)
|
||||
endif()
|
||||
target_link_libraries(lite-test mslite_kernel_reg)
|
||||
|
|
|
@ -1482,6 +1482,7 @@ function Run_arm64() {
|
|||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/third_party/hiai_ddk/lib/libhiai_ir_build.so ${benchmark_test_path}/libhiai_ir_build.so || exit 1
|
||||
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmslite_kernel_reg.so ${benchmark_test_path}/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1
|
||||
|
||||
# adb push all needed files to the phone
|
||||
|
@ -1900,6 +1901,7 @@ function Run_arm32() {
|
|||
fi
|
||||
|
||||
cp -a ${arm32_path}/mindspore-lite-${version}-inference-android-aarch32/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1
|
||||
cp -a ${arm32_path}/mindspore-lite-${version}-inference-android-aarch32/inference/lib/libmslite_kernel_reg.so ${benchmark_test_path}/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ${arm32_path}/mindspore-lite-${version}-inference-android-aarch32/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1
|
||||
|
||||
# adb push all needed files to the phone
|
||||
|
@ -1957,6 +1959,7 @@ function Run_arm64_fp16() {
|
|||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/third_party/hiai_ddk/lib/libhiai_ir_build.so ${benchmark_test_path}/libhiai_ir_build.so || exit 1
|
||||
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmslite_kernel_reg.so ${benchmark_test_path}/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1
|
||||
|
||||
# adb push all needed files to the phone
|
||||
|
@ -2126,6 +2129,7 @@ function Run_armv82_a32_fp16() {
|
|||
fi
|
||||
|
||||
cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1
|
||||
cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/inference/lib/libmslite_kernel_reg.so ${benchmark_test_path}/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ${armv82_path}/mindspore-lite-${version}-inference-android-aarch32/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1
|
||||
|
||||
# adb push all needed files to the phone
|
||||
|
@ -2303,6 +2307,7 @@ function Run_gpu() {
|
|||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/third_party/hiai_ddk/lib/libhiai_ir_build.so ${benchmark_test_path}/libhiai_ir_build.so || exit 1
|
||||
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmslite_kernel_reg.so ${benchmark_test_path}/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1
|
||||
|
||||
# adb push all needed files to the phone
|
||||
|
@ -2455,6 +2460,7 @@ function Run_npu() {
|
|||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/third_party/hiai_ddk/lib/libhiai_ir_build.so ${benchmark_test_path}/libhiai_ir_build.so || exit 1
|
||||
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmindspore-lite.so ${benchmark_test_path}/libmindspore-lite.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/inference/lib/libmslite_kernel_reg.so ${benchmark_test_path}/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ${arm64_path}/mindspore-lite-${version}-inference-android-aarch64/tools/benchmark/benchmark ${benchmark_test_path}/benchmark || exit 1
|
||||
|
||||
# adb push all needed files to the phone
|
||||
|
|
|
@ -31,6 +31,7 @@ function Run_cropper() {
|
|||
cp -a ./inference/third_party/hiai_ddk/lib/libhiai_ir_build.so "${cropper_test_path}"/libhiai_ir_build.so || exit 1
|
||||
|
||||
cp -a ./inference/lib/libmindspore-lite.a "${cropper_test_path}"/libmindspore-lite.a || exit 1
|
||||
cp -a ./inference/lib/libmslite_kernel_reg.so "${cropper_test_path}"/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ./tools/benchmark/benchmark "${cropper_test_path}"/benchmark || exit 1
|
||||
|
||||
cp -r "${x86_path}"/mindspore-lite-${version}-inference-linux-x64/tools/cropper/ "${cropper_test_path}" || exit 1
|
||||
|
@ -74,7 +75,7 @@ function Run_cropper() {
|
|||
-fno-addrsig -Wa,--noexecstack -Wformat -Werror=format-security -O0 -std=c++17 -DARM64=1 -O2 -DNDEBUG -s \
|
||||
-Wl,--exclude-libs,libgcc.a -Wl,--exclude-libs,libatomic.a -Wl,--build-id -Wl,--warn-shared-textrel \
|
||||
-Wl,--fatal-warnings -Wl,--no-undefined -Qunused-arguments -Wl,-z,noexecstack -Wl,--no-whole-archive -llog -latomic -lm \
|
||||
-L "${cropper_test_path}" -lhiai -lhiai_ir -lhiai_ir_build \
|
||||
-L "${cropper_test_path}" -lmslite_kernel_reg -lhiai -lhiai_ir -lhiai_ir_build \
|
||||
-shared -o libmindspore-lite.so -Wl,-soname,libmindspore-lite.so
|
||||
|
||||
if [ $? = 0 ]; then
|
||||
|
|
|
@ -156,6 +156,7 @@ function Run_arm() {
|
|||
fi
|
||||
|
||||
cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/lib/libmindspore-lite-train.so ${benchmark_train_test_path}/libmindspore-lite-train.so || exit 1
|
||||
cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/lib/libmslite_kernel_reg.so ${benchmark_train_test_path}/libmslite_kernel_reg.so || exit 1
|
||||
cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/tools/benchmark_train/benchmark_train ${benchmark_train_test_path}/benchmark_train || exit 1
|
||||
|
||||
# adb push all needed files to the phone
|
||||
|
|
|
@ -124,9 +124,7 @@ set(LITE_SRC
|
|||
${SRC_DIR}/tensor.cc
|
||||
${SRC_DIR}/ms_tensor.cc
|
||||
${SRC_DIR}/tensorlist.cc
|
||||
${SRC_DIR}/registry/kernel_interface_registry.cc
|
||||
${SRC_DIR}/kernel_registry.cc
|
||||
${SRC_DIR}/registry/register_kernel.cc
|
||||
${SRC_DIR}/inner_kernel.cc
|
||||
${SRC_DIR}/lite_kernel.cc
|
||||
${SRC_DIR}/lite_kernel_util.cc
|
||||
|
@ -209,6 +207,7 @@ target_link_libraries(converter_lite PRIVATE
|
|||
mindspore::json
|
||||
mindspore::eigen
|
||||
-Wl,--whole-archive mindspore_core -Wl,--no-whole-archive
|
||||
mslite_kernel_reg
|
||||
mindspore::glog
|
||||
mindspore::protobuf
|
||||
mindspore::flatbuffers
|
||||
|
|
Loading…
Reference in New Issue