!16495 [MS][LITE][DEVELOP]optimize custom kernel with shared_ptr

From: @jpc_chenjianping
Reviewed-by: @zhanghaibo5,@hangangqiang
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-05-19 17:13:39 +08:00 committed by Gitee
commit 612f6e0ffd
9 changed files with 101 additions and 85 deletions

View File

@ -158,13 +158,11 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std
std::vector<tensor::MSTensor *> tensors_out(out_tensors.begin(), out_tensors.end());
auto base_kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
if (base_kernel != nullptr) {
auto *lite_kernel = new (std::nothrow) kernel::LiteKernel(base_kernel);
auto *lite_kernel = new (std::nothrow) kernel::LiteKernel(base_kernel.get());
if (lite_kernel != nullptr) {
lite_kernel->set_desc(key);
*kernel = lite_kernel;
return RET_OK;
} else {
delete base_kernel;
}
}
return RET_ERROR;

View File

@ -38,8 +38,8 @@ bool RegisterKernelInterface::CheckReg(const lite::Model::Node *node, std::set<s
return lite::KernelInterfaceRegistry::Instance()->CheckReg(node, std::forward<std::set<std::string>>(providers));
}
KernelInterface *RegisterKernelInterface::GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive) {
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
return lite::KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
}
} // namespace kernel

View File

@ -14,12 +14,13 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
#ifndef MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_H_
#define MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_H_
#include <set>
#include <string>
#include <vector>
#include <memory>
#include "include/model.h"
#include "include/ms_tensor.h"
#include "schema/model_generated.h"
@ -44,7 +45,7 @@ class MS_API KernelInterface {
return 0;
}
};
typedef KernelInterface *(*KernelInterfaceCreator)();
MS_API typedef std::shared_ptr<KernelInterface> (*KernelInterfaceCreator)();
class MS_API RegisterKernelInterface {
public:
@ -52,7 +53,8 @@ class MS_API RegisterKernelInterface {
int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
bool CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers);
KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive);
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
virtual ~RegisterKernelInterface() = default;
private:
@ -72,11 +74,11 @@ class MS_API KernelInterfaceReg {
};
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
static KernelInterfaceReg g_##provider##op_type##_inter_reg(provider, op_type, creator);
static KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator);
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(provider, op_type, creator);
static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, creator);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
#endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_H_

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/registry/kernel_interface_registry.h"
#include <memory>
#include "src/registry/kernel_interface.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
@ -46,10 +47,7 @@ bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node, std::set<s
auto op_type = primitive->value_type();
if (op_type == schema::PrimitiveType_Custom) {
auto &&custom_type = GetCustomType(primitive);
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type, &providers](auto &&item) {
if (providers.find(item.first) == providers.end()) {
return false;
}
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type](auto &&item) {
if (item.second[custom_type] != nullptr) {
return true;
}
@ -57,13 +55,17 @@ bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node, std::set<s
});
}
return std::any_of(kernel_creators_.begin(), kernel_creators_.end(), [op_type, &mutex = this->mutex_](auto &&item) {
std::unique_lock<std::mutex> lock(mutex);
if (item.second[op_type] != nullptr) {
return true;
}
return false;
});
return std::any_of(kernel_creators_.begin(), kernel_creators_.end(),
[op_type, &providers, &mutex = this->mutex_](auto &&item) {
std::unique_lock<std::mutex> lock(mutex);
if (providers.find(item.first) == providers.end()) {
return false;
}
if (item.second[op_type] != nullptr) {
return true;
}
return false;
});
}
int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
@ -72,7 +74,8 @@ int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::s
return RET_OK;
}
kernel::KernelInterface *KernelInterfaceRegistry::GetCacheInterface(const std::string &provider, int op_type) {
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
int op_type) {
auto provider_iter = kernel_interfaces_.find(provider);
if (provider_iter != kernel_interfaces_.end()) {
auto kernel_iter = provider_iter->second.find(op_type);
@ -83,8 +86,8 @@ kernel::KernelInterface *KernelInterfaceRegistry::GetCacheInterface(const std::s
return nullptr;
}
kernel::KernelInterface *KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
const std::string &type) {
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
const std::string &type) {
auto provider_iter = custom_kernels_.find(provider);
if (provider_iter == custom_kernels_.end()) {
return nullptr;
@ -96,13 +99,13 @@ kernel::KernelInterface *KernelInterfaceRegistry::GetCustomCacheInterface(const
return nullptr;
}
kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive) {
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
const schema::Primitive *primitive) {
MS_ASSERT(primitive != nullptr);
int op_type = primitive->value_type();
std::unique_lock<std::mutex> lock(mutex_);
if (op_type == schema::PrimitiveType_Custom) {
auto &&type = GetCustomType(primitive);
auto &&type = GetCustomType(primitive);
for (auto &&item : custom_creators_) {
auto &&provider = item.first;
auto kernel = GetCustomCacheInterface(provider, type);
if (kernel != nullptr) {
return kernel;
@ -117,8 +120,19 @@ kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::
custom_kernels_[provider][type] = kernel;
return kernel;
}
return nullptr;
}
return nullptr;
}
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
MS_ASSERT(primitive != nullptr);
int op_type = primitive->value_type();
if (op_type == schema::PrimitiveType_Custom) {
return GetCustomKernelInterface(primitive);
}
std::unique_lock<std::mutex> lock(mutex_);
auto kernel = GetCacheInterface(provider, op_type);
if (kernel != nullptr) {
return kernel;
@ -164,18 +178,6 @@ KernelInterfaceRegistry::~KernelInterfaceRegistry() {
free(item.second);
item.second = nullptr;
}
for (auto &&i : kernel_interfaces_) {
for (auto &&j : i.second) {
delete (j.second);
j.second = nullptr;
}
}
for (auto &&i : custom_kernels_) {
for (auto &&j : i.second) {
delete (j.second);
j.second = nullptr;
}
}
}
} // namespace lite
} // namespace mindspore

View File

@ -14,11 +14,12 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
#ifndef MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_
#define MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_
#include <string>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include "src/registry/kernel_interface.h"
@ -33,25 +34,28 @@ class KernelInterfaceRegistry {
return &instance;
}
bool CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers);
kernel::KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive);
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator);
int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry();
private:
KernelInterfaceRegistry() = default;
kernel::KernelInterface *GetCacheInterface(const std::string &provider, int op_type);
kernel::KernelInterface *GetCustomCacheInterface(const std::string &provider, const std::string &type);
std::shared_ptr<kernel::KernelInterface> GetCacheInterface(const std::string &provider, int op_type);
std::shared_ptr<kernel::KernelInterface> GetCustomCacheInterface(const std::string &provider,
const std::string &type);
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive);
std::mutex mutex_;
// key: provider
std::map<std::string, kernel::KernelInterfaceCreator *> kernel_creators_;
std::map<std::string, std::map<int, kernel::KernelInterface *>> kernel_interfaces_;
std::map<std::string, std::map<int, std::shared_ptr<kernel::KernelInterface>>> kernel_interfaces_;
// key: provider key: custom type
std::map<std::string, std::map<std::string, kernel::KernelInterfaceCreator>> custom_creators_;
std::map<std::string, std::map<std::string, kernel::KernelInterface *>> custom_kernels_;
std::map<std::string, std::map<std::string, std::shared_ptr<kernel::KernelInterface>>> custom_kernels_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
#endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_

View File

@ -20,6 +20,7 @@
#include <set>
#include <string>
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "include/context.h"
#include "include/ms_tensor.h"
@ -46,9 +47,9 @@ struct MS_API KernelDesc {
}
};
typedef kernel::Kernel *(*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs,
const std::vector<tensor::MSTensor *> &outputs,
const schema::Primitive *primitive, const lite::Context *ctx);
typedef std::shared_ptr<kernel::Kernel> (*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs,
const std::vector<tensor::MSTensor *> &outputs,
const schema::Primitive *primitive, const lite::Context *ctx);
class MS_API RegisterKernel {
public:
static RegisterKernel *GetInstance();
@ -73,10 +74,10 @@ class MS_API KernelReg {
};
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, op_type, creator);
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, #op_type, creator);
} // namespace kernel
} // namespace mindspore

View File

@ -98,9 +98,8 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
return RET_OK;
}
kernel::CreateKernel RegistryKernelImpl::GetProviderCreator(const KernelDesc &desc,
kernel::CreateKernel RegistryKernelImpl::GetProviderCreator(const kernel::KernelDesc &desc,
const schema::Primitive *primitive) {
MS_ASSERT(primitive != nullptr);
kernel::CreateKernel creator = nullptr;
std::unique_lock<std::mutex> lock(lock_);
if (desc.type == schema::PrimitiveType_Custom) {
@ -111,19 +110,14 @@ kernel::CreateKernel RegistryKernelImpl::GetProviderCreator(const KernelDesc &de
auto param = primitive->value_as_Custom();
MS_ASSERT(param != nullptr);
auto custom_type = param->type()->str();
auto archs = custom_kernel_creators_[desc.provider];
if (desc.arch.empty()) {
for (auto &&providers : custom_kernel_creators_) {
auto archs = providers.second;
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
});
if (archs_iter != archs.end()) {
return archs_iter->second[custom_type][data_type_index];
}
} else {
auto find_arch_it = archs.find(desc.arch);
if (find_arch_it != archs.end()) {
return find_arch_it->second[custom_type][data_type_index];
}
}
return nullptr;

View File

@ -30,25 +30,31 @@ namespace mindspore {
namespace lite {
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const void *primitive, std::set<std::string> &&providers) {
std::vector<tensor::MSTensor *> in_tensors;
std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors));
std::vector<tensor::MSTensor *> out_tensors;
std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors));
for (auto &&provider : providers) {
auto kernel_interface = kernel::RegisterKernelInterface::Instance()->GetKernelInterface(
provider, static_cast<const schema::Primitive *>(primitive));
if (kernel_interface == nullptr) {
continue;
std::vector<tensor::MSTensor *> in_tensors(inputs.begin(), inputs.end());
std::vector<tensor::MSTensor *> out_tensors(outputs.begin(), outputs.end());
auto prim_type = GetPrimitiveType(primitive);
std::shared_ptr<kernel::KernelInterface> kernel_interface = nullptr;
if (prim_type == schema::PrimitiveType_Custom) {
kernel_interface = kernel::RegisterKernelInterface::Instance()->GetKernelInterface(
"", static_cast<const schema::Primitive *>(primitive));
} else {
for (auto &&provider : providers) {
kernel_interface = kernel::RegisterKernelInterface::Instance()->GetKernelInterface(
provider, static_cast<const schema::Primitive *>(primitive));
if (kernel_interface != nullptr) {
break;
}
}
auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast<const schema::Primitive *>(primitive));
if (ret != RET_OK) {
MS_LOG(ERROR) << "provider: " << provider << ", op_type: " << PrimitiveTypeName(GetPrimitiveType(primitive))
<< " infer fail!";
return ret;
}
return RET_OK;
}
if (kernel_interface == nullptr) {
MS_LOG(ERROR) << "Can't find kernel_interface!op_type: " << PrimitiveTypeName(prim_type);
return RET_ERROR;
}
auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast<const schema::Primitive *>(primitive));
if (ret != RET_OK) {
MS_LOG(ERROR) << "op_type: " << PrimitiveTypeName(prim_type) << " infer fail!";
return ret;
}
return RET_ERROR;
}

View File

@ -445,6 +445,16 @@ int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) {
MS_ASSERT(kernel != nullptr);
int ret = RET_NOT_SUPPORT;
auto prim_type = GetPrimitiveType(node->primitive_);
if (prim_type == schema::PrimitiveType_Custom) {
kernel::KernelKey desc{kCPU, data_type, prim_type, "", ""};
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
node->primitive_);
if (ret == RET_OK && *kernel != nullptr) {
return ret;
}
return RET_NOT_SUPPORT;
}
if (!context_->IsProviderEnabled()) {
return ret;
}
@ -453,8 +463,7 @@ int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const
}
for (auto &&device : context_->device_list_) {
if (!device.provider_.empty()) {
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), device.provider_device_,
device.provider_};
kernel::KernelKey desc{kCPU, data_type, prim_type, device.provider_device_, device.provider_};
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
node->primitive_);
if (ret == RET_OK && *kernel != nullptr) {