change namespace mindspore::registry to mindspore::registry

This commit is contained in:
xuanyue 2021-08-13 16:05:13 +08:00
parent 62a780f20f
commit df3db5196b
27 changed files with 181 additions and 211 deletions

View File

@ -28,17 +28,19 @@ class CustomAddInfer : public kernel::KernelInterface {
CustomAddInfer() = default;
~CustomAddInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType());
auto ret = common::CheckInputs(*inputs);
if (ret != lite::RET_OK) {
if (ret == lite::RET_INFER_INVALID) {
(*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running.
return ret;
return kLiteInferInvalid;
} else if (ret != lite::RET_OK) {
return kLiteError;
}
(*outputs)[0].SetShape((*inputs)[0].Shape());
return lite::RET_OK;
return kSuccess;
}
};
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }

View File

@ -98,7 +98,7 @@ bool PassTutorial::Run(const FuncGraphPtr &func_graph) {
namespace lite {
// register customed Pass
using mindspore::lite::registry::POSITION_BEGIN;
using mindspore::registry::POSITION_BEGIN;
REG_PASS(PassTutorial, opt::PassTutorial)
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"})
} // namespace lite

View File

@ -28,17 +28,19 @@ class CustomAddInfer : public kernel::KernelInterface {
CustomAddInfer() = default;
~CustomAddInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType());
auto ret = common::CheckInputs(*inputs);
if (ret != lite::RET_OK) {
if (ret == lite::RET_INFER_INVALID) {
(*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running.
return ret;
return kLiteInferInvalid;
} else if (ret != lite::RET_OK) {
return kLiteError;
}
(*outputs)[0].SetShape((*inputs)[0].Shape());
return lite::RET_OK;
return kSuccess;
}
};
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }

View File

@ -60,13 +60,13 @@ class CustomAddKernel : public Kernel {
// if output shape exists value -1, need to be inferred before applying memory for output tensor.
int PreProcess() {
if (common::CheckOutputs(outputs_) != lite::RET_OK) {
auto ret = lite::registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)
->Infer(&inputs_, &outputs_, primitive_);
if (ret != lite::RET_OK) {
auto status =
registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_);
if (status != kSuccess) {
std::cerr << "infer failed." << std::endl;
return lite::RET_ERROR;
}
ret = ReSize();
auto ret = ReSize();
if (ret != lite::RET_OK) {
std::cerr << "resize failed." << std::endl;
return ret;

View File

@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include "include/api/types.h"
#include "include/api/status.h"
#include "include/lite_utils.h"
#include "schema/model_generated.h"
@ -36,10 +37,10 @@ class MS_API KernelInterface {
/// \param[in] outputs Define the output tensors of op.
/// \param[in] primitive Define the attributes of op.
///
/// \return STATUS as an error code of inferring, STATUS is defined in errorcode.h..
virtual int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) {
return 0;
/// \return Status as a status identification of inferring.
virtual Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) {
return kSuccess;
}
};
} // namespace kernel

View File

@ -23,7 +23,6 @@
using mindspore::converter::FmkType;
namespace mindspore {
namespace lite {
namespace registry {
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
typedef converter::ModelParser *(*ModelParserCreator)();
@ -53,9 +52,8 @@ class MS_API ModelParserRegistry {
/// \param[in] fmk Define identification of a certain framework.
/// \param[in] parserCreator Define function pointer of creating ModelParser.
#define REG_MODEL_PARSER(fmk, parserCreator) \
static mindspore::lite::registry::ModelParserRegistry g_##type##fmk##ModelParserReg(fmk, parserCreator);
static mindspore::registry::ModelParserRegistry g_##type##fmk##ModelParserReg(fmk, parserCreator);
} // namespace registry
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H

View File

@ -32,7 +32,6 @@ class MS_API Pass;
using PassPtr = std::shared_ptr<Pass>;
} // namespace opt
namespace lite {
namespace registry {
/// \brief PassPosition defined where to plae user's pass.
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
@ -42,35 +41,48 @@ class MS_API PassRegistry {
public:
/// \brief Constructor of PassRegistry to register pass.
///
/// \param[in] pos Define where to replace the pass.
/// \param[in] pass Define user's defined pass.
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
/// \param[in] pass Define pass instance.
PassRegistry(const std::string &pass_name, const opt::PassPtr &pass);
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
///
/// \param[in position Define the place where assigned passes will run.
/// \param[in] assigned Define the name of passes assigned by user.
/// \param[in] position Define the place where assigned passes will run.
/// \param[in] assigned Define the names of the passes.
PassRegistry(PassPosition position, const std::vector<std::string> &assigned);
/// \brief Destructor of PassRegistrar.
~PassRegistry() = default;
/// \brief Static method to obtain external scheduling task assigned by user.
///
/// \param[in] position Define the place where assigned passes will run.
///
/// \return Passes' Name Vector.
static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
/// \brief Static method to obtain pass instance according to passes' name.
///
/// \param[in] pass_names Define the name of passes.
///
/// \return Pass Instance Vector.
static std::vector<opt::PassPtr> GetPassFromStoreRoom(const std::vector<std::string> &pass_names);
};
/// \brief Defined registering macro to register Pass, which called by user directly.
///
/// \param[in] name Define name of user's pass, which is a string.
/// \param[in] pass Define user's defined pass.
/// \param[in] name Define the name of the pass, a string which should guarantee uniqueness.
/// \param[in] pass Define pass instance.
#define REG_PASS(name, pass) \
static mindspore::lite::registry::PassRegistry g_##name##PassReg(#name, std::make_shared<pass>());
static mindspore::registry::PassRegistry g_##name##PassReg(#name, std::make_shared<pass>());
/// \brief Defined assigning macro to assign Passes, which called by user directly.
///
/// \param[in] position Define the place where assigned passes will run.
/// \param[in] assigned Define the name of passes assigned by user.
/// \param[in] assigned Define the names of the passes.
#define REG_SCHEDULED_PASS(position, assigned) \
static mindspore::lite::registry::PassRegistry g_##position(position, assigned);
static mindspore::registry::PassRegistry g_##position(position, assigned);
} // namespace registry
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_REGISTRY_H_

View File

@ -26,9 +26,9 @@
#include "include/api/types.h"
#include "include/api/kernel.h"
#include "include/api/data_type.h"
#include "include/api/status.h"
namespace mindspore {
namespace lite {
namespace registry {
/// \brief KernelDesc defined kernel's basic attribute.
struct KernelDesc {
@ -61,9 +61,9 @@ class MS_API RegisterKernel {
/// \param[in] type Define the ordinary op type.
/// \param[in] creator Define a function pointer to create a kernel.
///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
static int RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator);
/// \return Status as a status identification of registering.
static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator);
/// \brief Static method to register kernel which is corresponding to custom op.
///
@ -73,9 +73,9 @@ class MS_API RegisterKernel {
/// \param[in] type Define the concrete type of a custom op.
/// \param[in] creator Define a function pointer to create a kernel.
///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
static int RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator);
/// \return Status as a status identification of registering.
static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator);
/// \brief Static methon to get a kernel's create function.
///
@ -124,11 +124,10 @@ class MS_API KernelReg {
/// \param[in] data_type Define kernel's input data type.
/// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define a function pointer to create a kernel.
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \
static mindspore::lite::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, \
data_type, op_type, \
creator); \
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \
static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
op_type, creator); \
} // namespace
/// \brief Defined registering macro to register custom op kernel, which called by user directly.
@ -138,14 +137,12 @@ class MS_API KernelReg {
/// \param[in] data_type Define kernel's input data type.
/// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define a function pointer to create a kernel.
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \
static mindspore::lite::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, \
data_type, #op_type, \
creator); \
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \
static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
#op_type, creator); \
} // namespace
} // namespace registry
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_

View File

@ -25,7 +25,6 @@
#include "schema/model_generated.h"
namespace mindspore {
namespace lite {
namespace registry {
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
@ -39,8 +38,8 @@ class MS_API RegisterKernelInterface {
/// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define the KernelInterface create function.
///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
static int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
/// \return Status as a status identification of registering.
static Status CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
/// \brief Static method to register op whose primitive type is ordinary.
///
@ -48,8 +47,8 @@ class MS_API RegisterKernelInterface {
/// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define the KernelInterface create function.
///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
static int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
/// \return Status as a status identification of registering.
static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
/// \brief Static method to get registration of a certain op.
///
@ -88,9 +87,9 @@ class MS_API KernelInterfaceReg {
/// \param[in] provider Define the identification of user.
/// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define the KernelInterface create function.
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \
static mindspore::lite::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \
static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
} // namespace
/// \brief Defined registering macro to register custom op, which called by user directly.
@ -98,13 +97,12 @@ class MS_API KernelInterfaceReg {
/// \param[in] provider Define the identification of user.
/// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define the KernelInterface create function.
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \
static mindspore::lite::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
creator); \
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \
static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
creator); \
} // namespace
} // namespace registry
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_KERNEL_INTERFACE_H_

View File

@ -39,8 +39,8 @@ using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelCreator;
using mindspore::kernel::KernelKey;
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
using mindspore::lite::registry::CreateKernel;
using mindspore::lite::registry::KernelDesc;
using mindspore::registry::CreateKernel;
using mindspore::registry::KernelDesc;
#endif
namespace mindspore::lite {

View File

@ -21,11 +21,11 @@
#include "src/common/version_manager.h"
#include "schema/model_generated.h"
using mindspore::lite::registry::KernelInterfaceCreator;
using mindspore::registry::KernelInterfaceCreator;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore {
namespace lite {
namespace registry {
namespace {
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN;
std::string GetCustomType(const schema::Primitive *primitive) {
@ -35,10 +35,10 @@ std::string GetCustomType(const schema::Primitive *primitive) {
}
} // namespace
int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
KernelInterfaceCreator creator) {
Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
KernelInterfaceCreator creator) {
custom_creators_[provider][type] = creator;
return RET_OK;
return kSuccess;
}
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
@ -124,10 +124,10 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInter
return nullptr;
}
int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
return RET_ERROR;
return kLiteError;
}
std::unique_lock<std::mutex> lock(mutex_);
@ -137,12 +137,12 @@ int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Kerne
reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
if (kernel_creators_[provider] == nullptr) {
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
return RET_ERROR;
return kLiteError;
}
}
kernel_creators_[provider][op_type] = creator;
return RET_OK;
return kSuccess;
}
KernelInterfaceRegistry::~KernelInterfaceRegistry() {
@ -151,5 +151,5 @@ KernelInterfaceRegistry::~KernelInterfaceRegistry() {
item.second = nullptr;
}
}
} // namespace lite
} // namespace registry
} // namespace mindspore

View File

@ -26,7 +26,7 @@
#include "include/model.h"
namespace mindspore {
namespace lite {
namespace registry {
class KernelInterfaceRegistry {
public:
static KernelInterfaceRegistry *Instance() {
@ -36,8 +36,8 @@ class KernelInterfaceRegistry {
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
int CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator);
int Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator);
Status CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator);
Status Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry();
private:
@ -55,7 +55,7 @@ class KernelInterfaceRegistry {
std::map<std::string, std::map<std::string, registry::KernelInterfaceCreator>> custom_creators_;
std::map<std::string, std::map<std::string, std::shared_ptr<kernel::KernelInterface>>> custom_kernels_;
};
} // namespace lite
} // namespace registry
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_

View File

@ -21,22 +21,21 @@
#include "src/registry/register_kernel_impl.h"
namespace mindspore {
namespace lite {
namespace registry {
int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
return RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT;
#endif
}
int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
CreateKernel creator) {
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
CreateKernel creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
return RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT;
@ -45,12 +44,11 @@ int RegisterKernel::RegKernel(const std::string &arch, const std::string &provid
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
#else
MS_LOG(ERROR) << unsuppor_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT;
#endif
}
} // namespace registry
} // namespace lite
} // namespace mindspore

View File

@ -18,12 +18,12 @@
#include "include/errorcode.h"
#include "src/common/version_manager.h"
#include "src/common/log_adapter.h"
using mindspore::lite::registry::CreateKernel;
using mindspore::lite::registry::KernelDesc;
using mindspore::registry::CreateKernel;
using mindspore::registry::KernelDesc;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::lite {
namespace mindspore::registry {
namespace {
static const auto kKernelMaxNum =
(static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1) *
@ -44,11 +44,11 @@ int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) {
return data_type_index * kOpTypeLen + desc.type;
}
int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
if (data_type >= DataType::kNumberTypeEnd) {
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
return RET_ERROR;
return kLiteError;
}
std::unique_lock<std::mutex> lock(lock_);
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
@ -56,28 +56,28 @@ int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::stri
reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel)));
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
return RET_ERROR;
return kLiteError;
}
}
int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
return RET_ERROR;
return kLiteError;
}
custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
return RET_OK;
return kSuccess;
}
int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator) {
Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator) {
std::unique_lock<std::mutex> lock(lock_);
auto iter = kernel_creators_.find(provider);
if (iter == kernel_creators_.end()) {
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
if (kernel_creators_[provider][arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
return RET_ERROR;
return kLiteError;
}
} else {
auto iter_arch = iter->second.find(arch);
@ -85,7 +85,7 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
if (iter->second[arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
return RET_ERROR;
return kLiteError;
}
}
}
@ -95,11 +95,11 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
if (index >= kKernelMaxNum || index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
<< type;
return RET_ERROR;
return kLiteError;
}
kernel_creators_[provider][arch][index] = creator;
return RET_OK;
return kSuccess;
}
registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive,
@ -179,4 +179,4 @@ RegistryKernelImpl::~RegistryKernelImpl() {
}
}
}
} // namespace mindspore::lite
} // namespace mindspore::registry

View File

@ -25,7 +25,7 @@
#include <set>
#include "include/registry/register_kernel.h"
namespace mindspore::lite {
namespace mindspore::registry {
class RegistryKernelImpl {
public:
RegistryKernelImpl() = default;
@ -36,11 +36,11 @@ class RegistryKernelImpl {
return &instance;
}
int RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, const std::string &type,
registry::CreateKernel creator);
Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, registry::CreateKernel creator);
int RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator);
Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator);
virtual registry::CreateKernel GetProviderCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
@ -60,6 +60,6 @@ class RegistryKernelImpl {
registry::CreateKernel GetCustomKernelCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
int GetFuncIndex(const registry::KernelDesc &desc);
};
} // namespace mindspore::lite
} // namespace mindspore::registry
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_IMPL_H_

View File

@ -21,21 +21,20 @@
#include "src/registry/kernel_interface_registry.h"
namespace mindspore {
namespace lite {
namespace registry {
int RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
return KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT;
#endif
}
int RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) {
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
return KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT;
@ -45,12 +44,11 @@ int RegisterKernelInterface::CustomReg(const std::string &provider, const std::s
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr;
#endif
}
} // namespace registry
} // namespace lite
} // namespace mindspore

View File

@ -60,9 +60,12 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
std::transform(outputs.begin(), outputs.end(), std::back_inserter(out_tensors),
[](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared<MSTensor::Impl>(tensor)); });
auto ret = kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive));
if (ret != RET_OK) {
if (ret == kLiteInferInvalid) {
return RET_INFER_INVALID;
}
if (ret != kSuccess) {
MS_LOG(ERROR) << "op_type: " << PrimitiveTypeName(prim_type) << " infer fail!ret: " << ret;
return ret;
return RET_ERROR;
}
return RET_OK;
}

View File

@ -111,12 +111,12 @@ class TestCustomOpInfer : public KernelInterface {
public:
TestCustomOpInfer() = default;
~TestCustomOpInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType());
(*outputs)[0].SetShape((*inputs)[0].Shape());
return RET_OK;
return kSuccess;
}
};

View File

@ -80,12 +80,12 @@ class TestCustomAddInfer : public KernelInterface {
public:
TestCustomAddInfer() = default;
~TestCustomAddInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType());
(*outputs)[0].SetShape((*inputs)[0].Shape());
return RET_OK;
return kSuccess;
}
};

View File

@ -35,7 +35,7 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
ASSERT_NE(proposal_parser, nullptr);
REG_MODEL_PARSER(FmkType_CAFFE,
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
auto model_parser = lite::registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
ASSERT_NE(model_parser, nullptr);
ConverterParameters converter_parameters;
auto func_graph = model_parser->Parse(converter_parameters);

View File

@ -25,20 +25,19 @@
#include "ops/addn.h"
#include "ops/custom.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/registry/pass_content.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h"
using mindspore::converter::ConverterParameters;
using mindspore::converter::FmkType_CAFFE;
using mindspore::lite::registry::POSITION_BEGIN;
using mindspore::registry::POSITION_BEGIN;
namespace mindspore {
class PassRegistryTest : public mindspore::CommonTest {
public:
PassRegistryTest() = default;
void SetUp() override {
REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator);
auto model_parser = lite::registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
if (model_parser == nullptr) {
return;
}
@ -214,12 +213,11 @@ REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"})
} // namespace opt
TEST_F(PassRegistryTest, TestRegistry) {
auto &passes = lite::PassStoreRoomInfo();
auto &assigned_passes = lite::ExternalAssignedPassesInfo();
ASSERT_EQ(assigned_passes.size(), 1);
auto pass_names = assigned_passes[POSITION_BEGIN];
ASSERT_EQ(pass_names.size(), 1);
auto begin_pass = passes[pass_names.front()];
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
ASSERT_EQ(schedule_task.size(), 1);
auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
ASSERT_EQ(passes.size(), 1);
auto begin_pass = passes.front();
ASSERT_NE(begin_pass, nullptr);
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
ASSERT_NE(begin_pass_test, nullptr);

View File

@ -326,7 +326,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return nullptr;
}
if (!RunExternalPass(old_graph, registry::PassPosition::POSITION_BEGIN)) {
if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
return nullptr;
}
@ -361,7 +361,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return nullptr;
}
if (!RunExternalPass(old_graph, registry::PassPosition::POSITION_END)) {
if (!RunExternalPass(old_graph, registry::POSITION_END)) {
MS_LOG(ERROR) << "Run external pass failed, place is END";
return nullptr;
}

View File

@ -18,25 +18,26 @@
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/registry/pass_content.h"
namespace mindspore {
namespace lite {
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names) {
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is nullptr.";
return false;
}
auto &passes_info = PassStoreRoomInfo();
for (auto &name : pass_names) {
if (passes_info.find(name) == passes_info.end()) {
MS_LOG(ERROR) << "cannot find required pass.";
return false;
}
if (!passes_info[name]->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
auto schedule_passes = registry::PassRegistry::GetPassFromStoreRoom(pass_names);
if (schedule_passes.size() != pass_names.size()) {
MS_LOG(ERROR) << "exited pass cannot be obtained.";
return false;
}
int index = 0;
for (auto &pass : schedule_passes) {
if (!pass->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << pass_names[index];
return false;
}
++index;
}
return true;
}
@ -46,21 +47,10 @@ bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition posi
MS_LOG(ERROR) << "func graph is nullptr.";
return false;
}
auto &external_assigned = ExternalAssignedPassesInfo();
if (external_assigned.find(position) == external_assigned.end()) {
MS_LOG(DEBUG) << "there is no external pass in current position, position is " << position;
return true;
}
auto &passes_info = PassStoreRoomInfo();
for (auto &name : external_assigned[position]) {
if (passes_info.find(name) == passes_info.end()) {
MS_LOG(ERROR) << "cannot find required pass.";
return false;
}
if (!passes_info[name]->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
return false;
}
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(position);
if (!RunOptimizerPass(func_graph, schedule_task)) {
MS_LOG(ERROR) << "run external scheduled task failed.";
return false;
}
return true;
}

View File

@ -24,7 +24,7 @@
namespace mindspore {
namespace lite {
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names);
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names);
bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition position);
} // namespace lite
} // namespace mindspore

View File

@ -19,7 +19,6 @@
#include "src/common/log_adapter.h"
namespace mindspore {
namespace lite {
namespace registry {
namespace {
std::map<FmkType, ModelParserCreator> model_parser_room;
@ -42,5 +41,4 @@ converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
return nullptr;
}
} // namespace registry
} // namespace lite
} // namespace mindspore

View File

@ -1,32 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
#include <map>
#include <string>
#include <vector>
#include "include/registry/pass_registry.h"
namespace mindspore {
namespace lite {
std::map<std::string, opt::PassPtr> &MS_API PassStoreRoomInfo();
std::map<registry::PassPosition, std::vector<std::string>> &MS_API ExternalAssignedPassesInfo();
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H

View File

@ -19,11 +19,10 @@
#include <mutex>
#include <string>
#include <vector>
#include "tools/converter/registry/pass_content.h"
#include "src/common/log_adapter.h"
namespace mindspore {
namespace lite {
namespace registry {
namespace {
std::map<std::string, opt::PassPtr> pass_store_room;
std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
@ -38,19 +37,27 @@ void RegPass(const std::string &pass_name, const opt::PassPtr &pass) {
}
} // namespace
registry::PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) {
RegPass(pass_name, pass);
}
PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
registry::PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
std::unique_lock<std::mutex> lock(pass_mutex);
external_assigned_passes[position] = assigned;
}
std::map<std::string, opt::PassPtr> &PassStoreRoomInfo() { return pass_store_room; }
std::map<registry::PassPosition, std::vector<std::string>> &ExternalAssignedPassesInfo() {
return external_assigned_passes;
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
return external_assigned_passes[position];
}
} // namespace lite
std::vector<opt::PassPtr> PassRegistry::GetPassFromStoreRoom(const std::vector<std::string> &pass_names) {
std::vector<opt::PassPtr> schedule_passes;
for (auto &name : pass_names) {
auto iter = pass_store_room.find(name);
if (iter == pass_store_room.end()) {
continue;
}
schedule_passes.push_back(iter->second);
}
return schedule_passes;
}
} // namespace registry
} // namespace mindspore