fix kernel registry bug
This commit is contained in:
parent
d1b3096418
commit
1eb86b7df2
|
@ -33,7 +33,21 @@ using mindspore::schema::PrimitiveType_MAX;
|
||||||
using mindspore::schema::PrimitiveType_MIN;
|
using mindspore::schema::PrimitiveType_MIN;
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
KernelRegistry::KernelRegistry() {}
|
KernelRegistry::KernelRegistry() {
|
||||||
|
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN + 1;
|
||||||
|
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin + 1;
|
||||||
|
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN + 1;
|
||||||
|
// malloc an array contain creator functions of kernel
|
||||||
|
auto total_len = device_type_length_ * data_type_length_ * op_type_length_;
|
||||||
|
creator_arrays_ = (kernel::KernelCreator *)malloc(total_len * sizeof(kernel::KernelCreator));
|
||||||
|
if (creator_arrays_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc creator_arrays_ failed.";
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < total_len; ++i) {
|
||||||
|
creator_arrays_[i] = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); }
|
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); }
|
||||||
|
|
||||||
|
@ -43,25 +57,6 @@ KernelRegistry *KernelRegistry::GetInstance() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int KernelRegistry::Init() {
|
int KernelRegistry::Init() {
|
||||||
lock_.lock();
|
|
||||||
if (creator_arrays_ != nullptr) {
|
|
||||||
lock_.unlock();
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN;
|
|
||||||
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin;
|
|
||||||
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN;
|
|
||||||
// malloc an array contain creator functions of kernel
|
|
||||||
auto total_len = device_type_length_ * data_type_length_ * op_type_length_;
|
|
||||||
creator_arrays_ = (kernel::KernelCreator *)malloc(total_len * sizeof(kernel::KernelCreator));
|
|
||||||
if (creator_arrays_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "malloc creator_arrays_ failed.";
|
|
||||||
lock_.unlock();
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < total_len; ++i) {
|
|
||||||
creator_arrays_[i] = nullptr;
|
|
||||||
}
|
|
||||||
#ifdef ENABLE_ARM64
|
#ifdef ENABLE_ARM64
|
||||||
void *optimized_lib_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
|
void *optimized_lib_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
|
||||||
if (optimized_lib_handler != nullptr) {
|
if (optimized_lib_handler != nullptr) {
|
||||||
|
@ -70,7 +65,6 @@ int KernelRegistry::Init() {
|
||||||
MS_LOG(INFO) << "load optimize lib failed.";
|
MS_LOG(INFO) << "load optimize lib failed.";
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
lock_.unlock();
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,6 +76,10 @@ void KernelRegistry::FreeCreatorArray() {
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
|
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
|
||||||
|
if (creator_arrays_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Creator func array is null.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
int index = GetCreatorFuncIndex(desc);
|
int index = GetCreatorFuncIndex(desc);
|
||||||
auto it = creator_arrays_[index];
|
auto it = creator_arrays_[index];
|
||||||
if (it != nullptr) {
|
if (it != nullptr) {
|
||||||
|
@ -100,12 +98,20 @@ int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) {
|
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) {
|
||||||
|
if (creator_arrays_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Creator func array is null.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
int index = GetCreatorFuncIndex(desc);
|
int index = GetCreatorFuncIndex(desc);
|
||||||
creator_arrays_[index] = creator;
|
creator_arrays_[index] = creator;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type,
|
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type,
|
||||||
kernel::KernelCreator creator) {
|
kernel::KernelCreator creator) {
|
||||||
|
if (creator_arrays_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Creator func array is null.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
KernelKey desc = {arch, data_type, op_type};
|
KernelKey desc = {arch, data_type, op_type};
|
||||||
int index = GetCreatorFuncIndex(desc);
|
int index = GetCreatorFuncIndex(desc);
|
||||||
creator_arrays_[index] = creator;
|
creator_arrays_[index] = creator;
|
||||||
|
|
|
@ -58,10 +58,10 @@ class OptimizeModule {
|
||||||
if ((!support_optimize_ops) && (!support_fp16)) {
|
if ((!support_optimize_ops) && (!support_fp16)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY);
|
// optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY);
|
||||||
if (optimized_op_handler_ == nullptr) {
|
// if (optimized_op_handler_ == nullptr) {
|
||||||
printf("Open optimize shared library failed.\n");
|
// printf("Open optimize shared library failed.\n");
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
~OptimizeModule() = default;
|
~OptimizeModule() = default;
|
||||||
|
|
Loading…
Reference in New Issue