From 6f738d2e4a22cf98530494f8a14e4b44d152316d Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 15 Sep 2021 17:08:40 +0800 Subject: [PATCH] draw out model_parser basic class --- cmake/package_lite.cmake | 4 -- mindspore/lite/build_lite.sh | 4 +- .../registry}/model_parser.h | 45 ++++++++----------- .../include/registry/model_parser_registry.h | 3 ++ .../lite/include/registry/parser_context.h | 3 -- .../converter/registry/model_parser_test.h | 2 +- .../converter/registry/pass_registry_test.cc | 1 - mindspore/lite/tools/converter/converter.h | 2 +- .../parser/caffe/caffe_model_parser.cc | 2 +- .../parser/caffe/caffe_model_parser.h | 2 +- .../parser/onnx/onnx_model_parser.cc | 2 +- .../converter/parser/onnx/onnx_model_parser.h | 2 +- .../tools/converter/parser/parser_utils.h | 11 +++++ .../converter/parser/tf/tf_model_parser.cc | 2 +- .../converter/parser/tf/tf_model_parser.h | 2 +- .../parser/tflite/tflite_model_parser.cc | 2 +- .../parser/tflite/tflite_model_parser.h | 2 +- .../optimizer/fusion/batchmatmul_fusion.cc | 4 +- .../fusion/tf_bidirection_gru_fusion.cc | 9 ++-- .../graph/update_conv2d_param_pass.cc | 2 +- 20 files changed, 53 insertions(+), 53 deletions(-) rename mindspore/lite/{tools/converter => include/registry}/model_parser.h (55%) diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index a3dbbd004c2..a2165aa8805 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -334,8 +334,6 @@ elseif(WIN32) COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") - install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h - DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema @@ -477,8 +475,6 @@ else() COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") - install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h - DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema diff --git a/mindspore/lite/build_lite.sh b/mindspore/lite/build_lite.sh index daf73b06929..8b34d93885a 100755 --- a/mindspore/lite/build_lite.sh +++ b/mindspore/lite/build_lite.sh @@ -237,13 +237,13 @@ build_lite() { compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh cd ${BASEPATH}/../ if [[ "${local_lite_platform}" == "x86_64" ]]; then - sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master -j $THREAD_NUM + sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev -j $THREAD_NUM if [[ $? -ne 0 ]]; then echo "compile x86_64 for nnie failed." exit 1 fi elif [[ "${local_lite_platform}" == "arm32" ]]; then - sh ${compile_nnie_script} -I arm32 -b nnie_3516_master -j $THREAD_NUM + sh ${compile_nnie_script} -I arm32 -b nnie_3516_master_dev -j $THREAD_NUM if [[ $? -ne 0 ]]; then echo "compile arm32 for nnie failed." exit 1 diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/include/registry/model_parser.h similarity index 55% rename from mindspore/lite/tools/converter/model_parser.h rename to mindspore/lite/include/registry/model_parser.h index 15b1f5fdba0..b8d8b40c6c7 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/include/registry/model_parser.h @@ -14,41 +14,34 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H -#include -#include -#include -#include "schema/inner/model_generated.h" -#include "ir/anf.h" -#include "api/ir/func_graph.h" -#include "include/registry/model_parser_registry.h" -#include "utils/log_adapter.h" +#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_ +#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_ -namespace mindspore::converter { -class ModelParser { +#include "api/ir/func_graph.h" +#include "include/registry/parser_context.h" + +namespace mindspore { +namespace converter { +/// \brief ModelParser defined a base class to parse model. +class MS_API ModelParser { public: + /// \brief Constructor. ModelParser() = default; + /// \brief Destructor. virtual ~ModelParser() = default; + /// \brief Method to parse model, which must be onnx/caffe/tf/tflite. + /// + /// \param[in] flags Define the basic parameters when converting, which defined in parser_context.h. + /// + /// \return FuncGraph Pointer, which contains all information about the model. virtual api::FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; } protected: api::FuncGraphPtr res_graph_ = nullptr; }; +} // namespace converter +} // namespace mindspore -typedef ModelParser *(*ModelParserCreator)(); - -template -ModelParser *LiteModelParserCreator() { - auto *parser = new (std::nothrow) T(); - if (parser == nullptr) { - MS_LOG(ERROR) << "new model parser failed"; - return nullptr; - } - return parser; -} -} // namespace mindspore::converter - -#endif +#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_ diff --git a/mindspore/lite/include/registry/model_parser_registry.h b/mindspore/lite/include/registry/model_parser_registry.h index 5b6a0b5899a..c13a280e909 100644 --- a/mindspore/lite/include/registry/model_parser_registry.h +++ b/mindspore/lite/include/registry/model_parser_registry.h @@ -23,6 +23,9 @@ using mindspore::converter::FmkType; namespace mindspore { +namespace converter { +class ModelParser; +} // namespace converter namespace registry { /// \brief ModelParserCreator defined function pointer to get a ModelParser class. typedef converter::ModelParser *(*ModelParserCreator)(); diff --git a/mindspore/lite/include/registry/parser_context.h b/mindspore/lite/include/registry/parser_context.h index 3971599417e..0ccb3f1602e 100644 --- a/mindspore/lite/include/registry/parser_context.h +++ b/mindspore/lite/include/registry/parser_context.h @@ -39,9 +39,6 @@ struct MS_API ConverterParameters { std::string weight_file; std::map attrs; }; - -/// \brief ModelParser defined a base class of model parser -class MS_API ModelParser; } // namespace converter } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h b/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h index 33fa2565392..8c18659b80a 100644 --- a/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h +++ b/mindspore/lite/test/ut/tools/converter/registry/model_parser_test.h @@ -20,9 +20,9 @@ #include #include #include +#include "include/registry/model_parser.h" #include "include/registry/model_parser_registry.h" #include "ut/tools/converter/registry/node_parser_test.h" -#include "tools/converter/model_parser.h" namespace mindspore { class ModelParserTest : public converter::ModelParser { diff --git a/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc index a6b5b0745fb..5c145026804 100644 --- a/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc +++ b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc @@ -24,7 +24,6 @@ #include "ops/fusion/add_fusion.h" #include "ops/addn.h" #include "ops/custom.h" -#include "tools/converter/model_parser.h" #include "tools/converter/optimizer_manager.h" #include "tools/optimizer/common/gllo_utils.h" #include "ut/tools/converter/registry/model_parser_test.h" diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index 96a0dd74957..1bd701f8de5 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -19,9 +19,9 @@ #include #include +#include "include/registry/model_parser.h" #include "schema/inner/model_generated.h" #include "tools/converter/graphdef_transform.h" -#include "tools/converter/model_parser.h" #include "include/registry/model_parser_registry.h" #include "tools/converter/converter_flags.h" #include "tools/converter/anf_transform.h" diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index f033512ee44..b148e0c41c9 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -576,5 +576,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name) } return layer.name(); } -REG_MODEL_PARSER(kFmkTypeCaffe, converter::LiteModelParserCreator) +REG_MODEL_PARSER(kFmkTypeCaffe, LiteModelParserCreator) } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index ac1cd4012d1..21116d1bd3a 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -21,7 +21,7 @@ #include #include #include -#include "tools/converter/model_parser.h" +#include "include/registry/model_parser.h" #include "include/registry/model_parser_registry.h" #include "proto/caffe.pb.h" #include "ops/primitive_c.h" diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 0ff445046e5..5c273435285 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -1248,6 +1248,6 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const return RET_OK; } -REG_MODEL_PARSER(kFmkTypeOnnx, converter::LiteModelParserCreator) +REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 26d3add75b9..f92cdd2adcf 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -27,7 +27,7 @@ #include #include #include "securec/include/securec.h" -#include "tools/converter/model_parser.h" +#include "include/registry/model_parser.h" #include "include/registry/model_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "proto/onnx.pb.h" diff --git a/mindspore/lite/tools/converter/parser/parser_utils.h b/mindspore/lite/tools/converter/parser/parser_utils.h index aa2215161cf..1e48e05f988 100644 --- a/mindspore/lite/tools/converter/parser/parser_utils.h +++ b/mindspore/lite/tools/converter/parser/parser_utils.h @@ -19,6 +19,7 @@ #include #include +#include "include/registry/model_parser.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "src/common/log_adapter.h" @@ -39,6 +40,16 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod schema::Format dst_format, std::set *has_visited); int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format, schema::Format dst_format, std::set *has_visited); + +template +converter::ModelParser *LiteModelParserCreator() { + auto *parser = new (std::nothrow) T(); + if (parser == nullptr) { + MS_LOG(ERROR) << "new model parser failed"; + return nullptr; + } + return parser; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index b59e7cb6760..8adf8b44add 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -1158,6 +1158,6 @@ int TFModelParser::TF2AnfAdjust(const std::set &all_func_graphs) { return RET_OK; } -REG_MODEL_PARSER(kFmkTypeTf, converter::LiteModelParserCreator) +REG_MODEL_PARSER(kFmkTypeTf, LiteModelParserCreator) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 6b448e12034..7c9a853b8b7 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -29,7 +29,7 @@ #include "schema/inner/model_generated.h" #include "securec/include/securec.h" #include "tools/common/tensor_util.h" -#include "tools/converter/model_parser.h" +#include "include/registry/model_parser.h" #include "include/registry/model_parser_registry.h" #include "ops/primitive_c.h" diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 81905594522..52fe1422f38 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -734,5 +734,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set &all_func_g return RET_OK; } -REG_MODEL_PARSER(kFmkTypeTflite, converter::LiteModelParserCreator) +REG_MODEL_PARSER(kFmkTypeTflite, LiteModelParserCreator) } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index f80ce143307..5fb4f99a8fe 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -23,7 +23,7 @@ #include #include #include -#include "tools/converter/model_parser.h" +#include "include/registry/model_parser.h" #include "include/registry/model_parser_registry.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/common/tensor_util.h" diff --git a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc index 04341615baa..ddb33db9eeb 100644 --- a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -68,7 +68,7 @@ STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPt return RET_ERROR; } for (size_t i = 1; i < joint_fullconnect_size + 1; i++) { - auto tensor_addr = GetInputAddr(stack_node->input(i), 2); + auto tensor_addr = GetInputAddr(stack_node->input(i), kInputIndexTwo); if (tensor_addr == nullptr) { MS_LOG(ERROR) << "input tensor addr nullptr"; return RET_ERROR; @@ -178,7 +178,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons if (IsMarkedTrainOp(fullconnect_cnode)) { return nullptr; } - MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == 3, nullptr); + MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == kInputSizeThree, nullptr); auto left_slice_node = fullconnect_cnode->input(1); auto left_slice_cnode = left_slice_node->cast(); if (IsMarkedTrainOp(left_slice_cnode)) { diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 2ded271c44a..1008d28eefb 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -31,6 +31,7 @@ namespace mindspore { namespace opt { namespace { +constexpr int kOffsetTwo = 2; constexpr size_t kCondNodesNum = 12; constexpr size_t kCondCNodesNum = 4; constexpr size_t kBodyNodesNum = 69; @@ -162,7 +163,7 @@ const VectorRef TfBidirectionGruFusion::DefineFowardPattern() const { MS_CHECK_TRUE_RET(is_param6 != nullptr, {}); auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve, fw_init_state_, fw_min, fw_from_tensor, input_length_}); - fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end()); + fw_while.insert(fw_while.end(), fw_vars_.begin() + kOffsetTwo, fw_vars_.end()); auto is_var1 = std::make_shared(); MS_CHECK_TRUE_RET(is_var1 != nullptr, {}); fw_while.emplace_back(is_var1); @@ -232,7 +233,7 @@ const VectorRef TfBidirectionGruFusion::DefinebackwardPattern() const { MS_CHECK_TRUE_RET(is_param6 != nullptr, {}); auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve, bw_init_state_, bw_min, bw_from_tensor, input_length_}); - bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end()); + bw_while.insert(bw_while.end(), bw_vars_.begin() + kOffsetTwo, bw_vars_.end()); auto is_var2 = std::make_shared(); MS_CHECK_TRUE_RET(is_var2 != nullptr, {}); bw_while.emplace_back(is_var2); @@ -400,7 +401,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k return RET_ERROR; } auto fw_cand_kernel_shape = fw_cand_kernel_value->shape(); - if (fw_cand_kernel_shape.size() != 2) { + if (fw_cand_kernel_shape.size() != kInputSizeTwo) { return RET_ERROR; } auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf); @@ -408,7 +409,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k return RET_ERROR; } auto bw_cand_kernel_shape = bw_cand_kernel_value->shape(); - if (bw_cand_kernel_shape.size() != 2) { + if (bw_cand_kernel_shape.size() != kInputSizeTwo) { return RET_ERROR; } if (fw_cand_kernel_shape != bw_cand_kernel_shape) { diff --git a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc index 93ba17ba4fe..b19a329a8a0 100644 --- a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc @@ -45,7 +45,7 @@ void SetConvAttr(const PrimitivePtr &prim, const std::vector &kernel_si STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) { MS_ASSERT(cnode != nullptr); if (cnode->size() < kInputSizeThree) { - MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << cnode->size() - 1; + MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << (cnode->size() - 1); return lite::RET_ERROR; } auto weight = cnode->input(kInputIndexTwo);