diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 835689b2309..adf5104806c 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -217,6 +217,8 @@ elseif(WIN32) install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/build/mindspore/tools/converter/mindspore_core/gvar/libmindspore_gvar.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin_reg.dll + DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/../bin/libglog.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${NNACL_FILES} DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl COMPONENT ${RUNTIME_COMPONENT_NAME}) @@ -286,6 +288,8 @@ else() COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/mindspore_core/gvar/libmindspore_gvar.so DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin_reg.so + DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/third_party/glog/lib RENAME libglog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index a626711bbe0..77f7cbe6abb 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -46,6 +46,7 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/common/file_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/log_adapter.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/string_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/prim_util.cc diff --git a/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc index 52981daf560..54548699d0e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc @@ -91,8 +91,8 @@ int SimpleTile(void *cdata, int task_id) { void TileCPUKernel::FillOneDimTileParam() { // check if tile exact one dim int large_one_multiple_count = 0; - int multiple; - int mul_index; + int multiple = 0; + int mul_index = 0; for (auto i = 0; i < tile_parameter_->in_dim_; ++i) { if (tile_parameter_->multiples_[i] > 1) { large_one_multiple_count++; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index b79bfebe4f0..9015f33435e 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -430,6 +430,7 @@ if(ENABLE_CONVERTER) add_dependencies(lite-test fbs_inner_src) target_link_libraries(lite-test anf_exporter_mid + mslite_converter_plugin_reg tflite_parser_mid caffe_parser_mid onnx_parser_mid diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc index 2451a4bb1db..c1e50b67a19 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc @@ -22,7 +22,6 @@ #include "include/context.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/anf_exporter/anf_exporter.h" diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc index 45abc045c0c..09a83ef7fdc 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc @@ -22,7 +22,6 @@ #include "include/context.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" #include "test/common/import_from_meta_graphT.h" diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc index ea729b6d6fc..7c28dea60f7 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc @@ -22,7 +22,6 @@ #include "include/context.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" #include "test/common/import_from_meta_graphT.h" diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc index 5c2c9aa5d36..fafbc6e54a2 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc @@ -22,7 +22,6 @@ #include "include/context.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" #include "test/common/import_from_meta_graphT.h" diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc index 7ddaedc3f8b..701cdcc6961 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc @@ -22,7 +22,6 @@ #include "include/context.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/anf_exporter/anf_exporter.h" #include "test/common/import_from_meta_graphT.h" diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 9f0fdd5bc92..ffac7fe6a06 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -32,6 +32,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/dynamic_library_loader.cc ../optimizer/common/node_pass_extends.cc ../optimizer/common/pass_manager_extends.cc @@ -107,6 +108,7 @@ add_subdirectory(parser/onnx) add_subdirectory(parser/tf) add_subdirectory(legacy_optimizer) add_subdirectory(quantizer) +add_subdirectory(registry) add_subdirectory(${CORE_DIR} mindspore_core) set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) @@ -216,6 +218,7 @@ add_dependencies(converter_lite fbs_src) add_dependencies(converter_lite fbs_inner_src) target_link_libraries(converter_lite PRIVATE + mslite_converter_plugin_reg tflite_parser_mid tf_parser_mid caffe_parser_mid @@ -234,3 +237,7 @@ target_link_libraries(converter_lite PRIVATE mindspore::flatbuffers pthread ) + +if(NOT WIN32) + target_link_libraries(converter_lite PRIVATE dl) +endif() diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index b35209d1bc6..82b1fea7b0d 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -70,6 +70,7 @@ #include "tools/optimizer/fisson/iter_node_outputs.h" #include "tools/optimizer/fisson/node_out_shapes.h" #include "tools/optimizer/parallel/parallel_pass.h" +#include "tools/converter/registry/pass_registry.h" using std::string; namespace mindspore::lite { @@ -323,6 +324,22 @@ int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converte return RET_OK; } +STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position) { + auto instance = opt::PassRegistry::GetInstance(); + auto plugin_passes = instance->GetPasses(); + if (plugin_passes.find(position) == plugin_passes.end()) { + MS_LOG(DEBUG) << "there is no plugin pass in current position."; + return RET_OK; + } + + auto plugin_pass = plugin_passes.at(position); + if (!plugin_pass->Run(old_graph)) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return RET_ERROR; + } + return RET_OK; +} + int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) { // quant if (config->quantType == schema::QuantType_PostTraining) { @@ -383,11 +400,28 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con MS_LOG(ERROR) << "Run convert pass failed."; return nullptr; } + } - status = RunFusionPass(fg, config); - if (status != RET_OK) { - MS_LOG(ERROR) << "Run fusion pass failed."; - return nullptr; + auto format_pass = std::make_shared(); + format_pass->Init(config->fmk, config->trainModel); + if (!format_pass->RunOnlyForShape(old_graph)) { + MS_LOG(ERROR) << "Run format pass failed."; + return nullptr; + } + + status = RunPluginPass(old_graph, opt::POSITION_BEGIN); + if (status != RET_OK) { + MS_LOG(ERROR) << "Run plugin pass failed."; + return nullptr; + } + + for (auto &fg : func_graphs_) { + if (!config->disableFusion) { + status = RunFusionPass(fg, config); + if (status != RET_OK) { + MS_LOG(ERROR) << "Run fusion pass failed."; + return nullptr; + } } status = RunConv1DAdjustPass(fg, config); @@ -397,13 +431,19 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con } } - auto format_pass = std::make_shared(); + format_pass = std::make_shared(); format_pass->Init(config->fmk, config->trainModel); if (!format_pass->Run(old_graph)) { MS_LOG(ERROR) << "Run format pass failed."; return nullptr; } + status = RunPluginPass(old_graph, opt::POSITION_END); + if (status != RET_OK) { + MS_LOG(ERROR) << "Run plugin pass failed."; + return nullptr; + } + for (auto &fg : func_graphs_) { status = RunGraphPass(fg, config); if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 6f65cf900a5..a7afd39f94e 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -63,6 +63,8 @@ class AnfTransform { static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position); + int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config); void GetAllFuncGraph(const FuncGraphPtr &func_graph); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 29501afd07f..1203ea63a7c 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -20,33 +20,37 @@ #include "tools/converter/converter_flags.h" #include "src/common/log_adapter.h" #include "tools/common/storage.h" -#include "parser/caffe/caffe_converter.h" -#include "parser/tflite/tflite_converter.h" -#include "parser/onnx/onnx_converter.h" -#include "parser/tf/tf_converter.h" #include "tools/anf_exporter/anf_exporter.h" #include "include/version.h" #include "src/train/train_populate_parameter.h" +#include "tools/converter/registry/model_parser_registry.h" +#include "src/common/dynamic_library_loader.h" namespace mindspore { namespace lite { -using FmkType = converter::FmkType; +FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) { + if (flag.fmkIn == "MINDIR") { + kernel::PopulateTrainParameters(); + auto func_graph = LoadMindIR(flag.modelFile); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "get funcgraph failed."; + return nullptr; + } + func_graph->set_attr("graph_name", MakeValue("main_graph")); + func_graph->set_attr("fmk", MakeValue(static_cast(converter::FmkType_MS))); -MindsporeImporter::MindsporeImporter() { kernel::PopulateTrainParameters(); } - -std::unique_ptr Converter::CreateConverter(converter::FmkType fmk) { - switch (fmk) { - case FmkType::FmkType_MS: - return std::make_unique(); - case FmkType::FmkType_CAFFE: - return std::make_unique(); - case FmkType::FmkType_TFLITE: - return std::make_unique(); - case FmkType::FmkType_ONNX: - return std::make_unique(); - case FmkType::FmkType_TF: - return std::make_unique(); - default: { + auto status = UpdateFuncGraphInputsAndOutputsDtype(func_graph); + if (RET_OK != status) { + MS_LOG(ERROR) << "update graph inputs and outputs dtype failed."; + return nullptr; + } + return func_graph; + } else { + model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmkIn); + if (model_parser_ != nullptr) { + return model_parser_->Parse(flag.modelFile, flag.weightFile); + } else { + MS_LOG(ERROR) << "get funcGraph failed for fmk:" << flag.fmkIn; return nullptr; } } @@ -57,7 +61,21 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr & MS_LOG(ERROR) << "Input flag is nullptr"; return nullptr; } - auto graph = BuildFuncGraph(flag->modelFile, flag->weightFile, flag->quantType); + + // load plugin + if (!flag->pluginsPath.empty()) { + DynamicLibraryLoader dynamic_library_loader{}; + for (auto &path : flag->pluginsPath) { + auto status = dynamic_library_loader.Open(path.c_str()); + if (status != RET_OK) { + MS_LOG(ERROR) << "open dynamic library failed."; + return nullptr; + } + dynamic_library_loader.Close(); + } + } + + auto graph = BuildFuncGraph(*flag); if (graph == nullptr) { MS_LOG(ERROR) << "Parser/Import model return nullptr"; return nullptr; @@ -110,16 +128,8 @@ int RunConverter(int argc, const char **argv) { } // Load graph MS_LOG(DEBUG) << "start reading model file"; - auto converter = Converter::CreateConverter(flags->fmk); - if (converter == nullptr) { - oss.clear(); - oss << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " - << GetErrorInfo(RET_INPUT_PARAM_INVALID); - MS_LOG(ERROR) << oss.str(); - std::cout << oss.str() << std::endl; - return RET_INPUT_PARAM_INVALID; - } - auto meta_graph = converter->Convert(flags); + Converter cvt; + auto meta_graph = cvt.Convert(flags); NotSupportOp::GetInstance()->PrintOps(); status = ReturnCode::GetSingleReturnCode()->status_code(); if (meta_graph == nullptr) { diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index 2a6dbd525a4..fe6a5851d13 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -22,6 +22,7 @@ #include "schema/inner/model_generated.h" #include "tools/converter/graphdef_transform.h" #include "tools/converter/model_parser.h" +#include "tools/converter/registry/model_parser_registry.h" #include "tools/converter/converter_flags.h" #include "tools/converter/anf_transform.h" #include "tools/converter/converter_context.h" @@ -32,47 +33,17 @@ namespace mindspore { namespace lite { class Converter { public: - static std::unique_ptr CreateConverter(converter::FmkType fmk); - - virtual ~Converter() = default; - - virtual schema::MetaGraphT *Convert(const std::unique_ptr &flag); - - virtual FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, - schema::QuantType quant_type) = 0; + Converter() = default; + ~Converter() { delete model_parser_; } + schema::MetaGraphT *Convert(const std::unique_ptr &flag); + FuncGraphPtr BuildFuncGraph(const converter::Flags &flag); protected: - Converter() = default; - + ModelParser *model_parser_ = nullptr; std::unique_ptr metagraph_transform_ = std::make_unique(); std::unique_ptr funcgraph_transform_ = std::make_unique(); }; -class MindsporeImporter : public Converter { - public: - MindsporeImporter(); - - ~MindsporeImporter() override = default; - - FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, - schema::QuantType quant_type) override { - auto func_graph = LoadMindIR(model_file); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "get funcgraph failed."; - return nullptr; - } - func_graph->set_attr("graph_name", MakeValue("main_graph")); - func_graph->set_attr("fmk", MakeValue(static_cast(converter::FmkType_MS))); - - auto status = UpdateFuncGraphInputsAndOutputsDtype(func_graph); - if (RET_OK != status) { - MS_LOG(ERROR) << "update graph inputs and outputs dtype failed."; - return nullptr; - } - return func_graph; - } -}; - int RunConverter(int argc, const char **argv); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 725e228a1be..8ae562b2eb8 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -15,9 +15,13 @@ */ #include "tools/converter/converter_flags.h" -#include +#include +#include #include #include +#include +#include +#include #include "ir/dtype/type_id.h" namespace mindspore { @@ -180,6 +184,28 @@ int Flags::InitTrainModel() { return RET_OK; } +int Flags::InitConfigFile() { + auto plugins_path_str = GetStrFromConfigFile(this->configFile, "plugin_path"); + if (!plugins_path_str.empty()) { + const char *delimiter = ";"; + this->pluginsPath = SplitStringToVector(plugins_path_str, *delimiter); + } + + auto disable_fusion_flag = GetStrFromConfigFile(this->configFile, "disable_fusion"); + if (!disable_fusion_flag.empty()) { + if (disable_fusion_flag == "on") { + this->disableFusion = true; + } else if (disable_fusion_flag == "off") { + this->disableFusion = false; + } else { + std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off"; + return RET_INPUT_PARAM_INVALID; + } + } + + return RET_OK; +} + int Flags::Init(int argc, const char **argv) { int ret; if (argc == 1) { @@ -222,6 +248,14 @@ int Flags::Init(int argc, const char **argv) { return RET_INPUT_PARAM_INVALID; } + if (!this->configFile.empty()) { + ret = InitConfigFile(); + if (ret != RET_OK) { + std::cerr << "Init config file failed."; + return RET_INPUT_PARAM_INVALID; + } + } + ret = InitInputOutputDataType(); if (ret != RET_OK) { std::cerr << "Init input output datatype failed."; @@ -248,6 +282,80 @@ int Flags::Init(int argc, const char **argv) { return RET_OK; } + +std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key) { + std::string res; + if (file.empty()) { + MS_LOG(ERROR) << "file is nullptr"; + return res; + } + + auto resolved_path = std::make_unique(PATH_MAX); + if (resolved_path == nullptr) { + MS_LOG(ERROR) << "new resolved_path failed"; + return ""; + } + +#ifdef _WIN32 + char *real_path = _fullpath(resolved_path.get(), file.c_str(), 1024); +#else + char *real_path = realpath(file.c_str(), resolved_path.get()); +#endif + if (real_path == nullptr || strlen(real_path) == 0) { + MS_LOG(ERROR) << "file path is not valid : " << file; + return ""; + } + std::ifstream ifs(resolved_path.get()); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << real_path << " is not exist"; + return res; + } + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << real_path << "open failed"; + return res; + } + std::string line; + while (std::getline(ifs, line)) { + lite::Trim(&line); + if (line.empty()) { + continue; + } + auto index = line.find('='); + if (index == std::string::npos) { + MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; + return ""; + } + auto key = line.substr(0, index); + auto value = line.substr(index + 1); + lite::Trim(&key); + lite::Trim(&value); + if (key == target_key) { + return value; + } + } + return res; +} + +std::vector SplitStringToVector(const std::string &raw_str, const char &delimiter) { + if (raw_str.empty()) { + MS_LOG(ERROR) << "input string is empty."; + return {}; + } + std::vector res; + std::string::size_type last_pos = 0; + auto cur_pos = raw_str.find(delimiter); + while (cur_pos != std::string::npos) { + res.push_back(raw_str.substr(last_pos, cur_pos - last_pos)); + cur_pos++; + last_pos = cur_pos; + cur_pos = raw_str.find(delimiter, cur_pos); + } + if (last_pos < raw_str.size()) { + res.push_back(raw_str.substr(last_pos, raw_str.size() - last_pos + 1)); + } + return res; +} + } // namespace converter } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 8593efde30f..a16269b5069 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H #include +#include #include "tools/common/flag_parser.h" #include "ir/dtype/type_id.h" #include "schema/inner/model_generated.h" @@ -57,6 +58,8 @@ class Flags : public virtual mindspore::lite::FlagParser { int InitTrainModel(); + int InitConfigFile(); + int Init(int argc, const char **argv); public: @@ -83,7 +86,13 @@ class Flags : public virtual mindspore::lite::FlagParser { int quantWeightChannel; std::string trainModelIn; bool trainModel = false; + std::vector pluginsPath; + bool disableFusion = false; }; + +std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key); + +std::vector SplitStringToVector(const std::string &raw_str, const char &delimiter); } // namespace converter } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index 308855958e1..578d4a94ed9 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -34,8 +34,8 @@ class ModelParser { virtual ~ModelParser() = default; - FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) { - auto ret = ParseToFuncGraph(model_file, weight_file, quant_type); + FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) { + auto ret = ParseToFuncGraph(model_file, weight_file); if (ret != RET_OK) { MS_LOG(ERROR) << "Parse to func graph failed : " << ret; return nullptr; @@ -49,14 +49,25 @@ class ModelParser { } protected: - virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) = 0; + virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) = 0; virtual int PostAdjust() = 0; protected: FuncGraphPtr res_graph_ = nullptr; }; + +typedef ModelParser *(*ModelParserCreator)(); + +template +ModelParser *LiteModelParserCreator() { + auto *parser = new (std::nothrow) T(); + if (parser == nullptr) { + MS_LOG(ERROR) << "new model parser failed"; + return nullptr; + } + return parser; +} } // namespace mindspore::lite #endif diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h deleted file mode 100644 index 11b823609e8..00000000000 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ - -#include -#include -#include "tools/converter/converter.h" -#include "tools/converter/graphdef_transform.h" -#include "tools/converter/parser/caffe/caffe_model_parser.h" - -namespace mindspore::lite { -class CaffeConverter : public Converter { - public: - CaffeConverter() = default; - - ~CaffeConverter() override = default; - - FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, - schema::QuantType quant_type) override { - CaffeModelParser parser; - return parser.Parse(model_file, weight_file, quant_type); - } -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index afcf47d862e..b09d9cf9e1b 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -26,6 +26,8 @@ #include "tools/converter/ops/ops_def.h" #include "ir/func_graph.h" #include "tools/converter/converter_flags.h" +#include "tools/converter/converter_context.h" +#include "tools/converter/quant_param_holder.h" namespace mindspore::lite { bool IsSkipedLayer(const caffe::LayerParameter &layer) { @@ -56,8 +58,7 @@ CaffeModelParser::CaffeModelParser() = default; CaffeModelParser::~CaffeModelParser() = default; -int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { +int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) { STATUS status = InitOriginModel(model_file, weight_file); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -493,4 +494,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name) } int CaffeModelParser::PostAdjust() { return RET_OK; } + +REG_MODEL_PARSER(CAFFE, LiteModelParserCreator) } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index 1faf493a5bc..9454c02864f 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -22,9 +22,11 @@ #include #include #include "tools/converter/model_parser.h" +#include "tools/converter/registry/model_parser_registry.h" #include "proto/caffe.pb.h" #include "ops/primitive_c.h" +using STATUS = int; namespace mindspore::lite { class CaffeModelParser : public ModelParser { public: @@ -32,8 +34,7 @@ class CaffeModelParser : public ModelParser { ~CaffeModelParser() override; - int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; + int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override; int PostAdjust() override; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc index c05351af036..e4909330155 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -34,7 +34,7 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t std::vector shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); auto tensor_info = std::make_shared(data_type, shape_vector); if (tensor_info == nullptr) { - MS_LOG(ERROR) << "new a paramValueLite failed."; + MS_LOG(ERROR) << "new a tensor::Tensor failed."; return RET_ERROR; } if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, tensor_info) != RET_OK) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h deleted file mode 100644 index 4e253ed3085..00000000000 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H - -#include -#include -#include "tools/converter/converter.h" -#include "tools/converter/graphdef_transform.h" -#include "tools/converter/parser/onnx/onnx_model_parser.h" - -namespace mindspore { -namespace lite { -class OnnxConverter : public Converter { - public: - OnnxConverter() = default; - - ~OnnxConverter() override = default; - - FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, - schema::QuantType quant_type) override { - OnnxModelParser parser; - return parser.Parse(model_file, weight_file, quant_type); - } -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 6a582beea4c..111adc9fc15 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -21,7 +21,6 @@ #include #include #include "tools/optimizer/common/gllo_utils.h" -#include "src/common/utils.h" #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" #include "tools/common/tensor_util.h" @@ -29,6 +28,8 @@ #include "ops/tensor_list_stack.h" #include "ir/func_graph.h" #include "tools/converter/converter_flags.h" +#include "tools/converter/quant_param_holder.h" +#include "tools/converter/converter_context.h" namespace mindspore { namespace lite { @@ -43,8 +44,7 @@ static const std::unordered_map TYPE_MAP = { {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; -int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { +int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) { NotSupportOp::GetInstance()->set_fmk_type("ONNX"); res_graph_ = std::make_shared(); auto status = InitOriginModel(model_file); @@ -723,7 +723,7 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) { return const_node; } -ValueNodePtr CreateValueNode(const PrimitiveType &op_type) { +ValueNodePtr CreateValueNode(const schema::PrimitiveType &op_type) { auto node_type = schema::EnumNamePrimitiveType(op_type); auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); if (op_primc_fns.find(node_type) == op_primc_fns.end()) { @@ -1178,5 +1178,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type } int OnnxModelParser::PostAdjust() { return 0; } + +REG_MODEL_PARSER(ONNX, LiteModelParserCreator) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 463a6dfefad..79e8604271e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -29,8 +29,10 @@ #include #include "securec/include/securec.h" #include "tools/converter/model_parser.h" +#include "tools/converter/registry/model_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "proto/onnx.pb.h" +#include "schema/inner/model_generated.h" namespace mindspore { namespace lite { @@ -40,8 +42,7 @@ class OnnxModelParser : public ModelParser { ~OnnxModelParser() override = default; - int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; + int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override; int PostAdjust() override; @@ -73,9 +74,9 @@ class OnnxModelParser : public ModelParser { std::unordered_map *anf_nodes_map, const CNodePtr &cnode); STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); - STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector *quant_params); - STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector *quant_params); - STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); + STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector *quant_params); + STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector *quant_params); + STATUS CopyTensorQuantParam(const std::string &tensor_name, schema::QuantParamT *quant_param, bool scale_or_not); STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node, std::unordered_map *anf_nodes_map, const std::string &root_node_name); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_converter.h b/mindspore/lite/tools/converter/parser/tf/tf_converter.h deleted file mode 100644 index 108d65f6eda..00000000000 --- a/mindspore/lite/tools/converter/parser/tf/tf_converter.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ -#include -#include -#include "tools/converter/converter.h" -#include "tools/converter/parser/tf/tf_model_parser.h" - -namespace mindspore { -namespace lite { -class TFConverter : public Converter { - public: - TFConverter() = default; - - ~TFConverter() override = default; - - FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, - schema::QuantType quant_type) override { - TFModelParser parser; - return parser.Parse(model_file, weight_file, quant_type); - } -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 8c9dc156189..d2cf2927e5b 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -28,6 +28,7 @@ #include "ir/anf.h" #include "abstract/utils.h" #include "tools/converter/converter_flags.h" +#include "tools/converter/quant_param_holder.h" namespace mindspore { namespace lite { @@ -472,8 +473,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts( return RET_OK; } -int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { +int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile) { NotSupportOp::GetInstance()->set_fmk_type("TF"); auto status = ValidateFileStr(modelFile, ".pb"); if (status != RET_OK) { @@ -559,8 +559,8 @@ int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::str STATUS TFModelParser::ConvertSubgraphInputs(std::map *tf_sub_node_map, std::unordered_map *anf_sub_node_map, - const tensorflow::FunctionDef &tf_sub_fuction, CNodePtr cnode, - FuncGraphPtr sub_func_graph) { + const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode, + const FuncGraphPtr &sub_func_graph) { std::vector sub_graph_inputs; auto &tf_sub_signature = tf_sub_fuction.signature(); auto &sub_graph_name = tf_sub_signature.name(); @@ -598,7 +598,7 @@ STATUS TFModelParser::ConvertSubgraphInputs(std::map *tf_sub_node_map, const std::unordered_map &anf_sub_node_map, const tensorflow::FunctionDef &tf_sub_fuction, - FuncGraphPtr sub_func_graph) { + const FuncGraphPtr &sub_func_graph) { auto &tf_sub_signature = tf_sub_fuction.signature(); auto &sub_graph_name = tf_sub_signature.name(); @@ -949,7 +949,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, return status; } -STATUS TFModelParser::ProcessControlFlowOp(CNodePtr anf_node, const string op_type, +STATUS TFModelParser::ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def) { if (op_type == "StatelessWhile" || op_type == "While") { MS_LOG(INFO) << "find while node:" << node_def.name(); @@ -1077,5 +1077,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector *output_nodes, } int TFModelParser::PostAdjust() { return 0; } + +REG_MODEL_PARSER(TF, LiteModelParserCreator) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index c76b107d322..c3a2fa713c7 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -29,6 +29,7 @@ #include "securec/include/securec.h" #include "tools/common/tensor_util.h" #include "tools/converter/model_parser.h" +#include "tools/converter/registry/model_parser_registry.h" #include "ops/primitive_c.h" namespace mindspore { @@ -38,14 +39,15 @@ class TFModelParser : public ModelParser { TFModelParser() = default; ~TFModelParser() override = default; - int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); + int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile) override; int PostAdjust() override; private: static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); - STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, - const TypeId &type, const ParameterPtr ¶meter, std::vector *shape_vector); + static STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, + const TypeId &type, const ParameterPtr ¶meter, + std::vector *shape_vector); static STATUS SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, std::unordered_map *anf_node_map); @@ -63,7 +65,7 @@ class TFModelParser : public ModelParser { const std::map &tf_node_map, const FuncGraphPtr &func_graph_ptr, std::unordered_map *anf_node_map); - STATUS ProcessControlFlowOp(CNodePtr anf_node, const string op_type, const tensorflow::NodeDef &node_def); + STATUS ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def); STATUS ConvertRootGraphOutputs(); @@ -71,17 +73,18 @@ class TFModelParser : public ModelParser { STATUS ConvertSubgraphInputs(std::map *tf_sub_node_map, std::unordered_map *anf_sub_node_map, - const tensorflow::FunctionDef &tf_sub_fuction, CNodePtr cnode, - FuncGraphPtr sub_func_graph); + const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode, + const FuncGraphPtr &sub_func_graph); static STATUS ConvertSubgraphOutputs(std::map *tf_sub_node_map, const std::unordered_map &anf_sub_node_map, - const tensorflow::FunctionDef &tf_sub_fuction, FuncGraphPtr sub_func_graph); + const tensorflow::FunctionDef &tf_sub_fuction, + const FuncGraphPtr &sub_func_graph); STATUS ControlFlowNodePostProcess(const std::map &first_func_map, const std::map &second_func_map); - STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c); + static STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c); static STATUS MakeAnfGraphOutputs(std::vector *output_nodes, const FuncGraphPtr &anf_graph); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h deleted file mode 100644 index cb87ea9cc4e..00000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ - -#include -#include -#include -#include "tools/converter/converter.h" -#include "tools/converter/graphdef_transform.h" -#include "tools/converter/parser/tflite/tflite_model_parser.h" - -namespace mindspore::lite { -class TfliteConverter : public Converter { - public: - TfliteConverter() = default; - - ~TfliteConverter() override = default; - - FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file, - schema::QuantType quant_type) override { - TfliteModelParser parser; - return parser.Parse(model_file, weight_file, quant_type); - } -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 50acd8e1ee0..533f347493b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -19,12 +19,14 @@ #include #include #include -#include "tools/converter/converter_flags.h" -#include "src/common/file_utils.h" -#include "tools/converter/ops/ops_def.h" #include "ops/primitive_c.h" #include "ir/func_graph.h" +#include "src/common/file_utils.h" +#include "tools/converter/ops/ops_def.h" #include "tools/common/graph_util.h" +#include "tools/converter/quant_param_holder.h" +#include "tools/converter/converter_context.h" +#include "tools/converter/converter_flags.h" namespace mindspore::lite { std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *model_path) { @@ -42,8 +44,7 @@ std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *m return tflite::UnPackModel(tflite_model_buf_); } -int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { +int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) { // load graph tflite_model_ = ReadTfliteModel(model_file.c_str()); if (tflite_model_ == nullptr) { @@ -489,4 +490,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const } int TfliteModelParser::PostAdjust() { return 0; } + +REG_MODEL_PARSER(TFLITE, LiteModelParserCreator) } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index e88a830cd72..5c3598b2714 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -21,6 +21,7 @@ #include #include #include "tools/converter/model_parser.h" +#include "tools/converter/registry/model_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/common/tensor_util.h" @@ -32,8 +33,7 @@ class TfliteModelParser : public ModelParser { ~TfliteModelParser() override = default; - int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; + int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override; int PostAdjust() override; diff --git a/mindspore/lite/tools/converter/registry/CMakeLists.txt b/mindspore/lite/tools/converter/registry/CMakeLists.txt new file mode 100644 index 00000000000..a15d42ffab5 --- /dev/null +++ b/mindspore/lite/tools/converter/registry/CMakeLists.txt @@ -0,0 +1,10 @@ +file(GLOB CONVERT_REG_SRC + pass_registry.cc + model_parser_registry.cc + ) + +set_property(SOURCE ${CONVERT_REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) +add_library(mslite_converter_plugin_reg SHARED ${CONVERT_REG_SRC}) + +add_dependencies(mslite_converter_plugin_reg fbs_src) +add_dependencies(mslite_converter_plugin_reg fbs_inner_src) diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc new file mode 100644 index 00000000000..05efeb153e2 --- /dev/null +++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/registry/model_parser_registry.h" +#include +#include + +namespace mindspore { +namespace lite { +ModelParserRegistry *ModelParserRegistry::GetInstance() { + static ModelParserRegistry instance; + return &instance; +} + +ModelParser *ModelParserRegistry::GetModelParser(const std::string &fmk) { + auto it = parsers_.find(fmk); + if (it != parsers_.end()) { + auto creator = it->second; + return creator(); + } + return nullptr; +} + +void ModelParserRegistry::RegParser(const std::string &fmk, ModelParserCreator creator) { + auto instance = ModelParserRegistry::GetInstance(); + instance->parsers_[fmk] = creator; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.h b/mindspore/lite/tools/converter/registry/model_parser_registry.h new file mode 100644 index 00000000000..0ad73fae2af --- /dev/null +++ b/mindspore/lite/tools/converter/registry/model_parser_registry.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H +#include +#include +#include +#include "include/lite_utils.h" + +namespace mindspore::lite { +class MS_API ModelParser; +typedef ModelParser *(*ModelParserCreator)(); + +class MS_API ModelParserRegistry { + public: + ModelParserRegistry() = default; + ~ModelParserRegistry() = default; + + static ModelParserRegistry *GetInstance(); + ModelParser *GetModelParser(const std::string &fmk); + void RegParser(const std::string &fmk, ModelParserCreator creator); + + std::unordered_map parsers_; +}; + +class MS_API ModelRegistrar { + public: + ModelRegistrar(const std::string &fmk, ModelParserCreator creator) { + ModelParserRegistry::GetInstance()->RegParser(fmk, creator); + } + ~ModelRegistrar() = default; +}; + +#define REG_MODEL_PARSER(fmk, parserCreator) static ModelRegistrar g_##type##fmk##ModelParserReg(#fmk, parserCreator); +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H diff --git a/mindspore/lite/tools/converter/registry/pass_registry.cc b/mindspore/lite/tools/converter/registry/pass_registry.cc new file mode 100644 index 00000000000..8a5ae8f97cd --- /dev/null +++ b/mindspore/lite/tools/converter/registry/pass_registry.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/registry/pass_registry.h" +#include +#include + +namespace mindspore { +namespace opt { +PassRegistry *PassRegistry::GetInstance() { + static PassRegistry instance; + return &instance; +} + +void PassRegistry::RegPass(int position, const PassPtr &pass) { + if (pass == nullptr) { + std::cout << "pass is nullptr" << std::endl; + return; + } + auto instance = PassRegistry::GetInstance(); + std::unique_lock lock(instance->mutex_); + instance->passes_[position] = pass; +} + +const std::unordered_map &PassRegistry::GetPasses() const { + auto instance = PassRegistry::GetInstance(); + std::unique_lock lock(instance->mutex_); + return instance->passes_; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/registry/pass_registry.h b/mindspore/lite/tools/converter/registry/pass_registry.h new file mode 100644 index 00000000000..d2c0632d5bf --- /dev/null +++ b/mindspore/lite/tools/converter/registry/pass_registry.h @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_ + +#include +#include +#include +#include +#include +#include +#include "include/lite_utils.h" + +namespace mindspore { +namespace opt { +enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 }; + +class MS_API Pass; +using PassPtr = std::shared_ptr; +class MS_API PassRegistry { + public: + virtual ~PassRegistry() = default; + static PassRegistry *GetInstance(); + void RegPass(int position, const PassPtr &pass); + const std::unordered_map &GetPasses() const; + + private: + PassRegistry() = default; + + private: + std::unordered_map passes_; + std::mutex mutex_; +}; + +class MS_API PassRegistrar { + public: + PassRegistrar(int pos, const PassPtr &pass) { PassRegistry::GetInstance()->RegPass(pos, pass); } + ~PassRegistrar() = default; +}; + +#define REG_PASS(position, pass) static PassRegistrar g_##position##PassReg(position, std::make_shared()); + +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_