registry modelParser and pass

This commit is contained in:
liuyu 2021-05-11 17:06:34 +08:00
parent 2b8083915e
commit 3869a02fc3
35 changed files with 501 additions and 284 deletions

View File

@ -217,6 +217,8 @@ elseif(WIN32)
install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/mindspore_core/gvar/libmindspore_gvar.dll
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin_reg.dll
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/../bin/libglog.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${NNACL_FILES} DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -286,6 +288,8 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/mindspore_core/gvar/libmindspore_gvar.so
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin_reg.so
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0
DESTINATION ${CONVERTER_ROOT_DIR}/third_party/glog/lib RENAME libglog.so.0
COMPONENT ${RUNTIME_COMPONENT_NAME})

View File

@ -46,6 +46,7 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/common/file_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/log_adapter.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/string_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/prim_util.cc

View File

@ -91,8 +91,8 @@ int SimpleTile(void *cdata, int task_id) {
void TileCPUKernel::FillOneDimTileParam() {
// check if tile exact one dim
int large_one_multiple_count = 0;
int multiple;
int mul_index;
int multiple = 0;
int mul_index = 0;
for (auto i = 0; i < tile_parameter_->in_dim_; ++i) {
if (tile_parameter_->multiples_[i] > 1) {
large_one_multiple_count++;

View File

@ -430,6 +430,7 @@ if(ENABLE_CONVERTER)
add_dependencies(lite-test fbs_inner_src)
target_link_libraries(lite-test
anf_exporter_mid
mslite_converter_plugin_reg
tflite_parser_mid
caffe_parser_mid
onnx_parser_mid

View File

@ -22,7 +22,6 @@
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/anf_exporter/anf_exporter.h"

View File

@ -22,7 +22,6 @@
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "test/common/import_from_meta_graphT.h"

View File

@ -22,7 +22,6 @@
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "test/common/import_from_meta_graphT.h"

View File

@ -22,7 +22,6 @@
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "test/common/import_from_meta_graphT.h"

View File

@ -22,7 +22,6 @@
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "test/common/import_from_meta_graphT.h"

View File

@ -32,6 +32,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/dynamic_library_loader.cc
../optimizer/common/node_pass_extends.cc
../optimizer/common/pass_manager_extends.cc
@ -107,6 +108,7 @@ add_subdirectory(parser/onnx)
add_subdirectory(parser/tf)
add_subdirectory(legacy_optimizer)
add_subdirectory(quantizer)
add_subdirectory(registry)
add_subdirectory(${CORE_DIR} mindspore_core)
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
@ -216,6 +218,7 @@ add_dependencies(converter_lite fbs_src)
add_dependencies(converter_lite fbs_inner_src)
target_link_libraries(converter_lite PRIVATE
mslite_converter_plugin_reg
tflite_parser_mid
tf_parser_mid
caffe_parser_mid
@ -234,3 +237,7 @@ target_link_libraries(converter_lite PRIVATE
mindspore::flatbuffers
pthread
)
if(NOT WIN32)
target_link_libraries(converter_lite PRIVATE dl)
endif()

View File

@ -70,6 +70,7 @@
#include "tools/optimizer/fisson/iter_node_outputs.h"
#include "tools/optimizer/fisson/node_out_shapes.h"
#include "tools/optimizer/parallel/parallel_pass.h"
#include "tools/converter/registry/pass_registry.h"
using std::string;
namespace mindspore::lite {
@ -323,6 +324,22 @@ int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converte
return RET_OK;
}
STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position) {
auto instance = opt::PassRegistry::GetInstance();
auto plugin_passes = instance->GetPasses();
if (plugin_passes.find(position) == plugin_passes.end()) {
MS_LOG(DEBUG) << "there is no plugin pass in current position.";
return RET_OK;
}
auto plugin_pass = plugin_passes.at(position);
if (!plugin_pass->Run(old_graph)) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
return RET_OK;
}
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// quant
if (config->quantType == schema::QuantType_PostTraining) {
@ -383,11 +400,28 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr;
}
}
status = RunFusionPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run fusion pass failed.";
return nullptr;
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass->Init(config->fmk, config->trainModel);
if (!format_pass->RunOnlyForShape(old_graph)) {
MS_LOG(ERROR) << "Run format pass failed.";
return nullptr;
}
status = RunPluginPass(old_graph, opt::POSITION_BEGIN);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run plugin pass failed.";
return nullptr;
}
for (auto &fg : func_graphs_) {
if (!config->disableFusion) {
status = RunFusionPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run fusion pass failed.";
return nullptr;
}
}
status = RunConv1DAdjustPass(fg, config);
@ -397,13 +431,19 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
}
}
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass->Init(config->fmk, config->trainModel);
if (!format_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Run format pass failed.";
return nullptr;
}
status = RunPluginPass(old_graph, opt::POSITION_END);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run plugin pass failed.";
return nullptr;
}
for (auto &fg : func_graphs_) {
status = RunGraphPass(fg, config);
if (status != RET_OK) {

View File

@ -63,6 +63,8 @@ class AnfTransform {
static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
void GetAllFuncGraph(const FuncGraphPtr &func_graph);

View File

@ -20,33 +20,37 @@
#include "tools/converter/converter_flags.h"
#include "src/common/log_adapter.h"
#include "tools/common/storage.h"
#include "parser/caffe/caffe_converter.h"
#include "parser/tflite/tflite_converter.h"
#include "parser/onnx/onnx_converter.h"
#include "parser/tf/tf_converter.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "include/version.h"
#include "src/train/train_populate_parameter.h"
#include "tools/converter/registry/model_parser_registry.h"
#include "src/common/dynamic_library_loader.h"
namespace mindspore {
namespace lite {
using FmkType = converter::FmkType;
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
if (flag.fmkIn == "MINDIR") {
kernel::PopulateTrainParameters();
auto func_graph = LoadMindIR(flag.modelFile);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcgraph failed.";
return nullptr;
}
func_graph->set_attr("graph_name", MakeValue("main_graph"));
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
MindsporeImporter::MindsporeImporter() { kernel::PopulateTrainParameters(); }
std::unique_ptr<Converter> Converter::CreateConverter(converter::FmkType fmk) {
switch (fmk) {
case FmkType::FmkType_MS:
return std::make_unique<MindsporeImporter>();
case FmkType::FmkType_CAFFE:
return std::make_unique<CaffeConverter>();
case FmkType::FmkType_TFLITE:
return std::make_unique<TfliteConverter>();
case FmkType::FmkType_ONNX:
return std::make_unique<OnnxConverter>();
case FmkType::FmkType_TF:
return std::make_unique<TFConverter>();
default: {
auto status = UpdateFuncGraphInputsAndOutputsDtype(func_graph);
if (RET_OK != status) {
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";
return nullptr;
}
return func_graph;
} else {
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmkIn);
if (model_parser_ != nullptr) {
return model_parser_->Parse(flag.modelFile, flag.weightFile);
} else {
MS_LOG(ERROR) << "get funcGraph failed for fmk:" << flag.fmkIn;
return nullptr;
}
}
@ -57,7 +61,21 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
MS_LOG(ERROR) << "Input flag is nullptr";
return nullptr;
}
auto graph = BuildFuncGraph(flag->modelFile, flag->weightFile, flag->quantType);
// load plugin
if (!flag->pluginsPath.empty()) {
DynamicLibraryLoader dynamic_library_loader{};
for (auto &path : flag->pluginsPath) {
auto status = dynamic_library_loader.Open(path.c_str());
if (status != RET_OK) {
MS_LOG(ERROR) << "open dynamic library failed.";
return nullptr;
}
dynamic_library_loader.Close();
}
}
auto graph = BuildFuncGraph(*flag);
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
@ -110,16 +128,8 @@ int RunConverter(int argc, const char **argv) {
}
// Load graph
MS_LOG(DEBUG) << "start reading model file";
auto converter = Converter::CreateConverter(flags->fmk);
if (converter == nullptr) {
oss.clear();
oss << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " "
<< GetErrorInfo(RET_INPUT_PARAM_INVALID);
MS_LOG(ERROR) << oss.str();
std::cout << oss.str() << std::endl;
return RET_INPUT_PARAM_INVALID;
}
auto meta_graph = converter->Convert(flags);
Converter cvt;
auto meta_graph = cvt.Convert(flags);
NotSupportOp::GetInstance()->PrintOps();
status = ReturnCode::GetSingleReturnCode()->status_code();
if (meta_graph == nullptr) {

View File

@ -22,6 +22,7 @@
#include "schema/inner/model_generated.h"
#include "tools/converter/graphdef_transform.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/registry/model_parser_registry.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/anf_transform.h"
#include "tools/converter/converter_context.h"
@ -32,47 +33,17 @@ namespace mindspore {
namespace lite {
class Converter {
public:
static std::unique_ptr<Converter> CreateConverter(converter::FmkType fmk);
virtual ~Converter() = default;
virtual schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
virtual FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
schema::QuantType quant_type) = 0;
Converter() = default;
~Converter() { delete model_parser_; }
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
protected:
Converter() = default;
ModelParser *model_parser_ = nullptr;
std::unique_ptr<GraphDefTransform> metagraph_transform_ = std::make_unique<GraphDefTransform>();
std::unique_ptr<AnfTransform> funcgraph_transform_ = std::make_unique<AnfTransform>();
};
class MindsporeImporter : public Converter {
public:
MindsporeImporter();
~MindsporeImporter() override = default;
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
schema::QuantType quant_type) override {
auto func_graph = LoadMindIR(model_file);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcgraph failed.";
return nullptr;
}
func_graph->set_attr("graph_name", MakeValue("main_graph"));
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
auto status = UpdateFuncGraphInputsAndOutputsDtype(func_graph);
if (RET_OK != status) {
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";
return nullptr;
}
return func_graph;
}
};
int RunConverter(int argc, const char **argv);
} // namespace lite
} // namespace mindspore

View File

@ -15,9 +15,13 @@
*/
#include "tools/converter/converter_flags.h"
#include <regex>
#include <climits>
#include <cstdlib>
#include <string>
#include <algorithm>
#include <fstream>
#include <vector>
#include <memory>
#include "ir/dtype/type_id.h"
namespace mindspore {
@ -180,6 +184,28 @@ int Flags::InitTrainModel() {
return RET_OK;
}
int Flags::InitConfigFile() {
auto plugins_path_str = GetStrFromConfigFile(this->configFile, "plugin_path");
if (!plugins_path_str.empty()) {
const char *delimiter = ";";
this->pluginsPath = SplitStringToVector(plugins_path_str, *delimiter);
}
auto disable_fusion_flag = GetStrFromConfigFile(this->configFile, "disable_fusion");
if (!disable_fusion_flag.empty()) {
if (disable_fusion_flag == "on") {
this->disableFusion = true;
} else if (disable_fusion_flag == "off") {
this->disableFusion = false;
} else {
std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off";
return RET_INPUT_PARAM_INVALID;
}
}
return RET_OK;
}
int Flags::Init(int argc, const char **argv) {
int ret;
if (argc == 1) {
@ -222,6 +248,14 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}
if (!this->configFile.empty()) {
ret = InitConfigFile();
if (ret != RET_OK) {
std::cerr << "Init config file failed.";
return RET_INPUT_PARAM_INVALID;
}
}
ret = InitInputOutputDataType();
if (ret != RET_OK) {
std::cerr << "Init input output datatype failed.";
@ -248,6 +282,80 @@ int Flags::Init(int argc, const char **argv) {
return RET_OK;
}
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
std::string res;
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
return res;
}
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "new resolved_path failed";
return "";
}
#ifdef _WIN32
char *real_path = _fullpath(resolved_path.get(), file.c_str(), 1024);
#else
char *real_path = realpath(file.c_str(), resolved_path.get());
#endif
if (real_path == nullptr || strlen(real_path) == 0) {
MS_LOG(ERROR) << "file path is not valid : " << file;
return "";
}
std::ifstream ifs(resolved_path.get());
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
return res;
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << real_path << "open failed";
return res;
}
std::string line;
while (std::getline(ifs, line)) {
lite::Trim(&line);
if (line.empty()) {
continue;
}
auto index = line.find('=');
if (index == std::string::npos) {
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
return "";
}
auto key = line.substr(0, index);
auto value = line.substr(index + 1);
lite::Trim(&key);
lite::Trim(&value);
if (key == target_key) {
return value;
}
}
return res;
}
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter) {
if (raw_str.empty()) {
MS_LOG(ERROR) << "input string is empty.";
return {};
}
std::vector<std::string> res;
std::string::size_type last_pos = 0;
auto cur_pos = raw_str.find(delimiter);
while (cur_pos != std::string::npos) {
res.push_back(raw_str.substr(last_pos, cur_pos - last_pos));
cur_pos++;
last_pos = cur_pos;
cur_pos = raw_str.find(delimiter, cur_pos);
}
if (last_pos < raw_str.size()) {
res.push_back(raw_str.substr(last_pos, raw_str.size() - last_pos + 1));
}
return res;
}
} // namespace converter
} // namespace lite
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H
#include <string>
#include <vector>
#include "tools/common/flag_parser.h"
#include "ir/dtype/type_id.h"
#include "schema/inner/model_generated.h"
@ -57,6 +58,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitTrainModel();
int InitConfigFile();
int Init(int argc, const char **argv);
public:
@ -83,7 +86,13 @@ class Flags : public virtual mindspore::lite::FlagParser {
int quantWeightChannel;
std::string trainModelIn;
bool trainModel = false;
std::vector<std::string> pluginsPath;
bool disableFusion = false;
};
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter);
} // namespace converter
} // namespace lite
} // namespace mindspore

View File

@ -34,8 +34,8 @@ class ModelParser {
virtual ~ModelParser() = default;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) {
auto ret = ParseToFuncGraph(model_file, weight_file, quant_type);
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) {
auto ret = ParseToFuncGraph(model_file, weight_file);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse to func graph failed : " << ret;
return nullptr;
@ -49,14 +49,25 @@ class ModelParser {
}
protected:
virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) = 0;
virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) = 0;
virtual int PostAdjust() = 0;
protected:
FuncGraphPtr res_graph_ = nullptr;
};
typedef ModelParser *(*ModelParserCreator)();
template <class T>
ModelParser *LiteModelParserCreator() {
auto *parser = new (std::nothrow) T();
if (parser == nullptr) {
MS_LOG(ERROR) << "new model parser failed";
return nullptr;
}
return parser;
}
} // namespace mindspore::lite
#endif

View File

@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_
#include <string>
#include <memory>
#include "tools/converter/converter.h"
#include "tools/converter/graphdef_transform.h"
#include "tools/converter/parser/caffe/caffe_model_parser.h"
namespace mindspore::lite {
class CaffeConverter : public Converter {
public:
CaffeConverter() = default;
~CaffeConverter() override = default;
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
schema::QuantType quant_type) override {
CaffeModelParser parser;
return parser.Parse(model_file, weight_file, quant_type);
}
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_

View File

@ -26,6 +26,8 @@
#include "tools/converter/ops/ops_def.h"
#include "ir/func_graph.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/quant_param_holder.h"
namespace mindspore::lite {
bool IsSkipedLayer(const caffe::LayerParameter &layer) {
@ -56,8 +58,7 @@ CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default;
int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) {
STATUS status = InitOriginModel(model_file, weight_file);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -493,4 +494,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
}
int CaffeModelParser::PostAdjust() { return RET_OK; }
REG_MODEL_PARSER(CAFFE, LiteModelParserCreator<CaffeModelParser>)
} // namespace mindspore::lite

View File

@ -22,9 +22,11 @@
#include <set>
#include <unordered_map>
#include "tools/converter/model_parser.h"
#include "tools/converter/registry/model_parser_registry.h"
#include "proto/caffe.pb.h"
#include "ops/primitive_c.h"
using STATUS = int;
namespace mindspore::lite {
class CaffeModelParser : public ModelParser {
public:
@ -32,8 +34,7 @@ class CaffeModelParser : public ModelParser {
~CaffeModelParser() override;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override;
int PostAdjust() override;

View File

@ -34,7 +34,7 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t
std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end());
auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new a paramValueLite failed.";
MS_LOG(ERROR) << "new a tensor::Tensor failed.";
return RET_ERROR;
}
if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, tensor_info) != RET_OK) {

View File

@ -1,43 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#include <string>
#include <memory>
#include "tools/converter/converter.h"
#include "tools/converter/graphdef_transform.h"
#include "tools/converter/parser/onnx/onnx_model_parser.h"
namespace mindspore {
namespace lite {
class OnnxConverter : public Converter {
public:
OnnxConverter() = default;
~OnnxConverter() override = default;
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
schema::QuantType quant_type) override {
OnnxModelParser parser;
return parser.Parse(model_file, weight_file, quant_type);
}
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H

View File

@ -21,7 +21,6 @@
#include <unordered_map>
#include <utility>
#include "tools/optimizer/common/gllo_utils.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
#include "tools/common/tensor_util.h"
@ -29,6 +28,8 @@
#include "ops/tensor_list_stack.h"
#include "ir/func_graph.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/converter_context.h"
namespace mindspore {
namespace lite {
@ -43,8 +44,7 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) {
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
res_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file);
@ -723,7 +723,7 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
return const_node;
}
ValueNodePtr CreateValueNode(const PrimitiveType &op_type) {
ValueNodePtr CreateValueNode(const schema::PrimitiveType &op_type) {
auto node_type = schema::EnumNamePrimitiveType(op_type);
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
if (op_primc_fns.find(node_type) == op_primc_fns.end()) {
@ -1178,5 +1178,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type
}
int OnnxModelParser::PostAdjust() { return 0; }
REG_MODEL_PARSER(ONNX, LiteModelParserCreator<OnnxModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -29,8 +29,10 @@
#include <unordered_map>
#include "securec/include/securec.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/registry/model_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "proto/onnx.pb.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
@ -40,8 +42,7 @@ class OnnxModelParser : public ModelParser {
~OnnxModelParser() override = default;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override;
int PostAdjust() override;
@ -73,9 +74,9 @@ class OnnxModelParser : public ModelParser {
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode);
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<schema::QuantParamT> *quant_params);
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<schema::QuantParamT> *quant_params);
STATUS CopyTensorQuantParam(const std::string &tensor_name, schema::QuantParamT *quant_param, bool scale_or_not);
STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
const std::string &root_node_name);

View File

@ -1,40 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
#include <string>
#include <memory>
#include "tools/converter/converter.h"
#include "tools/converter/parser/tf/tf_model_parser.h"
namespace mindspore {
namespace lite {
class TFConverter : public Converter {
public:
TFConverter() = default;
~TFConverter() override = default;
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
schema::QuantType quant_type) override {
TFModelParser parser;
return parser.Parse(model_file, weight_file, quant_type);
}
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_

View File

@ -28,6 +28,7 @@
#include "ir/anf.h"
#include "abstract/utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/quant_param_holder.h"
namespace mindspore {
namespace lite {
@ -472,8 +473,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
return RET_OK;
}
int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile) {
NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) {
@ -559,8 +559,8 @@ int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::str
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
const tensorflow::FunctionDef &tf_sub_fuction, CNodePtr cnode,
FuncGraphPtr sub_func_graph) {
const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,
const FuncGraphPtr &sub_func_graph) {
std::vector<ParameterPtr> sub_graph_inputs;
auto &tf_sub_signature = tf_sub_fuction.signature();
auto &sub_graph_name = tf_sub_signature.name();
@ -598,7 +598,7 @@ STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorfl
STATUS TFModelParser::ConvertSubgraphOutputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_sub_node_map,
const tensorflow::FunctionDef &tf_sub_fuction,
FuncGraphPtr sub_func_graph) {
const FuncGraphPtr &sub_func_graph) {
auto &tf_sub_signature = tf_sub_fuction.signature();
auto &sub_graph_name = tf_sub_signature.name();
@ -949,7 +949,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
return status;
}
STATUS TFModelParser::ProcessControlFlowOp(CNodePtr anf_node, const string op_type,
STATUS TFModelParser::ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type,
const tensorflow::NodeDef &node_def) {
if (op_type == "StatelessWhile" || op_type == "While") {
MS_LOG(INFO) << "find while node:" << node_def.name();
@ -1077,5 +1077,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
}
int TFModelParser::PostAdjust() { return 0; }
REG_MODEL_PARSER(TF, LiteModelParserCreator<TFModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "securec/include/securec.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/registry/model_parser_registry.h"
#include "ops/primitive_c.h"
namespace mindspore {
@ -38,14 +39,15 @@ class TFModelParser : public ModelParser {
TFModelParser() = default;
~TFModelParser() override = default;
int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile) override;
int PostAdjust() override;
private:
static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value,
const TypeId &type, const ParameterPtr &parameter, std::vector<int64_t> *shape_vector);
static STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value,
const TypeId &type, const ParameterPtr &parameter,
std::vector<int64_t> *shape_vector);
static STATUS SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
@ -63,7 +65,7 @@ class TFModelParser : public ModelParser {
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
STATUS ProcessControlFlowOp(CNodePtr anf_node, const string op_type, const tensorflow::NodeDef &node_def);
STATUS ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def);
STATUS ConvertRootGraphOutputs();
@ -71,17 +73,18 @@ class TFModelParser : public ModelParser {
STATUS ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
const tensorflow::FunctionDef &tf_sub_fuction, CNodePtr cnode,
FuncGraphPtr sub_func_graph);
const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,
const FuncGraphPtr &sub_func_graph);
static STATUS ConvertSubgraphOutputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_sub_node_map,
const tensorflow::FunctionDef &tf_sub_fuction, FuncGraphPtr sub_func_graph);
const tensorflow::FunctionDef &tf_sub_fuction,
const FuncGraphPtr &sub_func_graph);
STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map,
const std::map<CNodePtr, FuncGraphPtr> &second_func_map);
STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c);
static STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c);
static STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);

View File

@ -1,42 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
#include <string>
#include <memory>
#include <map>
#include "tools/converter/converter.h"
#include "tools/converter/graphdef_transform.h"
#include "tools/converter/parser/tflite/tflite_model_parser.h"
namespace mindspore::lite {
class TfliteConverter : public Converter {
public:
TfliteConverter() = default;
~TfliteConverter() override = default;
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
schema::QuantType quant_type) override {
TfliteModelParser parser;
return parser.Parse(model_file, weight_file, quant_type);
}
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_

View File

@ -19,12 +19,14 @@
#include <memory>
#include <algorithm>
#include <utility>
#include "tools/converter/converter_flags.h"
#include "src/common/file_utils.h"
#include "tools/converter/ops/ops_def.h"
#include "ops/primitive_c.h"
#include "ir/func_graph.h"
#include "src/common/file_utils.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/common/graph_util.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/converter_flags.h"
namespace mindspore::lite {
std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) {
@ -42,8 +44,7 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
return tflite::UnPackModel(tflite_model_buf_);
}
int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) {
// load graph
tflite_model_ = ReadTfliteModel(model_file.c_str());
if (tflite_model_ == nullptr) {
@ -489,4 +490,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
}
int TfliteModelParser::PostAdjust() { return 0; }
REG_MODEL_PARSER(TFLITE, LiteModelParserCreator<TfliteModelParser>)
} // namespace mindspore::lite

View File

@ -21,6 +21,7 @@
#include <memory>
#include <vector>
#include "tools/converter/model_parser.h"
#include "tools/converter/registry/model_parser_registry.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include "tools/common/tensor_util.h"
@ -32,8 +33,7 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override = default;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override;
int PostAdjust() override;

View File

@ -0,0 +1,10 @@
file(GLOB CONVERT_REG_SRC
pass_registry.cc
model_parser_registry.cc
)
set_property(SOURCE ${CONVERT_REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(mslite_converter_plugin_reg SHARED ${CONVERT_REG_SRC})
add_dependencies(mslite_converter_plugin_reg fbs_src)
add_dependencies(mslite_converter_plugin_reg fbs_inner_src)

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/registry/model_parser_registry.h"
#include <string>
#include <unordered_map>
namespace mindspore {
namespace lite {
ModelParserRegistry *ModelParserRegistry::GetInstance() {
static ModelParserRegistry instance;
return &instance;
}
ModelParser *ModelParserRegistry::GetModelParser(const std::string &fmk) {
auto it = parsers_.find(fmk);
if (it != parsers_.end()) {
auto creator = it->second;
return creator();
}
return nullptr;
}
void ModelParserRegistry::RegParser(const std::string &fmk, ModelParserCreator creator) {
auto instance = ModelParserRegistry::GetInstance();
instance->parsers_[fmk] = creator;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H
#include <string>
#include <memory>
#include <unordered_map>
#include "include/lite_utils.h"
namespace mindspore::lite {
class MS_API ModelParser;
typedef ModelParser *(*ModelParserCreator)();
class MS_API ModelParserRegistry {
public:
ModelParserRegistry() = default;
~ModelParserRegistry() = default;
static ModelParserRegistry *GetInstance();
ModelParser *GetModelParser(const std::string &fmk);
void RegParser(const std::string &fmk, ModelParserCreator creator);
std::unordered_map<std::string, ModelParserCreator> parsers_;
};
class MS_API ModelRegistrar {
public:
ModelRegistrar(const std::string &fmk, ModelParserCreator creator) {
ModelParserRegistry::GetInstance()->RegParser(fmk, creator);
}
~ModelRegistrar() = default;
};
#define REG_MODEL_PARSER(fmk, parserCreator) static ModelRegistrar g_##type##fmk##ModelParserReg(#fmk, parserCreator);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/registry/pass_registry.h"
#include <iostream>
#include <unordered_map>
namespace mindspore {
namespace opt {
PassRegistry *PassRegistry::GetInstance() {
static PassRegistry instance;
return &instance;
}
void PassRegistry::RegPass(int position, const PassPtr &pass) {
if (pass == nullptr) {
std::cout << "pass is nullptr" << std::endl;
return;
}
auto instance = PassRegistry::GetInstance();
std::unique_lock<std::mutex> lock(instance->mutex_);
instance->passes_[position] = pass;
}
const std::unordered_map<int, PassPtr> &PassRegistry::GetPasses() const {
auto instance = PassRegistry::GetInstance();
std::unique_lock<std::mutex> lock(instance->mutex_);
return instance->passes_;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_
#include <vector>
#include <string>
#include <utility>
#include <mutex>
#include <memory>
#include <unordered_map>
#include "include/lite_utils.h"
namespace mindspore {
namespace opt {
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
class MS_API Pass;
using PassPtr = std::shared_ptr<Pass>;
class MS_API PassRegistry {
public:
virtual ~PassRegistry() = default;
static PassRegistry *GetInstance();
void RegPass(int position, const PassPtr &pass);
const std::unordered_map<int, PassPtr> &GetPasses() const;
private:
PassRegistry() = default;
private:
std::unordered_map<int, PassPtr> passes_;
std::mutex mutex_;
};
class MS_API PassRegistrar {
public:
PassRegistrar(int pos, const PassPtr &pass) { PassRegistry::GetInstance()->RegPass(pos, pass); }
~PassRegistrar() = default;
};
#define REG_PASS(position, pass) static PassRegistrar g_##position##PassReg(position, std::make_shared<pass>());
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_