forked from mindspore-Ecosystem/mindspore
change namespace mindspore::registry to mindspore::registry
This commit is contained in:
parent
62a780f20f
commit
df3db5196b
|
@ -28,17 +28,19 @@ class CustomAddInfer : public kernel::KernelInterface {
|
|||
CustomAddInfer() = default;
|
||||
~CustomAddInfer() = default;
|
||||
|
||||
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
(*outputs)[0].SetFormat((*inputs)[0].format());
|
||||
(*outputs)[0].SetDataType((*inputs)[0].DataType());
|
||||
auto ret = common::CheckInputs(*inputs);
|
||||
if (ret != lite::RET_OK) {
|
||||
if (ret == lite::RET_INFER_INVALID) {
|
||||
(*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running.
|
||||
return ret;
|
||||
return kLiteInferInvalid;
|
||||
} else if (ret != lite::RET_OK) {
|
||||
return kLiteError;
|
||||
}
|
||||
(*outputs)[0].SetShape((*inputs)[0].Shape());
|
||||
return lite::RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
};
|
||||
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }
|
||||
|
|
|
@ -98,7 +98,7 @@ bool PassTutorial::Run(const FuncGraphPtr &func_graph) {
|
|||
|
||||
namespace lite {
|
||||
// register customed Pass
|
||||
using mindspore::lite::registry::POSITION_BEGIN;
|
||||
using mindspore::registry::POSITION_BEGIN;
|
||||
REG_PASS(PassTutorial, opt::PassTutorial)
|
||||
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"})
|
||||
} // namespace lite
|
||||
|
|
|
@ -28,17 +28,19 @@ class CustomAddInfer : public kernel::KernelInterface {
|
|||
CustomAddInfer() = default;
|
||||
~CustomAddInfer() = default;
|
||||
|
||||
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
(*outputs)[0].SetFormat((*inputs)[0].format());
|
||||
(*outputs)[0].SetDataType((*inputs)[0].DataType());
|
||||
auto ret = common::CheckInputs(*inputs);
|
||||
if (ret != lite::RET_OK) {
|
||||
if (ret == lite::RET_INFER_INVALID) {
|
||||
(*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running.
|
||||
return ret;
|
||||
return kLiteInferInvalid;
|
||||
} else if (ret != lite::RET_OK) {
|
||||
return kLiteError;
|
||||
}
|
||||
(*outputs)[0].SetShape((*inputs)[0].Shape());
|
||||
return lite::RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
};
|
||||
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }
|
||||
|
|
|
@ -60,13 +60,13 @@ class CustomAddKernel : public Kernel {
|
|||
// if output shape exists value -1, need to be inferred before applying memory for output tensor.
|
||||
int PreProcess() {
|
||||
if (common::CheckOutputs(outputs_) != lite::RET_OK) {
|
||||
auto ret = lite::registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)
|
||||
->Infer(&inputs_, &outputs_, primitive_);
|
||||
if (ret != lite::RET_OK) {
|
||||
auto status =
|
||||
registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_);
|
||||
if (status != kSuccess) {
|
||||
std::cerr << "infer failed." << std::endl;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = ReSize();
|
||||
auto ret = ReSize();
|
||||
if (ret != lite::RET_OK) {
|
||||
std::cerr << "resize failed." << std::endl;
|
||||
return ret;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/lite_utils.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
|
@ -36,10 +37,10 @@ class MS_API KernelInterface {
|
|||
/// \param[in] outputs Define the output tensors of op.
|
||||
/// \param[in] primitive Define the attributes of op.
|
||||
///
|
||||
/// \return STATUS as an error code of inferring, STATUS is defined in errorcode.h..
|
||||
virtual int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) {
|
||||
return 0;
|
||||
/// \return Status as a status identification of inferring.
|
||||
virtual Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) {
|
||||
return kSuccess;
|
||||
}
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
|
||||
typedef converter::ModelParser *(*ModelParserCreator)();
|
||||
|
@ -53,9 +52,8 @@ class MS_API ModelParserRegistry {
|
|||
/// \param[in] fmk Define identification of a certain framework.
|
||||
/// \param[in] parserCreator Define function pointer of creating ModelParser.
|
||||
#define REG_MODEL_PARSER(fmk, parserCreator) \
|
||||
static mindspore::lite::registry::ModelParserRegistry g_##type##fmk##ModelParserReg(fmk, parserCreator);
|
||||
static mindspore::registry::ModelParserRegistry g_##type##fmk##ModelParserReg(fmk, parserCreator);
|
||||
} // namespace registry
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
|
|
|
@ -32,7 +32,6 @@ class MS_API Pass;
|
|||
using PassPtr = std::shared_ptr<Pass>;
|
||||
} // namespace opt
|
||||
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
/// \brief PassPosition defined where to plae user's pass.
|
||||
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
|
||||
|
@ -42,35 +41,48 @@ class MS_API PassRegistry {
|
|||
public:
|
||||
/// \brief Constructor of PassRegistry to register pass.
|
||||
///
|
||||
/// \param[in] pos Define where to replace the pass.
|
||||
/// \param[in] pass Define user's defined pass.
|
||||
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
|
||||
/// \param[in] pass Define pass instance.
|
||||
PassRegistry(const std::string &pass_name, const opt::PassPtr &pass);
|
||||
|
||||
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
||||
///
|
||||
/// \param[in position Define the place where assigned passes will run.
|
||||
/// \param[in] assigned Define the name of passes assigned by user.
|
||||
/// \param[in] position Define the place where assigned passes will run.
|
||||
/// \param[in] assigned Define the names of the passes.
|
||||
PassRegistry(PassPosition position, const std::vector<std::string> &assigned);
|
||||
|
||||
/// \brief Destructor of PassRegistrar.
|
||||
~PassRegistry() = default;
|
||||
|
||||
/// \brief Static method to obtain external scheduling task assigned by user.
|
||||
///
|
||||
/// \param[in] position Define the place where assigned passes will run.
|
||||
///
|
||||
/// \return Passes' Name Vector.
|
||||
static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
|
||||
|
||||
/// \brief Static method to obtain pass instance according to passes' name.
|
||||
///
|
||||
/// \param[in] pass_names Define the name of passes.
|
||||
///
|
||||
/// \return Pass Instance Vector.
|
||||
static std::vector<opt::PassPtr> GetPassFromStoreRoom(const std::vector<std::string> &pass_names);
|
||||
};
|
||||
|
||||
/// \brief Defined registering macro to register Pass, which called by user directly.
|
||||
///
|
||||
/// \param[in] name Define name of user's pass, which is a string.
|
||||
/// \param[in] pass Define user's defined pass.
|
||||
/// \param[in] name Define the name of the pass, a string which should guarantee uniqueness.
|
||||
/// \param[in] pass Define pass instance.
|
||||
#define REG_PASS(name, pass) \
|
||||
static mindspore::lite::registry::PassRegistry g_##name##PassReg(#name, std::make_shared<pass>());
|
||||
static mindspore::registry::PassRegistry g_##name##PassReg(#name, std::make_shared<pass>());
|
||||
|
||||
/// \brief Defined assigning macro to assign Passes, which called by user directly.
|
||||
///
|
||||
/// \param[in] position Define the place where assigned passes will run.
|
||||
/// \param[in] assigned Define the name of passes assigned by user.
|
||||
/// \param[in] assigned Define the names of the passes.
|
||||
#define REG_SCHEDULED_PASS(position, assigned) \
|
||||
static mindspore::lite::registry::PassRegistry g_##position(position, assigned);
|
||||
static mindspore::registry::PassRegistry g_##position(position, assigned);
|
||||
} // namespace registry
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_REGISTRY_H_
|
||||
|
|
|
@ -26,9 +26,9 @@
|
|||
#include "include/api/types.h"
|
||||
#include "include/api/kernel.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
/// \brief KernelDesc defined kernel's basic attribute.
|
||||
struct KernelDesc {
|
||||
|
@ -61,9 +61,9 @@ class MS_API RegisterKernel {
|
|||
/// \param[in] type Define the ordinary op type.
|
||||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
///
|
||||
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
|
||||
static int RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
CreateKernel creator);
|
||||
/// \return Status as a status identification of registering.
|
||||
static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
CreateKernel creator);
|
||||
|
||||
/// \brief Static method to register kernel which is corresponding to custom op.
|
||||
///
|
||||
|
@ -73,9 +73,9 @@ class MS_API RegisterKernel {
|
|||
/// \param[in] type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
///
|
||||
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
|
||||
static int RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator);
|
||||
/// \return Status as a status identification of registering.
|
||||
static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator);
|
||||
|
||||
/// \brief Static methon to get a kernel's create function.
|
||||
///
|
||||
|
@ -124,11 +124,10 @@ class MS_API KernelReg {
|
|||
/// \param[in] data_type Define kernel's input data type.
|
||||
/// \param[in] op_type Define the ordinary op type.
|
||||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::lite::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, \
|
||||
data_type, op_type, \
|
||||
creator); \
|
||||
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
|
||||
op_type, creator); \
|
||||
} // namespace
|
||||
|
||||
/// \brief Defined registering macro to register custom op kernel, which called by user directly.
|
||||
|
@ -138,14 +137,12 @@ class MS_API KernelReg {
|
|||
/// \param[in] data_type Define kernel's input data type.
|
||||
/// \param[in] op_type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::lite::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, \
|
||||
data_type, #op_type, \
|
||||
creator); \
|
||||
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
|
||||
#op_type, creator); \
|
||||
} // namespace
|
||||
} // namespace registry
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "schema/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
/// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
|
||||
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
|
||||
|
@ -39,8 +38,8 @@ class MS_API RegisterKernelInterface {
|
|||
/// \param[in] op_type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
///
|
||||
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
|
||||
static int CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
|
||||
/// \return Status as a status identification of registering.
|
||||
static Status CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
|
||||
|
||||
/// \brief Static method to register op whose primitive type is ordinary.
|
||||
///
|
||||
|
@ -48,8 +47,8 @@ class MS_API RegisterKernelInterface {
|
|||
/// \param[in] op_type Define the ordinary op type.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
///
|
||||
/// \return STATUS as an error code of registering, STATUS is defined in errorcode.h.
|
||||
static int Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
||||
/// \return Status as a status identification of registering.
|
||||
static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
||||
|
||||
/// \brief Static method to get registration of a certain op.
|
||||
///
|
||||
|
@ -88,9 +87,9 @@ class MS_API KernelInterfaceReg {
|
|||
/// \param[in] provider Define the identification of user.
|
||||
/// \param[in] op_type Define the ordinary op type.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::lite::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
|
||||
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
|
||||
} // namespace
|
||||
|
||||
/// \brief Defined registering macro to register custom op, which called by user directly.
|
||||
|
@ -98,13 +97,12 @@ class MS_API KernelInterfaceReg {
|
|||
/// \param[in] provider Define the identification of user.
|
||||
/// \param[in] op_type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::lite::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
|
||||
creator); \
|
||||
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
|
||||
creator); \
|
||||
} // namespace
|
||||
} // namespace registry
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_KERNEL_INTERFACE_H_
|
||||
|
|
|
@ -39,8 +39,8 @@ using mindspore::kernel::KERNEL_ARCH;
|
|||
using mindspore::kernel::KernelCreator;
|
||||
using mindspore::kernel::KernelKey;
|
||||
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
|
||||
using mindspore::lite::registry::CreateKernel;
|
||||
using mindspore::lite::registry::KernelDesc;
|
||||
using mindspore::registry::CreateKernel;
|
||||
using mindspore::registry::KernelDesc;
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
|
|
@ -21,11 +21,11 @@
|
|||
#include "src/common/version_manager.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
using mindspore::lite::registry::KernelInterfaceCreator;
|
||||
using mindspore::registry::KernelInterfaceCreator;
|
||||
using mindspore::schema::PrimitiveType_MAX;
|
||||
using mindspore::schema::PrimitiveType_MIN;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
namespace {
|
||||
static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN;
|
||||
std::string GetCustomType(const schema::Primitive *primitive) {
|
||||
|
@ -35,10 +35,10 @@ std::string GetCustomType(const schema::Primitive *primitive) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
int KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
|
||||
KernelInterfaceCreator creator) {
|
||||
Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
|
||||
KernelInterfaceCreator creator) {
|
||||
custom_creators_[provider][type] = creator;
|
||||
return RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
|
||||
|
@ -124,10 +124,10 @@ std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInter
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
if (op_type < PrimitiveType_MIN || op_type > kMaxKernelNum) {
|
||||
MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << kMaxKernelNum;
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
@ -137,12 +137,12 @@ int KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, Kerne
|
|||
reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
|
||||
if (kernel_creators_[provider] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
kernel_creators_[provider][op_type] = creator;
|
||||
return RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
KernelInterfaceRegistry::~KernelInterfaceRegistry() {
|
||||
|
@ -151,5 +151,5 @@ KernelInterfaceRegistry::~KernelInterfaceRegistry() {
|
|||
item.second = nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "include/model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
class KernelInterfaceRegistry {
|
||||
public:
|
||||
static KernelInterfaceRegistry *Instance() {
|
||||
|
@ -36,8 +36,8 @@ class KernelInterfaceRegistry {
|
|||
|
||||
std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
|
||||
const schema::Primitive *primitive);
|
||||
int CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator);
|
||||
int Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator);
|
||||
Status CustomReg(const std::string &provider, const std::string &op_type, registry::KernelInterfaceCreator creator);
|
||||
Status Reg(const std::string &provider, int op_type, registry::KernelInterfaceCreator creator);
|
||||
virtual ~KernelInterfaceRegistry();
|
||||
|
||||
private:
|
||||
|
@ -55,7 +55,7 @@ class KernelInterfaceRegistry {
|
|||
std::map<std::string, std::map<std::string, registry::KernelInterfaceCreator>> custom_creators_;
|
||||
std::map<std::string, std::map<std::string, std::shared_ptr<kernel::KernelInterface>>> custom_kernels_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTRY_KERNEL_INTERFACE_REGISTRY_H_
|
||||
|
|
|
@ -21,22 +21,21 @@
|
|||
#include "src/registry/register_kernel_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
int RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
|
||||
return lite::RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
|
||||
return RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
|
||||
#else
|
||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
#endif
|
||||
}
|
||||
|
||||
int RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
|
||||
CreateKernel creator) {
|
||||
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
|
||||
CreateKernel creator) {
|
||||
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
|
||||
return lite::RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
||||
return RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
||||
#else
|
||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
|
@ -45,12 +44,11 @@ int RegisterKernel::RegKernel(const std::string &arch, const std::string &provid
|
|||
|
||||
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
|
||||
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
|
||||
return lite::RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
|
||||
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
|
||||
#else
|
||||
MS_LOG(ERROR) << unsuppor_custom_kernel_register_log;
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
#endif
|
||||
}
|
||||
} // namespace registry
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,12 +18,12 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
using mindspore::lite::registry::CreateKernel;
|
||||
using mindspore::lite::registry::KernelDesc;
|
||||
|
||||
using mindspore::registry::CreateKernel;
|
||||
using mindspore::registry::KernelDesc;
|
||||
using mindspore::schema::PrimitiveType_MAX;
|
||||
using mindspore::schema::PrimitiveType_MIN;
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace mindspore::registry {
|
||||
namespace {
|
||||
static const auto kKernelMaxNum =
|
||||
(static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1) *
|
||||
|
@ -44,11 +44,11 @@ int RegistryKernelImpl::GetFuncIndex(const KernelDesc &desc) {
|
|||
return data_type_index * kOpTypeLen + desc.type;
|
||||
}
|
||||
|
||||
int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, CreateKernel creator) {
|
||||
if (data_type >= DataType::kNumberTypeEnd) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
|
||||
|
@ -56,28 +56,28 @@ int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::stri
|
|||
reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel)));
|
||||
if (custom_kernel_creators_[provider][arch][type] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
|
||||
if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
|
||||
MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
|
||||
return RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
registry::CreateKernel creator) {
|
||||
Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
registry::CreateKernel creator) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = kernel_creators_.find(provider);
|
||||
if (iter == kernel_creators_.end()) {
|
||||
kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
|
||||
if (kernel_creators_[provider][arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
} else {
|
||||
auto iter_arch = iter->second.find(arch);
|
||||
|
@ -85,7 +85,7 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
|
|||
iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
|
||||
if (iter->second[arch] == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -95,11 +95,11 @@ int RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &pr
|
|||
if (index >= kKernelMaxNum || index < 0) {
|
||||
MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
|
||||
<< type;
|
||||
return RET_ERROR;
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
kernel_creators_[provider][arch][index] = creator;
|
||||
return RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive,
|
||||
|
@ -179,4 +179,4 @@ RegistryKernelImpl::~RegistryKernelImpl() {
|
|||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace mindspore::registry
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include <set>
|
||||
#include "include/registry/register_kernel.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace mindspore::registry {
|
||||
class RegistryKernelImpl {
|
||||
public:
|
||||
RegistryKernelImpl() = default;
|
||||
|
@ -36,11 +36,11 @@ class RegistryKernelImpl {
|
|||
return &instance;
|
||||
}
|
||||
|
||||
int RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, const std::string &type,
|
||||
registry::CreateKernel creator);
|
||||
Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||
const std::string &type, registry::CreateKernel creator);
|
||||
|
||||
int RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
registry::CreateKernel creator);
|
||||
Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||
registry::CreateKernel creator);
|
||||
|
||||
virtual registry::CreateKernel GetProviderCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
|
||||
|
||||
|
@ -60,6 +60,6 @@ class RegistryKernelImpl {
|
|||
registry::CreateKernel GetCustomKernelCreator(const schema::Primitive *primitive, registry::KernelDesc *desc);
|
||||
int GetFuncIndex(const registry::KernelDesc &desc);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
} // namespace mindspore::registry
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_REGISTRY_REGISTER_KERNEL_IMPL_H_
|
||||
|
|
|
@ -21,21 +21,20 @@
|
|||
#include "src/registry/kernel_interface_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
int RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
|
||||
return lite::KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
|
||||
return KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
|
||||
#else
|
||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
#endif
|
||||
}
|
||||
|
||||
int RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
|
||||
KernelInterfaceCreator creator) {
|
||||
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
|
||||
KernelInterfaceCreator creator) {
|
||||
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
|
||||
return lite::KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
|
||||
return KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
|
||||
#else
|
||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
|
@ -45,12 +44,11 @@ int RegisterKernelInterface::CustomReg(const std::string &provider, const std::s
|
|||
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
|
||||
const std::string &provider, const schema::Primitive *primitive) {
|
||||
#ifdef ENABLE_CUSTOM_KERNEL_REGISTRY
|
||||
return lite::KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
|
||||
return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
|
||||
#else
|
||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
} // namespace registry
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,9 +60,12 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
|
|||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(out_tensors),
|
||||
[](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared<MSTensor::Impl>(tensor)); });
|
||||
auto ret = kernel_interface->Infer(&in_tensors, &out_tensors, static_cast<const schema::Primitive *>(primitive));
|
||||
if (ret != RET_OK) {
|
||||
if (ret == kLiteInferInvalid) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "op_type: " << PrimitiveTypeName(prim_type) << " infer fail!ret: " << ret;
|
||||
return ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -111,12 +111,12 @@ class TestCustomOpInfer : public KernelInterface {
|
|||
public:
|
||||
TestCustomOpInfer() = default;
|
||||
~TestCustomOpInfer() = default;
|
||||
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
(*outputs)[0].SetFormat((*inputs)[0].format());
|
||||
(*outputs)[0].SetDataType((*inputs)[0].DataType());
|
||||
(*outputs)[0].SetShape((*inputs)[0].Shape());
|
||||
return RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -80,12 +80,12 @@ class TestCustomAddInfer : public KernelInterface {
|
|||
public:
|
||||
TestCustomAddInfer() = default;
|
||||
~TestCustomAddInfer() = default;
|
||||
int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
(*outputs)[0].SetFormat((*inputs)[0].format());
|
||||
(*outputs)[0].SetDataType((*inputs)[0].DataType());
|
||||
(*outputs)[0].SetShape((*inputs)[0].Shape());
|
||||
return RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
|
|||
ASSERT_NE(proposal_parser, nullptr);
|
||||
REG_MODEL_PARSER(FmkType_CAFFE,
|
||||
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
|
||||
auto model_parser = lite::registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
|
||||
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
|
||||
ASSERT_NE(model_parser, nullptr);
|
||||
ConverterParameters converter_parameters;
|
||||
auto func_graph = model_parser->Parse(converter_parameters);
|
||||
|
|
|
@ -25,20 +25,19 @@
|
|||
#include "ops/addn.h"
|
||||
#include "ops/custom.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/registry/pass_content.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
||||
using mindspore::converter::ConverterParameters;
|
||||
using mindspore::converter::FmkType_CAFFE;
|
||||
using mindspore::lite::registry::POSITION_BEGIN;
|
||||
using mindspore::registry::POSITION_BEGIN;
|
||||
namespace mindspore {
|
||||
class PassRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
PassRegistryTest() = default;
|
||||
void SetUp() override {
|
||||
REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator);
|
||||
auto model_parser = lite::registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
|
||||
auto model_parser = registry::ModelParserRegistry::GetModelParser(FmkType_CAFFE);
|
||||
if (model_parser == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -214,12 +213,11 @@ REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"})
|
|||
} // namespace opt
|
||||
|
||||
TEST_F(PassRegistryTest, TestRegistry) {
|
||||
auto &passes = lite::PassStoreRoomInfo();
|
||||
auto &assigned_passes = lite::ExternalAssignedPassesInfo();
|
||||
ASSERT_EQ(assigned_passes.size(), 1);
|
||||
auto pass_names = assigned_passes[POSITION_BEGIN];
|
||||
ASSERT_EQ(pass_names.size(), 1);
|
||||
auto begin_pass = passes[pass_names.front()];
|
||||
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
|
||||
ASSERT_EQ(schedule_task.size(), 1);
|
||||
auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
|
||||
ASSERT_EQ(passes.size(), 1);
|
||||
auto begin_pass = passes.front();
|
||||
ASSERT_NE(begin_pass, nullptr);
|
||||
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
|
||||
ASSERT_NE(begin_pass_test, nullptr);
|
||||
|
|
|
@ -326,7 +326,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (!RunExternalPass(old_graph, registry::PassPosition::POSITION_BEGIN)) {
|
||||
if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
|
||||
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -361,7 +361,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (!RunExternalPass(old_graph, registry::PassPosition::POSITION_END)) {
|
||||
if (!RunExternalPass(old_graph, registry::POSITION_END)) {
|
||||
MS_LOG(ERROR) << "Run external pass failed, place is END";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -18,25 +18,26 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/registry/pass_content.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names) {
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto &passes_info = PassStoreRoomInfo();
|
||||
for (auto &name : pass_names) {
|
||||
if (passes_info.find(name) == passes_info.end()) {
|
||||
MS_LOG(ERROR) << "cannot find required pass.";
|
||||
return false;
|
||||
}
|
||||
if (!passes_info[name]->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
|
||||
auto schedule_passes = registry::PassRegistry::GetPassFromStoreRoom(pass_names);
|
||||
if (schedule_passes.size() != pass_names.size()) {
|
||||
MS_LOG(ERROR) << "exited pass cannot be obtained.";
|
||||
return false;
|
||||
}
|
||||
int index = 0;
|
||||
for (auto &pass : schedule_passes) {
|
||||
if (!pass->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "run pass failed, pass name is " << pass_names[index];
|
||||
return false;
|
||||
}
|
||||
++index;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -46,21 +47,10 @@ bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition posi
|
|||
MS_LOG(ERROR) << "func graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto &external_assigned = ExternalAssignedPassesInfo();
|
||||
if (external_assigned.find(position) == external_assigned.end()) {
|
||||
MS_LOG(DEBUG) << "there is no external pass in current position, position is " << position;
|
||||
return true;
|
||||
}
|
||||
auto &passes_info = PassStoreRoomInfo();
|
||||
for (auto &name : external_assigned[position]) {
|
||||
if (passes_info.find(name) == passes_info.end()) {
|
||||
MS_LOG(ERROR) << "cannot find required pass.";
|
||||
return false;
|
||||
}
|
||||
if (!passes_info[name]->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "run pass failed, pass name is " << name;
|
||||
return false;
|
||||
}
|
||||
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(position);
|
||||
if (!RunOptimizerPass(func_graph, schedule_task)) {
|
||||
MS_LOG(ERROR) << "run external scheduled task failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names);
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names);
|
||||
bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition position);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
namespace {
|
||||
std::map<FmkType, ModelParserCreator> model_parser_room;
|
||||
|
@ -42,5 +41,4 @@ converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
|
|||
return nullptr;
|
||||
}
|
||||
} // namespace registry
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/pass_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::map<std::string, opt::PassPtr> &MS_API PassStoreRoomInfo();
|
||||
std::map<registry::PassPosition, std::vector<std::string>> &MS_API ExternalAssignedPassesInfo();
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H
|
|
@ -19,11 +19,10 @@
|
|||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tools/converter/registry/pass_content.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace registry {
|
||||
namespace {
|
||||
std::map<std::string, opt::PassPtr> pass_store_room;
|
||||
std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
|
||||
|
@ -38,19 +37,27 @@ void RegPass(const std::string &pass_name, const opt::PassPtr &pass) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
registry::PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) {
|
||||
RegPass(pass_name, pass);
|
||||
}
|
||||
PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
|
||||
|
||||
registry::PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
|
||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
external_assigned_passes[position] = assigned;
|
||||
}
|
||||
|
||||
std::map<std::string, opt::PassPtr> &PassStoreRoomInfo() { return pass_store_room; }
|
||||
|
||||
std::map<registry::PassPosition, std::vector<std::string>> &ExternalAssignedPassesInfo() {
|
||||
return external_assigned_passes;
|
||||
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
|
||||
return external_assigned_passes[position];
|
||||
}
|
||||
} // namespace lite
|
||||
|
||||
std::vector<opt::PassPtr> PassRegistry::GetPassFromStoreRoom(const std::vector<std::string> &pass_names) {
|
||||
std::vector<opt::PassPtr> schedule_passes;
|
||||
for (auto &name : pass_names) {
|
||||
auto iter = pass_store_room.find(name);
|
||||
if (iter == pass_store_room.end()) {
|
||||
continue;
|
||||
}
|
||||
schedule_passes.push_back(iter->second);
|
||||
}
|
||||
return schedule_passes;
|
||||
}
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue