From e21273e836ea6196fece79be25d8526f5c91346a Mon Sep 17 00:00:00 2001 From: xuanyue Date: Sat, 26 Jun 2021 10:52:16 +0800 Subject: [PATCH] open flags when converting --- cmake/package_lite.cmake | 5 +++-- .../include/registry/model_parser_registry.h | 18 ++++++++++++++++-- .../registry/model_parser_registry_test.cc | 7 +++---- .../converter/registry/model_parser_test.cc | 2 +- .../converter/registry/model_parser_test.h | 3 +-- .../converter/registry/pass_registry_test.cc | 7 +++---- mindspore/lite/tools/converter/converter.cc | 13 ++++++++++++- mindspore/lite/tools/converter/model_parser.h | 7 +++---- .../parser/caffe/caffe_model_parser.cc | 8 ++++---- .../parser/caffe/caffe_model_parser.h | 2 +- .../converter/parser/onnx/onnx_model_parser.cc | 6 +++--- .../converter/parser/onnx/onnx_model_parser.h | 2 +- .../converter/parser/tf/tf_model_parser.cc | 7 ++++--- .../converter/parser/tf/tf_model_parser.h | 2 +- .../parser/tflite/tflite_model_parser.cc | 6 +++--- .../parser/tflite/tflite_model_parser.h | 2 +- 16 files changed, 60 insertions(+), 37 deletions(-) diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 968d949c8b7..aa21349fcd3 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -323,7 +323,8 @@ elseif(WIN32) else() if(SUPPORT_TRAIN) install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR} - COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE) + COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE + PATTERN "framework.h" EXCLUDE) install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION @@ -331,7 +332,7 @@ else() else() install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE - PATTERN "*registry.h" EXCLUDE) + PATTERN "*registry.h" EXCLUDE PATTERN "framework.h" EXCLUDE) endif() install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype COMPONENT ${RUNTIME_COMPONENT_NAME}) diff --git a/mindspore/lite/include/registry/model_parser_registry.h b/mindspore/lite/include/registry/model_parser_registry.h index fcb15230ee7..ea9e081dc44 100644 --- a/mindspore/lite/include/registry/model_parser_registry.h +++ b/mindspore/lite/include/registry/model_parser_registry.h @@ -16,13 +16,27 @@ #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H #define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H + +#include #include -#include +#include #include "include/lite_utils.h" #include "include/registry/framework.h" +#include "schema/inner/model_generated.h" using mindspore::lite::converter::FmkType; namespace mindspore::lite { +namespace converter { +/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. +struct MS_API ConverterParameters { + FmkType fmk_; + schema::QuantType quant_type_; + std::string model_file_; + std::string weight_file_; + std::map attrs_; +}; +} // namespace converter + /// \brief ModelParser defined a model parser class MS_API ModelParser; @@ -56,7 +70,7 @@ class MS_API ModelParserRegistry { /// \param[in] creator Define function pointer of creating ModelParser. int RegParser(const FmkType fmk, ModelParserCreator creator); - std::unordered_map parsers_; + std::map parsers_; }; /// \brief ModelRegistrar defined registration class of ModelParser. diff --git a/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc b/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc index 0342affcbdd..d823b613af4 100644 --- a/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc +++ b/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc @@ -18,10 +18,9 @@ #include "common/common_test.h" #include "ut/tools/converter/registry/model_parser_test.h" #include "tools/optimizer/common/gllo_utils.h" -#include "tools/converter/converter_flags.h" using mindspore::lite::ModelRegistrar; -using mindspore::lite::converter::Flags; +using mindspore::lite::converter::ConverterParameters; using mindspore::lite::converter::FmkType_CAFFE; namespace mindspore { class ModelParserRegistryTest : public mindspore::CommonTest { @@ -39,8 +38,8 @@ TEST_F(ModelParserRegistryTest, TestRegistry) { TestModelParserCreator); // register test model parser creator, which will overwrite existing. auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE); ASSERT_NE(model_parser, nullptr); - Flags flags; - auto func_graph = model_parser->Parse(flags); + ConverterParameters converter_parameters; + auto func_graph = model_parser->Parse(converter_parameters); ASSERT_NE(func_graph, nullptr); auto node_list = func_graph->GetOrderedCnodes(); ASSERT_EQ(node_list.size(), 3); diff --git a/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.cc b/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.cc index d87ec0a7249..31e28cf275d 100644 --- a/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.cc @@ -21,7 +21,7 @@ #include "include/registry/model_parser_registry.h" namespace mindspore { -FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) { +FuncGraphPtr ModelParserTest::Parse(const lite::converter::ConverterParameters &flag) { // construct funcgraph res_graph_ = std::make_shared(); auto ret = InitOriginModelStructure(); diff --git a/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h b/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h index 72a13b9cec2..c3804324e62 100644 --- a/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h +++ b/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h @@ -23,13 +23,12 @@ #include "include/registry/model_parser_registry.h" #include "ut/tools/converter/registry/node_parser_test.h" #include "tools/converter/model_parser.h" -#include "tools/converter/converter_flags.h" namespace mindspore { class ModelParserTest : public lite::ModelParser { public: ModelParserTest() = default; - FuncGraphPtr Parse(const lite::converter::Flags &flag) override; + FuncGraphPtr Parse(const lite::converter::ConverterParameters &flag) override; private: int InitOriginModelStructure(); diff --git a/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc index 3649c3edc1f..aa5cdfbe7a0 100644 --- a/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc +++ b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc @@ -24,13 +24,12 @@ #include "ops/fusion/add_fusion.h" #include "ops/addn.h" #include "ops/custom.h" -#include "tools/converter/converter_flags.h" #include "tools/converter/model_parser.h" #include "tools/optimizer/common/gllo_utils.h" #include "ut/tools/converter/registry/model_parser_test.h" using mindspore::lite::ModelRegistrar; -using mindspore::lite::converter::Flags; +using mindspore::lite::converter::ConverterParameters; using mindspore::lite::converter::FmkType_CAFFE; namespace mindspore { class PassRegistryTest : public mindspore::CommonTest { @@ -42,8 +41,8 @@ class PassRegistryTest : public mindspore::CommonTest { if (model_parser == nullptr) { return; } - Flags flags; - func_graph_ = model_parser->Parse(flags); + ConverterParameters converter_parameters; + func_graph_ = model_parser->Parse(converter_parameters); } FuncGraphPtr func_graph_ = nullptr; }; diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 9b035a678bc..3c886ec1f4a 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -31,6 +31,15 @@ #include "tools/converter/import/mindspore_importer.h" namespace mindspore { namespace lite { +namespace { +void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) { + converter_parameters->fmk_ = flag.fmk; + converter_parameters->quant_type_ = flag.quantType; + converter_parameters->model_file_ = flag.modelFile; + converter_parameters->weight_file_ = flag.weightFile; +} +} // namespace + FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) { FuncGraphPtr func_graph = nullptr; if (flag.fmk == converter::FmkType::FmkType_MS) { @@ -45,7 +54,9 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) { if (model_parser_ == nullptr) { return nullptr; } - func_graph = model_parser_->Parse(flag); + converter::ConverterParameters converter_parameters; + InitConverterParameters(flag, &converter_parameters); + func_graph = model_parser_->Parse(converter_parameters); } if (func_graph == nullptr) { MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn; diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index 05491aea2fd..11f3be07e43 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -22,9 +22,8 @@ #include "schema/inner/model_generated.h" #include "ir/anf.h" #include "ir/func_graph.h" -#include "tools/converter/converter_context.h" -#include "tools/converter/converter_flags.h" -#include "tools/converter/quant_param_holder.h" +#include "include/registry/model_parser_registry.h" +#include "utils/log_adapter.h" namespace mindspore::lite { using namespace schema; @@ -34,7 +33,7 @@ class ModelParser { virtual ~ModelParser() = default; - virtual FuncGraphPtr Parse(const converter::Flags &flag) { return this->res_graph_; } + virtual FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; } protected: FuncGraphPtr res_graph_ = nullptr; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index f6f2f1e0b76..e8fcd9b3307 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -67,10 +67,10 @@ CaffeModelParser::CaffeModelParser() = default; CaffeModelParser::~CaffeModelParser() = default; -FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) { - auto model_file = flag.modelFile; - auto weight_file = flag.weightFile; - quant_type_ = flag.quantType; +FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) { + auto model_file = flag.model_file_; + auto weight_file = flag.weight_file_; + quant_type_ = flag.quant_type_; STATUS status = InitOriginModel(model_file, weight_file); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index d864789aefc..4228bda1d4b 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -34,7 +34,7 @@ class CaffeModelParser : public ModelParser { ~CaffeModelParser() override; - FuncGraphPtr Parse(const converter::Flags &flag) override; + FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; private: STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 6c610b26aad..a3cfbf40b4e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -53,9 +53,9 @@ std::unordered_map TYPE_MAP = { {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; -FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) { - string model_file = flag.modelFile; - quant_type_ = flag.quantType; +FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) { + string model_file = flag.model_file_; + quant_type_ = flag.quant_type_; NotSupportOp::GetInstance()->set_fmk_type("ONNX"); res_graph_ = std::make_shared(); auto status = InitOriginModel(model_file); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index da213b8672e..d4a170069ae 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -42,7 +42,7 @@ class OnnxModelParser : public ModelParser { ~OnnxModelParser() override = default; - FuncGraphPtr Parse(const converter::Flags &flag) override; + FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; static int Onnx2AnfAdjust(const std::set &all_func_graphs); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 9551b7834d9..92aae3fca64 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -22,6 +22,7 @@ #include "src/common/utils.h" #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" +#include "tools/converter/converter_context.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/ops/ops_def.h" @@ -478,9 +479,9 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts( return RET_OK; } -FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) { - auto modelFile = flag.modelFile; - quant_type_ = flag.quantType; +FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) { + auto modelFile = flag.model_file_; + quant_type_ = flag.quant_type_; NotSupportOp::GetInstance()->set_fmk_type("TF"); auto status = ValidateFileStr(modelFile, ".pb"); if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index f78ef636f3b..2a63210d61f 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -40,7 +40,7 @@ class TFModelParser : public ModelParser { TFModelParser() = default; ~TFModelParser() override = default; - FuncGraphPtr Parse(const converter::Flags &flag) override; + FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; static int TF2AnfAdjust(const std::set &all_func_graphs); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index a0430ddfc0f..caa6968fb26 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -51,9 +51,9 @@ std::unique_ptr TfliteModelParser::ReadTfliteModel(const std::st return tflite::UnPackModel(tflite_model_buf_); } -FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) { - auto model_file = flag.modelFile; - quant_type_ = flag.quantType; +FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) { + auto model_file = flag.model_file_; + quant_type_ = flag.quant_type_; // load graph tflite_model_ = ReadTfliteModel(model_file); if (tflite_model_ == nullptr) { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 96d44707951..e88143bb942 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -34,7 +34,7 @@ class TfliteModelParser : public ModelParser { ~TfliteModelParser() override = default; - FuncGraphPtr Parse(const converter::Flags &flag) override; + FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; static int Tflite2AnfAdjust(const std::set &all_func_graphs);