!16020 [MS][LITE][STABLE]support register custom kernel

From: @jpc_chenjianping
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-05-08 09:50:14 +08:00 committed by Gitee
commit b8d64f2ae9
16 changed files with 305 additions and 94 deletions

View File

@ -134,6 +134,8 @@ set(LITE_SRC
${LITE_DIR}/src/common/prim_util.cc
${LITE_DIR}/src/common/tensor_util.cc
${LITE_DIR}/src/runtime/infer_manager.cc
${LITE_DIR}/src/kernel_interface_registry.cc
${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/lite_model.cc
${LITE_DIR}/src/tensorlist.cc
${LITE_DIR}/src/tensor.cc

View File

@ -23,8 +23,13 @@ RegisterKernelInterface *RegisterKernelInterface::Instance() {
return &instance;
}
int RegisterKernelInterface::Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) {
return lite::KernelInterfaceRegistry::Instance()->Reg(vendor, op_type, creator);
int RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
return lite::KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
}
int RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) {
return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
}
} // namespace kernel
} // namespace mindspore

View File

@ -32,7 +32,7 @@ struct CapabilityParam {
class KernelInterface {
public:
virtual ~KernelInterface() = default;
virtual int Infer(const std::vector<tensor::MSTensor *> &tensor_in, std::vector<tensor::MSTensor *> *outputs,
virtual int Infer(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs,
const schema::Primitive *primitive) {
return 0;
}
@ -47,7 +47,8 @@ typedef KernelInterface *(*KernelInterfaceCreator)();
class RegisterKernelInterface {
public:
static RegisterKernelInterface *Instance();
int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator);
int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
virtual ~RegisterKernelInterface() = default;
private:
@ -56,14 +57,21 @@ class RegisterKernelInterface {
class KernelInterfaceReg {
public:
KernelInterfaceReg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator) {
RegisterKernelInterface::Instance()->Reg(vendor, op_type, creator);
KernelInterfaceReg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
RegisterKernelInterface::Instance()->Reg(provider, op_type, creator);
}
KernelInterfaceReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator) {
RegisterKernelInterface::Instance()->CustomReg(provider, op_type, creator);
}
~KernelInterfaceReg() = default;
};
#define REGISTER_KERNEL_INTERFACE(vendor, op_type, creator) \
static KernelInterfaceReg g_##vendor##op_type##_inter_reg(vendor, op_type, creator);
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
static KernelInterfaceReg g_##provider##op_type##_inter_reg(provider, op_type, creator);
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(provider, op_type, creator);
} // namespace kernel
} // namespace mindspore

View File

@ -17,6 +17,8 @@
#include "src/kernel_interface.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
#include "schema/model_generated.h"
using mindspore::kernel::KernelInterfaceCreator;
using mindspore::schema::PrimitiveType_MAX;
@ -24,27 +26,89 @@ using mindspore::schema::PrimitiveType_MIN;
namespace mindspore {
namespace lite {
namespace {
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN;
}
int KernelInterfaceRegistry::Reg(const std::string &vendor, const int &op_type, KernelInterfaceCreator creator) {
auto vendor_hash = std::hash<std::string>{}(vendor);
auto iter = kernel_interfaces_.find(vendor_hash);
if (iter == kernel_interfaces_.end()) {
kernel_interfaces_[vendor_hash] =
reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator)));
if (kernel_interfaces_[vendor_hash] == nullptr) {
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
return RET_ERROR;
}
bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node) {
if (VersionManager::GetInstance()->GetSchemaVersion() == SCHEMA_V0) {
return false;
}
auto primitive = static_cast<const schema::Primitive *>(node->primitive_);
if (primitive == nullptr) {
return false;
}
auto op_type = primitive->value_type();
if (op_type == schema::PrimitiveType_Custom) {
return std::any_of(custom_interfaces_.begin(), custom_interfaces_.end(), [node](auto &&item) {
if (item.second[node->name_] != nullptr) {
return true;
}
return false;
});
}
return std::any_of(kernel_interfaces_.begin(), kernel_interfaces_.end(),
[op_type, &mutex = this->mutex_](auto &&item) {
std::unique_lock<std::mutex> lock(mutex);
if (item.second[op_type] != nullptr) {
return true;
}
return false;
});
}
int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) {
custom_interfaces_[provider][op_type] = creator;
return RET_OK;
}
kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::string &provider, int op_type) {
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
return nullptr;
}
std::unique_lock<std::mutex> lock(mutex_);
auto iter = kernel_interfaces_.find(provider);
if (iter == kernel_interfaces_.end()) {
return nullptr;
}
auto creator = iter->second[op_type];
if (creator != nullptr) {
return creator();
}
return nullptr;
}
int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
return RET_ERROR;
}
kernel_interfaces_[vendor_hash][op_type] = creator;
std::unique_lock<std::mutex> lock(mutex_);
auto iter = kernel_interfaces_.find(provider);
if (iter == kernel_interfaces_.end()) {
kernel_interfaces_[provider] =
reinterpret_cast<KernelInterfaceCreator *>(malloc(kMaxKernelNum * sizeof(KernelInterfaceCreator)));
if (kernel_interfaces_[provider] == nullptr) {
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
return RET_ERROR;
}
}
kernel_interfaces_[provider][op_type] = creator;
return RET_OK;
}
KernelInterfaceRegistry::~KernelInterfaceRegistry() {
for (auto &&item : kernel_interfaces_) {
free(item.second);
item.second = nullptr;
}
}
} // namespace lite
} // namespace mindspore

View File

@ -18,8 +18,10 @@
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
#include <string>
#include <unordered_map>
#include <map>
#include <mutex>
#include "src/kernel_interface.h"
#include "include/model.h"
namespace mindspore {
namespace lite {
@ -29,14 +31,21 @@ class KernelInterfaceRegistry {
static KernelInterfaceRegistry instance;
return &instance;
}
int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry() = default;
bool CheckReg(const lite::Model::Node *node);
kernel::KernelInterface *GetKernelInterface(const std::string &provider, int op_type);
const std::map<std::string, kernel::KernelInterfaceCreator *> &kernel_interfaces() { return kernel_interfaces_; }
int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator);
int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry();
private:
KernelInterfaceRegistry() = default;
std::unordered_map<size_t, kernel::KernelInterfaceCreator *> kernel_interfaces_;
std::mutex mutex_;
// key: provider
std::map<std::string, kernel::KernelInterfaceCreator *> kernel_interfaces_;
// key: provider key: custom type
std::map<std::string, std::map<std::string, kernel::KernelInterfaceCreator>> custom_interfaces_;
};
} // namespace lite
} // namespace mindspore

View File

@ -38,7 +38,7 @@ using mindspore::kernel::KernelKey;
namespace mindspore::lite {
namespace {
static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin + 1) * (PrimitiveType_MAX - PrimitiveType_MIN + 1);
static const int kKernelMaxNum = (kNumberTypeEnd - kNumberTypeBegin - 1) * (PrimitiveType_MAX - PrimitiveType_MIN);
} // namespace
KernelRegistry *KernelRegistry::GetInstance() {
@ -56,50 +56,81 @@ KernelRegistry *KernelRegistry::GetInstance() {
}
int KernelRegistry::GetFuncIndex(const kernel::KernelKey &desc) {
int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin;
return dType_index * op_type_length_ + desc.type;
if (desc.data_type >= kNumberTypeEnd) {
return -1;
}
int data_type_index = static_cast<int>(desc.data_type) - kNumberTypeBegin - 1;
if (data_type_index < 0) {
return -1;
}
return data_type_index * op_type_length_ + desc.type;
}
int KernelRegistry::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type,
const int type, kernel::CreateKernel creator) {
auto vendor_hash = std::hash<std::string>{}(vendor);
auto arch_hash = std::hash<std::string>{}(arch);
auto iter = kernel_creators_.find(vendor_hash);
if (iter == kernel_creators_.end()) {
all_vendors_.insert(vendor);
kernel_creators_[vendor_hash][arch_hash] =
reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
if (kernel_creators_[vendor_hash][arch_hash] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch;
int KernelRegistry::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
const std::string &type, CreateKernel creator) {
if (data_type >= kNumberTypeEnd) {
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
return RET_ERROR;
}
std::unique_lock<std::mutex> lock(lock_);
auto iter = custom_kernel_creators_.find(provider);
if (iter == custom_kernel_creators_.end()) {
custom_kernel_creators_[provider][arch] =
reinterpret_cast<CreateKernel *>(malloc(data_type_length_ * sizeof(CreateKernel)));
if (custom_kernel_creators_[provider][arch] == nullptr) {
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
return RET_ERROR;
}
memset(kernel_creators_[vendor_hash][arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel));
memset(custom_kernel_creators_[provider][arch], 0, data_type_length_ * sizeof(CreateKernel));
}
int data_type_index = data_type - kNumberTypeBegin - 1;
if (data_type_index < 0) {
MS_LOG(ERROR) << "invalid data_type: " << data_type << "!provider: " << provider;
return RET_ERROR;
}
custom_kernel_creators_[provider][arch][data_type_index] = creator;
return RET_OK;
}
int KernelRegistry::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type,
kernel::CreateKernel creator) {
std::unique_lock<std::mutex> lock(lock_);
auto iter = kernel_creators_.find(provider);
if (iter == kernel_creators_.end()) {
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
if (kernel_creators_[provider][arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
return RET_ERROR;
}
memset(kernel_creators_[provider][arch], 0, kKernelMaxNum * sizeof(CreateKernel));
} else {
auto iter_arch = iter->second.find(arch_hash);
auto iter_arch = iter->second.find(arch);
if (iter_arch == iter->second.end()) {
iter->second[arch_hash] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
if (iter->second[arch_hash] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! vendor: " << vendor << ",arch:" << arch;
iter->second[arch] = reinterpret_cast<CreateKernel *>(malloc(kKernelMaxNum * sizeof(CreateKernel)));
if (iter->second[arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
return RET_ERROR;
}
memset(iter->second[arch_hash], 0, kKernelMaxNum * sizeof(CreateKernel));
memset(iter->second[arch], 0, kKernelMaxNum * sizeof(CreateKernel));
}
}
KernelKey desc = {kCPU, data_type, type, arch, vendor};
KernelKey desc = {kCPU, data_type, type, arch, provider};
int index = GetFuncIndex(desc);
if (index >= kKernelMaxNum) {
if (index >= kKernelMaxNum || index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << data_type << ",op type " << type;
return RET_ERROR;
}
kernel_creators_[vendor_hash][arch_hash][index] = creator;
kernel_creators_[provider][arch][index] = creator;
return RET_OK;
}
int KernelRegistry::Init() { return RET_OK; }
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
if (desc.vendor == kBuiltin) {
if (desc.provider == kBuiltin) {
int index = GetCreatorFuncIndex(desc);
if (index >= array_size_ || index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
@ -108,29 +139,29 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
}
return creator_arrays_[index];
}
MS_LOG(ERROR) << "Call wrong interface!vendor: " << desc.vendor;
MS_LOG(ERROR) << "Call wrong interface!provider: " << desc.provider;
return nullptr;
}
kernel::CreateKernel KernelRegistry::GetDelegateCreator(const kernel::KernelKey &desc) {
auto vendor_hash = std::hash<std::string>{}(desc.vendor);
auto it_by_vendor = kernel_creators_.find(vendor_hash);
if (it_by_vendor == kernel_creators_.end()) {
return nullptr;
}
auto arch_hash = std::hash<std::string>{}(desc.kernel_arch);
auto it_by_arch = it_by_vendor->second.find(arch_hash);
if (it_by_arch == it_by_vendor->second.end()) {
return nullptr;
}
kernel::CreateKernel creator = nullptr;
auto index = GetFuncIndex(desc);
if (index < 0 || index >= kKernelMaxNum) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.kernel_arch << ", data_type" << desc.data_type << ",op type "
<< desc.type << ", vendor: " << desc.vendor;
if (index >= kKernelMaxNum || index < 0) {
return nullptr;
}
return it_by_arch->second[index];
std::unique_lock<std::mutex> lock(lock_);
for (auto &&item : kernel_creators_) {
for (auto &&arch_item : item.second) {
creator = arch_item.second[index];
if (creator != nullptr) {
break;
}
}
if (creator != nullptr) {
break;
}
}
return creator;
}
int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
@ -172,6 +203,19 @@ KernelRegistry::~KernelRegistry() {
free(instance->creator_arrays_);
instance->creator_arrays_ = nullptr;
}
for (auto &&item : kernel_creators_) {
for (auto &&creator : item.second) {
free(creator.second);
creator.second = nullptr;
}
}
for (auto &&item : custom_kernel_creators_) {
for (auto &&creator : item.second) {
free(creator.second);
creator.second = nullptr;
}
}
}
bool KernelRegistry::SupportKernel(const KernelKey &key) {
@ -184,7 +228,7 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std
kernel::LiteKernel **kernel, const void *primitive) {
MS_ASSERT(ctx != nullptr);
MS_ASSERT(kernel != nullptr);
if (key.vendor == kBuiltin) {
if (key.provider == kBuiltin) {
auto creator = GetCreator(key);
if (creator != nullptr) {
*kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
@ -196,17 +240,18 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std
}
} else {
auto creator = GetDelegateCreator(key);
if (creator != nullptr) {
std::vector<tensor::MSTensor *> tensors_in;
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
std::vector<tensor::MSTensor *> tensors_out;
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
*kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
if (*kernel != nullptr) {
return RET_OK;
}
return RET_ERROR;
if (creator == nullptr) {
return RET_NOT_SUPPORT;
}
std::vector<tensor::MSTensor *> tensors_in;
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
std::vector<tensor::MSTensor *> tensors_out;
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
*kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
if (*kernel != nullptr) {
return RET_OK;
}
return RET_ERROR;
}
return RET_NOT_SUPPORT;
}

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_
#include <string>
#include <map>
#include <unordered_map>
#include <vector>
#include <set>
@ -42,9 +43,14 @@ class KernelRegistry {
virtual kernel::CreateKernel GetDelegateCreator(const kernel::KernelKey &desc);
int GetCreatorFuncIndex(kernel::KernelKey desc);
int GetFuncIndex(const kernel::KernelKey &desc);
const std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> &kernel_creators() {
return kernel_creators_;
}
int RegCustomKernel(const std::string &arch, const std::string &vendor, TypeId data_type, const std::string &type,
kernel::CreateKernel creator);
void RegKernel(kernel::KernelKey desc, kernel::KernelCreator creator);
void RegKernel(kernel::KERNEL_ARCH arch, TypeId data_type, int type, kernel::KernelCreator creator);
int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type,
int RegKernel(const std::string &arch, const std::string &vendor, TypeId data_type, int type,
kernel::CreateKernel creator);
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
bool SupportKernel(const kernel::KernelKey &key);
@ -58,8 +64,8 @@ class KernelRegistry {
static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
static const int array_size_{device_type_length_ * data_type_length_ * op_type_length_};
kernel::KernelCreator *creator_arrays_ = nullptr;
std::unordered_map<std::size_t, std::unordered_map<std::size_t, kernel::CreateKernel *>> kernel_creators_;
std::set<std::string> all_vendors_;
std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> kernel_creators_;
std::map<std::string, std::unordered_map<std::string, kernel::CreateKernel *>> custom_kernel_creators_;
private:
std::mutex lock_;

View File

@ -41,11 +41,11 @@ struct KernelKey {
TypeId data_type;
int type;
std::string kernel_arch;
std::string vendor{kBuiltin};
std::string provider{kBuiltin};
bool operator<(const KernelKey &dst) const {
if (vendor != dst.vendor) {
return vendor < dst.vendor;
if (provider != dst.provider) {
return provider < dst.provider;
} else if (kernel_arch != dst.kernel_arch) {
return kernel_arch < dst.kernel_arch;
} else if (arch != dst.arch) {

View File

@ -23,9 +23,14 @@ RegisterKernel *RegisterKernel::GetInstance() {
return &instance;
}
int RegisterKernel::RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type,
const int op_type, CreateKernel creator) {
return lite::KernelRegistry::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator);
int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type,
const std::string &type, CreateKernel creator) {
return lite::KernelRegistry::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
}
int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int op_type,
CreateKernel creator) {
return lite::KernelRegistry::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
}
} // namespace kernel
} // namespace mindspore

View File

@ -29,22 +29,30 @@ typedef kernel::LiteKernel *(*CreateKernel)(const std::vector<tensor::MSTensor *
class RegisterKernel {
public:
static RegisterKernel *GetInstance();
int RegKernel(const std::string &arch, const std::string &vendor, const TypeId data_type, const int type,
CreateKernel creator);
int RegKernel(const std::string &arch, const std::string &provider, TypeId data_type, int type, CreateKernel creator);
int RegCustomKernel(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &type,
CreateKernel creator);
};
class KernelReg {
public:
~KernelReg() = default;
KernelReg(const std::string &arch, const std::string &vendor, const TypeId data_type, const int op_type,
KernelReg(const std::string &arch, const std::string &provider, TypeId data_type, int op_type, CreateKernel creator) {
RegisterKernel::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
}
KernelReg(const std::string &arch, const std::string &provider, TypeId data_type, const std::string &op_type,
CreateKernel creator) {
RegisterKernel::GetInstance()->RegKernel(arch, vendor, data_type, op_type, creator);
RegisterKernel::GetInstance()->RegCustomKernel(arch, provider, data_type, op_type, creator);
}
};
#define REGISTER_KERNEL(arch, vendor, data_type, op_type, creator) \
static KernelReg g_##arch##vendor##data_type##op_type##kernelReg(arch, vendor, data_type, op_type, creator);
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
} // namespace kernel
} // namespace mindspore

View File

@ -14,14 +14,43 @@
* limitations under the License.
*/
#include "src/runtime/infer_manager.h"
#include <algorithm>
#include "src/common/prim_util.h"
#include "src/common/tensor_util.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "nnacl/errorcode.h"
#include "src/tensorlist.h"
#include "src/kernel_interface_registry.h"
#include "src/kernel_registry.h"
namespace mindspore {
namespace lite {
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive) {
std::vector<tensor::MSTensor *> in_tensors;
std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors));
std::vector<tensor::MSTensor *> out_tensors;
std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors));
int op_type = GetPrimitiveType(primitive);
for (auto &&item : KernelInterfaceRegistry::Instance()->kernel_interfaces()) {
auto provider = item.first;
auto kernel_interface = KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, op_type);
if (kernel_interface == nullptr) {
continue;
}
auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast<const schema::Primitive *>(primitive));
if (ret != RET_OK) {
MS_LOG(ERROR) << "provider: " << provider << ", op_type: " << PrimitiveTypeName(GetPrimitiveType(primitive))
<< " infer fail!";
return ret;
}
return RET_OK;
}
return RET_ERROR;
}
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
OpParameter *parameter) {
std::vector<TensorC *> in_tensors;

View File

@ -25,8 +25,10 @@
#include "nnacl/infer/infer.h"
namespace mindspore::lite {
int KernelInferShape(const std::vector<lite::Tensor *> &tensors_in, std::vector<lite::Tensor *> *outputs,
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
OpParameter *parameter);
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive);
class InferManager {
public:
static InferManager *GetInstance() {

View File

@ -48,6 +48,7 @@
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
#endif
#include "src/kernel_interface_registry.h"
namespace mindspore::lite {
using kernel::KERNEL_ARCH::kCPU;
@ -130,6 +131,10 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
FindNodeInoutTensors(*node, &inputs, &outputs);
if (KernelInterfaceRegistry::Instance()->CheckReg(node)) {
return KernelInferShape(inputs, outputs, node->primitive_);
}
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
auto parame_gen =
PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(node->primitive_), schema_version);
@ -432,6 +437,18 @@ int Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std:
return RET_NOT_SUPPORT;
}
int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) {
int ret = RET_ERROR;
if (KernelRegistry::GetInstance()->kernel_creators().size() != 0 &&
VersionManager::GetInstance()->GetSchemaVersion() != SCHEMA_V0) {
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), "", ""};
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
node->primitive_);
}
return ret;
}
kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, const Model::Node *node,
TypeId prefer_data_type) {
@ -439,14 +456,18 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
// why we need this
TypeId data_type =
(node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors);
kernel::LiteKernel *kernel = nullptr;
int status;
status = FindProviderKernel(in_tensors, out_tensors, node, data_type, &kernel);
if (status == RET_OK && kernel != nullptr) {
return kernel;
}
OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)];
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_));
return nullptr;
}
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
kernel::LiteKernel *kernel = nullptr;
int status;
#ifdef SUPPORT_GPU
// if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) {
status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);

View File

@ -60,6 +60,7 @@ class Scheduler {
kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, const Model::Node *node,
TypeId prefer_data_type = kTypeUnknown);
int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
kernel::LiteKernel **kernel);
@ -67,6 +68,9 @@ class Scheduler {
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
int FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
int FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel);
// schedule a partial node to a subgraph_kernel
kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node);
// schedule a node to a kernel

View File

@ -141,6 +141,8 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/tensorlist.cc
${LITE_DIR}/src/executor.cc
${LITE_DIR}/src/inner_context.cc
${LITE_DIR}/src/kernel_interface.cc
${LITE_DIR}/src/kernel_interface_registry.cc
${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/register_kernel.cc
${LITE_DIR}/src/lite_kernel.cc

View File

@ -123,6 +123,7 @@ set(LITE_SRC
${SRC_DIR}/tensor.cc
${SRC_DIR}/ms_tensor.cc
${SRC_DIR}/tensorlist.cc
${SRC_DIR}/kernel_interface_registry.cc
${SRC_DIR}/kernel_registry.cc
${SRC_DIR}/register_kernel.cc
${SRC_DIR}/lite_kernel.cc