forked from mindspore-Ecosystem/mindspore
!16020 [MS][LITE][STABLE]support register custom kernel
From: @jpc_chenjianping Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
b8d64f2ae9
|
@ -134,6 +134,8 @@ set(LITE_SRC
|
|||
${LITE_DIR}/src/common/prim_util.cc
|
||||
${LITE_DIR}/src/common/tensor_util.cc
|
||||
${LITE_DIR}/src/runtime/infer_manager.cc
|
||||
${LITE_DIR}/src/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/kernel_registry.cc
|
||||
${LITE_DIR}/src/lite_model.cc
|
||||
${LITE_DIR}/src/tensorlist.cc
|
||||
${LITE_DIR}/src/tensor.cc
|
||||
|
|
|
@ -23,8 +23,13 @@ RegisterKernelInterface *RegisterKernelInterface::Instance() {
|
|||
return &instance;
|
||||
}
|
||||
|
||||
int RegisterKernelInterface::Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) {
|
||||
return lite::KernelInterfaceRegistry::Instance()->Reg(vendor, op_type, creator);
|
||||
int RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
return lite::KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
|
||||
}
|
||||
|
||||
int RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
|
||||
KernelInterfaceCreator creator) {
|
||||
return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,7 +32,7 @@ struct CapabilityParam {
|
|||
class KernelInterface {
|
||||
public:
|
||||
virtual ~KernelInterface() = default;
|
||||
virtual int Infer(const std::vector<tensor::MSTensor *> &tensor_in, std::vector<tensor::MSTensor *> *outputs,
|
||||
virtual int Infer(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive) {
|
||||
return 0;
|
||||
}
|
||||
|
@ -47,7 +47,8 @@ typedef KernelInterface *(*KernelInterfaceCreator)();
|
|||
class RegisterKernelInterface {
|
||||
public:
|
||||
static RegisterKernelInterface *Instance();
|
||||
int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator);
|
||||
int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
|
||||
int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
||||
virtual ~RegisterKernelInterface() = default;
|
||||
|
||||
private:
|
||||
|
@ -56,14 +57,21 @@ class RegisterKernelInterface {
|
|||
|
||||
class KernelInterfaceReg {
|
||||
public:
|
||||
KernelInterfaceReg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) {
|
||||
RegisterKernelInterface::Instance()->Reg(vendor, op_type, creator);
|
||||
KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
RegisterKernelInterface::Instance()->Reg(provider, op_type, creator);
|
||||
}
|
||||
|
||||
KernelInterfaceReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator) {
|
||||
RegisterKernelInterface::Instance()->CustomReg(provider, op_type, creator);
|
||||
}
|
||||
~KernelInterfaceReg() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL_INTERFACE(vendor, op_type, creator) \
|
||||
static KernelInterfaceReg g_##vendor##op_type##_inter_reg(vendor, op_type, creator);
|
||||
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
static KernelInterfaceReg g_##provider##op_type##_inter_reg(provider, op_type, creator);
|
||||
|
||||
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(provider, op_type, creator);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "src/kernel_interface.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
using mindspore::kernel::KernelInterfaceCreator;
|
||||
using mindspore::schema::PrimitiveType_MAX;
|
||||
|
@ -24,27 +26,89 @@ using mindspore::schema::PrimitiveType_MIN;
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
|
||||
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN;
|
||||
}
|
||||
|
||||
int KernelInterfaceRegistry::Reg(const std::string &vendor, const int &op_type, KernelInterfaceCreator creator) {
|
||||
auto vendor_hash = std::hash<std::string>{}(vendor);
|
||||
auto iter = kernel_interfaces_.find(vendor_hash);
|
||||
if (iter == kernel_interfaces_.end()) {
|
||||
kernel_interfaces_[vendor_hash] =
|
||||
reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator)));
|
||||
if (kernel_interfaces_[vendor_hash] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) {
|
||||
if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) {
|
||||
return false;
|
||||
}
|
||||
auto primitive = static_cast<const schema::Primitive *>(node->primitive_);
|
||||
if (primitive == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto op_type = primitive->value_type();
|
||||
if (op_type == schema::PrimitiveType_Custom) {
|
||||
return std::any_of(custom_interfaces_.begin(), custom_interfaces_.end(), [node](auto &&item) {
|
||||
if (item.second[node->name_] != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
return std::any_of(kernel_interfaces_.begin(), kernel_interfaces_.end(),
|
||||
[op_type, &mutex = this->mutex_](auto &&item) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (item.second[op_type] != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &op_type,
|
||||
KernelInterfaceCreator creator) {
|
||||
custom_interfaces_[provider][op_type] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::string &provider, int op_type) {
|
||||
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
|
||||
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto iter = kernel_interfaces_.find(provider);
|
||||
if (iter == kernel_interfaces_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto creator = iter->second[op_type];
|
||||
if (creator != nullptr) {
|
||||
return creator();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
|
||||
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernel_interfaces_[vendor_hash][op_type] = creator;
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto iter = kernel_interfaces_.find(provider);
|
||||
if (iter == kernel_interfaces_.end()) {
|
||||
kernel_interfaces_[provider] =
|
||||
reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator)));
|
||||
if (kernel_interfaces_[provider] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
kernel_interfaces_[provider][op_type] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
KernelInterfaceRegistry::~KernelInterfaceRegistry() {
|
||||
for (auto &&item : kernel_interfaces_) {
|
||||
free(item.second);
|
||||
item.second = nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include "src/kernel_interface.h"
|
||||
#include "include/model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -29,14 +31,21 @@ class KernelInterfaceRegistry {
|
|||
static KernelInterfaceRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator);
|
||||
virtual ~KernelInterfaceRegistry() = default;
|
||||
bool CheckReg(const lite::Model::Node *node);
|
||||
kernel::KernelInterface *GetKernelInterface(const std::string &provider, int op_type);
|
||||
const std::map<std::string, kernel::KernelInterfaceCreator *> &kernel_interfaces() { return kernel_interfaces_; }
|
||||
int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator);
|
||||
int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator);
|
||||
virtual ~KernelInterfaceRegistry();
|
||||
|
||||
private:
|
||||
KernelInterfaceRegistry() = default;
|
||||
|
||||
std::unordered_map<size_t, kernel::KernelInterfaceCreator *> kernel_interfaces_;
|
||||
std::mutex mutex_;
|
||||
// key: provider
|
||||
std::map<std::string, kernel::KernelInterfaceCreator *> kernel_interfaces_;
|
||||
// key: provider key: custom type
|
||||
std::map<std::string, std::map<std::string, kernel::KernelInterfaceCreator>> custom_interfaces_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,7 +38,7 @@ using mindspore::kernel::KernelKey;
|
|||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin + 1) * (PrimitiveType_MAX - PrimitiveType_MIN + 1);
|
||||
static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin - 1) * (PrimitiveType_MAX - PrimitiveType_MIN);
|
||||
} // namespace
|
||||
|
||||
KernelRegistry *KernelRegistry::GetInstance() {
|
||||
|
@ -56,50 +56,81 @@ KernelRegistry *KernelRegistry::GetInstance() {
|
|||
}
|
||||
|
||||
int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) {
|
||||
int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin;
|
||||
return dType_index * op_type_length_ + desc.type;
|
||||
if (desc.data_type >= kNumberTypeEnd) {
|
||||
return -1;
|
||||
}
|
||||
int data_type_index = static_cast<int>(desc.data_type) - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0) {
|
||||
return -1;
|
||||
}
|
||||
return data_type_index * op_type_length_ + desc.type;
|
||||
}
|
||||
|
||||
int KernelRegistry::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type,
|
||||
const int type, kernel::CreateKernel creator) {
|
||||
auto vendor_hash = std::hash<std::string>{}(vendor);
|
||||
auto arch_hash = std::hash<std::string>{}(arch);
|
||||
auto iter = kernel_creators_.find(vendor_hash);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
all_vendors_.insert(vendor);
|
||||
kernel_creators_[vendor_hash][arch_hash] =
|
||||
reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (kernel_creators_[vendor_hash][arch_hash] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch;
|
||||
int KernelRegistry::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
if (data_type >= kNumberTypeEnd) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = custom_kernel_creators_.find(provider);
|
||||
if (iter == custom_kernel_creators_.end()) {
|
||||
custom_kernel_creators_[provider][arch] =
|
||||
reinterpret_cast<CreateKernel *>(malloc(data_type_length_ * sizeof(CreateKernel)));
|
||||
if (custom_kernel_creators_[provider][arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(kernel_creators_[vendor_hash][arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
memset(custom_kernel_creators_[provider][arch], 0, data_type_length_ * sizeof(CreateKernel));
|
||||
}
|
||||
|
||||
int data_type_index = data_type - kNumberTypeBegin - 1;
|
||||
if (data_type_index < 0) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
}
|
||||
custom_kernel_creators_[provider][arch][data_type_index] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int KernelRegistry::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type,
|
||||
kernel::CreateKernel creator) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = kernel_creators_.find(provider);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (kernel_creators_[provider][arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(kernel_creators_[provider][arch], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
} else {
|
||||
auto iter_arch = iter->second.find(arch_hash);
|
||||
auto iter_arch = iter->second.find(arch);
|
||||
if (iter_arch == iter->second.end()) {
|
||||
iter->second[arch_hash] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (iter->second[arch_hash] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch;
|
||||
iter->second[arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
|
||||
if (iter->second[arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(iter->second[arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
memset(iter->second[arch], 0, kKernelMaxNum * sizeof(CreateKernel));
|
||||
}
|
||||
}
|
||||
|
||||
KernelKey desc = {kCPU, data_type, type, arch, vendor};
|
||||
KernelKey desc = {kCPU, data_type, type, arch, provider};
|
||||
int index = GetFuncIndex(desc);
|
||||
if (index >= kKernelMaxNum) {
|
||||
if (index >= kKernelMaxNum || index < 0) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernel_creators_[vendor_hash][arch_hash][index] = creator;
|
||||
|
||||
kernel_creators_[provider][arch][index] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int KernelRegistry::Init() { return RET_OK; }
|
||||
|
||||
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
|
||||
if (desc.vendor == kBuiltin) {
|
||||
if (desc.provider == kBuiltin) {
|
||||
int index = GetCreatorFuncIndex(desc);
|
||||
if (index >= array_size_ || index < 0) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
|
||||
|
@ -108,29 +139,29 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
|
|||
}
|
||||
return creator_arrays_[index];
|
||||
}
|
||||
MS_LOG(ERROR) << "Call wrong interface!vendor: " << desc.vendor;
|
||||
MS_LOG(ERROR) << "Call wrong interface!provider: " << desc.provider;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) {
|
||||
auto vendor_hash = std::hash<std::string>{}(desc.vendor);
|
||||
auto it_by_vendor = kernel_creators_.find(vendor_hash);
|
||||
if (it_by_vendor == kernel_creators_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto arch_hash = std::hash<std::string>{}(desc.kernel_arch);
|
||||
auto it_by_arch = it_by_vendor->second.find(arch_hash);
|
||||
if (it_by_arch == it_by_vendor->second.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
kernel::CreateKernel creator = nullptr;
|
||||
auto index = GetFuncIndex(desc);
|
||||
if (index < 0 || index >= kKernelMaxNum) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.kernel_arch << ", data_type" << desc.data_type << ",op type "
|
||||
<< desc.type << ", vendor: " << desc.vendor;
|
||||
if (index >= kKernelMaxNum || index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it_by_arch->second[index];
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
for (auto &&item : kernel_creators_) {
|
||||
for (auto &&arch_item : item.second) {
|
||||
creator = arch_item.second[index];
|
||||
if (creator != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (creator != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return creator;
|
||||
}
|
||||
|
||||
int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
|
||||
|
@ -172,6 +203,19 @@ KernelRegistry::~KernelRegistry() {
|
|||
free(instance->creator_arrays_);
|
||||
instance->creator_arrays_ = nullptr;
|
||||
}
|
||||
|
||||
for (auto &&item : kernel_creators_) {
|
||||
for (auto &&creator : item.second) {
|
||||
free(creator.second);
|
||||
creator.second = nullptr;
|
||||
}
|
||||
}
|
||||
for (auto &&item : custom_kernel_creators_) {
|
||||
for (auto &&creator : item.second) {
|
||||
free(creator.second);
|
||||
creator.second = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelRegistry::SupportKernel(const KernelKey &key) {
|
||||
|
@ -184,7 +228,7 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std
|
|||
kernel::LiteKernel **kernel, const void *primitive) {
|
||||
MS_ASSERT(ctx != nullptr);
|
||||
MS_ASSERT(kernel != nullptr);
|
||||
if (key.vendor == kBuiltin) {
|
||||
if (key.provider == kBuiltin) {
|
||||
auto creator = GetCreator(key);
|
||||
if (creator != nullptr) {
|
||||
*kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
|
||||
|
@ -196,17 +240,18 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std
|
|||
}
|
||||
} else {
|
||||
auto creator = GetDelegateCreator(key);
|
||||
if (creator != nullptr) {
|
||||
std::vector<tensor::MSTensor *> tensors_in;
|
||||
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
|
||||
std::vector<tensor::MSTensor *> tensors_out;
|
||||
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
|
||||
*kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
|
||||
if (*kernel != nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
return RET_ERROR;
|
||||
if (creator == nullptr) {
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
std::vector<tensor::MSTensor *> tensors_in;
|
||||
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
|
||||
std::vector<tensor::MSTensor *> tensors_out;
|
||||
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
|
||||
*kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
|
||||
if (*kernel != nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
@ -42,9 +43,14 @@ class KernelRegistry {
|
|||
virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc);
|
||||
int GetCreatorFuncIndex(kernel::KernelKey desc);
|
||||
int GetFuncIndex(const kernel::KernelKey &desc);
|
||||
const std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> &kernel_creators() {
|
||||
return kernel_creators_;
|
||||
}
|
||||
int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type,
|
||||
kernel::CreateKernel creator);
|
||||
void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator);
|
||||
void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator);
|
||||
int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type,
|
||||
int RegKernel(const std::string &arch, const std::string &vendor, TypeId data_type, int type,
|
||||
kernel::CreateKernel creator);
|
||||
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
|
||||
bool SupportKernel(const kernel::KernelKey &key);
|
||||
|
@ -58,8 +64,8 @@ class KernelRegistry {
|
|||
static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
|
||||
static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_};
|
||||
kernel::KernelCreator *creator_arrays_ = nullptr;
|
||||
std::unordered_map<std::size_t, std::unordered_map<std::size_t, kernel::CreateKernel *>> kernel_creators_;
|
||||
std::set<std::string> all_vendors_;
|
||||
std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> kernel_creators_;
|
||||
std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> custom_kernel_creators_;
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
|
|
|
@ -41,11 +41,11 @@ struct KernelKey {
|
|||
TypeId data_type;
|
||||
int type;
|
||||
std::string kernel_arch;
|
||||
std::string vendor{kBuiltin};
|
||||
std::string provider{kBuiltin};
|
||||
|
||||
bool operator<(const KernelKey &dst) const {
|
||||
if (vendor != dst.vendor) {
|
||||
return vendor < dst.vendor;
|
||||
if (provider != dst.provider) {
|
||||
return provider < dst.provider;
|
||||
} else if (kernel_arch != dst.kernel_arch) {
|
||||
return kernel_arch < dst.kernel_arch;
|
||||
} else if (arch != dst.arch) {
|
||||
|
|
|
@ -23,9 +23,14 @@ RegisterKernel *RegisterKernel::GetInstance() {
|
|||
return &instance;
|
||||
}
|
||||
|
||||
int RegisterKernel::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type,
|
||||
const int op_type, CreateKernel creator) {
|
||||
return lite::KernelRegistry::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator);
|
||||
int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
return lite::KernelRegistry::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
|
||||
}
|
||||
|
||||
int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int op_type,
|
||||
CreateKernel creator) {
|
||||
return lite::KernelRegistry::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,22 +29,30 @@ typedef kernel::LiteKernel *(*CreateKernel)(const std::vector<tensor::MSTensor *
|
|||
class RegisterKernel {
|
||||
public:
|
||||
static RegisterKernel *GetInstance();
|
||||
int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type,
|
||||
CreateKernel creator);
|
||||
int RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type, CreateKernel creator);
|
||||
int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &type,
|
||||
CreateKernel creator);
|
||||
};
|
||||
|
||||
class KernelReg {
|
||||
public:
|
||||
~KernelReg() = default;
|
||||
|
||||
KernelReg(const std::string &arch, const std::string &vendor, const TypeId data_type, const int op_type,
|
||||
KernelReg(const std::string &arch, const std::string &provider, TypeId data_type, int op_type, CreateKernel creator) {
|
||||
RegisterKernel::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
||||
}
|
||||
|
||||
KernelReg(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &op_type,
|
||||
CreateKernel creator) {
|
||||
RegisterKernel::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator);
|
||||
RegisterKernel::GetInstance()->RegCustomKernel(arch, provider, data_type, op_type, creator);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL(arch, vendor, data_type, op_type, creator) \
|
||||
static KernelReg g_##arch##vendor##data_type##op_type##kernelReg(arch, vendor, data_type, op_type, creator);
|
||||
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
|
||||
|
||||
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -14,14 +14,43 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include <algorithm>
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/tensor_util.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "src/kernel_interface_registry.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const void *primitive) {
|
||||
std::vector<tensor::MSTensor *> in_tensors;
|
||||
std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors));
|
||||
std::vector<tensor::MSTensor *> out_tensors;
|
||||
std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors));
|
||||
int op_type = GetPrimitiveType(primitive);
|
||||
for (auto &&item : KernelInterfaceRegistry::Instance()->kernel_interfaces()) {
|
||||
auto provider = item.first;
|
||||
auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, op_type);
|
||||
if (kernel_interface == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast<const schema::Primitive *>(primitive));
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "provider: " << provider << ", op_type: " << PrimitiveTypeName(GetPrimitiveType(primitive))
|
||||
<< " infer fail!";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
|
||||
OpParameter *parameter) {
|
||||
std::vector<TensorC *> in_tensors;
|
||||
|
|
|
@ -25,8 +25,10 @@
|
|||
#include "nnacl/infer/infer.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, std::vector<lite::Tensor *> *outputs,
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
|
||||
OpParameter *parameter);
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const void *primitive);
|
||||
class InferManager {
|
||||
public:
|
||||
static InferManager *GetInstance() {
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
|
||||
#endif
|
||||
#include "src/kernel_interface_registry.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
using kernel::KERNEL_ARCH::kCPU;
|
||||
|
@ -130,6 +131,10 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
|
|||
std::vector<Tensor *> inputs;
|
||||
std::vector<Tensor *> outputs;
|
||||
FindNodeInoutTensors(*node, &inputs, &outputs);
|
||||
if (KernelInterfaceRegistry::Instance()->CheckReg(node)) {
|
||||
return KernelInferShape(inputs, outputs, node->primitive_);
|
||||
}
|
||||
|
||||
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||
auto parame_gen =
|
||||
PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(node->primitive_), schema_version);
|
||||
|
@ -432,6 +437,18 @@ int Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std:
|
|||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) {
|
||||
int ret = RET_ERROR;
|
||||
if (KernelRegistry::GetInstance()->kernel_creators().size() != 0 &&
|
||||
VersionManager::GetInstance()->GetSchemaVersion() != SCHEMA_V0) {
|
||||
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), "", ""};
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
|
||||
node->primitive_);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
|
||||
const std::vector<Tensor *> &out_tensors, const Model::Node *node,
|
||||
TypeId prefer_data_type) {
|
||||
|
@ -439,14 +456,18 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
// why we need this
|
||||
TypeId data_type =
|
||||
(node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors);
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
int status;
|
||||
status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel);
|
||||
if (status == RET_OK && kernel != nullptr) {
|
||||
return kernel;
|
||||
}
|
||||
OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
if (op_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_));
|
||||
return nullptr;
|
||||
}
|
||||
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
int status;
|
||||
#ifdef SUPPORT_GPU
|
||||
// if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) {
|
||||
status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);
|
||||
|
|
|
@ -60,6 +60,7 @@ class Scheduler {
|
|||
kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors,
|
||||
const std::vector<Tensor *> &out_tensors, const Model::Node *node,
|
||||
TypeId prefer_data_type = kTypeUnknown);
|
||||
|
||||
int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
|
||||
kernel::LiteKernel **kernel);
|
||||
|
@ -67,6 +68,9 @@ class Scheduler {
|
|||
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
|
||||
int FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
|
||||
|
||||
int FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel);
|
||||
// schedule a partial node to a subgraph_kernel
|
||||
kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node);
|
||||
// schedule a node to a kernel
|
||||
|
|
|
@ -141,6 +141,8 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/src/tensorlist.cc
|
||||
${LITE_DIR}/src/executor.cc
|
||||
${LITE_DIR}/src/inner_context.cc
|
||||
${LITE_DIR}/src/kernel_interface.cc
|
||||
${LITE_DIR}/src/kernel_interface_registry.cc
|
||||
${LITE_DIR}/src/kernel_registry.cc
|
||||
${LITE_DIR}/src/register_kernel.cc
|
||||
${LITE_DIR}/src/lite_kernel.cc
|
||||
|
|
|
@ -123,6 +123,7 @@ set(LITE_SRC
|
|||
${SRC_DIR}/tensor.cc
|
||||
${SRC_DIR}/ms_tensor.cc
|
||||
${SRC_DIR}/tensorlist.cc
|
||||
${SRC_DIR}/kernel_interface_registry.cc
|
||||
${SRC_DIR}/kernel_registry.cc
|
||||
${SRC_DIR}/register_kernel.cc
|
||||
${SRC_DIR}/lite_kernel.cc
|
||||
|
|
Loading…
Reference in New Issue