!16495 [MS][LITE][DEVELOP]optimize custom kernel with shared_ptr
From: @jpc_chenjianping Reviewed-by: @zhanghaibo5,@hangangqiang Signed-off-by: @hangangqiang
This commit is contained in:
commit
612f6e0ffd
|
@ -158,13 +158,11 @@ int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std
|
|||
std::vector<tensor::MSTensor *> tensors_out(out_tensors.begin(), out_tensors.end());
|
||||
auto base_kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
|
||||
if (base_kernel != nullptr) {
|
||||
auto *lite_kernel = new (std::nothrow) kernel::LiteKernel(base_kernel);
|
||||
auto *lite_kernel = new (std::nothrow) kernel::LiteKernel(base_kernel.get());
|
||||
if (lite_kernel != nullptr) {
|
||||
lite_kernel->set_desc(key);
|
||||
*kernel = lite_kernel;
|
||||
return RET_OK;
|
||||
} else {
|
||||
delete base_kernel;
|
||||
}
|
||||
}
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -38,8 +38,8 @@ bool RegisterKernelInterface::CheckReg(const lite::Model::Node *node, std::set<s
|
|||
return lite::KernelInterfaceRegistry::Instance()->CheckReg(node, std::forward<std::set<std::string>>(providers));
|
||||
}
|
||||
|
||||
KernelInterface *RegisterKernelInterface::GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive) {
|
||||
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
|
||||
const std::string &provider, const schema::Primitive *primitive) {
|
||||
return lite::KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -14,12 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
|
||||
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_H_
|
||||
#define MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/model.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
@ -44,7 +45,7 @@ class MS_API KernelInterface {
|
|||
return 0;
|
||||
}
|
||||
};
|
||||
typedef KernelInterface *(*KernelInterfaceCreator)();
|
||||
MS_API typedef std::shared_ptr<KernelInterface> (*KernelInterfaceCreator)();
|
||||
|
||||
class MS_API RegisterKernelInterface {
|
||||
public:
|
||||
|
@ -52,7 +53,8 @@ class MS_API RegisterKernelInterface {
|
|||
int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
|
||||
int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
||||
bool CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers);
|
||||
KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive);
|
||||
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive);
|
||||
virtual ~RegisterKernelInterface() = default;
|
||||
|
||||
private:
|
||||
|
@ -72,11 +74,11 @@ class MS_API KernelInterfaceReg {
|
|||
};
|
||||
|
||||
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
static KernelInterfaceReg g_##provider##op_type##_inter_reg(provider, op_type, creator);
|
||||
static KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator);
|
||||
|
||||
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(provider, op_type, creator);
|
||||
static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, creator);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_H_
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_H_
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
#include <memory>
|
||||
#include "src/registry/kernel_interface.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
@ -46,10 +47,7 @@ bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node, std::set<s
|
|||
auto op_type = primitive->value_type();
|
||||
if (op_type == schema::PrimitiveType_Custom) {
|
||||
auto &&custom_type = GetCustomType(primitive);
|
||||
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type, &providers](auto &&item) {
|
||||
if (providers.find(item.first) == providers.end()) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(custom_creators_.begin(), custom_creators_.end(), [&custom_type](auto &&item) {
|
||||
if (item.second[custom_type] != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
@ -57,13 +55,17 @@ bool KernelInterfaceRegistry::CheckReg(const lite::Model::Node *node, std::set<s
|
|||
});
|
||||
}
|
||||
|
||||
return std::any_of(kernel_creators_.begin(), kernel_creators_.end(), [op_type, &mutex = this->mutex_](auto &&item) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (item.second[op_type] != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return std::any_of(kernel_creators_.begin(), kernel_creators_.end(),
|
||||
[op_type, &providers, &mutex = this->mutex_](auto &&item) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (providers.find(item.first) == providers.end()) {
|
||||
return false;
|
||||
}
|
||||
if (item.second[op_type] != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
|
||||
|
@ -72,7 +74,8 @@ int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::s
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::KernelInterface *KernelInterfaceRegistry::GetCacheInterface(const std::string &provider, int op_type) {
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
|
||||
int op_type) {
|
||||
auto provider_iter = kernel_interfaces_.find(provider);
|
||||
if (provider_iter != kernel_interfaces_.end()) {
|
||||
auto kernel_iter = provider_iter->second.find(op_type);
|
||||
|
@ -83,8 +86,8 @@ kernel::KernelInterface *KernelInterfaceRegistry::GetCacheInterface(const std::s
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
kernel::KernelInterface *KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
|
||||
const std::string &type) {
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
|
||||
const std::string &type) {
|
||||
auto provider_iter = custom_kernels_.find(provider);
|
||||
if (provider_iter == custom_kernels_.end()) {
|
||||
return nullptr;
|
||||
|
@ -96,13 +99,13 @@ kernel::KernelInterface *KernelInterfaceRegistry::GetCustomCacheInterface(const
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive) {
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
|
||||
const schema::Primitive *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
int op_type = primitive->value_type();
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (op_type == schema::PrimitiveType_Custom) {
|
||||
auto &&type = GetCustomType(primitive);
|
||||
auto &&type = GetCustomType(primitive);
|
||||
for (auto &&item : custom_creators_) {
|
||||
auto &&provider = item.first;
|
||||
auto kernel = GetCustomCacheInterface(provider, type);
|
||||
if (kernel != nullptr) {
|
||||
return kernel;
|
||||
|
@ -117,8 +120,19 @@ kernel::KernelInterface *KernelInterfaceRegistry::GetKernelInterface(const std::
|
|||
custom_kernels_[provider][type] = kernel;
|
||||
return kernel;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(
|
||||
const std::string &provider, const schema::Primitive *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
int op_type = primitive->value_type();
|
||||
if (op_type == schema::PrimitiveType_Custom) {
|
||||
return GetCustomKernelInterface(primitive);
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto kernel = GetCacheInterface(provider, op_type);
|
||||
if (kernel != nullptr) {
|
||||
return kernel;
|
||||
|
@ -164,18 +178,6 @@ KernelInterfaceRegistry::~KernelInterfaceRegistry() {
|
|||
free(item.second);
|
||||
item.second = nullptr;
|
||||
}
|
||||
for (auto &&i : kernel_interfaces_) {
|
||||
for (auto &&j : i.second) {
|
||||
delete (j.second);
|
||||
j.second = nullptr;
|
||||
}
|
||||
}
|
||||
for (auto &&i : custom_kernels_) {
|
||||
for (auto &&j : i.second) {
|
||||
delete (j.second);
|
||||
j.second = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
|
||||
#define MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_
|
||||
#define MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include "src/registry/kernel_interface.h"
|
||||
|
@ -33,25 +34,28 @@ class KernelInterfaceRegistry {
|
|||
return &instance;
|
||||
}
|
||||
bool CheckReg(const lite::Model::Node *node, std::set<std::string> &&providers);
|
||||
kernel::KernelInterface *GetKernelInterface(const std::string &provider, const schema::Primitive *primitive);
|
||||
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive);
|
||||
int CustomReg(const std::string &provider, const std::string &op_type, kernel::KernelInterfaceCreator creator);
|
||||
int Reg(const std::string &provider, int op_type, kernel::KernelInterfaceCreator creator);
|
||||
virtual ~KernelInterfaceRegistry();
|
||||
|
||||
private:
|
||||
KernelInterfaceRegistry() = default;
|
||||
kernel::KernelInterface *GetCacheInterface(const std::string &provider, int op_type);
|
||||
kernel::KernelInterface *GetCustomCacheInterface(const std::string &provider, const std::string &type);
|
||||
std::shared_ptr<kernel::KernelInterface> GetCacheInterface(const std::string &provider, int op_type);
|
||||
std::shared_ptr<kernel::KernelInterface> GetCustomCacheInterface(const std::string &provider,
|
||||
const std::string &type);
|
||||
std::shared_ptr<kernel::KernelInterface> GetCustomKernelInterface(const schema::Primitive *primitive);
|
||||
|
||||
std::mutex mutex_;
|
||||
// key: provider
|
||||
std::map<std::string, kernel::KernelInterfaceCreator *> kernel_creators_;
|
||||
std::map<std::string, std::map<int, kernel::KernelInterface *>> kernel_interfaces_;
|
||||
std::map<std::string, std::map<int, std::shared_ptr<kernel::KernelInterface>>> kernel_interfaces_;
|
||||
// key: provider key: custom type
|
||||
std::map<std::string, std::map<std::string, kernel::KernelInterfaceCreator>> custom_creators_;
|
||||
std::map<std::string, std::map<std::string, kernel::KernelInterface *>> custom_kernels_;
|
||||
std::map<std::string, std::map<std::string, std::shared_ptr<kernel::KernelInterface>>> custom_kernels_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_KERNEL_DEV_DELEGATE_REGISTRY_H_
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/context.h"
|
||||
#include "include/ms_tensor.h"
|
||||
|
@ -46,9 +47,9 @@ struct MS_API KernelDesc {
|
|||
}
|
||||
};
|
||||
|
||||
typedef kernel::Kernel *(*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs,
|
||||
const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive, const lite::Context *ctx);
|
||||
typedef std::shared_ptr<kernel::Kernel> (*CreateKernel)(const std::vector<tensor::MSTensor *> &inputs,
|
||||
const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive, const lite::Context *ctx);
|
||||
class MS_API RegisterKernel {
|
||||
public:
|
||||
static RegisterKernel *GetInstance();
|
||||
|
@ -73,10 +74,10 @@ class MS_API KernelReg {
|
|||
};
|
||||
|
||||
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, op_type, creator);
|
||||
|
||||
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(arch, provider, data_type, op_type, creator);
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, #op_type, creator);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -98,9 +98,8 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::CreateKernel RegistryKernelImpl::GetProviderCreator(const KernelDesc &desc,
|
||||
kernel::CreateKernel RegistryKernelImpl::GetProviderCreator(const kernel::KernelDesc &desc,
|
||||
const schema::Primitive *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
kernel::CreateKernel creator = nullptr;
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (desc.type == schema::PrimitiveType_Custom) {
|
||||
|
@ -111,19 +110,14 @@ kernel::CreateKernel RegistryKernelImpl::GetProviderCreator(const KernelDesc &de
|
|||
auto param = primitive->value_as_Custom();
|
||||
MS_ASSERT(param != nullptr);
|
||||
auto custom_type = param->type()->str();
|
||||
auto archs = custom_kernel_creators_[desc.provider];
|
||||
if (desc.arch.empty()) {
|
||||
for (auto &&providers : custom_kernel_creators_) {
|
||||
auto archs = providers.second;
|
||||
auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
|
||||
return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
|
||||
});
|
||||
if (archs_iter != archs.end()) {
|
||||
return archs_iter->second[custom_type][data_type_index];
|
||||
}
|
||||
} else {
|
||||
auto find_arch_it = archs.find(desc.arch);
|
||||
if (find_arch_it != archs.end()) {
|
||||
return find_arch_it->second[custom_type][data_type_index];
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
|
|
@ -30,25 +30,31 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const void *primitive, std::set<std::string> &&providers) {
|
||||
std::vector<tensor::MSTensor *> in_tensors;
|
||||
std::copy(inputs.begin(), inputs.end(), std::back_inserter(in_tensors));
|
||||
std::vector<tensor::MSTensor *> out_tensors;
|
||||
std::copy(outputs.begin(), outputs.end(), std::back_inserter(out_tensors));
|
||||
for (auto &&provider : providers) {
|
||||
auto kernel_interface = kernel::RegisterKernelInterface::Instance()->GetKernelInterface(
|
||||
provider, static_cast<const schema::Primitive *>(primitive));
|
||||
if (kernel_interface == nullptr) {
|
||||
continue;
|
||||
std::vector<tensor::MSTensor *> in_tensors(inputs.begin(), inputs.end());
|
||||
std::vector<tensor::MSTensor *> out_tensors(outputs.begin(), outputs.end());
|
||||
auto prim_type = GetPrimitiveType(primitive);
|
||||
std::shared_ptr<kernel::KernelInterface> kernel_interface = nullptr;
|
||||
if (prim_type == schema::PrimitiveType_Custom) {
|
||||
kernel_interface = kernel::RegisterKernelInterface::Instance()->GetKernelInterface(
|
||||
"", static_cast<const schema::Primitive *>(primitive));
|
||||
} else {
|
||||
for (auto &&provider : providers) {
|
||||
kernel_interface = kernel::RegisterKernelInterface::Instance()->GetKernelInterface(
|
||||
provider, static_cast<const schema::Primitive *>(primitive));
|
||||
if (kernel_interface != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast<const schema::Primitive *>(primitive));
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "provider: " << provider << ", op_type: " << PrimitiveTypeName(GetPrimitiveType(primitive))
|
||||
<< " infer fail!";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (kernel_interface == nullptr) {
|
||||
MS_LOG(ERROR) << "Can't find kernel_interface!op_type: " << PrimitiveTypeName(prim_type);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = kernel_interface->Infer(in_tensors, out_tensors, static_cast<const schema::Primitive *>(primitive));
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "op_type: " << PrimitiveTypeName(prim_type) << " infer fail!";
|
||||
return ret;
|
||||
}
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
|
|
@ -445,6 +445,16 @@ int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const
|
|||
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel) {
|
||||
MS_ASSERT(kernel != nullptr);
|
||||
int ret = RET_NOT_SUPPORT;
|
||||
auto prim_type = GetPrimitiveType(node->primitive_);
|
||||
if (prim_type == schema::PrimitiveType_Custom) {
|
||||
kernel::KernelKey desc{kCPU, data_type, prim_type, "", ""};
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
|
||||
node->primitive_);
|
||||
if (ret == RET_OK && *kernel != nullptr) {
|
||||
return ret;
|
||||
}
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
if (!context_->IsProviderEnabled()) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -453,8 +463,7 @@ int Scheduler::FindProviderKernel(const std::vector<Tensor *> &in_tensors, const
|
|||
}
|
||||
for (auto &&device : context_->device_list_) {
|
||||
if (!device.provider_.empty()) {
|
||||
kernel::KernelKey desc{kCPU, data_type, GetPrimitiveType(node->primitive_), device.provider_device_,
|
||||
device.provider_};
|
||||
kernel::KernelKey desc{kCPU, data_type, prim_type, device.provider_device_, device.provider_};
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, nullptr, kernel,
|
||||
node->primitive_);
|
||||
if (ret == RET_OK && *kernel != nullptr) {
|
||||
|
|
Loading…
Reference in New Issue