change namespace mindspore::registry to mindspore::registry

This commit is contained in:
xuanyue 2021-08-13 16:05:13 +08:00
parent 62a780f20f
commit df3db5196b
27 changed files with 181 additions and 211 deletions

View File

@ -28,17 +28,19 @@ class CustomAddInfer : public kernel::KernelInterface {
CustomAddInfer() = default; CustomAddInfer() = default;
~CustomAddInfer() = default; ~CustomAddInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override { const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format()); (*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType()); (*outputs)[0].SetDataType((*inputs)[0].DataType());
auto ret = common::CheckInputs(*inputs); auto ret = common::CheckInputs(*inputs);
if (ret != lite::RET_OK) { if (ret == lite::RET_INFER_INVALID) {
(*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running. (*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running.
return ret; return kLiteInferInvalid;
} else if (ret != lite::RET_OK) {
return kLiteError;
} }
(*outputs)[0].SetShape((*inputs)[0].Shape()); (*outputs)[0].SetShape((*inputs)[0].Shape());
return lite::RET_OK; return kSuccess;
} }
}; };
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); } std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }

View File

@ -98,7 +98,7 @@ bool PassTutorial::Run(const FuncGraphPtr &func_graph) {
namespace lite { namespace lite {
// register customed Pass // register customed Pass
using mindspore::lite::registry::POSITION_BEGIN; using mindspore::registry::POSITION_BEGIN;
REG_PASS(PassTutorial, opt::PassTutorial) REG_PASS(PassTutorial, opt::PassTutorial)
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"}) REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"})
} // namespace lite } // namespace lite

View File

@ -28,17 +28,19 @@ class CustomAddInfer : public kernel::KernelInterface {
CustomAddInfer() = default; CustomAddInfer() = default;
~CustomAddInfer() = default; ~CustomAddInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override { const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format()); (*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType()); (*outputs)[0].SetDataType((*inputs)[0].DataType());
auto ret = common::CheckInputs(*inputs); auto ret = common::CheckInputs(*inputs);
if (ret != lite::RET_OK) { if (ret == lite::RET_INFER_INVALID) {
(*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running. (*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running.
return ret; return kLiteInferInvalid;
} else if (ret != lite::RET_OK) {
return kLiteError;
} }
(*outputs)[0].SetShape((*inputs)[0].Shape()); (*outputs)[0].SetShape((*inputs)[0].Shape());
return lite::RET_OK; return kSuccess;
} }
}; };
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); } std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }

View File

@ -60,13 +60,13 @@ class CustomAddKernel : public Kernel {
// if output shape exists value -1, need to be inferred before applying memory for output tensor. // if output shape exists value -1, need to be inferred before applying memory for output tensor.
int PreProcess() { int PreProcess() {
if (common::CheckOutputs(outputs_) != lite::RET_OK) { if (common::CheckOutputs(outputs_) != lite::RET_OK) {
auto ret = lite::registry::RegisterKernelInterface::GetKernelInterface({}, primitive_) auto status =
->Infer(&inputs_, &outputs_, primitive_); registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_);
if (ret != lite::RET_OK) { if (status != kSuccess) {
std::cerr << "infer failed." << std::endl; std::cerr << "infer failed." << std::endl;
return lite::RET_ERROR; return lite::RET_ERROR;
} }
ret = ReSize(); auto ret = ReSize();
if (ret != lite::RET_OK) { if (ret != lite::RET_OK) {
std::cerr << "resize failed." << std::endl; std::cerr << "resize failed." << std::endl;
return ret; return ret;

View File

@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/status.h"
#include "include/lite_utils.h" #include "include/lite_utils.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
@ -36,10 +37,10 @@ class MS_API KernelInterface {
/// \param[in] outputs Define the output tensors of op. /// \param[in] outputs Define the output tensors of op.
/// \param[in] primitive Define the attributes of op. /// \param[in] primitive Define the attributes of op.
/// ///
/// \return STATUS as an error code of inferring, STATUS is defined in errorcode.h.. /// \return Status as a status identification of inferring.
virtual int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, virtual Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) { const schema::Primitive *primitive) {
return 0; return kSuccess;
} }
}; };
} // namespace kernel } // namespace kernel

View File

@ -23,7 +23,6 @@
using mindspore::converter::FmkType; using mindspore::converter::FmkType;
namespace mindspore { namespace mindspore {
namespace lite {
namespace registry { namespace registry {
/// \brief ModelParserCreator defined function pointer to get a ModelParser class. /// \brief ModelParserCreator defined function pointer to get a ModelParser class.
typedef converter::ModelParser *(*ModelParserCreator)(); typedef converter::ModelParser *(*ModelParserCreator)();
@ -53,9 +52,8 @@ class MS_API ModelParserRegistry {
/// \param[in] fmk Define identification of a certain framework. /// \param[in] fmk Define identification of a certain framework.
/// \param[in] parserCreator Define function pointer of creating ModelParser. /// \param[in] parserCreator Define function pointer of creating ModelParser.
#define REG_MODEL_PARSER(fmk, parserCreator) \ #define REG_MODEL_PARSER(fmk, parserCreator) \
static mindspore::lite::registry::ModelParserRegistry g_##type##fmk##ModelParserReg(fmk, parserCreator); static mindspore::registry::ModelParserRegistry g_##type##fmk##ModelParserReg(fmk, parserCreator);
} // namespace registry } // namespace registry
} // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H

View File

@ -32,7 +32,6 @@ class MS_API Pass;
using PassPtr = std::shared_ptr<Pass>; using PassPtr = std::shared_ptr<Pass>;
} // namespace opt } // namespace opt
namespace lite {
namespace registry { namespace registry {
/// \brief PassPosition defined where to plae user's pass. /// \brief PassPosition defined where to plae user's pass.
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 }; enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
@ -42,35 +41,48 @@ class MS_API PassRegistry {
public: public:
/// \brief Constructor of PassRegistry to register pass. /// \brief Constructor of PassRegistry to register pass.
/// ///
/// \param[in] pos Define where to replace the pass. /// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
/// \param[in] pass Define user's defined pass. /// \param[in] pass Define pass instance.
PassRegistry(const std::string &pass_name, const opt::PassPtr &pass); PassRegistry(const std::string &pass_name, const opt::PassPtr &pass);
/// \brief Constructor of PassRegistry to assign which passes are required for external extension. /// \brief Constructor of PassRegistry to assign which passes are required for external extension.
/// ///
/// \param[in position Define the place where assigned passes will run. /// \param[in] position Define the place where assigned passes will run.
/// \param[in] assigned Define the name of passes assigned by user. /// \param[in] assigned Define the names of the passes.
PassRegistry(PassPosition position, const std::vector<std::string> &assigned); PassRegistry(PassPosition position, const std::vector<std::string> &assigned);
/// \brief Destructor of PassRegistrar. /// \brief Destructor of PassRegistrar.
~PassRegistry() = default; ~PassRegistry() = default;
/// \brief Static method to obtain external scheduling task assigned by user.
///
/// \param[in] position Define the place where assigned passes will run.
///
/// \return Passes' Name Vector.
static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
/// \brief Static method to obtain pass instance according to passes' name.
///
/// \param[in] pass_names Define the name of passes.
///
/// \return Pass Instance Vector.
static std::vector<opt::PassPtr> GetPassFromStoreRoom(const std::vector<std::string> &pass_names);
}; };
/// \brief Defined registering macro to register Pass, which called by user directly. /// \brief Defined registering macro to register Pass, which called by user directly.
/// ///
/// \param[in] name Define name of user's pass, which is a string. /// \param[in] name Define the name of the pass, a string which should guarantee uniqueness.
/// \param[in] pass Define user's defined pass. /// \param[in] pass Define pass instance.
#define REG_PASS(name, pass) \ #define REG_PASS(name, pass) \
static mindspore::lite::registry::PassRegistry g_##name##PassReg(#name, std::make_shared<pass>()); static mindspore::registry::PassRegistry g_##name##PassReg(#name, std::make_shared<pass>());
/// \brief Defined assigning macro to assign Passes, which called by user directly. /// \brief Defined assigning macro to assign Passes, which called by user directly.
/// ///
/// \param[in] position Define the place where assigned passes will run. /// \param[in] position Define the place where assigned passes will run.
/// \param[in] assigned Define the name of passes assigned by user. /// \param[in] assigned Define the names of the passes.
#define REG_SCHEDULED_PASS(position, assigned) \ #define REG_SCHEDULED_PASS(position, assigned) \
static mindspore::lite::registry::PassRegistry g_##position(position, assigned); static mindspore::registry::PassRegistry g_##position(position, assigned);
} // namespace registry } // namespace registry
} // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_REGISTRY_H_ #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_REGISTRY_H_

View File

@ -26,9 +26,9 @@
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/kernel.h" #include "include/api/kernel.h"
#include "include/api/data_type.h" #include "include/api/data_type.h"
#include "include/api/status.h"
namespace mindspore { namespace mindspore {
namespace lite {
namespace registry { namespace registry {
/// \brief KernelDesc defined kernel's basic attribute. /// \brief KernelDesc defined kernel's basic attribute.
struct KernelDesc { struct KernelDesc {
@ -61,9 +61,9 @@ class MS_API RegisterKernel {
/// \param[in] type Define the ordinary op type. /// \param[in] type Define the ordinary op type.
/// \param[in] creator Define a function pointer to create a kernel. /// \param[in] creator Define a function pointer to create a kernel.
/// ///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h. /// \return Status as a status identification of registering.
static int RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator); CreateKernel creator);
/// \brief Static method to register kernel which is corresponding to custom op. /// \brief Static method to register kernel which is corresponding to custom op.
/// ///
@ -73,9 +73,9 @@ class MS_API RegisterKernel {
/// \param[in] type Define the concrete type of a custom op. /// \param[in] type Define the concrete type of a custom op.
/// \param[in] creator Define a function pointer to create a kernel. /// \param[in] creator Define a function pointer to create a kernel.
/// ///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h. /// \return Status as a status identification of registering.
static int RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator); const std::string &type, CreateKernel creator);
/// \brief Static methon to get a kernel's create function. /// \brief Static methon to get a kernel's create function.
/// ///
@ -124,11 +124,10 @@ class MS_API KernelReg {
/// \param[in] data_type Define kernel's input data type. /// \param[in] data_type Define kernel's input data type.
/// \param[in] op_type Define the ordinary op type. /// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define a function pointer to create a kernel. /// \param[in] creator Define a function pointer to create a kernel.
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \ #define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \ namespace { \
static mindspore::lite::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, \ static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
data_type, op_type, \ op_type, creator); \
creator); \
} // namespace } // namespace
/// \brief Defined registering macro to register custom op kernel, which called by user directly. /// \brief Defined registering macro to register custom op kernel, which called by user directly.
@ -138,14 +137,12 @@ class MS_API KernelReg {
/// \param[in] data_type Define kernel's input data type. /// \param[in] data_type Define kernel's input data type.
/// \param[in] op_type Define the concrete type of a custom op. /// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define a function pointer to create a kernel. /// \param[in] creator Define a function pointer to create a kernel.
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \ #define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \ namespace { \
static mindspore::lite::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, \ static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
data_type, #op_type, \ #op_type, creator); \
creator); \
} // namespace } // namespace
} // namespace registry } // namespace registry
} // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_ #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_

View File

@ -25,7 +25,6 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
namespace mindspore { namespace mindspore {
namespace lite {
namespace registry { namespace registry {
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface. /// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>; using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
@ -39,8 +38,8 @@ class MS_API RegisterKernelInterface {
/// \param[in] op_type Define the concrete type of a custom op. /// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define the KernelInterface create function. /// \param[in] creator Define the KernelInterface create function.
/// ///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h. /// \return Status as a status identification of registering.
static int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator); static Status CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
/// \brief Static method to register op whose primitive type is ordinary. /// \brief Static method to register op whose primitive type is ordinary.
/// ///
@ -48,8 +47,8 @@ class MS_API RegisterKernelInterface {
/// \param[in] op_type Define the ordinary op type. /// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define the KernelInterface create function. /// \param[in] creator Define the KernelInterface create function.
/// ///
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h. /// \return Status as a status identification of registering.
static int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator); static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
/// \brief Static method to get registration of a certain op. /// \brief Static method to get registration of a certain op.
/// ///
@ -88,9 +87,9 @@ class MS_API KernelInterfaceReg {
/// \param[in] provider Define the identification of user. /// \param[in] provider Define the identification of user.
/// \param[in] op_type Define the ordinary op type. /// \param[in] op_type Define the ordinary op type.
/// \param[in] creator Define the KernelInterface create function. /// \param[in] creator Define the KernelInterface create function.
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \ #define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \ namespace { \
static mindspore::lite::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \ static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
} // namespace } // namespace
/// \brief Defined registering macro to register custom op, which called by user directly. /// \brief Defined registering macro to register custom op, which called by user directly.
@ -98,13 +97,12 @@ class MS_API KernelInterfaceReg {
/// \param[in] provider Define the identification of user. /// \param[in] provider Define the identification of user.
/// \param[in] op_type Define the concrete type of a custom op. /// \param[in] op_type Define the concrete type of a custom op.
/// \param[in] creator Define the KernelInterface create function. /// \param[in] creator Define the KernelInterface create function.
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \ #define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \ namespace { \
static mindspore::lite::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \ static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
creator); \ creator); \
} // namespace } // namespace
} // namespace registry } // namespace registry
} // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_KERNEL_INTERFACE_H_ #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_KERNEL_INTERFACE_H_

View File

@ -39,8 +39,8 @@ using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelCreator; using mindspore::kernel::KernelCreator;
using mindspore::kernel::KernelKey; using mindspore::kernel::KernelKey;
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY #ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
using mindspore::lite::registry::CreateKernel; using mindspore::registry::CreateKernel;
using mindspore::lite::registry::KernelDesc; using mindspore::registry::KernelDesc;
#endif #endif
namespace mindspore::lite { namespace mindspore::lite {

View File

@ -21,11 +21,11 @@
#include "src/common/version_manager.h" #include "src/common/version_manager.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
using mindspore::lite::registry::KernelInterfaceCreator; using mindspore::registry::KernelInterfaceCreator;
using mindspore::schema::PrimitiveType_MAX; using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN; using mindspore::schema::PrimitiveType_MIN;
namespace mindspore { namespace mindspore {
namespace lite { namespace registry {
namespace { namespace {
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN; static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN;
std::string GetCustomType(const schema::Primitive *primitive) { std::string GetCustomType(const schema::Primitive *primitive) {
@ -35,10 +35,10 @@ std::string GetCustomType(const schema::Primitive *primitive) {
} }
} // namespace } // namespace
int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type, Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
KernelInterfaceCreator creator) { KernelInterfaceCreator creator) {
custom_creators_[provider][type] = creator; custom_creators_[provider][type] = creator;
return RET_OK; return kSuccess;
} }
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider, std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
@ -124,10 +124,10 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInter
return nullptr; return nullptr;
} }
int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) { if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum; MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
return RET_ERROR; return kLiteError;
} }
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
@ -137,12 +137,12 @@ int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Kerne
reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator))); reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
if (kernel_creators_[provider] == nullptr) { if (kernel_creators_[provider] == nullptr) {
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!"; MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
return RET_ERROR; return kLiteError;
} }
} }
kernel_creators_[provider][op_type] = creator; kernel_creators_[provider][op_type] = creator;
return RET_OK; return kSuccess;
} }
KernelInterfaceRegistry::~KernelInterfaceRegistry() { KernelInterfaceRegistry::~KernelInterfaceRegistry() {
@ -151,5 +151,5 @@ KernelInterfaceRegistry::~KernelInterfaceRegistry() {
item.second = nullptr; item.second = nullptr;
} }
} }
} // namespace lite } // namespace registry
} // namespace mindspore } // namespace mindspore

View File

@ -26,7 +26,7 @@
#include "include/model.h" #include "include/model.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace registry {
class KernelInterfaceRegistry { class KernelInterfaceRegistry {
public: public:
static KernelInterfaceRegistry *Instance() { static KernelInterfaceRegistry *Instance() {
@ -36,8 +36,8 @@ class KernelInterfaceRegistry {
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider, std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive); const schema::Primitive *primitive);
int CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator); Status CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator);
int Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator); Status Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry(); virtual ~KernelInterfaceRegistry();
private: private:
@ -55,7 +55,7 @@ class KernelInterfaceRegistry {
std::map<std::string, std::map<std::string, registry::KernelInterfaceCreator>> custom_creators_; std::map<std::string, std::map<std::string, registry::KernelInterfaceCreator>> custom_creators_;
std::map<std::string, std::map<std::string, std::shared_ptr<kernel::KernelInterface>>> custom_kernels_; std::map<std::string, std::map<std::string, std::shared_ptr<kernel::KernelInterface>>> custom_kernels_;
}; };
} // namespace lite } // namespace registry
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_ #endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_

View File

@ -21,22 +21,21 @@
#include "src/registry/register_kernel_impl.h" #include "src/registry/register_kernel_impl.h"
namespace mindspore { namespace mindspore {
namespace lite {
namespace registry { namespace registry {
int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) { const std::string &type, CreateKernel creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY #ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator); return RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT; return lite::RET_NOT_SUPPORT;
#endif #endif
} }
int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type, Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
CreateKernel creator) { CreateKernel creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY #ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator); return RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT; return lite::RET_NOT_SUPPORT;
@ -45,12 +44,11 @@ int RegisterKernel::RegKernel(const std::string &arch, const std::string &provid
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) { CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY #ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc); return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
#else #else
MS_LOG(ERROR) << unsuppor_custom_kernel_register_log; MS_LOG(ERROR) << unsuppor_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT; return lite::RET_NOT_SUPPORT;
#endif #endif
} }
} // namespace registry } // namespace registry
} // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -18,12 +18,12 @@
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/common/version_manager.h" #include "src/common/version_manager.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
using mindspore::lite::registry::CreateKernel;
using mindspore::lite::registry::KernelDesc; using mindspore::registry::CreateKernel;
using mindspore::registry::KernelDesc;
using mindspore::schema::PrimitiveType_MAX; using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN; using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::registry {
namespace mindspore::lite {
namespace { namespace {
static const auto kKernelMaxNum = static const auto kKernelMaxNum =
(static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1) * (static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1) *
@ -44,11 +44,11 @@ int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) {
return data_type_index * kOpTypeLen + desc.type; return data_type_index * kOpTypeLen + desc.type;
} }
int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) { const std::string &type, CreateKernel creator) {
if (data_type >= DataType::kNumberTypeEnd) { if (data_type >= DataType::kNumberTypeEnd) {
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider; MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
return RET_ERROR; return kLiteError;
} }
std::unique_lock<std::mutex> lock(lock_); std::unique_lock<std::mutex> lock(lock_);
if (custom_kernel_creators_[provider][arch][type] == nullptr) { if (custom_kernel_creators_[provider][arch][type] == nullptr) {
@ -56,28 +56,28 @@ int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::stri
reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel))); reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel)));
if (custom_kernel_creators_[provider][arch][type] == nullptr) { if (custom_kernel_creators_[provider][arch][type] == nullptr) {
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch; MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
return RET_ERROR; return kLiteError;
} }
} }
int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1; int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
if (data_type_index < 0 || data_type_index >= kDataTypeLen) { if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider; MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
return RET_ERROR; return kLiteError;
} }
custom_kernel_creators_[provider][arch][type][data_type_index] = creator; custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
return RET_OK; return kSuccess;
} }
int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator) { registry::CreateKernel creator) {
std::unique_lock<std::mutex> lock(lock_); std::unique_lock<std::mutex> lock(lock_);
auto iter = kernel_creators_.find(provider); auto iter = kernel_creators_.find(provider);
if (iter == kernel_creators_.end()) { if (iter == kernel_creators_.end()) {
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel))); kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
if (kernel_creators_[provider][arch] == nullptr) { if (kernel_creators_[provider][arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
return RET_ERROR; return kLiteError;
} }
} else { } else {
auto iter_arch = iter->second.find(arch); auto iter_arch = iter->second.find(arch);
@ -85,7 +85,7 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel))); iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
if (iter->second[arch] == nullptr) { if (iter->second[arch] == nullptr) {
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch; MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
return RET_ERROR; return kLiteError;
} }
} }
} }
@ -95,11 +95,11 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
if (index >= kKernelMaxNum || index < 0) { if (index >= kKernelMaxNum || index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type " MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
<< type; << type;
return RET_ERROR; return kLiteError;
} }
kernel_creators_[provider][arch][index] = creator; kernel_creators_[provider][arch][index] = creator;
return RET_OK; return kSuccess;
} }
registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive, registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive,
@ -179,4 +179,4 @@ RegistryKernelImpl::~RegistryKernelImpl() {
} }
} }
} }
} // namespace mindspore::lite } // namespace mindspore::registry

View File

@ -25,7 +25,7 @@
#include <set> #include <set>
#include "include/registry/register_kernel.h" #include "include/registry/register_kernel.h"
namespace mindspore::lite { namespace mindspore::registry {
class RegistryKernelImpl { class RegistryKernelImpl {
public: public:
RegistryKernelImpl() = default; RegistryKernelImpl() = default;
@ -36,11 +36,11 @@ class RegistryKernelImpl {
return &instance; return &instance;
} }
int RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, const std::string &type, Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
registry::CreateKernel creator); const std::string &type, registry::CreateKernel creator);
int RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
registry::CreateKernel creator); registry::CreateKernel creator);
virtual registry::CreateKernel GetProviderCreator(const schema::Primitive *primitive, registry::KernelDesc *desc); virtual registry::CreateKernel GetProviderCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
@ -60,6 +60,6 @@ class RegistryKernelImpl {
registry::CreateKernel GetCustomKernelCreator(const schema::Primitive *primitive, registry::KernelDesc *desc); registry::CreateKernel GetCustomKernelCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
int GetFuncIndex(const registry::KernelDesc &desc); int GetFuncIndex(const registry::KernelDesc &desc);
}; };
} // namespace mindspore::lite } // namespace mindspore::registry
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_IMPL_H_ #endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_IMPL_H_

View File

@ -21,21 +21,20 @@
#include "src/registry/kernel_interface_registry.h" #include "src/registry/kernel_interface_registry.h"
namespace mindspore { namespace mindspore {
namespace lite {
namespace registry { namespace registry {
int RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY #ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator); return KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT; return lite::RET_NOT_SUPPORT;
#endif #endif
} }
int RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type, Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) { KernelInterfaceCreator creator) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY #ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator); return KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return lite::RET_NOT_SUPPORT; return lite::RET_NOT_SUPPORT;
@ -45,12 +44,11 @@ int RegisterKernelInterface::CustomReg(const std::string &provider, const std::s
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface( std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) { const std::string &provider, const schema::Primitive *primitive) {
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY #ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
return lite::KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive); return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr; return nullptr;
#endif #endif
} }
} // namespace registry } // namespace registry
} // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -60,9 +60,12 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
std::transform(outputs.begin(), outputs.end(), std::back_inserter(out_tensors), std::transform(outputs.begin(), outputs.end(), std::back_inserter(out_tensors),
[](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared<MSTensor::Impl>(tensor)); }); [](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared<MSTensor::Impl>(tensor)); });
auto ret = kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive)); auto ret = kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive));
if (ret != RET_OK) { if (ret == kLiteInferInvalid) {
return RET_INFER_INVALID;
}
if (ret != kSuccess) {
MS_LOG(ERROR) << "op_type: " << PrimitiveTypeName(prim_type) << " infer fail!ret: " << ret; MS_LOG(ERROR) << "op_type: " << PrimitiveTypeName(prim_type) << " infer fail!ret: " << ret;
return ret; return RET_ERROR;
} }
return RET_OK; return RET_OK;
} }

View File

@ -111,12 +111,12 @@ class TestCustomOpInfer : public KernelInterface {
public: public:
TestCustomOpInfer() = default; TestCustomOpInfer() = default;
~TestCustomOpInfer() = default; ~TestCustomOpInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override { const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format()); (*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType()); (*outputs)[0].SetDataType((*inputs)[0].DataType());
(*outputs)[0].SetShape((*inputs)[0].Shape()); (*outputs)[0].SetShape((*inputs)[0].Shape());
return RET_OK; return kSuccess;
} }
}; };

View File

@ -80,12 +80,12 @@ class TestCustomAddInfer : public KernelInterface {
public: public:
TestCustomAddInfer() = default; TestCustomAddInfer() = default;
~TestCustomAddInfer() = default; ~TestCustomAddInfer() = default;
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override { const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format()); (*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType()); (*outputs)[0].SetDataType((*inputs)[0].DataType());
(*outputs)[0].SetShape((*inputs)[0].Shape()); (*outputs)[0].SetShape((*inputs)[0].Shape());
return RET_OK; return kSuccess;
} }
}; };

View File

@ -35,7 +35,7 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
ASSERT_NE(proposal_parser, nullptr); ASSERT_NE(proposal_parser, nullptr);
REG_MODEL_PARSER(FmkType_CAFFE, REG_MODEL_PARSER(FmkType_CAFFE,
TestModelParserCreator); // register test model parser creator, which will overwrite existing. TestModelParserCreator); // register test model parser creator, which will overwrite existing.
auto model_parser = lite::registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE); auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
ASSERT_NE(model_parser, nullptr); ASSERT_NE(model_parser, nullptr);
ConverterParameters converter_parameters; ConverterParameters converter_parameters;
auto func_graph = model_parser->Parse(converter_parameters); auto func_graph = model_parser->Parse(converter_parameters);

View File

@ -25,20 +25,19 @@
#include "ops/addn.h" #include "ops/addn.h"
#include "ops/custom.h" #include "ops/custom.h"
#include "tools/converter/model_parser.h" #include "tools/converter/model_parser.h"
#include "tools/converter/registry/pass_content.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h" #include "ut/tools/converter/registry/model_parser_test.h"
using mindspore::converter::ConverterParameters; using mindspore::converter::ConverterParameters;
using mindspore::converter::FmkType_CAFFE; using mindspore::converter::FmkType_CAFFE;
using mindspore::lite::registry::POSITION_BEGIN; using mindspore::registry::POSITION_BEGIN;
namespace mindspore { namespace mindspore {
class PassRegistryTest : public mindspore::CommonTest { class PassRegistryTest : public mindspore::CommonTest {
public: public:
PassRegistryTest() = default; PassRegistryTest() = default;
void SetUp() override { void SetUp() override {
REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator); REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator);
auto model_parser = lite::registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE); auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
if (model_parser == nullptr) { if (model_parser == nullptr) {
return; return;
} }
@ -214,12 +213,11 @@ REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"})
} // namespace opt } // namespace opt
TEST_F(PassRegistryTest, TestRegistry) { TEST_F(PassRegistryTest, TestRegistry) {
auto &passes = lite::PassStoreRoomInfo(); auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
auto &assigned_passes = lite::ExternalAssignedPassesInfo(); ASSERT_EQ(schedule_task.size(), 1);
ASSERT_EQ(assigned_passes.size(), 1); auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
auto pass_names = assigned_passes[POSITION_BEGIN]; ASSERT_EQ(passes.size(), 1);
ASSERT_EQ(pass_names.size(), 1); auto begin_pass = passes.front();
auto begin_pass = passes[pass_names.front()];
ASSERT_NE(begin_pass, nullptr); ASSERT_NE(begin_pass, nullptr);
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass); auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
ASSERT_NE(begin_pass_test, nullptr); ASSERT_NE(begin_pass_test, nullptr);

View File

@ -326,7 +326,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return nullptr; return nullptr;
} }
if (!RunExternalPass(old_graph, registry::PassPosition::POSITION_BEGIN)) { if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN"; MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
return nullptr; return nullptr;
} }
@ -361,7 +361,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return nullptr; return nullptr;
} }
if (!RunExternalPass(old_graph, registry::PassPosition::POSITION_END)) { if (!RunExternalPass(old_graph, registry::POSITION_END)) {
MS_LOG(ERROR) << "Run external pass failed, place is END"; MS_LOG(ERROR) << "Run external pass failed, place is END";
return nullptr; return nullptr;
} }

View File

@ -18,25 +18,26 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "backend/optimizer/common/pass.h" #include "backend/optimizer/common/pass.h"
#include "tools/converter/registry/pass_content.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names) { bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is nullptr."; MS_LOG(ERROR) << "func graph is nullptr.";
return false; return false;
} }
auto &passes_info = PassStoreRoomInfo(); auto schedule_passes = registry::PassRegistry::GetPassFromStoreRoom(pass_names);
for (auto &name : pass_names) { if (schedule_passes.size() != pass_names.size()) {
if (passes_info.find(name) == passes_info.end()) { MS_LOG(ERROR) << "exited pass cannot be obtained.";
MS_LOG(ERROR) << "cannot find required pass."; return false;
return false; }
} int index = 0;
if (!passes_info[name]->Run(func_graph)) { for (auto &pass : schedule_passes) {
MS_LOG(ERROR) << "run pass failed, pass name is " << name; if (!pass->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << pass_names[index];
return false; return false;
} }
++index;
} }
return true; return true;
} }
@ -46,21 +47,10 @@ bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition posi
MS_LOG(ERROR) << "func graph is nullptr."; MS_LOG(ERROR) << "func graph is nullptr.";
return false; return false;
} }
auto &external_assigned = ExternalAssignedPassesInfo(); auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(position);
if (external_assigned.find(position) == external_assigned.end()) { if (!RunOptimizerPass(func_graph, schedule_task)) {
MS_LOG(DEBUG) << "there is no external pass in current position, position is " << position; MS_LOG(ERROR) << "run external scheduled task failed.";
return true; return false;
}
auto &passes_info = PassStoreRoomInfo();
for (auto &name : external_assigned[position]) {
if (passes_info.find(name) == passes_info.end()) {
MS_LOG(ERROR) << "cannot find required pass.";
return false;
}
if (!passes_info[name]->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
return false;
}
} }
return true; return true;
} }

View File

@ -24,7 +24,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names); bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names);
bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition position); bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition position);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -19,7 +19,6 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace lite {
namespace registry { namespace registry {
namespace { namespace {
std::map<FmkType, ModelParserCreator> model_parser_room; std::map<FmkType, ModelParserCreator> model_parser_room;
@ -42,5 +41,4 @@ converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
return nullptr; return nullptr;
} }
} // namespace registry } // namespace registry
} // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -1,32 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
#include <map>
#include <string>
#include <vector>
#include "include/registry/pass_registry.h"
namespace mindspore {
namespace lite {
std::map<std::string, opt::PassPtr> &MS_API PassStoreRoomInfo();
std::map<registry::PassPosition, std::vector<std::string>> &MS_API ExternalAssignedPassesInfo();
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H

View File

@ -19,11 +19,10 @@
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
#include "tools/converter/registry/pass_content.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace registry {
namespace { namespace {
std::map<std::string, opt::PassPtr> pass_store_room; std::map<std::string, opt::PassPtr> pass_store_room;
std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes; std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
@ -38,19 +37,27 @@ void RegPass(const std::string &pass_name, const opt::PassPtr &pass) {
} }
} // namespace } // namespace
registry::PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
RegPass(pass_name, pass);
}
registry::PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) { PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
std::unique_lock<std::mutex> lock(pass_mutex); std::unique_lock<std::mutex> lock(pass_mutex);
external_assigned_passes[position] = assigned; external_assigned_passes[position] = assigned;
} }
std::map<std::string, opt::PassPtr> &PassStoreRoomInfo() { return pass_store_room; } std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
return external_assigned_passes[position];
std::map<registry::PassPosition, std::vector<std::string>> &ExternalAssignedPassesInfo() {
return external_assigned_passes;
} }
} // namespace lite
std::vector<opt::PassPtr> PassRegistry::GetPassFromStoreRoom(const std::vector<std::string> &pass_names) {
std::vector<opt::PassPtr> schedule_passes;
for (auto &name : pass_names) {
auto iter = pass_store_room.find(name);
if (iter == pass_store_room.end()) {
continue;
}
schedule_passes.push_back(iter->second);
}
return schedule_passes;
}
} // namespace registry
} // namespace mindspore } // namespace mindspore