From b0776fd0043536d83555cd4787443d50b2abcd46 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Tue, 19 Apr 2022 19:39:04 -0700 Subject: [PATCH] split converter so --- cmake/package_lite.cmake | 13 +- .../lite/test/st/mindrt_parallel_test.cc | 24 +- .../fusion/activation_fusion_test.cc | 6 +- .../fusion/add_concat_act_fusion_test.cc | 2 +- .../fusion/conv_activation_fusion_test.cc | 6 +- .../fusion/conv_biasadd_fusion_test.cc | 6 +- .../optimizer/fusion/conv_bn_fusion_test.cc | 4 +- .../fusion/conv_scale_fusion_test.cc | 4 +- .../matmul_act_fusion_inout_test.cc | 6 +- .../fusion/matmul_mul_fusion_test.cc | 2 +- .../fusion/trans_matmul_fusion_test.cc | 6 +- mindspore/lite/tools/converter/CMakeLists.txt | 38 +- .../tools/converter/adapter/acl/acl_pass.cc | 4 +- .../tools/converter/adapter/acl/acl_pass.h | 4 +- .../adapter/acl/src/acl_pass_impl.cc | 6 +- .../converter/adapter/acl/src/acl_pass_impl.h | 4 +- .../lite/tools/converter/anf_transform.cc | 105 +++-- .../lite/tools/converter/anf_transform.h | 19 +- mindspore/lite/tools/converter/converter.cc | 367 ++++++++++++++---- mindspore/lite/tools/converter/converter.h | 37 +- .../converter/converter_lite/CMakeLists.txt | 6 +- .../{ => converter_lite}/converter_flags.cc | 274 ++----------- .../{ => converter_lite}/converter_flags.h | 76 +--- .../tools/converter/converter_lite/main.cc | 82 ++++ .../lite/tools/converter/cxx_api/converter.cc | 4 + .../lite/tools/converter/export_model.cc | 26 +- mindspore/lite/tools/converter/export_model.h | 8 +- .../tools/converter/graphdef_transform.cc | 28 +- .../lite/tools/converter/graphdef_transform.h | 3 +- .../tools/converter/import/mindir_adjust.h | 2 +- .../import/mindir_control_flow_adjust.h | 2 +- .../converter/import/mindspore_importer.cc | 39 +- .../converter/import/mindspore_importer.h | 12 +- .../tools/converter/import/primitive_adjust.h | 2 +- .../legacy_optimizer/graph/infershape_pass.cc | 1 - .../legacy_optimizer/graph/infershape_pass.h | 6 +- .../set_unused_quant_param_to_default_pass.cc | 2 +- .../set_unused_quant_param_to_default_pass.h | 7 +- mindspore/lite/tools/converter/main.cc | 57 --- .../converter/parser/tf/functionalize_cond.h | 2 - .../parser/tf/functionalize_control_op_pass.h | 2 +- .../converter/parser/tf/functionalize_while.h | 2 - .../converter/parser/tflite/CMakeLists.txt | 1 + .../quantizer/bias_correction_strategy.cc | 7 +- .../quantizer/bias_correction_strategy.h | 6 +- .../tools/converter/quantizer/cle_strategy.h | 4 +- .../converter/quantizer/debug_info_manager.cc | 28 +- .../converter/quantizer/debug_info_manager.h | 6 +- .../converter/quantizer/dynamic_quantizer.cc | 8 +- .../converter/quantizer/dynamic_quantizer.h | 4 +- .../quantizer/full_quant_quantizer.cc | 30 +- .../quantizer/full_quant_quantizer.h | 4 +- .../converter/quantizer/parameter_tunner.cc | 68 ++-- .../converter/quantizer/parameter_tunner.h | 15 +- .../quantizer/quantization_optimizer.cc | 71 ++-- .../quantizer/quantization_optimizer.h | 7 +- .../converter/quantizer/quantize_util.cc | 8 +- .../tools/converter/quantizer/quantize_util.h | 7 +- .../tools/converter/quantizer/quantizer.h | 6 +- .../converter/quantizer/weight_quantizer.cc | 18 +- .../converter/quantizer/weight_quantizer.h | 6 +- .../lite/tools/lite_exporter/fetch_content.h | 2 +- .../mindir_exporter/mindir_serializer.cc | 10 +- .../tools/mindir_exporter/mindir_serializer.h | 6 +- .../optimizer/fusion/conv_transform_fusion.h | 2 +- .../fusion/matmul_activation_fusion.cc | 4 +- .../fusion/matmul_activation_fusion.h | 11 +- .../graph/clip_convert_activation_pass.h | 3 - .../lite/tools/optimizer/graph/dump_graph.h | 7 +- .../optimizer/graph/slice_prepose_pass.h | 2 +- .../graph/unused_transpose_node_remove_pass.h | 2 +- 71 files changed, 822 insertions(+), 837 deletions(-) rename mindspore/lite/tools/converter/{ => converter_lite}/converter_flags.cc (58%) rename mindspore/lite/tools/converter/{ => converter_lite}/converter_flags.h (57%) create mode 100644 mindspore/lite/tools/converter/converter_lite/main.cc delete mode 100644 mindspore/lite/tools/converter/main.cc diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 15c6efb3121..c8b21af45da 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -5,6 +5,11 @@ set(RUNTIME_PKG_NAME ${PKG_NAME_PREFIX}-${RUNTIME_COMPONENT_NAME}) set(CONVERTER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/converter) set(OBFUSCATOR_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/obfuscator) set(CROPPER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/cropper) +if(WIN32) + set(BUILD_DIR ${TOP_DIR}/build) +else() + set(BUILD_DIR ${TOP_DIR}/mindspore/lite/build) +endif() set(TEST_CASE_DIR ${TOP_DIR}/mindspore/lite/test/build) set(RUNTIME_DIR ${RUNTIME_PKG_NAME}/runtime) @@ -483,6 +488,8 @@ if(PLATFORM_ARM64) COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so + DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin.so DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so.0 @@ -713,8 +720,10 @@ elseif(WIN32) file(GLOB LIB_LIST ${CXX_DIR}/libstdc++-6.dll ${CXX_DIR}/libwinpthread-1.dll ${CXX_DIR}/libssp-0.dll ${CXX_DIR}/libgcc_s_*-1.dll) if(MSLITE_ENABLE_CONVERTER) - install(FILES ${TOP_DIR}/build/mindspore/tools/converter/converter_lite.exe + install(FILES ${TOP_DIR}/build/mindspore/tools/converter/converter_lite/converter_lite.exe DESTINATION ${CONVERTER_ROOT_DIR}/converter COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${TOP_DIR}/build/mindspore/tools/converter/libmindspore_converter.dll + DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) @@ -867,6 +876,8 @@ else() COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so + DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin.so DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so.0 diff --git a/mindspore/lite/test/st/mindrt_parallel_test.cc b/mindspore/lite/test/st/mindrt_parallel_test.cc index 99ea0a0a9c5..5f71f7f5492 100644 --- a/mindspore/lite/test/st/mindrt_parallel_test.cc +++ b/mindspore/lite/test/st/mindrt_parallel_test.cc @@ -24,6 +24,7 @@ #include "src/runtime/lite_session.h" #include "src/runtime/kernel_exec.h" #include "src/common/file_utils.h" +#include "include/converter.h" namespace mindspore { class MindrtParallelTest : public mindspore::CommonTest { @@ -96,12 +97,12 @@ int CheckRuntime1(lite::LiteSession *session) { } TEST_F(MindrtParallelTest, offline1) { - const char *converter_argv[] = {"./converter", "--fmk=TFLITE", - "--modelFile=./mindrtParallel/mindrt_parallel_model.tflite", - "--outputFile=./mindrtParallel/mindrt_parallel_model_split", - "--configFile=./mindrtParallel/mindrt_parallel_model.config"}; - int converter_ret = mindspore::lite::RunConverter(5, converter_argv); - ASSERT_EQ(converter_ret, lite::RET_OK); + mindspore::Converter converter(converter::kFmkTypeTflite, "./mindrtParallel/mindrt_parallel_model.tflite", + "./mindrtParallel/mindrt_parallel_model_split"); + converter.SetConfigFile("./mindrtParallel/mindrt_parallel_model.config"); + + auto status = converter.Convert(); + ASSERT_EQ(status, kSuccess); size_t size = 0; char *graph_buf = lite::ReadFile("./mindrtParallel/mindrt_parallel_model_split.ms", &size); @@ -135,11 +136,12 @@ TEST_F(MindrtParallelTest, offline1) { } TEST_F(MindrtParallelTest, runtime1) { - const char *converter_argv[] = {"./converter", "--fmk=TFLITE", - "--modelFile=./mindrtParallel/mindrt_parallel_model.tflite", - "--outputFile=./mindrtParallel/mindrt_parallel_model"}; - int converter_ret = mindspore::lite::RunConverter(4, converter_argv); - ASSERT_EQ(converter_ret, lite::RET_OK); + mindspore::Converter converter(converter::kFmkTypeTflite, "./mindrtParallel/mindrt_parallel_model.tflite", + "./mindrtParallel/mindrt_parallel_model"); + converter.SetConfigFile("./mindrtParallel/mindrt_parallel_model.config"); + + auto status = converter.Convert(); + ASSERT_EQ(status, kSuccess); size_t size = 0; char *graph_buf = lite::ReadFile("./mindrtParallel/mindrt_parallel_model.ms", &size); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/activation_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/activation_fusion_test.cc index 3db1d5bcf66..00249b44eaf 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/activation_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/activation_fusion_test.cc @@ -112,7 +112,7 @@ TEST_F(ActivationFusionTest, TestHardTanhReluNode) { auto meta_graph = BuildGraph(schema::ActivationType_HARD_TANH, schema::ActivationType_RELU); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -126,7 +126,7 @@ TEST_F(ActivationFusionTest, TestRelu6HardTanhNode) { auto meta_graph = BuildGraph(schema::ActivationType_RELU6, schema::ActivationType_HARD_TANH); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -140,7 +140,7 @@ TEST_F(ActivationFusionTest, TestBadCase_ReluSigmoid) { auto meta_graph = BuildGraph(schema::ActivationType_RELU, schema::ActivationType_SIGMOID); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 2); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/add_concat_act_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/add_concat_act_fusion_test.cc index 8f88cbe502a..4e2b390308b 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/add_concat_act_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/add_concat_act_fusion_test.cc @@ -157,7 +157,7 @@ TEST_F(AddConcatActivationFusionTest, TestAddConcatReluNode) { auto meta_graph = BuildGraph(schema::ActivationType_RELU6); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), kGraphNodeSize); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc index 73502d8f3b4..c69f9a3d9f8 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc @@ -136,7 +136,7 @@ TEST_F(ConvActivationFusionTest, TestConvReluNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -149,7 +149,7 @@ TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU6); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -162,7 +162,7 @@ TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_LEAKY_RELU); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 2); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc index 9c47b44d68d..3d1984c3ad7 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc @@ -145,7 +145,7 @@ TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_BiasAdd); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -156,7 +156,7 @@ TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_AddFusion); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -166,7 +166,7 @@ TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_MatMulFusion); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 2); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc index 5cf8b279ff0..3edcdeb30fa 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc @@ -262,7 +262,7 @@ TEST_F(ConvBNFusionTest, TestConvAddNode) { auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2DFusion); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -272,7 +272,7 @@ TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) { auto meta_graph = BuildTFGraph(schema::PrimitiveType_Conv2DFusion); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc index 9b9a250de46..2593bd06423 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc @@ -187,7 +187,7 @@ TEST_F(ConvScaleFusionTest, TestConvScaleNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, true); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -198,7 +198,7 @@ TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) { auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, false); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc index d5f36521822..498aeb6ed6d 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc @@ -21,7 +21,6 @@ #include "plugin/device/cpu/kernel/nnacl/op_base.h" #include "ops/fusion/mat_mul_fusion.h" #include "ops/fusion/activation.h" -#include "tools/converter/converter_flags.h" namespace mindspore { class MatMulActivationFusionInoutTest : public FusionInoutTest { @@ -29,10 +28,7 @@ class MatMulActivationFusionInoutTest : public FusionInoutTest { MatMulActivationFusionInoutTest() = default; protected: - void InitPass() override { - converter::Flags ctx; - this->pass_ = std::make_shared(ctx); - } + void InitPass() override { this->pass_ = std::make_shared(nullptr); } void InitGraph() override { this->graph_ = std::make_shared(); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/matmul_mul_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/matmul_mul_fusion_test.cc index 825c3147813..8737c7d0daa 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/matmul_mul_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/matmul_mul_fusion_test.cc @@ -136,7 +136,7 @@ TEST_F(MatMulAddFusionTest, TestMatMulMulNode) { auto meta_graph = BuildGraph(); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/trans_matmul_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/trans_matmul_fusion_test.cc index 65c85a0c946..442037cc687 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/trans_matmul_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/trans_matmul_fusion_test.cc @@ -158,7 +158,7 @@ TEST_F(TransMatMulFusionTest, TestTransMatMulNode1) { auto meta_graph = BuildGraph(trans_param_a, trans_param_b, output_dims); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -180,7 +180,7 @@ TEST_F(TransMatMulFusionTest, TestTransMatMulNode2) { auto meta_graph = BuildGraph(trans_param_a, trans_param_b, output_dims); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 1); @@ -202,7 +202,7 @@ TEST_F(TransMatMulFusionTest, TestBadCase_TransMatMul) { auto meta_graph = BuildGraph(trans_param_a, trans_param_b, output_dims); auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get()); auto anf_transform = new lite::AnfTransform(); - auto new_graph = anf_transform->Transform(func_graph); + auto new_graph = anf_transform->Transform(func_graph, nullptr); ASSERT_NE(nullptr, new_graph); auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 3); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 339b2ab5216..37d6fb79069 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -4,6 +4,7 @@ set(USE_GLOG on) if(MSLITE_ENABLE_MODEL_ENCRYPTION) add_compile_definitions(ENABLE_OPENSSL) endif() + set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) @@ -38,7 +39,6 @@ include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel) file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/converter.cc - ${CMAKE_CURRENT_SOURCE_DIR}/converter_flags.cc ${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc @@ -72,8 +72,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/*.cc ) -list(REMOVE_ITEM CONVERTER_SRC cxx_api/converter.cc) - if((NOT WIN32) AND MSLITE_ENABLE_DPICO_ATC_ADAPTER) add_subdirectory(adapter/dpico) endif() @@ -297,7 +295,8 @@ set_property(SOURCE ${LITE_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindsp add_library(converter_runtime_mid OBJECT ${LITE_SRC}) add_dependencies(converter_runtime_mid fbs_src fbs_inner_src) -target_compile_options(converter_runtime_mid PRIVATE "-Wno-stringop-overflow") +add_library(mindspore_converter SHARED $) +target_compile_options(mindspore_converter PRIVATE "-Wno-stringop-overflow") add_library(converter_src_mid OBJECT ${CONVERTER_SRC}) add_dependencies(converter_src_mid fbs_src fbs_inner_src) @@ -307,32 +306,27 @@ add_dependencies(ccsrc_src_mid fbs_src fbs_inner_src) target_compile_definitions(ccsrc_src_mid PRIVATE BACKEND_DLL) target_compile_definitions(converter_src_mid PRIVATE BACKEND_DLL) -add_executable(converter_lite - main.cc - ) -add_dependencies(converter_lite fbs_src fbs_inner_src) if(NOT ENABLE_CLOUD_AND_LITE) - add_dependencies(converter_lite nnacl_mid) + add_dependencies(mindspore_converter nnacl_mid) endif() if((NOT WIN32) AND MSLITE_ENABLE_DPICO_ATC_ADAPTER) - add_dependencies(converter_lite dpico_atc_adapter) + add_dependencies(mindspore_converter dpico_atc_adapter) endif() if(MSLITE_GPU_BACKEND STREQUAL opencl) include_directories(${SRC_DIR}/runtime/kernel/opencl) - target_link_libraries(converter_lite PRIVATE opencl_kernel_mid) + target_link_libraries(mindspore_converter opencl_kernel_mid) endif() if(MSLITE_ENABLE_FP16) - target_link_libraries(converter_lite PRIVATE + target_link_libraries(mindspore_converter nnacl_fp16_mid ) endif() -target_link_libraries(converter_lite PRIVATE +target_link_libraries(mindspore_converter ccsrc_src_mid converter_src_mid - converter_runtime_mid cpu_ops_mid nnacl_mid cpu_kernel_mid @@ -359,34 +353,36 @@ target_link_libraries(converter_lite PRIVATE ) if(SUPPORT_TRAIN) - target_link_libraries(converter_lite PRIVATE train_cpu_kernel_mid) + target_link_libraries(mindspore_converter train_cpu_kernel_mid) endif() if(ENABLE_CONVERT_PYTORCH_MODEL) - target_link_libraries(converter_lite PRIVATE pytorch_parser_mid) + target_link_libraries(mindspore_converter pytorch_parser_mid) endif() if(NOT ENABLE_CLOUD_AND_LITE) - target_link_libraries(converter_lite PRIVATE + target_link_libraries(mindspore_converter ccsrc_debug_common_mid_ mindir_proto_mid _mindspore_transform_express_ir_obj) endif() if(MSLITE_ENABLE_ACL) - target_link_libraries(converter_lite PRIVATE + target_link_libraries(mindspore_converter lite_acl_mid ascend_kernel_mid) endif() if(NOT MSVC) - target_link_libraries(converter_lite PRIVATE pthread) + target_link_libraries(mindspore_converter pthread) endif() if(NOT WIN32) - target_link_libraries(converter_lite PRIVATE dl) + target_link_libraries(mindspore_converter dl) endif() if(ENABLE_MODEL_OBF) - target_link_libraries(converter_lite PRIVATE + target_link_libraries(mindspore_converter ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so) endif() + +add_subdirectory(converter_lite) diff --git a/mindspore/lite/tools/converter/adapter/acl/acl_pass.cc b/mindspore/lite/tools/converter/adapter/acl/acl_pass.cc index 0dc9c86e857..3fb2ebe5c76 100644 --- a/mindspore/lite/tools/converter/adapter/acl/acl_pass.cc +++ b/mindspore/lite/tools/converter/adapter/acl/acl_pass.cc @@ -21,9 +21,9 @@ namespace mindspore { namespace opt { -AclPass::AclPass(const converter::Flags &config) : Pass("ACL") { +AclPass::AclPass(const std::shared_ptr ¶m) : Pass("ACL") { #ifdef ENABLE_LITE_ACL - impl_ = std::make_shared(config); + impl_ = std::make_shared(param); #endif } diff --git a/mindspore/lite/tools/converter/adapter/acl/acl_pass.h b/mindspore/lite/tools/converter/adapter/acl/acl_pass.h index 35ecd091d9c..12bb79dd7bd 100644 --- a/mindspore/lite/tools/converter/adapter/acl/acl_pass.h +++ b/mindspore/lite/tools/converter/adapter/acl/acl_pass.h @@ -20,7 +20,7 @@ #define USE_DEPRECATED_API #include #include "backend/common/optimizer/pass.h" -#include "tools/converter/converter_flags.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore { namespace opt { @@ -29,7 +29,7 @@ using AclPassImplPtr = std::shared_ptr; class AclPass : public Pass { public: - explicit AclPass(const converter::Flags &config); + explicit AclPass(const std::shared_ptr ¶m); ~AclPass() override = default; bool Run(const FuncGraphPtr &func_graph) override; diff --git a/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.cc b/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.cc index 0847bc1e858..7b507941e3b 100644 --- a/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.cc +++ b/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.cc @@ -101,9 +101,9 @@ STATUS PreProcForOnnx(const FuncGraphPtr &func_graph) { } } // namespace -AclPassImpl::AclPassImpl(const converter::Flags &config) - : fmk_type_(config.fmk), - user_options_cfg_(std::move(config.aclModelOptionCfgParam)), +AclPassImpl::AclPassImpl(const std::shared_ptr ¶m) + : fmk_type_(param->fmk_type), + user_options_cfg_(std::move(param->aclModelOptionCfgParam)), om_parameter_(nullptr), custom_node_(nullptr) {} diff --git a/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.h b/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.h index b4d84de25c0..7b3f98b9e76 100644 --- a/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.h +++ b/mindspore/lite/tools/converter/adapter/acl/src/acl_pass_impl.h @@ -27,8 +27,8 @@ #include "include/registry/converter_context.h" #include "cxx_api/model/acl/acl_model_options.h" #include "tools/converter/adapter/acl/common/acl_types.h" -#include "tools/converter/converter_flags.h" #include "ops/custom.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore { namespace opt { @@ -37,7 +37,7 @@ using mindspore::lite::STATUS; class AclPassImpl { public: - explicit AclPassImpl(const converter::Flags &config); + explicit AclPassImpl(const std::shared_ptr ¶m); ~AclPassImpl() = default; bool Run(const FuncGraphPtr &func_graph); diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 25b000c229c..abe46d30268 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -171,13 +171,12 @@ STATUS AnfTransform::MarkTrainOp(const FuncGraphPtr &func_graph) { return RET_OK; } -int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { +int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { auto status = MarkTrainOp(old_graph); if (status != RET_OK) { MS_LOG(ERROR) << "MarkTrainOp failed."; return RET_ERROR; } - CHECK_NULL_RETURN(config); auto optimizer = std::make_shared(); CHECK_NULL_RETURN(optimizer); auto fusion_pm = std::make_shared("anf fusion pass manager", false); @@ -190,8 +189,8 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared(config->fmk)); - fusion_pm->AddPass(std::make_shared(config->fmk)); + fusion_pm->AddPass(std::make_shared(param->fmk_type)); + fusion_pm->AddPass(std::make_shared(param->fmk_type)); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); @@ -200,7 +199,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - if (config->fullQuantParam.target_device != quant::NVGPU) { + if (param->fullQuantParam.target_device != quant::NVGPU) { fusion_pm->AddPass(std::make_shared()); } fusion_pm->AddPass(std::make_shared()); @@ -212,7 +211,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared(config->fmk, config->trainModel)); + fusion_pm->AddPass(std::make_shared(param->fmk_type, param->train_model)); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); @@ -226,7 +225,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared(*config)); + fusion_pm->AddPass(std::make_shared(param)); optimizer->AddPassManager(fusion_pm); if (optimizer->Optimize(old_graph) == nullptr) { MS_LOG(ERROR) << "run op fusion failed."; @@ -235,12 +234,12 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: return RET_OK; } -int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { +int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { MS_LOG(DEBUG) << "Run ParallelPass start"; - if (config->trainModel || config->parallel_split_config_.parallel_split_type_ == converter::SplitNo) { + if (param->train_model || param->parallel_split_config.parallel_split_type_ == SplitNo) { return RET_OK; } - if (config->parallel_split_config_.parallel_split_type_ == converter::SplitByUserRatio) { + if (param->parallel_split_config.parallel_split_type_ == SplitByUserRatio) { auto optimizer = std::make_shared(); CHECK_NULL_RETURN(optimizer); auto graph_inputs = old_graph->get_inputs(); @@ -261,9 +260,8 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter } } // 1. deal with split strategy - std::unordered_map split_strategys = - opt::ParserSplitStrategy(config->parallel_split_config_.parallel_compute_rates_, - config->parallel_split_config_.parallel_devices_, split_mode); + std::unordered_map split_strategys = opt::ParserSplitStrategy( + param->parallel_split_config.parallel_compute_rates_, param->parallel_split_config.parallel_devices_, split_mode); if (split_strategys.empty()) { MS_LOG(WARNING) << "No valid split_strategy. Run convert without split"; return RET_OK; @@ -279,7 +277,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter // we do not deal with single conv node for (int match_number = max_match_number; match_number > opt::kDefaultBatch; --match_number) { // 3. multi_conv parallel pass - parallel_pm->AddPass(std::make_shared(split_strategys, config->fmk, match_number)); + parallel_pm->AddPass(std::make_shared(split_strategys, param->fmk_type, match_number)); parallel_pm->AddPass(std::make_shared()); parallel_pm->AddPass(std::make_shared()); } @@ -293,18 +291,18 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter return RET_OK; } -int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { +int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { auto optimizer = std::make_shared(); CHECK_NULL_RETURN(optimizer); auto graph_pm = std::make_shared("anf graph pass manager", true); CHECK_NULL_RETURN(graph_pm); - if (config->fmk == converter::kFmkTypeTflite || config->fmk == converter::kFmkTypeTf || - config->fmk == converter::kFmkTypeOnnx) { + if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf || + param->fmk_type == converter::kFmkTypeOnnx) { graph_pm->AddPass(std::make_shared()); } auto slice_prepose_pass = std::make_shared(); CHECK_NULL_RETURN(slice_prepose_pass); - slice_prepose_pass->SetFmkType(config->fmk); + slice_prepose_pass->SetFmkType(param->fmk_type); graph_pm->AddPass(slice_prepose_pass); optimizer->AddPassManager(graph_pm); if (optimizer->Optimize(old_graph) == nullptr) { @@ -314,8 +312,8 @@ int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::F return RET_OK; } -int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { - auto acl_pass = std::make_shared(*config); +int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { + auto acl_pass = std::make_shared(param); CHECK_NULL_RETURN(acl_pass); if (!acl_pass->Run(old_graph)) { MS_LOG(ERROR) << "Acl pass failed."; @@ -326,8 +324,8 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter: CHECK_NULL_RETURN(optimizer); auto convert_pm = std::make_shared("anf graph convert pass manager", true); CHECK_NULL_RETURN(convert_pm); - convert_pm->AddPass(std::make_shared(config->trainModel)); - convert_pm->AddPass(std::make_shared(config->fmk, config->trainModel)); + convert_pm->AddPass(std::make_shared(param->train_model)); + convert_pm->AddPass(std::make_shared(param->fmk_type, param->train_model)); convert_pm->AddPass(std::make_shared()); optimizer->AddPassManager(convert_pm); if (optimizer->Optimize(old_graph) == nullptr) { @@ -337,14 +335,14 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter: return RET_OK; } -int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { +int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { auto optimizer = std::make_shared(); auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); CHECK_NULL_RETURN(optimizer); CHECK_NULL_RETURN(const_fold_pm); - const_fold_pm->AddPass(std::make_shared(config->fmk, config->trainModel)); - if (!config->trainModel) { - const_fold_pm->AddPass(std::make_shared(config->fmk, config->trainModel)); + const_fold_pm->AddPass(std::make_shared(param->fmk_type, param->train_model)); + if (!param->train_model) { + const_fold_pm->AddPass(std::make_shared(param->fmk_type, param->train_model)); } const_fold_pm->AddPass(std::make_shared()); const_fold_pm->AddPass(std::make_shared()); @@ -356,8 +354,8 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte return RET_OK; } -int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, converter::Flags *config) { - quant::QuantizationOptimizer optimizer(config); +int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { + quant::QuantizationOptimizer optimizer(param); auto ret = optimizer.Run(old_graph); if (ret != RET_OK) { MS_LOG(ERROR) << "Post training quantization failed."; @@ -366,8 +364,8 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, converter::Flags *co return RET_OK; } -bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { - auto eliminate_cast_pass = std::make_shared(config->fmk, config->trainModel); +bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { + auto eliminate_cast_pass = std::make_shared(param->fmk_type, param->train_model); MS_CHECK_TRUE_RET(eliminate_cast_pass != nullptr, false); if (!eliminate_cast_pass->Run(old_graph)) { MS_LOG(ERROR) << "Run cast elimination pass failed."; @@ -390,11 +388,12 @@ bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const converter::F return true; } -FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { +FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, + const std::shared_ptr ¶m) { MS_ASSERT(old_graph != nullptr); - MS_ASSERT(config != nullptr); + MS_ASSERT(param != nullptr); - auto status = RunConvertPass(old_graph, config); + auto status = RunConvertPass(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Run convert pass failed."; return nullptr; @@ -405,7 +404,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con return nullptr; } - status = RunConstFoldPass(old_graph, config); + status = RunConstFoldPass(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Run const fold pass failed."; return nullptr; @@ -420,13 +419,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con } } - if (!RunEliminateRedundantPass(old_graph, config)) { + if (!RunEliminateRedundantPass(old_graph, param)) { MS_LOG(ERROR) << "Run elimination of redundant pass failed."; return nullptr; } - if (!config->disableFusion) { - status = RunFusionPass(old_graph, config); + if (!param->no_fusion) { + status = RunFusionPass(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Run fusion pass failed."; return nullptr; @@ -451,23 +450,23 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con } } - status = RunGraphPass(old_graph, config); + status = RunGraphPass(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Run convert pass failed."; return nullptr; } - status = RunParallelPass(old_graph, config); + status = RunParallelPass(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Run convert pass failed."; return nullptr; } - if (!config->pluginsPath.empty() && config->commonQuantParam.quant_type != schema::QuantType_QUANT_NONE) { + if (!param->plugins_path.empty() && param->commonQuantParam.quant_type != schema::QuantType_QUANT_NONE) { MS_LOG(ERROR) << "Unsupported external extension with quantization."; return nullptr; } - status = DoQuantize(old_graph, const_cast(config)); + status = DoQuantize(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Do Quantize failed."; return nullptr; @@ -476,17 +475,17 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con return old_graph; } -bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) { - if (config == nullptr) { +bool AnfTransform::StoreBuiltinPass(const std::shared_ptr ¶m) { + if (param == nullptr) { MS_LOG(ERROR) << "config is nullptr"; return false; } - auto fmk = config->fmk; - auto is_train = config->trainModel; + auto fmk = param->fmk_type; + auto is_train = param->train_model; // pass_name, pass and boolean value to indicate whether can be called by external extension, std::vector> pass_infos = { - {"DumpGraph", std::make_shared(config), true}, - {"RemoveRedundantOpPass", std::make_shared(config->trainModel), false}, + {"DumpGraph", std::make_shared(param), true}, + {"RemoveRedundantOpPass", std::make_shared(param->train_model), false}, {"ToNCHWFormat", std::make_shared(fmk, is_train), true}, {"ToNHWCFormat", std::make_shared(fmk, is_train), true}, {"ConstFoldPass", std::make_shared(fmk, is_train), true}, @@ -494,25 +493,25 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) { {"DeleteRedundantTranspose", std::make_shared(), false}, {"SpecialNodePostProcess", std::make_shared(), false}, {"DecreaseTransposeAlgo", std::make_shared(fmk, is_train), true}, - {"SpecifyGraphInputFormat", std::make_shared(config->graphInputFormat), false}}; + {"SpecifyGraphInputFormat", std::make_shared(param->input_format), false}}; for (const auto &pass_info : pass_infos) { MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false); PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info), std::get(pass_info)); } - auto dump_graph_outer = std::make_shared(config); + auto dump_graph_outer = std::make_shared(param); MS_CHECK_TRUE_MSG(dump_graph_outer != nullptr, false, "dumpGraph object is a nullptr."); registry::PassRegistry("DumpGraph", dump_graph_outer); return true; } -FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { +FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr ¶m) { MS_CHECK_TRUE_MSG(main_graph != nullptr, nullptr, "Input func_graph is nullptr"); - MS_CHECK_TRUE_MSG(config != nullptr, nullptr, "Input converter config is nullptr"); - if (!StoreBuiltinPass(config)) { + MS_CHECK_TRUE_MSG(param != nullptr, nullptr, "Input converter param is nullptr"); + if (!StoreBuiltinPass(param)) { MS_LOG(ERROR) << "store pass failed."; return nullptr; } - auto new_graph = TransformFuncGraph(main_graph, config); + auto new_graph = TransformFuncGraph(main_graph, param); if (new_graph == nullptr) { MS_LOG(ERROR) << "optimizer failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 76f628e1ec8..24825b83ba1 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -23,7 +23,6 @@ #include "backend/common/optimizer/optimizer.h" #include "schema/inner/model_generated.h" #include "tools/common/meta_graph_serializer.h" -#include "tools/converter/converter_flags.h" #include "ir/anf.h" #include "tools/converter/quantizer/quantizer.h" #include "tools/converter/converter_context.h" @@ -34,24 +33,24 @@ class AnfTransform { public: AnfTransform(); virtual ~AnfTransform(); - FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); + FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m); private: - FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); + FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m); - static int RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + static int RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m); - static int RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + static int RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m); - static int RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + static int RunConvertPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m); - static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const converter::Flags *config); + static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const std::shared_ptr ¶m); - static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + static int RunParallelPass(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m); - static int DoQuantize(const FuncGraphPtr &old_graph, converter::Flags *config); + static int DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m); - static bool StoreBuiltinPass(const converter::Flags *config); + static bool StoreBuiltinPass(const std::shared_ptr ¶m); static STATUS MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index eabfef38149..90abdd78e05 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -19,7 +19,7 @@ #include #include #include -#include "tools/converter/converter_flags.h" +#include #include "src/common/log_adapter.h" #include "tools/common/meta_graph_serializer.h" #include "tools/lite_exporter/anf_exporter.h" @@ -42,43 +42,52 @@ #include "include/api/model.h" #include "tools/mindir_exporter/mindir_serializer.h" #include "src/common/primitive_t_utils.h" +#include "tools/converter/config_parser/acl_option_param_parser.h" +#include "tools/converter/config_parser/micro_param_parser.h" +#include "tools/converter/config_parser/preprocess_parser.h" +#include "tools/converter/config_parser/quant_param_parser.h" +#include "tools/common/string_util.h" +#include "src/common/file_utils.h" namespace mindspore { +extern "C" { +void common_log_init(); +} namespace lite { namespace { constexpr size_t kMaxNum1024 = 1024; -void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) { - MS_ASSERT(converter_parameters != nullptr); - converter_parameters->fmk = flag.fmk; - converter_parameters->model_file = flag.modelFile; - converter_parameters->weight_file = flag.weightFile; -} +constexpr size_t kPluginPathMaxNum = 10; +constexpr int kPathLengthUpperLimit = 1024; +constexpr size_t kEncMaxLen = 16; + FuncGraphPtr ConvertGraph(const api::FuncGraphPtr &func_graph) { auto impl = func_graph->impl(); return std::dynamic_pointer_cast(impl); } } // namespace -FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) { +FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr ¶m) { api::FuncGraphPtr func_graph_base = nullptr; - if (flag.fmk == converter::FmkType::kFmkTypeMs) { + if (param->fmk_type == converter::FmkType::kFmkTypeMs) { #ifdef SUPPORT_TRAIN kernel::PopulateTrainParameters(); #endif MindsporeImporter ms_import; - func_graph_base = api::MakeShared(ms_import.ImportMindIR(flag)); + func_graph_base = api::MakeShared(ms_import.ImportMindIR(param)); } else { - model_parser_ = registry::ModelParserRegistry::GetModelParser(flag.fmk); + model_parser_ = registry::ModelParserRegistry::GetModelParser(param->fmk_type); if (model_parser_ == nullptr) { - MS_LOG(ERROR) << "Unsupported to converter models with fmk: " << flag.fmkIn; + MS_LOG(ERROR) << "Unsupported to converter models with fmk: " << param->fmk_type; return nullptr; } converter::ConverterParameters converter_parameters; - InitConverterParameters(flag, &converter_parameters); + converter_parameters.fmk = param->fmk_type; + converter_parameters.model_file = param->model_file; + converter_parameters.weight_file = param->weight_file; func_graph_base = model_parser_->Parse(converter_parameters); } if (func_graph_base == nullptr) { - MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn; + MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << param->fmk_type; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT); return nullptr; } @@ -94,9 +103,10 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) { return func_graph; } -FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size) { +FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr ¶m, const void *buf, + const size_t &size) { MindsporeImporter ms_import; - FuncGraphPtr func_graph = ms_import.ImportMindIR(flag, buf, size); + FuncGraphPtr func_graph = ms_import.ImportMindIR(param, buf, size); if (func_graph == nullptr) { MS_LOG(ERROR) << "Get funcGraph failed."; return nullptr; @@ -110,39 +120,50 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag, const void return func_graph; } -schema::MetaGraphT *Converter::Convert(const std::unique_ptr &flag, const void *buf, - const size_t &size) { - if (flag == nullptr || buf == nullptr) { - MS_LOG(ERROR) << "Input flag is nullptr"; +schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr ¶m, const void *buf, + const size_t &size) { + if (param == nullptr || buf == nullptr) { + MS_LOG(ERROR) << "Input param is nullptr"; return nullptr; } - auto graph = BuildFuncGraph(*flag, buf, size); + auto graph = BuildFuncGraph(param, buf, size); if (graph == nullptr) { MS_LOG(ERROR) << "Parser/Import model return nullptr"; return nullptr; } MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed."); // funcgraph_transform - graph = funcgraph_transform_->Transform(graph, flag.get()); + graph = funcgraph_transform_->Transform(graph, param); MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr."); // export protobuf - auto status = MindIRSerialize(flag, graph); + auto status = MindIRSerialize(param, graph); if (status != RET_OK) { MS_LOG(WARNING) << "Export to mindir proto return nullptr."; } - return TransferFuncGraph(flag, graph); + return TransferFuncGraph(param, graph); } -schema::MetaGraphT *Converter::Convert(const std::unique_ptr &flag) { - if (flag == nullptr) { - MS_LOG(ERROR) << "Input flag is nullptr"; +schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr ¶m) { + if (param == nullptr) { + MS_LOG(ERROR) << "Input param is nullptr"; return nullptr; } + param->aclModelOptionCfgParam.om_file_path = param->output_file; + param->aclModelOptionCfgParam.offline = true; + + if (!param->config_file.empty()) { + auto ret = InitConfigFile(param); + if (ret != RET_OK) { + std::cerr << "Init config file failed." << std::endl; + return nullptr; + } + } + // load plugin static std::vector> dl_loaders; - if (!flag->pluginsPath.empty()) { - for (auto &path : flag->pluginsPath) { + if (!param->plugins_path.empty()) { + for (auto &path : param->plugins_path) { auto dl_loader = std::make_shared(); MS_CHECK_TRUE_RET(dl_loader != nullptr, nullptr); auto status = dl_loader->Open(path); @@ -154,7 +175,7 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr & } } - auto graph = BuildFuncGraph(*flag); + auto graph = BuildFuncGraph(param); if (graph == nullptr) { MS_LOG(ERROR) << "Parser/Import model return nullptr"; return nullptr; @@ -162,23 +183,23 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr & MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed"); // funcgraph transform - graph = funcgraph_transform_->Transform(graph, flag.get()); + graph = funcgraph_transform_->Transform(graph, param); if (graph == nullptr) { MS_LOG(ERROR) << "Transform anf graph return nullptr"; return nullptr; } // export protobuf - auto status = MindIRSerialize(flag, graph); + auto status = MindIRSerialize(param, graph); if (status != RET_OK) { MS_LOG(WARNING) << "Export to mindir proto return nullptr."; } - return TransferFuncGraph(flag, graph); + return TransferFuncGraph(param, graph); } -schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr &flag, - FuncGraphPtr func_graph) { +schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr ¶m, + FuncGraphPtr func_graph) { MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed"); #ifdef MSLITE_ENABLE_GRAPH_KERNEL if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { @@ -187,7 +208,7 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr flatbuffer - auto meta_graph = Export(func_graph, false, false, flag->trainModel); + auto meta_graph = Export(func_graph, false, false, param->train_model); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta graph return nullptr"; return nullptr; @@ -195,7 +216,7 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptrSetGraphDef(meta_graph); - auto status = metagraph_transform_->Transform(*flag); + auto status = metagraph_transform_->Transform(param); if (status != RET_OK) { MS_LOG(ERROR) << "Transform meta graph failed " << status; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -233,8 +254,8 @@ int CheckExistCustomOps(const schema::MetaGraphT *meta_graph, bool *exist_custom return RET_OK; } -int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr &flags) { - if (flags->trainModel) { +int PreInference(const schema::MetaGraphT &meta_graph, bool train_model) { + if (train_model) { MS_LOG(WARNING) << "train model dont support pre-infer."; return RET_OK; } @@ -298,32 +319,213 @@ int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr(); - if (flags == nullptr) { - oss.clear(); - oss << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED); - MS_LOG(ERROR) << oss.str(); - std::cout << oss.str() << std::endl; - return RET_MEMORY_FAILED; +int ConverterImpl::InitConfigFile(const std::shared_ptr ¶m) { + lite::ConfigFileParser config_parser; + auto ret = config_parser.ParseConfigFile(param->config_file); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse config file failed."; + return ret; } - auto status = flags->Init(argc, argv); - if (status != RET_OK) { - if (status != RET_SUCCESS_EXIT) { - oss.clear(); - oss << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status); - MS_LOG(ERROR) << oss.str(); - std::cout << oss.str() << std::endl; + ret = lite::PreprocessParser::ParsePreprocess(config_parser.GetDataPreProcessString(), ¶m->dataPreProcessParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse preprocess failed."; + return ret; + } + ret = lite::QuantParamParser::ParseCommonQuant(config_parser.GetCommonQuantString(), ¶m->commonQuantParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse common quant param failed."; + return ret; + } + ret = lite::QuantParamParser::ParseFullQuant(config_parser.GetFullQuantString(), ¶m->fullQuantParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse full quant param failed."; + return ret; + } + ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(), + ¶m->mixedBitWeightQuantParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; + return ret; + } + ret = InitExtendedIntegrationInfo(param, config_parser); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse extended integration info failed."; + return ret; + } + + lite::AclOptionParamParser acl_param_parser; + ret = acl_param_parser.ParseAclOptionCfg(config_parser.GetAclOptionCfgString(), ¶m->aclModelOptionCfgParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse acl option param failed."; + return ret; + } + (void)CheckOfflineParallelConfig(param->config_file, ¶m->parallel_split_config); + + lite::MicroParamParser micro_param_parser; + ret = micro_param_parser.ParseMicroParam(config_parser.GetMicroParamString(), ¶m->microParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse micro param failed."; + return ret; + } + return RET_OK; +} + +int ConverterImpl::InitExtendedIntegrationInfo(const std::shared_ptr ¶m, + const lite::ConfigFileParser &config_parser) { + auto extended_info = config_parser.GetRegistryInfoString(); + if (!extended_info.plugin_path.empty()) { + const char delimiter = ';'; + auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, delimiter); + if (relative_path.size() > kPluginPathMaxNum) { + MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than " << kPluginPathMaxNum; + return RET_INPUT_PARAM_INVALID; + } + for (auto &i : relative_path) { + param->plugins_path.push_back(lite::RealPath(i.c_str())); } - return status; } - // Load graph - MS_LOG(DEBUG) << "start reading model file"; - Converter cvt; - auto meta_graph = cvt.Convert(flags); + + if (!extended_info.disable_fusion.empty()) { + if (extended_info.disable_fusion == "on") { + param->no_fusion = true; + } else if (extended_info.disable_fusion == "off") { + param->no_fusion = false; + } else { + std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl; + return RET_INPUT_PARAM_INVALID; + } + } + return RET_OK; +} + +bool ConverterImpl::CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) { + // device: [device0 device1] ---> {cpu, gpu} + // computeRate: [x: y] x >=0 && y >=0 && x/y < 10 + MS_ASSERT(parallel_split_config != nullptr); + std::vector config_devices = {"cpu", "gpu", "npu"}; + auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate); + if (compute_rate_result.empty()) { + return false; + } + std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0); + if (device0_result.empty()) { + return false; + } + std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1); + if (device1_result.empty()) { + return false; + } + bool device0_flag = false; + bool device1_flag = false; + for (const auto &device : config_devices) { + if (device == device0_result) { + device0_flag = true; + } + if (device == device1_result) { + device1_flag = true; + } + } + if (!device0_flag || !device1_flag) { + return false; + } + const char delimiter = ';'; + std::vector device_rates = lite::SplitStringToVector(compute_rate_result, delimiter); + const char colon = ':'; + for (const auto &device : device_rates) { + std::vector rate = lite::SplitStringToVector(device, colon); + int64_t compute_rate = 0; + try { + compute_rate = std::stoi(rate.back()); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Get compute rate failed: " << e.what(); + return false; + } + parallel_split_config->parallel_compute_rates_.push_back(compute_rate); + } + const size_t support_rates_num = 2; + if (parallel_split_config->parallel_compute_rates_.size() != support_rates_num) { + return false; + } + int64_t bigger_rate = INT32_MIN; + int64_t smaller_rate = INT32_MAX; + for (const auto &rate : parallel_split_config->parallel_compute_rates_) { + if (rate <= 0 || rate > INT32_MAX) { + return false; + } + bigger_rate = std::max(rate, bigger_rate); + smaller_rate = std::min(rate, smaller_rate); + } + parallel_split_config->parallel_devices_.push_back(device0_result); + parallel_split_config->parallel_devices_.push_back(device1_result); + // parall_split_type will extend by other user's attr + parallel_split_config->parallel_split_type_ = SplitByUserRatio; + if (smaller_rate == 0) { + MS_LOG(ERROR) << "smaller_rate is zero"; + return false; + } + return bigger_rate / smaller_rate <= kMaxSplitRatio; +} + +std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const std::string &target_key) { + std::string res; + if (file.empty()) { + MS_LOG(ERROR) << "file is nullptr"; + return res; + } + auto resolved_path = std::make_unique(PATH_MAX); + if (resolved_path == nullptr) { + MS_LOG(ERROR) << "new resolved_path failed"; + return ""; + } + +#ifdef _WIN32 + auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit); +#else + char *real_path = realpath(file.c_str(), resolved_path.get()); +#endif + if (real_path == nullptr || strlen(real_path) == 0) { + MS_LOG(ERROR) << "file path is not valid : " << file; + return ""; + } + std::ifstream ifs(resolved_path.get()); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << real_path << " is not exist"; + return res; + } + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << real_path << "open failed"; + return res; + } + std::string line; + while (std::getline(ifs, line)) { + lite::Trim(&line); + if (line.empty() || line.at(0) == '#' || line.at(0) == '[') { + continue; + } + auto index = line.find('='); + if (index == std::string::npos) { + MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; + return ""; + } + auto key = line.substr(0, index); + auto value = line.substr(index + 1); + lite::Trim(&key); + lite::Trim(&value); + if (key == target_key) { + return value; + } + } + return res; +} + +int RunConverter(const std::shared_ptr ¶m) { + mindspore::common_log_init(); + + ConverterImpl converter_impl; + auto meta_graph = converter_impl.Convert(param); NotSupportOp::GetInstance()->PrintOps(); - status = ReturnCode::GetSingleReturnCode()->status_code(); + int status = ReturnCode::GetSingleReturnCode()->status_code(); + std::ostringstream oss; if (meta_graph == nullptr) { oss.clear(); oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); @@ -335,8 +537,8 @@ int RunConverter(int argc, const char **argv) { // save graph to file meta_graph->version = Version(); - if (flags->infer) { - status = PreInference(*meta_graph, flags); + if (param->pre_infer) { + status = PreInference(*meta_graph, param->train_model); if (status != RET_OK) { oss.clear(); oss << "PRE INFERENCE FAILED:" << status << " " << GetErrorInfo(status); @@ -347,10 +549,10 @@ int RunConverter(int argc, const char **argv) { } } - if (flags->microParam.enable_micro) { - status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode, - flags->microParam.target, flags->microParam.support_parallel, - flags->microParam.debug_mode); + if (param->microParam.enable_micro) { + status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam.codegen_mode, + param->microParam.target, param->microParam.support_parallel, + param->microParam.debug_mode); if (status != RET_OK) { delete meta_graph; oss.clear(); @@ -360,7 +562,27 @@ int RunConverter(int argc, const char **argv) { return status; } } else { - status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile, flags->encKey, flags->keyLen, flags->encMode); + unsigned char encKey[kEncMaxLen] = {0}; + size_t keyLen = 0; + if (param->enable_encryption) { + if (!param->encrypt_key.empty()) { + keyLen = lite::Hex2ByteArray(param->encrypt_key, encKey, kEncMaxLen); + if (keyLen != kEncMaxLen) { + MS_LOG(ERROR) << "enc_key must expressed in hexadecimal characters " + << " and only support AES-GCM method and the key length is 16."; + return RET_INPUT_PARAM_INVALID; + } + } else { + MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false."; + return RET_INPUT_PARAM_INVALID; + } + } + status = MetaGraphSerializer::Save(*meta_graph, param->output_file, encKey, keyLen, param->encrypt_mode); + if (memset_s(encKey, kEncMaxLen, 0, kEncMaxLen) != EOK) { + MS_LOG(ERROR) << "memset failed."; + delete meta_graph; + return RET_ERROR; + } if (status != RET_OK) { delete meta_graph; oss.clear(); @@ -370,15 +592,6 @@ int RunConverter(int argc, const char **argv) { return status; } } - // clear key - flags->dec_key.clear(); - flags->encKeyStr.clear(); - status = memset_s(flags->encKey, converter::kEncMaxLen, 0, converter::kEncMaxLen); - if (status != EOK) { - MS_LOG(ERROR) << "memset failed."; - delete meta_graph; - return RET_ERROR; - } delete meta_graph; oss.clear(); oss << "CONVERT RESULT SUCCESS:" << status; diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index 76060b29e10..b77c9ec742d 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -20,39 +20,54 @@ #define USE_DEPRECATED_API #include #include +#include "include/converter.h" #include "include/registry/model_parser.h" #include "schema/inner/model_generated.h" #include "tools/converter/graphdef_transform.h" #include "include/registry/model_parser_registry.h" -#include "tools/converter/converter_flags.h" #include "tools/converter/anf_transform.h" #include "tools/converter/converter_context.h" #include "tools/common/graph_util.h" +#include "tools/converter/preprocess/preprocess_param.h" +#include "tools/converter/quantizer/quant_params.h" +#include "tools/converter/adapter/acl/common/acl_types.h" +#include "micro/coder/config.h" +#include "tools/converter/cxx_api/converter_para.h" +#include "tools/converter/config_parser/config_file_parser.h" namespace mindspore { namespace lite { -class Converter { +constexpr auto kMaxSplitRatio = 10; +constexpr auto kComputeRate = "computeRate"; +constexpr auto kSplitDevice0 = "device0"; +constexpr auto kSplitDevice1 = "device1"; + +class ConverterImpl { public: - Converter() = default; - ~Converter() { + ConverterImpl() = default; + ~ConverterImpl() { delete model_parser_; this->model_parser_ = nullptr; } - schema::MetaGraphT *Convert(const std::unique_ptr &flag); - schema::MetaGraphT *Convert(const std::unique_ptr &flag, const void *buf, const size_t &size); + schema::MetaGraphT *Convert(const std::shared_ptr ¶m); + schema::MetaGraphT *Convert(const std::shared_ptr ¶m, const void *buf, const size_t &size); private: - FuncGraphPtr BuildFuncGraph(const converter::Flags &flag); - FuncGraphPtr BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size); - schema::MetaGraphT *TransferFuncGraph(const std::unique_ptr &flag, FuncGraphPtr func_graph); + FuncGraphPtr BuildFuncGraph(const std::shared_ptr ¶m); + FuncGraphPtr BuildFuncGraph(const std::shared_ptr ¶m, const void *buf, const size_t &size); + schema::MetaGraphT *TransferFuncGraph(const std::shared_ptr ¶m, FuncGraphPtr func_graph); + + int InitConfigFile(const std::shared_ptr ¶m); + int InitExtendedIntegrationInfo(const std::shared_ptr ¶m, + const lite::ConfigFileParser &config_file_parser); + bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config); + std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key); protected: converter::ModelParser *model_parser_ = nullptr; std::unique_ptr metagraph_transform_ = std::make_unique(); std::unique_ptr funcgraph_transform_ = std::make_unique(); }; - -int RunConverter(int argc, const char **argv); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_lite/CMakeLists.txt b/mindspore/lite/tools/converter/converter_lite/CMakeLists.txt index 53d5bc5cfb8..1e9509e6f10 100644 --- a/mindspore/lite/tools/converter/converter_lite/CMakeLists.txt +++ b/mindspore/lite/tools/converter/converter_lite/CMakeLists.txt @@ -1,8 +1,12 @@ remove_definitions(-DUSE_GLOG) +link_directories(${opencv_INC}/../lib) + add_executable(converter_lite main.cc converter_flags.cc ${TOP_DIR}/mindspore/lite/src/common/log.cc + ${TOP_DIR}/mindspore/lite/src/common/utils.cc ${TOP_DIR}/mindspore/core/utils/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../common/string_util.cc) -target_link_libraries(converter_lite PRIVATE mindspore-converter) + +target_link_libraries(converter_lite mindspore_converter) diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc similarity index 58% rename from mindspore/lite/tools/converter/converter_flags.cc rename to mindspore/lite/tools/converter/converter_lite/converter_flags.cc index 9cd677d3a70..14762deed15 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc @@ -13,8 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include "tools/converter/converter_flags.h" +#include "tools/converter/converter_lite/converter_flags.h" #include #include #include @@ -22,27 +21,13 @@ #include #include #include -#include "ir/dtype/type_id.h" -#include "common/file_utils.h" -#include "tools/common/string_util.h" -#include "common/log_util.h" -#include "tools/converter/converter_context.h" -#include "tools/converter/config_parser/config_file_parser.h" -#include "tools/converter/config_parser/preprocess_parser.h" -#include "tools/converter/config_parser/quant_param_parser.h" -#include "tools/converter/config_parser/acl_option_param_parser.h" -#include "tools/converter/config_parser/micro_param_parser.h" -namespace mindspore { -namespace converter { +#include "tools/common/string_util.h" + +namespace mindspore::converter { using mindspore::lite::RET_INPUT_PARAM_INVALID; using mindspore::lite::RET_OK; -namespace { -constexpr size_t kPluginPathMaxNum = 10; -constexpr int kQuantBitNumInt16 = 16; -constexpr int kPathLengthUpperLimit = 1024; -constexpr int kMinShapeSizeInStr = 2; -} // namespace + Flags::Flags() { AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", ""); AddFlag(&Flags::modelFile, "modelFile", @@ -106,13 +91,13 @@ Flags::Flags() { int Flags::InitInputOutputDataType() { if (this->inputDataTypeStr == "FLOAT") { - this->inputDataType = TypeId::kNumberTypeFloat32; + this->inputDataType = DataType::kNumberTypeFloat32; } else if (this->inputDataTypeStr == "INT8") { - this->inputDataType = TypeId::kNumberTypeInt8; + this->inputDataType = DataType::kNumberTypeInt8; } else if (this->inputDataTypeStr == "UINT8") { - this->inputDataType = TypeId::kNumberTypeUInt8; + this->inputDataType = DataType::kNumberTypeUInt8; } else if (this->inputDataTypeStr == "DEFAULT") { - this->inputDataType = TypeId::kTypeUnknown; + this->inputDataType = DataType::kTypeUnknown; } else { std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT, got: " @@ -121,13 +106,13 @@ int Flags::InitInputOutputDataType() { } if (this->outputDataTypeStr == "FLOAT") { - this->outputDataType = TypeId::kNumberTypeFloat32; + this->outputDataType = DataType::kNumberTypeFloat32; } else if (this->outputDataTypeStr == "INT8") { - this->outputDataType = TypeId::kNumberTypeInt8; + this->outputDataType = DataType::kNumberTypeInt8; } else if (this->outputDataTypeStr == "UINT8") { - this->outputDataType = TypeId::kNumberTypeUInt8; + this->outputDataType = DataType::kNumberTypeUInt8; } else if (this->outputDataTypeStr == "DEFAULT") { - this->outputDataType = TypeId::kTypeUnknown; + this->outputDataType = DataType::kTypeUnknown; } else { std::cerr << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT, got: " @@ -177,11 +162,11 @@ int Flags::InitTrainModel() { std::cerr << "INPUT ILLEGAL: train model converter supporting only MINDIR format" << std::endl; return RET_INPUT_PARAM_INVALID; } - if ((this->inputDataType != TypeId::kNumberTypeFloat32) && (this->inputDataType != TypeId::kTypeUnknown)) { + if ((this->inputDataType != DataType::kNumberTypeFloat32) && (this->inputDataType != DataType::kTypeUnknown)) { std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 input tensors" << std::endl; return RET_INPUT_PARAM_INVALID; } - if ((this->outputDataType != TypeId::kNumberTypeFloat32) && (this->outputDataType != TypeId::kTypeUnknown)) { + if ((this->outputDataType != DataType::kNumberTypeFloat32) && (this->outputDataType != DataType::kTypeUnknown)) { std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors" << std::endl; return RET_INPUT_PARAM_INVALID; } @@ -202,7 +187,11 @@ int Flags::InitInTensorShape() const { } shape.clear(); auto string_split = lite::StrSplit(shape_str, std::string(":")); - CHECK_LESS_RETURN(string_split.size(), kMinShapeSizeInStr); + constexpr int kMinShapeSizeInStr = 2; + if (string_split.size() < kMinShapeSizeInStr) { + MS_LOG(ERROR) << "shape size must not be less than " << kMinShapeSizeInStr; + return mindspore::lite::RET_ERROR; + } auto name = string_split[0]; for (size_t i = 1; i < string_split.size() - 1; ++i) { name += ":" + string_split[i]; @@ -236,7 +225,7 @@ int Flags::InitInTensorShape() const { shape.push_back(dim_value); } } - lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(name, shape); + graph_input_shape_map[name] = shape; } return RET_OK; } @@ -253,89 +242,6 @@ int Flags::InitGraphInputFormat() { return RET_OK; } -int Flags::InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser) { - auto extended_info = config_file_parser.GetRegistryInfoString(); - if (!extended_info.plugin_path.empty()) { - const char delimiter = ';'; - auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, delimiter); - if (relative_path.size() > kPluginPathMaxNum) { - MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than " << kPluginPathMaxNum; - return RET_INPUT_PARAM_INVALID; - } - for (auto &i : relative_path) { - this->pluginsPath.push_back(lite::RealPath(i.c_str())); - } - } - - if (!extended_info.disable_fusion.empty()) { - if (extended_info.disable_fusion == "on") { - this->disableFusion = true; - } else if (extended_info.disable_fusion == "off") { - this->disableFusion = false; - } else { - std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl; - return RET_INPUT_PARAM_INVALID; - } - } - return RET_OK; -} - -void Flags::InitAclDefaultOption() { - this->aclModelOptionCfgParam.om_file_path = this->outputFile; - this->aclModelOptionCfgParam.offline = true; -} - -int Flags::InitConfigFile() { - lite::ConfigFileParser config_file_parser; - auto ret = config_file_parser.ParseConfigFile(this->configFile); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse config file failed."; - return ret; - } - ret = - lite::PreprocessParser::ParsePreprocess(config_file_parser.GetDataPreProcessString(), &this->dataPreProcessParam); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse preprocess failed."; - return ret; - } - ret = lite::QuantParamParser::ParseCommonQuant(config_file_parser.GetCommonQuantString(), &this->commonQuantParam); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse common quant param failed."; - return ret; - } - ret = lite::QuantParamParser::ParseFullQuant(config_file_parser.GetFullQuantString(), &this->fullQuantParam); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse full quant param failed."; - return ret; - } - ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_file_parser.GetMixedBitWeightQuantString(), - &this->mixedBitWeightQuantParam); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; - return ret; - } - ret = InitExtendedIntegrationInfo(config_file_parser); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse extended integration info failed."; - return ret; - } - lite::AclOptionParamParser acl_param_parser; - ret = acl_param_parser.ParseAclOptionCfg(config_file_parser.GetAclOptionCfgString(), &this->aclModelOptionCfgParam); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse acl option param failed."; - return ret; - } - (void)CheckOfflineParallelConfig(this->configFile, ¶llel_split_config_); - - lite::MicroParamParser micro_param_parser; - ret = micro_param_parser.ParseMicroParam(config_file_parser.GetMicroParamString(), &this->microParam); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse micro param failed."; - return ret; - } - return RET_OK; -} - int Flags::InitSaveFP16() { if (saveFP16Str == "on") { saveFP16 = true; @@ -398,13 +304,6 @@ int Flags::InitEncrypt() { MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false."; return RET_INPUT_PARAM_INVALID; } - keyLen = lite::Hex2ByteArray(encKeyStr, encKey, kEncMaxLen); - if (keyLen != kEncMaxLen) { - MS_LOG(ERROR) << "enc_key " << encKeyStr << " must expressed in hexadecimal characters " - << " and only support AES-GCM method and the key length is 16."; - return RET_INPUT_PARAM_INVALID; - } - encKeyStr.clear(); } return RET_OK; } @@ -412,7 +311,7 @@ int Flags::InitEncrypt() { int Flags::PreInit(int argc, const char **argv) { if (argc == 1) { std::cout << this->Usage() << std::endl; - return lite::RET_SUCCESS_EXIT; + return lite::RET_OK; } lite::Option err = this->ParseFlags(argc, argv); @@ -424,7 +323,7 @@ int Flags::PreInit(int argc, const char **argv) { if (this->help) { std::cout << this->Usage() << std::endl; - return lite::RET_SUCCESS_EXIT; + return lite::RET_OK; } if (this->modelFile.empty()) { std::cerr << "INPUT MISSING: model file path is necessary" << std::endl; @@ -450,15 +349,6 @@ int Flags::PreInit(int argc, const char **argv) { return RET_INPUT_PARAM_INVALID; } - if (!this->configFile.empty()) { - auto ret = InitConfigFile(); - if (ret != RET_OK) { - std::cerr << "Init config file failed." << std::endl; - return RET_INPUT_PARAM_INVALID; - } - } - - InitAclDefaultOption(); return RET_OK; } @@ -526,120 +416,4 @@ int Flags::Init(int argc, const char **argv) { } return RET_OK; } - -bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) { - // device: [device0 device1] ---> {cpu, gpu} - // computeRate: [x: y] x >=0 && y >=0 && x/y < 10 - MS_ASSERT(parallel_split_config != nullptr); - std::vector config_devices = {"cpu", "gpu", "npu"}; - auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate); - if (compute_rate_result.empty()) { - return false; - } - std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0); - if (device0_result.empty()) { - return false; - } - std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1); - if (device1_result.empty()) { - return false; - } - bool device0_flag = false; - bool device1_flag = false; - for (const auto &device : config_devices) { - if (device == device0_result) { - device0_flag = true; - } - if (device == device1_result) { - device1_flag = true; - } - } - if (!device0_flag || !device1_flag) { - return false; - } - const char delimiter = ';'; - std::vector device_rates = lite::SplitStringToVector(compute_rate_result, delimiter); - const char colon = ':'; - for (const auto &device : device_rates) { - std::vector rate = lite::SplitStringToVector(device, colon); - int64_t compute_rate = 0; - try { - compute_rate = std::stoi(rate.back()); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "Get compute rate failed: " << e.what(); - return false; - } - parallel_split_config->parallel_compute_rates_.push_back(compute_rate); - } - if (parallel_split_config->parallel_compute_rates_.size() != 2) { - return false; - } - int64_t bigger_rate = INT32_MIN; - int64_t smaller_rate = INT32_MAX; - for (const auto &rate : parallel_split_config->parallel_compute_rates_) { - if (rate <= 0 || rate > INT32_MAX) { - return false; - } - bigger_rate = std::max(rate, bigger_rate); - smaller_rate = std::min(rate, smaller_rate); - } - parallel_split_config->parallel_devices_.push_back(device0_result); - parallel_split_config->parallel_devices_.push_back(device1_result); - // parall_split_type will extend by other user's attr - parallel_split_config->parallel_split_type_ = SplitByUserRatio; - return bigger_rate / smaller_rate <= kMaxSplitRatio; -} - -std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key) { - std::string res; - if (file.empty()) { - MS_LOG(ERROR) << "file is nullptr"; - return res; - } - auto resolved_path = std::make_unique(PATH_MAX); - if (resolved_path == nullptr) { - MS_LOG(ERROR) << "new resolved_path failed"; - return ""; - } - -#ifdef _WIN32 - auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit); -#else - char *real_path = realpath(file.c_str(), resolved_path.get()); -#endif - if (real_path == nullptr || strlen(real_path) == 0) { - MS_LOG(ERROR) << "file path is not valid : " << file; - return ""; - } - std::ifstream ifs(resolved_path.get()); - if (!ifs.good()) { - MS_LOG(ERROR) << "file: " << real_path << " is not exist"; - return res; - } - if (!ifs.is_open()) { - MS_LOG(ERROR) << "file: " << real_path << "open failed"; - return res; - } - std::string line; - while (std::getline(ifs, line)) { - lite::Trim(&line); - if (line.empty() || line.at(0) == '#' || line.at(0) == '[') { - continue; - } - auto index = line.find('='); - if (index == std::string::npos) { - MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; - return ""; - } - auto key = line.substr(0, index); - auto value = line.substr(index + 1); - lite::Trim(&key); - lite::Trim(&value); - if (key == target_key) { - return value; - } - } - return res; -} -} // namespace converter -} // namespace mindspore +} // namespace mindspore::converter diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_lite/converter_flags.h similarity index 57% rename from mindspore/lite/tools/converter/converter_flags.h rename to mindspore/lite/tools/converter/converter_lite/converter_flags.h index e0a4a9bf5da..3d037d8e199 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,104 +13,65 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_ #include #include +#include #include "include/api/format.h" +#include "include/api/data_type.h" #include "include/registry/converter_context.h" #include "tools/common/flag_parser.h" -#include "ir/dtype/type_id.h" -#include "schema/inner/model_generated.h" -#include "tools/converter/preprocess/preprocess_param.h" -#include "tools/converter/quantizer/quant_params.h" -#include "tools/converter/adapter/acl/common/acl_types.h" -#include "micro/coder/config.h" namespace mindspore { -namespace lite { -class ConfigFileParser; -} // namespace lite namespace converter { -using mindspore::schema::QuantType; -enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 }; -constexpr auto kMaxSplitRatio = 10; -constexpr auto kComputeRate = "computeRate"; -constexpr auto kSplitDevice0 = "device0"; -constexpr auto kSplitDevice1 = "device1"; -constexpr size_t kEncMaxLen = 16; -struct ParallelSplitConfig { - ParallelSplitType parallel_split_type_ = SplitNo; - std::vector parallel_compute_rates_; - std::vector parallel_devices_; -}; - class Flags : public virtual mindspore::lite::FlagParser { public: Flags(); - ~Flags() = default; - int InitInputOutputDataType(); - int InitFmk(); - int InitTrainModel(); - int InitConfigFile(); - int InitInTensorShape() const; - int InitGraphInputFormat(); - - int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser); - int InitEncrypt(); - int InitPreInference(); - int InitSaveFP16(); - int InitNoFusion(); - - void InitAclDefaultOption(); - int InitExportMindIR(); - int Init(int argc, const char **argv); - int PreInit(int argc, const char **argv); - std::string modelFile; - std::string outputFile; std::string fmkIn; FmkType fmk; + std::string modelFile; + std::string outputFile; std::string weightFile; - TypeId inputDataType; - TypeId outputDataType; std::string saveFP16Str = "off"; bool saveFP16 = false; std::string noFusionStr = "false"; + bool disableFusion = false; std::string inputDataTypeStr; + DataType inputDataType; std::string outputDataTypeStr; - ParallelSplitConfig parallel_split_config_{}; + DataType outputDataType; std::string configFile; std::string trainModelIn; bool trainModel = false; - std::vector pluginsPath; - bool disableFusion = false; std::string inTensorShape; + mutable std::map> graph_input_shape_map; std::string dec_key = ""; std::string dec_mode = "AES-GCM"; std::string graphInputFormatStr; - std::string device; mindspore::Format graphInputFormat = mindspore::NHWC; std::string encKeyStr; std::string encMode = "AES-GCM"; std::string inferStr; + bool infer = false; std::string exportMindIR; + bool export_mindir = false; #ifdef ENABLE_OPENSSL std::string encryptionStr = "true"; bool encryption = true; @@ -118,22 +79,7 @@ class Flags : public virtual mindspore::lite::FlagParser { std::string encryptionStr = "false"; bool encryption = false; #endif - bool infer = false; - unsigned char encKey[kEncMaxLen]; - size_t keyLen = 0; - bool export_mindir = false; - - lite::quant::CommonQuantParam commonQuantParam; - lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam; - lite::quant::FullQuantParam fullQuantParam; - lite::preprocess::DataPreProcessParam dataPreProcessParam; - lite::acl::AclModelOptionCfg aclModelOptionCfgParam; - lite::micro::MicroParam microParam; }; - -bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config); - -std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key); } // namespace converter } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_lite/main.cc b/mindspore/lite/tools/converter/converter_lite/main.cc new file mode 100644 index 00000000000..2c3e9a9635d --- /dev/null +++ b/mindspore/lite/tools/converter/converter_lite/main.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(__linux__) && !defined(Debug) +#include +#endif +#define USE_DEPRECATED_API +#include +#include "include/converter.h" +#include "include/api/status.h" +#include "tools/converter/converter_lite/converter_flags.h" + +#if defined(__linux__) && !defined(Debug) +void SignalHandler(int sig) { + printf("encounter an unknown error, please verify the input model file or build the debug version\n"); + exit(1); +} +#endif + +int main(int argc, const char **argv) { +#if defined(__linux__) && !defined(Debug) + signal(SIGSEGV, SignalHandler); + signal(SIGABRT, SignalHandler); + signal(SIGFPE, SignalHandler); + signal(SIGBUS, SignalHandler); +#endif + +#ifndef Debug + try { +#endif + mindspore::converter::Flags flags; + auto ret = flags.Init(argc, argv); + if (ret != mindspore::kSuccess) { + MS_LOG(ERROR) << "Flags Init failed. Ret: " << ret; + std::cout << "Flags Init failed. Ret: " << ret << std::endl; + return ret; + } + mindspore::Converter converter(flags.fmk, flags.modelFile, flags.outputFile, flags.weightFile); + converter.SetConfigFile(flags.configFile); + converter.SetWeightFp16(flags.saveFP16); + converter.SetInputShape(flags.graph_input_shape_map); + converter.SetInputFormat(flags.graphInputFormat); + converter.SetInputDataType(flags.inputDataType); + converter.SetOutputDataType(flags.outputDataType); + converter.SetExportMindIR(flags.export_mindir); + converter.SetDecryptKey(flags.dec_key); + flags.dec_key.clear(); + converter.SetDecryptMode(flags.dec_mode); + converter.SetEnableEncryption(flags.encryption); + converter.SetEncryptKey(flags.encKeyStr); + flags.encKeyStr.clear(); + converter.SetInfer(flags.infer); + converter.SetTrainModel(flags.trainModel); + converter.SetNoFusion(flags.disableFusion); + + auto status = converter.Convert(); + if (status != mindspore::kSuccess) { + MS_LOG(ERROR) << "Convert failed. Ret: " << status; + std::cout << "Convert failed. Ret: " << status << std::endl; + } + return status.StatusCode(); +#ifndef Debug + } catch (const std::exception &e) { + std::cout << e.what() << std::endl; + std::cout << "encounter an unknown error, please verify the input model file or build the debug version\n"; + return mindspore::kLiteError; + } +#endif +} diff --git a/mindspore/lite/tools/converter/cxx_api/converter.cc b/mindspore/lite/tools/converter/cxx_api/converter.cc index a21d3c8306e..2e55fec9de0 100644 --- a/mindspore/lite/tools/converter/cxx_api/converter.cc +++ b/mindspore/lite/tools/converter/cxx_api/converter.cc @@ -16,6 +16,7 @@ #include "include/converter.h" #include "include/api/data_type.h" #include "tools/converter/cxx_api/converter_para.h" +#include "tools/converter/converter_context.h" #include "src/common/log_adapter.h" namespace mindspore { @@ -65,6 +66,9 @@ bool Converter::GetWeightFp16() const { void Converter::SetInputShape(const std::map> &input_shape) { if (data_ != nullptr) { + for (auto &it : input_shape) { + lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(it.first, it.second); + } data_->input_shape = input_shape; } } diff --git a/mindspore/lite/tools/converter/export_model.cc b/mindspore/lite/tools/converter/export_model.cc index a767b8f7abd..dcc3daab85a 100644 --- a/mindspore/lite/tools/converter/export_model.cc +++ b/mindspore/lite/tools/converter/export_model.cc @@ -59,7 +59,7 @@ void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, No } AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph, - const converter::Flags *flags) { + const std::shared_ptr ¶m) { MS_ASSERT(cnode != nullptr && mirror_graph != nullptr); if (index >= cnode->size()) { MS_LOG(ERROR) << "input index out of range."; @@ -92,9 +92,9 @@ AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const DataInfo data_info; STATUS status; if (utils::isa(node)) { - status = FetchDataFromParameterNode(cnode, index, flags->fmk, &data_info, true); + status = FetchDataFromParameterNode(cnode, index, param->fmk_type, &data_info, true); } else if (utils::isa(node)) { - status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info, true); + status = FetchDataFromValueNode(cnode, index, param->fmk_type, param->train_model, &data_info, true); } else { status = RET_ERROR; } @@ -152,10 +152,10 @@ PrimitivePtr ClonePrimitive(const CNodePtr &cnode) { } } // namespace -FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags, +FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const std::shared_ptr ¶m, std::map *cloned_func_graph) { MS_ASSERT(graph != nullptr); - MS_ASSERT(flags != nullptr); + MS_ASSERT(param != nullptr); MS_ASSERT(cloned_func_graph != nullptr); auto cloned_func_graph_iter = cloned_func_graph->find(graph); if (cloned_func_graph_iter != cloned_func_graph->end()) { @@ -196,10 +196,10 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *f if (mirror_input == nullptr) { if (IsValueNode(origin_input)) { auto sub_func_graph = GetValueNode(origin_input); - auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, flags, cloned_func_graph); + auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, param, cloned_func_graph); mirror_input = NewValueNode(mirror_sub_graph); } else { - mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, flags); + mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, param); } if (mirror_input == nullptr) { MS_LOG(ERROR) << "node input cannot be found."; @@ -226,11 +226,11 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *f return mirror_graph; } -STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) { +STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr ¶m) { CHECK_NULL_RETURN(graph); - CHECK_NULL_RETURN(flags); + CHECK_NULL_RETURN(param); std::map cloned_func_graph; - auto mirror_graph = CloneFuncGraph(graph, flags, &cloned_func_graph); + auto mirror_graph = CloneFuncGraph(graph, param, &cloned_func_graph); if (mirror_graph == nullptr) { MS_LOG(ERROR) << "Clone funcGraph failed."; return RET_ERROR; @@ -253,8 +253,8 @@ STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) { CHECK_NULL_RETURN(optimizer); auto graph_pm = std::make_shared("anf graph pass manager", true); CHECK_NULL_RETURN(graph_pm); - if (flags->fmk == converter::kFmkTypeTflite || flags->fmk == converter::kFmkTypeTf || - flags->fmk == converter::kFmkTypeOnnx) { + if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf || + param->fmk_type == converter::kFmkTypeOnnx) { graph_pm->AddPass(std::make_shared()); } optimizer->AddPassManager(graph_pm); @@ -274,7 +274,7 @@ STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) { return RET_ERROR; } metagraph_transform->SetGraphDef(meta_graph); - auto status = metagraph_transform->Transform(*flags); + auto status = metagraph_transform->Transform(param); if (status != RET_OK) { MS_LOG(ERROR) << "Transform meta graph failed " << status; delete meta_graph; diff --git a/mindspore/lite/tools/converter/export_model.h b/mindspore/lite/tools/converter/export_model.h index 3192594ce4f..08734124e8e 100644 --- a/mindspore/lite/tools/converter/export_model.h +++ b/mindspore/lite/tools/converter/export_model.h @@ -18,14 +18,16 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H #include -#include "tools/converter/converter_flags.h" +#include +#include "include/errorcode.h" #include "ir/func_graph.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore { namespace lite { -FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags, +FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const std::shared_ptr ¶m, std::map *cloned_func_graph); -STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags); +STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr ¶m); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index bf8a68a48c6..538b1ab1b3d 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -19,7 +19,6 @@ #include #include "schema/model_generated.h" #include "src/common/log_adapter.h" -#include "tools/converter/converter_flags.h" #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" #include "tools/converter/legacy_optimizer/graph/infershape_pass.h" @@ -51,11 +50,11 @@ std::vector GetGraphNodes(const schema::MetaGraphT &graph_defT return old_nodes; } -int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT) { +int QuantTransform(const std::shared_ptr ¶m, schema::MetaGraphT *graph_defT) { MS_ASSERT(graph_defT != nullptr); // quantization - if (ctx.commonQuantParam.quant_type == schema::QuantType_QUANT_NONE || - ctx.commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { + if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE || + param->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { { // quantization // init old node indices @@ -63,7 +62,7 @@ int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT) Optimizer tensor_quant_optimizer; tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass()); - tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); + tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type)); tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass()); tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); auto status = tensor_quant_optimizer.Run(graph_defT); @@ -78,8 +77,9 @@ int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT) Optimizer quant_node_optimizer; quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); auto old_nodes = GetGraphNodes(*graph_defT); - quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); - quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType)); + quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type)); + quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(static_cast(param->input_data_type), + static_cast(param->output_data_type))); quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); @@ -94,12 +94,12 @@ int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT) } } // namespace -int GraphDefTransform::Transform(const converter::Flags &ctx) { +int GraphDefTransform::Transform(const std::shared_ptr ¶m) { STATUS status; { auto old_nodes = GetGraphNodes(*graph_defT_); Optimizer unused_op_remove_optimizer; - if (!ctx.trainModel) { + if (!param->train_model) { unused_op_remove_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass()); } unused_op_remove_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); @@ -116,7 +116,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // init old node indices auto old_nodes = GetGraphNodes(*graph_defT_); Optimizer format_trans_optimizer; - if (!ctx.trainModel && ctx.fmk != converter::kFmkTypeOnnx) { + if (!param->train_model && param->fmk_type != converter::kFmkTypeOnnx) { format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); } @@ -127,7 +127,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - auto ret = QuantTransform(ctx, graph_defT_); + auto ret = QuantTransform(param, graph_defT_); if (ret != RET_OK && status != RET_NO_CHANGE) { return status; } @@ -149,10 +149,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { { Optimizer forming_model_optimizer; - forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); - forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(ctx)); + forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type)); + forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(param)); forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); - forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.saveFP16)); + forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(param->weight_fp16)); status = forming_model_optimizer.Run(graph_defT_); if (status != RET_OK) { MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed."; diff --git a/mindspore/lite/tools/converter/graphdef_transform.h b/mindspore/lite/tools/converter/graphdef_transform.h index dad205df468..19acfbe5cb7 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.h +++ b/mindspore/lite/tools/converter/graphdef_transform.h @@ -23,7 +23,6 @@ #include "tools/converter/quantizer/quantizer.h" #include "schema/inner/model_generated.h" #include "tools/common/meta_graph_serializer.h" -#include "tools/converter/converter_flags.h" namespace mindspore { namespace lite { @@ -35,7 +34,7 @@ class GraphDefTransform { public: GraphDefTransform(); virtual ~GraphDefTransform(); - virtual int Transform(const converter::Flags &ctx); + virtual int Transform(const std::shared_ptr ¶m); void SetGraphDef(schema::MetaGraphT *dst_def); protected: diff --git a/mindspore/lite/tools/converter/import/mindir_adjust.h b/mindspore/lite/tools/converter/import/mindir_adjust.h index 0ec545c12cf..30e0569d18a 100644 --- a/mindspore/lite/tools/converter/import/mindir_adjust.h +++ b/mindspore/lite/tools/converter/import/mindir_adjust.h @@ -19,8 +19,8 @@ #include #include -#include "tools/converter/converter_flags.h" #include "tools/optimizer/common/gllo_utils.h" +#include "include/registry/converter_context.h" using mindspore::converter::FmkType; using mindspore::schema::QuantType; diff --git a/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h b/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h index b824b7772c8..c6663a6c6b3 100644 --- a/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h +++ b/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h @@ -20,8 +20,8 @@ #include #include #include -#include "tools/converter/converter_flags.h" #include "tools/optimizer/common/gllo_utils.h" +#include "include/registry/converter_context.h" using mindspore::converter::FmkType; using mindspore::schema::QuantType; diff --git a/mindspore/lite/tools/converter/import/mindspore_importer.cc b/mindspore/lite/tools/converter/import/mindspore_importer.cc index d411f48ffe3..70637b66044 100644 --- a/mindspore/lite/tools/converter/import/mindspore_importer.cc +++ b/mindspore/lite/tools/converter/import/mindspore_importer.cc @@ -42,11 +42,12 @@ constexpr size_t kDependInputNum = 3; constexpr size_t kDependFirstInputIdx = 1; constexpr size_t kTupleGetItemFirstInputIdx = 1; } // namespace -STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag) { +STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, + const std::shared_ptr ¶m) { MS_ASSERT(func_graph != nullptr); auto primitive_adjust_pass = std::make_shared(); MS_CHECK_TRUE_MSG(primitive_adjust_pass != nullptr, RET_NULL_PTR, "primitive_adjust_pass is nullptr."); - primitive_adjust_pass->SetFmkType(flag.fmk); + primitive_adjust_pass->SetFmkType(param->fmk_type); if (!primitive_adjust_pass->Run(func_graph)) { MS_LOG(ERROR) << "primitive adjust failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); @@ -54,14 +55,14 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const } auto mindir_adjust_pass = std::make_shared(); MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr."); - mindir_adjust_pass->SetFmkType(flag.fmk); - mindir_adjust_pass->SetTrainFlag(flag.trainModel); + mindir_adjust_pass->SetFmkType(param->fmk_type); + mindir_adjust_pass->SetTrainFlag(param->train_model); if (!mindir_adjust_pass->Run(func_graph)) { MS_LOG(ERROR) << "MindIr adjust failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return RET_ERROR; } - if (!flag.trainModel) { + if (!param->train_model) { auto cast_op_adjust = std::make_shared(); MS_CHECK_TRUE_MSG(cast_op_adjust != nullptr, RET_NULL_PTR, "cast_op_adjust is nullptr."); if (!cast_op_adjust->Run(func_graph)) { @@ -72,7 +73,7 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const } auto mindir_control_flow_adjust = std::make_shared(); MS_CHECK_TRUE_MSG(mindir_control_flow_adjust != nullptr, RET_NULL_PTR, "mindir_control_flow_adjust is nullptr."); - mindir_control_flow_adjust->SetFmkType(flag.fmk); + mindir_control_flow_adjust->SetFmkType(param->fmk_type); if (!mindir_control_flow_adjust->Run(func_graph)) { MS_LOG(ERROR) << "MindIR control flow adjust failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); @@ -221,35 +222,37 @@ void MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) { } } -FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size) { +FuncGraphPtr MindsporeImporter::ImportMindIR(const std::shared_ptr ¶m, const void *buff, + const size_t &size) { MindIRLoader mindir_loader; auto func_graph = mindir_loader.LoadMindIR(buff, size); - return CheckAndUpdateFuncGraph(flag, func_graph); + return CheckAndUpdateFuncGraph(param, func_graph); } -FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) { +FuncGraphPtr MindsporeImporter::ImportMindIR(const std::shared_ptr ¶m) { FuncGraphPtr func_graph; - if (!flag.dec_key.empty()) { + if (!param->decrypt_key.empty()) { unsigned char key[32]; - const size_t key_len = Hex2ByteArray(flag.dec_key, key, 32); + const size_t key_len = Hex2ByteArray(param->decrypt_key, key, 32); if (key_len == 0) { return nullptr; } - MindIRLoader mindir_loader(false, key, key_len, flag.dec_mode, false); - func_graph = mindir_loader.LoadMindIR(flag.modelFile); + MindIRLoader mindir_loader(false, key, key_len, param->decrypt_mode, false); + func_graph = mindir_loader.LoadMindIR(param->model_file); auto ret = memset_s(key, sizeof(key), 0, key_len); if (ret != 0) { MS_LOG(EXCEPTION) << "memset_s error"; } } else { MindIRLoader mindir_loader; - func_graph = mindir_loader.LoadMindIR(flag.modelFile); + func_graph = mindir_loader.LoadMindIR(param->model_file); } - return CheckAndUpdateFuncGraph(flag, func_graph); + return CheckAndUpdateFuncGraph(param, func_graph); } -FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph) { +FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr ¶m, + FuncGraphPtr func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR"; MS_LOG(ERROR) @@ -299,12 +302,12 @@ FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags & MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend."; return func_graph; #endif - if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) { + if ((status = Mindir2AnfAdjust(func_graph, param)) != RET_OK) { MS_LOG(ERROR) << "Mindir2AnfAdjust failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - auto unify_format = std::make_shared(converter::kFmkTypeMs, flag.trainModel); + auto unify_format = std::make_shared(converter::kFmkTypeMs, param->train_model); MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr."); if (!unify_format->Run(func_graph)) { MS_LOG(ERROR) << "Run insert transpose failed."; diff --git a/mindspore/lite/tools/converter/import/mindspore_importer.h b/mindspore/lite/tools/converter/import/mindspore_importer.h index 3c9e75bed7c..3899cf05d4c 100644 --- a/mindspore/lite/tools/converter/import/mindspore_importer.h +++ b/mindspore/lite/tools/converter/import/mindspore_importer.h @@ -20,23 +20,25 @@ #include #include #include -#include "tools/converter/converter_flags.h" +#include #include "load_mindir/load_model.h" +#include "tools/converter/cxx_api/converter_para.h" +#include "include/errorcode.h" namespace mindspore::lite { class MindsporeImporter { public: MindsporeImporter() = default; ~MindsporeImporter() = default; - FuncGraphPtr ImportMindIR(const converter::Flags &flag); - FuncGraphPtr ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size); + FuncGraphPtr ImportMindIR(const std::shared_ptr ¶m); + FuncGraphPtr ImportMindIR(const std::shared_ptr ¶m, const void *buff, const size_t &size); private: static void RemoveUnusedGraphInput(const FuncGraphPtr &func_graph); STATUS GetFuncGraphOutputName(const CNodePtr &cnode); STATUS TraceOutput(const AnfNodePtr &node); - FuncGraphPtr CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph); - STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag); + FuncGraphPtr CheckAndUpdateFuncGraph(const std::shared_ptr ¶m, FuncGraphPtr func_graph); + STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m); std::vector output_tensor_name_; }; diff --git a/mindspore/lite/tools/converter/import/primitive_adjust.h b/mindspore/lite/tools/converter/import/primitive_adjust.h index 349a252cf71..98bcb02b9df 100644 --- a/mindspore/lite/tools/converter/import/primitive_adjust.h +++ b/mindspore/lite/tools/converter/import/primitive_adjust.h @@ -21,8 +21,8 @@ #include #include #include "backend/common/optimizer/pass.h" -#include "tools/converter/converter_flags.h" #include "tools/optimizer/common/gllo_utils.h" +#include "include/registry/converter_context.h" using mindspore::converter::FmkType; namespace mindspore { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index 78e0c83584e..3b7f26a0d84 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -28,7 +28,6 @@ #include "src/runtime/infer_manager.h" #include "src/common/primitive_t_utils.h" #include "tools/common/node_util.h" -#include "tools/converter/converter_flags.h" #include "src/common/string_utils.h" #include "src/common/log_util.h" #include "nnacl/op_base.h" diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h index 986c2307571..0d8aeacb556 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h @@ -25,10 +25,8 @@ #include #include "tools/common/graph_util.h" #include "tools/converter/optimizer.h" -#include "tools/converter/converter_flags.h" +#include "include/registry/converter_context.h" -using mindspore::converter::kFmkTypeTf; -using mindspore::schema::TensorT; namespace mindspore { namespace lite { struct InferTensor { @@ -61,7 +59,7 @@ class InferShapePass : public GraphPass { void InitInferTensor(MetaGraphT *graph); int InferSubgraph(const int64_t &subgraph_index, MetaGraphT *graph); - converter::FmkType fmk_type_ = kFmkTypeTf; + converter::FmkType fmk_type_ = converter::kFmkTypeTf; std::vector tensors_ = {}; std::set partial_cnode_inferred_{}; }; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc index 3bc1675464e..f5e8023eaed 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc @@ -23,7 +23,7 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) { for (auto &tensor : graph->allTensors) { bool has_quant_param = false; for (auto &quant_param : tensor->quantParams) { - if (ctx_.fullQuantParam.target_device != quant::NVGPU) { + if (param_->fullQuantParam.target_device != quant::NVGPU) { quant_param->min = 0.0; quant_param->max = 0.0; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h index ec397421aee..28da8677e61 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h @@ -17,21 +17,22 @@ #define LITE_UNUSED_QUANT_PARAM_DATA_REMOVE_PASS_H #include #include "tools/converter/optimizer.h" -#include "tools/converter/converter_flags.h" #include "tools/common/graph_util.h" +#include "tools/converter/cxx_api/converter_para.h" + namespace mindspore { namespace lite { class SetUnusedQuantParamToDefaultPass : public GraphPass { public: SetUnusedQuantParamToDefaultPass() {} - explicit SetUnusedQuantParamToDefaultPass(const converter::Flags &ctx) : ctx_(ctx) {} + explicit SetUnusedQuantParamToDefaultPass(const std::shared_ptr ¶m) : param_(param) {} ~SetUnusedQuantParamToDefaultPass() override = default; STATUS Run(schema::MetaGraphT *graph) override; private: - converter::Flags ctx_; + const std::shared_ptr param_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/main.cc b/mindspore/lite/tools/converter/main.cc deleted file mode 100644 index 25f8c9c7f51..00000000000 --- a/mindspore/lite/tools/converter/main.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#if defined(__linux__) && !defined(Debug) -#include -#endif -#define USE_DEPRECATED_API -#include "tools/converter/converter.h" - -#if defined(__linux__) && !defined(Debug) -void SignalHandler(int sig) { - printf("encounter an unknown error, please verify the input model file or build the debug version\n"); - exit(1); -} -#endif - -namespace mindspore { -extern "C" { -extern void common_log_init(); -} -} // namespace mindspore - -int main(int argc, const char **argv) { -#if defined(__linux__) && !defined(Debug) - signal(SIGSEGV, SignalHandler); - signal(SIGABRT, SignalHandler); - signal(SIGFPE, SignalHandler); - signal(SIGBUS, SignalHandler); -#endif - int ret = 0; -#ifndef Debug - try { -#endif - mindspore::common_log_init(); - ret = mindspore::lite::RunConverter(argc, argv); -#ifndef Debug - } catch (const std::exception &e) { - ret = mindspore::lite::RET_ERROR; - std::cout << e.what() << std::endl; - std::cout << "encounter an unknown error, please verify the input model file or build the debug version\n"; - } -#endif - return ret; -} diff --git a/mindspore/lite/tools/converter/parser/tf/functionalize_cond.h b/mindspore/lite/tools/converter/parser/tf/functionalize_cond.h index ca05bc6324d..3b52497ff85 100644 --- a/mindspore/lite/tools/converter/parser/tf/functionalize_cond.h +++ b/mindspore/lite/tools/converter/parser/tf/functionalize_cond.h @@ -22,11 +22,9 @@ #include #include #include "backend/common/optimizer/pass.h" -#include "tools/converter/converter_flags.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/parser/tf/functionalize_control_op_pass.h" -using mindspore::converter::FmkType; namespace mindspore::opt { typedef enum { kThenBranch = 0, kElseBranch = 1 } BranchType; diff --git a/mindspore/lite/tools/converter/parser/tf/functionalize_control_op_pass.h b/mindspore/lite/tools/converter/parser/tf/functionalize_control_op_pass.h index 1258b7fa4f4..c3dbd034bf5 100644 --- a/mindspore/lite/tools/converter/parser/tf/functionalize_control_op_pass.h +++ b/mindspore/lite/tools/converter/parser/tf/functionalize_control_op_pass.h @@ -22,9 +22,9 @@ #include #include #include "backend/common/optimizer/pass.h" -#include "tools/converter/converter_flags.h" #include "tools/converter/ops/ops_def.h" #include "tools/optimizer/common/gllo_utils.h" +#include "include/registry/converter_context.h" using mindspore::converter::FmkType; namespace mindspore::opt { diff --git a/mindspore/lite/tools/converter/parser/tf/functionalize_while.h b/mindspore/lite/tools/converter/parser/tf/functionalize_while.h index 100ccfced6b..cb413003363 100644 --- a/mindspore/lite/tools/converter/parser/tf/functionalize_while.h +++ b/mindspore/lite/tools/converter/parser/tf/functionalize_while.h @@ -21,11 +21,9 @@ #include #include #include "backend/common/optimizer/pass.h" -#include "tools/converter/converter_flags.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/parser/tf/functionalize_control_op_pass.h" -using mindspore::converter::FmkType; namespace mindspore::opt { constexpr const int POS_INVALID = -1; diff --git a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt index d7ddd976574..141aec657b0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt +++ b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt @@ -3,6 +3,7 @@ merge_parser(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_BINARY_DIR}/tools/converter/par file(GLOB_RECURSE TFLITE_SRC_LIST ${CMAKE_BINARY_DIR}/tools/converter/parser/tflite/tflite_op_parser.cc) set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) +set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS FLATBUFFERS_LOCALE_INDEPENDENT=0) add_library(tflite_parser_mid OBJECT ${TFLITE_SRC_LIST} ) diff --git a/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.cc b/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.cc index 339afd2c29b..79cef79e0fe 100644 --- a/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.cc +++ b/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.cc @@ -313,7 +313,7 @@ int BiasCorrectionStrategy::Int8Inference(const MSKernelCallBack &before_call_ba MS_LOG(ERROR) << "New model failed."; return RET_ERROR; } - auto ret = BuildModelByFuncGraph(int8_model, quant_func_graph, flags_); + auto ret = BuildModelByFuncGraph(int8_model, quant_func_graph, param_); if (ret != kSuccess) { MS_LOG(ERROR) << "Build error."; return RET_ERROR; @@ -637,7 +637,7 @@ MSKernelCallBack BiasCorrectionStrategy::GetNVGPUInt8AfterCallBack() { int BiasCorrectionStrategy::DoBiasCorrection(const FuncGraphPtr &quant_func_graph) { int status; - switch (this->flags_.fullQuantParam.target_device) { + switch (param_->fullQuantParam.target_device) { case CPU: status = DoCPUBiasCorrection(quant_func_graph); break; @@ -645,8 +645,7 @@ int BiasCorrectionStrategy::DoBiasCorrection(const FuncGraphPtr &quant_func_grap status = DoNVGPUBiasCorrection(quant_func_graph); break; default: - MS_LOG(ERROR) << "Unsupported target device " << this->flags_.fullQuantParam.target_device - << " for bias correction."; + MS_LOG(ERROR) << "Unsupported target device " << param_->fullQuantParam.target_device << " for bias correction."; return RET_ERROR; } if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.h b/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.h index 3960be92589..3916bd58346 100644 --- a/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.h +++ b/mindspore/lite/tools/converter/quantizer/bias_correction_strategy.h @@ -39,10 +39,10 @@ enum CallBackType { class BiasCorrectionStrategy { public: - BiasCorrectionStrategy(const converter::Flags &flags, const std::shared_ptr &calibrator, + BiasCorrectionStrategy(const std::shared_ptr ¶m, const std::shared_ptr &calibrator, const std::shared_ptr &quant_strategy, std::shared_ptr fp32_ms_model, int activation_q_min, int activation_q_max) - : flags_(flags), + : param_(param), calibrator_(calibrator), quant_strategy_(quant_strategy), fp32_ms_model_(fp32_ms_model), @@ -133,7 +133,7 @@ class BiasCorrectionStrategy { } private: - converter::Flags flags_; + const std::shared_ptr param_; std::shared_ptr calibrator_{nullptr}; std::shared_ptr quant_strategy_{nullptr}; std::shared_ptr fp32_ms_model_{nullptr}; diff --git a/mindspore/lite/tools/converter/quantizer/cle_strategy.h b/mindspore/lite/tools/converter/quantizer/cle_strategy.h index 9c961c3896e..9f0f3d5a6be 100644 --- a/mindspore/lite/tools/converter/quantizer/cle_strategy.h +++ b/mindspore/lite/tools/converter/quantizer/cle_strategy.h @@ -28,8 +28,7 @@ namespace mindspore::lite::quant { class CLEStrategy { public: - explicit CLEStrategy(const FuncGraphPtr &func_graph, const converter::Flags &flags) - : func_graph_(func_graph), flags_(flags) {} + explicit CLEStrategy(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {} ~CLEStrategy(); @@ -54,7 +53,6 @@ class CLEStrategy { private: FuncGraphPtr func_graph_ = nullptr; - converter::Flags flags_; CLEPattern *cle_pattern_ = nullptr; }; } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/debug_info_manager.cc b/mindspore/lite/tools/converter/quantizer/debug_info_manager.cc index 86933b7b59a..7dec64b14d7 100644 --- a/mindspore/lite/tools/converter/quantizer/debug_info_manager.cc +++ b/mindspore/lite/tools/converter/quantizer/debug_info_manager.cc @@ -663,11 +663,11 @@ int DebugInfoManager::SaveOutputInfo(const std::string &file_path) { int DebugInfoManager::StatisticsDataPerRound( const std::shared_ptr &origin, const std::shared_ptr &quant, - const std::map &op_parameters, const converter::Flags &config, + const std::map &op_parameters, const std::shared_ptr ¶m, const std::map &origin_input_tensor_map, const std::map &quant_input_tensor_map, const int &round) { int ret; - auto data_preprocess = config.dataPreProcessParam; + auto data_preprocess = param->dataPreProcessParam; for (auto tensor : origin->GetInputs()) { if (data_preprocess.calibrate_size > 0) { ret = preprocess::PreProcess(data_preprocess, tensor.Name(), round, &tensor); @@ -681,8 +681,8 @@ int DebugInfoManager::StatisticsDataPerRound( } std::cout << "Statistics the original data distribution. Round " << round << std::endl; auto origin_before_callBack = - GetBeforeCallBack(origin_input_tensor_map, op_parameters, true, config.commonQuantParam.debug_mode); - auto origin_after_callBack = GetAfterCallBack(op_parameters, true, config.commonQuantParam.debug_mode); + GetBeforeCallBack(origin_input_tensor_map, op_parameters, true, param->commonQuantParam.debug_mode); + auto origin_after_callBack = GetAfterCallBack(op_parameters, true, param->commonQuantParam.debug_mode); auto origin_outputs = origin->GetOutputs(); auto status = origin->Predict(origin->GetInputs(), &origin_outputs, origin_before_callBack, origin_after_callBack); if (status != kSuccess) { @@ -692,8 +692,8 @@ int DebugInfoManager::StatisticsDataPerRound( std::cout << "Statistics the quant data distribution. Round " << round << std::endl; auto quant_before_callBack = - GetBeforeCallBack(quant_input_tensor_map, op_parameters, false, config.commonQuantParam.debug_mode); - auto quant_after_callBack = GetAfterCallBack(op_parameters, false, config.commonQuantParam.debug_mode); + GetBeforeCallBack(quant_input_tensor_map, op_parameters, false, param->commonQuantParam.debug_mode); + auto quant_after_callBack = GetAfterCallBack(op_parameters, false, param->commonQuantParam.debug_mode); for (auto &tensor : quant->GetInputs()) { auto tensor_data = tensor.MutableData(); CHECK_NULL_RETURN(tensor_data); @@ -722,7 +722,7 @@ std::string DebugInfoManager::CreateFilePath(const std::string &dir_path, const int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptr &origin, const std::shared_ptr &quant, const std::map &op_parameters, - const converter::Flags &config, + const std::shared_ptr ¶m, const mindspore::lite::LiteModel &origin_lite_model, const mindspore::lite::LiteModel &quant_lite_model) { auto begin = GetTimeUs(); @@ -732,27 +732,27 @@ int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptrdataPreProcessParam; // When the calibration data set does not exist, use 1 round of random numbers for comparison int rounds = data_preprocess.calibrate_size > 0 ? data_preprocess.calibrate_size : 1; for (int round = 0; round < rounds; round++) { - ret = StatisticsDataPerRound(origin, quant, op_parameters, config, origin_input_tensor_map, quant_input_tensor_map, + ret = StatisticsDataPerRound(origin, quant, op_parameters, param, origin_input_tensor_map, quant_input_tensor_map, round); if (ret != RET_OK) { MS_LOG(ERROR) << "Statistics Data failed for round: " << round; FreeBuffer(); return RET_ERROR; } - ret = GetClipAndCos(config.commonQuantParam.debug_mode); + ret = GetClipAndCos(param->commonQuantParam.debug_mode); if (ret != RET_OK) { MS_LOG(ERROR) << "Get clip and cos failed."; FreeBuffer(); return ret; } GetOutputInfo(); - if (config.commonQuantParam.debug_mode == quant::DETAIL) { + if (param->commonQuantParam.debug_mode == quant::DETAIL) { auto file_name = "round_" + std::to_string(round) + ".csv"; - auto file_path = CreateFilePath(config.commonQuantParam.debug_info_save_path, file_name); + auto file_path = CreateFilePath(param->commonQuantParam.debug_info_save_path, file_name); ret = SaveInfo(file_path); if (ret != RET_OK) { MS_LOG(ERROR) << "Failed to save debug info to " + file_path; @@ -765,7 +765,7 @@ int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptrcommonQuantParam.debug_info_save_path, file_name); ret = SaveQuantParam(quant_param_save_path); if (ret != RET_OK) { MS_LOG(ERROR) << "Failed to save quant param to " + quant_param_save_path; @@ -773,7 +773,7 @@ int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptrcommonQuantParam.debug_info_save_path, file_name); ret = SaveOutputInfo(output_param_save_path); if (ret != RET_OK) { MS_LOG(ERROR) << "Failed to save output param to " + output_param_save_path; diff --git a/mindspore/lite/tools/converter/quantizer/debug_info_manager.h b/mindspore/lite/tools/converter/quantizer/debug_info_manager.h index 45141aeb72b..180f538a3e9 100644 --- a/mindspore/lite/tools/converter/quantizer/debug_info_manager.h +++ b/mindspore/lite/tools/converter/quantizer/debug_info_manager.h @@ -87,7 +87,8 @@ class DebugInfoManager { public: int CompareOriginWithQuant(const std::shared_ptr &origin, const std::shared_ptr &quant, - const std::map &op_parameters, const converter::Flags &config, + const std::map &op_parameters, + const std::shared_ptr ¶m, const mindspore::lite::LiteModel &origin_lite_model, const mindspore::lite::LiteModel &quant_lite_model); @@ -158,7 +159,8 @@ class DebugInfoManager { int StatisticsDataPerRound(const std::shared_ptr &origin, const std::shared_ptr &quant, - const std::map &op_parameters, const converter::Flags &config, + const std::map &op_parameters, + const std::shared_ptr ¶m, const std::map &origin_input_tensor_map, const std::map &quant_input_tensor_map, const int &round); diff --git a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc index ef2211754b2..51d3d9927dd 100644 --- a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc @@ -22,9 +22,9 @@ namespace mindspore::lite::quant { int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) { // Dynamic dont support filters. - flags_.commonQuantParam.min_quant_weight_channel = 0; - flags_.commonQuantParam.min_quant_weight_size = 0; - auto quantizer = WeightQuantizer(flags_); + param_->commonQuantParam.min_quant_weight_channel = 0; + param_->commonQuantParam.min_quant_weight_size = 0; + auto quantizer = WeightQuantizer(param_); const std::set support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather}; const std::set symmetric_nodes = {prim::kPrimMatMulFusion}; auto ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes); @@ -36,7 +36,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) { const std::set support_dynamic_quant_ops = { prim::kPrimMatMulFusion, }; - ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, flags_.commonQuantParam.skip_quant_node); + ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, param_->commonQuantParam.skip_quant_node); if (ret != RET_OK) { MS_LOG(ERROR) << "Insert dynamic quant failed."; return ret; diff --git a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h index ee2690a9fdc..8a172e7bc07 100644 --- a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h @@ -41,8 +41,8 @@ namespace mindspore::lite::quant { class DynamicQuantizer : public Quantizer { public: - explicit DynamicQuantizer(const converter::Flags &flags) : Quantizer(flags) { - bit_num_ = flags.commonQuantParam.bit_num; + explicit DynamicQuantizer(const std::shared_ptr ¶m) : Quantizer(param) { + bit_num_ = param->commonQuantParam.bit_num; } ~DynamicQuantizer() = default; diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc index 95300b229a8..a04ea1c3593 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc @@ -180,7 +180,7 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const Parame MS_LOG(ERROR) << op_name << " Do bias quant failed."; return ret; } - } else if (flags_.fullQuantParam.per_channel && CheckNodeInSet(cnode, per_channel_ops_)) { + } else if (param_->fullQuantParam.per_channel && CheckNodeInSet(cnode, per_channel_ops_)) { ret = DoParameterWeightQuant(cnode, input_node, primitive, input_index, true); if (ret != RET_OK) { MS_LOG(ERROR) << op_name << " Do bias quant failed."; @@ -443,7 +443,7 @@ void FullQuantQuantizer::InitKirinConfig() { weight_channel_symmetric_ = true; weight_layer_symmetric_ = false; support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection}; - flags_.fullQuantParam.bias_correction = false; + param_->fullQuantParam.bias_correction = false; per_channel_ops_ = {prim::kPrimConv2DFusion}; } @@ -457,7 +457,7 @@ void FullQuantQuantizer::InitNvGpuConfig() { support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMul, prim::kPrimActivation, prim::kPrimConv2dTransposeFusion}; per_channel_ops_ = {}; - flags_.fullQuantParam.bias_correction = false; + param_->fullQuantParam.bias_correction = false; } void FullQuantQuantizer::InitQMinMax() { @@ -509,7 +509,7 @@ int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) { } int FullQuantQuantizer::InitDeviceConfig(const FuncGraphPtr &func_graph) { - switch (flags_.fullQuantParam.target_device) { + switch (param_->fullQuantParam.target_device) { case CPU: InitCpuConfig(); break; @@ -520,18 +520,18 @@ int FullQuantQuantizer::InitDeviceConfig(const FuncGraphPtr &func_graph) { InitNvGpuConfig(); break; default: - MS_LOG(ERROR) << " Unsupported device " << flags_.fullQuantParam.target_device; + MS_LOG(ERROR) << " Unsupported device " << param_->fullQuantParam.target_device; return RET_ERROR; break; } InitQMinMax(); calibrator_ = std::make_shared(this->bit_num_, activation_q_max_, activation_q_min_, - this->flags_.fullQuantParam.activation_quant_method, - this->flags_.dataPreProcessParam, activation_symmetric_); + param_->fullQuantParam.activation_quant_method, + param_->dataPreProcessParam, activation_symmetric_); MSLITE_CHECK_PTR(calibrator_); - quant_strategy_ = std::make_unique(flags_.commonQuantParam.min_quant_weight_size, - flags_.commonQuantParam.min_quant_weight_channel, - flags_.commonQuantParam.skip_quant_node); + quant_strategy_ = std::make_unique(param_->commonQuantParam.min_quant_weight_size, + param_->commonQuantParam.min_quant_weight_channel, + param_->commonQuantParam.skip_quant_node); CHECK_NULL_RETURN(quant_strategy_); auto ret = MarkQuantNode(func_graph); if (ret != RET_OK) { @@ -593,7 +593,7 @@ int FullQuantQuantizer::DoInference(CollectType collect_type) { int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) { MS_ASSERT(func_graph != nullptr); MS_LOG(INFO) << "start to parse config file"; - if (flags_.dataPreProcessParam.calibrate_path.empty()) { + if (param_->dataPreProcessParam.calibrate_path.empty()) { MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir."; return RET_INPUT_PARAM_INVALID; } @@ -611,7 +611,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) { MS_LOG(ERROR) << "New model failed."; return RET_ERROR; } - auto ret = BuildModelByFuncGraph(fp32_ms_model_, func_graph, flags_); + auto ret = BuildModelByFuncGraph(fp32_ms_model_, func_graph, param_); if (ret != mindspore::kSuccess) { MS_LOG(ERROR) << "Build model failed."; return RET_ERROR; @@ -623,7 +623,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) { return status; } - if (flags_.fullQuantParam.activation_quant_method == KL) { + if (param_->fullQuantParam.activation_quant_method == KL) { status = QuantWithKL(); if (status != RET_OK) { MS_LOG(ERROR) << "Quant with KL failed."; @@ -649,9 +649,9 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) { } } - if (this->flags_.fullQuantParam.bias_correction) { + if (param_->fullQuantParam.bias_correction) { MS_LOG(INFO) << "do bias correction"; - BiasCorrectionStrategy strategy(flags_, calibrator_, quant_strategy_, fp32_ms_model_, activation_q_min_, + BiasCorrectionStrategy strategy(param_, calibrator_, quant_strategy_, fp32_ms_model_, activation_q_min_, activation_q_max_); status = strategy.DoBiasCorrection(func_graph); if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h index 830fdaf417e..bb4409deeb7 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h @@ -39,8 +39,8 @@ namespace mindspore::lite::quant { class FullQuantQuantizer : public Quantizer { public: - explicit FullQuantQuantizer(const converter::Flags &flags) : Quantizer(flags) { - bit_num_ = flags.commonQuantParam.bit_num; + explicit FullQuantQuantizer(const std::shared_ptr ¶m) : Quantizer(param) { + bit_num_ = param_->commonQuantParam.bit_num; } ~FullQuantQuantizer() override; diff --git a/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc b/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc index 0fe18db3776..719c607aa1f 100644 --- a/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc +++ b/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc @@ -44,12 +44,12 @@ MinMax ParameterOptimizer::GetFineTuneRange(std::vector *candidate_scales return min_max; } -int ParameterOptimizer::CloneFuncGraph(const FuncGraphPtr &func_graph, converter::Flags *flags, +int ParameterOptimizer::CloneFuncGraph(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m, FuncGraphPtr *func_graph_bak) { CHECK_NULL_RETURN(func_graph_bak); - CHECK_NULL_RETURN(flags); + CHECK_NULL_RETURN(param); std::map cloned_func_graph; - *func_graph_bak = lite::CloneFuncGraph(func_graph, flags, &cloned_func_graph); + *func_graph_bak = lite::CloneFuncGraph(func_graph, param, &cloned_func_graph); CHECK_NULL_RETURN(*func_graph_bak); static auto root_func_manager = Manage(*func_graph_bak); std::set all_func_graphs = {}; @@ -60,11 +60,12 @@ int ParameterOptimizer::CloneFuncGraph(const FuncGraphPtr &func_graph, converter return RET_OK; } -int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags, +int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph, + const std::shared_ptr ¶m, std::shared_ptr origin_model, int origin_model_size, - const InferenceParam ¶m, double *init_scale, + const InferenceParam &infer_param, double *init_scale, std::vector *candidate_scales, bool is_run_all) { - CHECK_NULL_RETURN(flags); + CHECK_NULL_RETURN(param); CHECK_NULL_RETURN(origin_model); CHECK_NULL_RETURN(init_scale); CHECK_NULL_RETURN(candidate_scales); @@ -75,18 +76,18 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph float best_compress_cos_sim = 0.0f; int best_compress_model_size = 0; size_t over_error_count = 0; - for (size_t round = 0; round < param.rounds; round++) { - auto scale = param.start_scale + round * param.step; - flags->commonQuantParam.quant_type = schema::QuantType_QUANT_WEIGHT; + for (size_t round = 0; round < infer_param.rounds; round++) { + auto scale = infer_param.start_scale + round * infer_param.step; + param->commonQuantParam.quant_type = schema::QuantType_QUANT_WEIGHT; FuncGraphPtr func_graph_bak; - auto ret = CloneFuncGraph(func_graph, flags, &func_graph_bak); + auto ret = CloneFuncGraph(func_graph, param, &func_graph_bak); if (ret != RET_OK) { MS_LOG(ERROR) << "Clone FuncGraph failed."; return ret; } // quant - auto quantizer = std::make_unique(*flags); + auto quantizer = std::make_unique(param); CHECK_NULL_RETURN(quantizer); auto status = quantizer->DoQuantize(func_graph_bak, scale); if (status != RET_OK) { @@ -98,7 +99,7 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph int weight_quant_size; auto weight_quant_model = std::make_shared(); CHECK_NULL_RETURN(weight_quant_model); - auto build_status = BuildModelByFuncGraph(weight_quant_model, func_graph_bak, *flags, &weight_quant_size); + auto build_status = BuildModelByFuncGraph(weight_quant_model, func_graph_bak, param, &weight_quant_size); if (build_status != kSuccess) { MS_LOG(WARNING) << "build model failed!"; continue; @@ -155,28 +156,29 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph return RET_OK; } -int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags, +int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, + const std::shared_ptr ¶m, std::shared_ptr origin_model, int *origin_model_size) { - CHECK_NULL_RETURN(flags); + CHECK_NULL_RETURN(param); CHECK_NULL_RETURN(origin_model); CHECK_NULL_RETURN(origin_model_size); FuncGraphPtr func_graph_bak; - auto ret = CloneFuncGraph(func_graph, flags, &func_graph_bak); + auto ret = CloneFuncGraph(func_graph, param, &func_graph_bak); if (ret != RET_OK) { MS_LOG(ERROR) << "Clone FuncGraph failed."; return RET_ERROR; } - flags->commonQuantParam.quant_type = schema::QuantType_QUANT_NONE; + param->commonQuantParam.quant_type = schema::QuantType_QUANT_NONE; *origin_model_size = 0; - auto status = BuildModelByFuncGraph(origin_model, func_graph_bak, *flags, origin_model_size); + auto status = BuildModelByFuncGraph(origin_model, func_graph_bak, param, origin_model_size); if (status != kSuccess) { MS_LOG(ERROR) << "build model failed!"; return RET_ERROR; } auto origin_inputs = origin_model->GetInputs(); for (auto input : origin_inputs) { - if (flags->dataPreProcessParam.calibrate_size > 0) { - ret = preprocess::PreProcess(flags->dataPreProcessParam, input.Name(), 0, &input); + if (param->dataPreProcessParam.calibrate_size > 0) { + ret = preprocess::PreProcess(param->dataPreProcessParam, input.Name(), 0, &input); } else { ret = GenerateRandomData(&input); } @@ -195,16 +197,16 @@ int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, con return RET_OK; } -int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, converter::Flags *flags, +int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m, double *init_scale) { - CHECK_NULL_RETURN(flags); + CHECK_NULL_RETURN(param); CHECK_NULL_RETURN(init_scale); double default_init_scale = *init_scale; auto origin_model = std::make_shared(); int origin_model_size; - auto ret = OriginModelInference(func_graph, flags, origin_model, &origin_model_size); + auto ret = OriginModelInference(func_graph, param, origin_model, &origin_model_size); if (ret != RET_OK) { MS_LOG(ERROR) << "Origin Model Inference failed."; return ret; @@ -215,14 +217,14 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve float step = 0.005f; std::vector candidate_scales; - InferenceParam param{}; - param.rounds = giant_rounds; - param.start_scale = start_scale; - param.step = step; - param.thread_num = flags->commonQuantParam.thread_num; + InferenceParam infer_param{}; + infer_param.rounds = giant_rounds; + infer_param.start_scale = start_scale; + infer_param.step = step; + infer_param.thread_num = param->commonQuantParam.thread_num; std::cout << "==========Search with giant step==============\n"; - ret = WeightQuantModelInference(func_graph, flags, origin_model, origin_model_size, param, init_scale, + ret = WeightQuantModelInference(func_graph, param, origin_model, origin_model_size, infer_param, init_scale, &candidate_scales, false); if (ret != RET_OK) { MS_LOG(ERROR) << "Weight quant graph inference failed."; @@ -240,12 +242,12 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve const int baby_step_rounds = 25; step = (min_max.max - min_max.min) / baby_step_rounds; - param.rounds = baby_step_rounds; - param.start_scale = start_scale; - param.step = step; - param.thread_num = flags->commonQuantParam.thread_num; + infer_param.rounds = baby_step_rounds; + infer_param.start_scale = start_scale; + infer_param.step = step; + infer_param.thread_num = param->commonQuantParam.thread_num; std::cout << "==========Search with baby step==============\n"; - ret = WeightQuantModelInference(func_graph, flags, origin_model, origin_model_size, param, init_scale, + ret = WeightQuantModelInference(func_graph, param, origin_model, origin_model_size, infer_param, init_scale, &candidate_scales, true); if (ret != RET_OK) { MS_LOG(ERROR) << "Weight quant graph inference failed."; diff --git a/mindspore/lite/tools/converter/quantizer/parameter_tunner.h b/mindspore/lite/tools/converter/quantizer/parameter_tunner.h index 5685d3499af..7b211ae3a2d 100644 --- a/mindspore/lite/tools/converter/quantizer/parameter_tunner.h +++ b/mindspore/lite/tools/converter/quantizer/parameter_tunner.h @@ -28,7 +28,6 @@ #include "tools/converter/parser/parser_utils.h" #include "include/model.h" #include "base/base.h" -#include "tools/converter/converter_flags.h" namespace mindspore::lite::quant { struct InferenceParam { @@ -44,19 +43,21 @@ class ParameterOptimizer { ~ParameterOptimizer() = default; - int GridSearchForScale(const FuncGraphPtr &func_graph, converter::Flags *flags, double *init_scale); + int GridSearchForScale(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m, + double *init_scale); private: MinMax GetFineTuneRange(std::vector *candidate_scales); - int CloneFuncGraph(const FuncGraphPtr &func_graph, converter::Flags *flags, FuncGraphPtr *func_graph_bak); + int CloneFuncGraph(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m, + FuncGraphPtr *func_graph_bak); - int WeightQuantModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags, + int WeightQuantModelInference(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m, std::shared_ptr origin_model, int origin_model_size, - const InferenceParam ¶m, double *init_scale, std::vector *candidate_scales, - bool is_run_all); + const InferenceParam &infer_param, double *init_scale, + std::vector *candidate_scales, bool is_run_all); - int OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags, + int OriginModelInference(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m, std::shared_ptr origin_model, int *origin_model_size); }; } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quantization_optimizer.cc b/mindspore/lite/tools/converter/quantizer/quantization_optimizer.cc index 92ba7d24c92..b4bd4d88e56 100644 --- a/mindspore/lite/tools/converter/quantizer/quantization_optimizer.cc +++ b/mindspore/lite/tools/converter/quantizer/quantization_optimizer.cc @@ -59,8 +59,8 @@ void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set *all_f } } -int DoFullQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) { - auto quantizer = std::make_unique(*config); +int DoFullQuant(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { + auto quantizer = std::make_unique(param); if (quantizer == nullptr) { MS_LOG(ERROR) << "New FullQuantQuantizer failed"; return RET_ERROR; @@ -73,16 +73,16 @@ int DoFullQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) { return RET_OK; } -int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) { - double init_scale = config->mixedBitWeightQuantParam.init_scale; - if (config->commonQuantParam.bit_num == 0 && config->mixedBitWeightQuantParam.auto_tune) { +int DoWeightQuant(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { + double init_scale = param->mixedBitWeightQuantParam.init_scale; + if (param->commonQuantParam.bit_num == 0 && param->mixedBitWeightQuantParam.auto_tune) { ParameterOptimizer optimizer; - auto status = optimizer.GridSearchForScale(old_graph, const_cast(config), &init_scale); + auto status = optimizer.GridSearchForScale(old_graph, param, &init_scale); if (status != RET_OK) { MS_LOG(ERROR) << "Grid search with scale failed."; return status; } - auto quantizer = std::make_unique(*config); + auto quantizer = std::make_unique(param); if (quantizer == nullptr) { MS_LOG(ERROR) << "New WeightQuantizer failed"; return RET_ERROR; @@ -93,7 +93,7 @@ int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) return RET_ERROR; } } else { - auto quantizer = std::make_unique(*config); + auto quantizer = std::make_unique(param); if (quantizer == nullptr) { MS_LOG(ERROR) << "New WeightQuantizer failed"; return RET_ERROR; @@ -107,8 +107,8 @@ int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) return RET_OK; } -int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) { - auto quantizer = std::make_unique(*config); +int DoDynamicQuant(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { + auto quantizer = std::make_unique(param); if (quantizer == nullptr) { MS_LOG(ERROR) << "New DynamicQuantizer failed"; return RET_ERROR; @@ -121,7 +121,7 @@ int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config return RET_OK; } -lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const converter::Flags &flags) { +lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m) { auto meta_graph = Export(func_graph, true, true); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta_graph failed"; @@ -131,7 +131,7 @@ lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const converter: // transform GraphDefTransform fb_transform; fb_transform.SetGraphDef(meta_graph); - auto status = fb_transform.Transform(flags); + auto status = fb_transform.Transform(param); if (status != RET_OK) { MS_LOG(ERROR) << "FBTransform model failed"; delete meta_graph; @@ -152,11 +152,11 @@ lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const converter: return static_cast(LiteModel::Import((const char *)content, size)); } -int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config, +int DoQuantDebug(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m, const std::shared_ptr &origin_model, mindspore::lite::LiteModel *origin_lite_model) { auto quant_model = std::make_shared(); CHECK_NULL_RETURN(quant_model); - auto ret = BuildModelByFuncGraph(quant_model, old_graph, *config); + auto ret = BuildModelByFuncGraph(quant_model, old_graph, param); if (ret != kSuccess) { MS_LOG(ERROR) << "Build model failed"; return RET_ERROR; @@ -165,12 +165,12 @@ int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config, FetchOpParameterFromFuncGraph(old_graph, &op_parameters); DebugInfoManager manager; - auto quant_lite_model = ParseLiteModel(old_graph, *config); + auto quant_lite_model = ParseLiteModel(old_graph, param); if (quant_lite_model == nullptr) { MS_LOG(ERROR) << "Parse lite model failed"; return RET_ERROR; } - auto status = manager.CompareOriginWithQuant(origin_model, quant_model, op_parameters, *config, *origin_lite_model, + auto status = manager.CompareOriginWithQuant(origin_model, quant_model, op_parameters, param, *origin_lite_model, *quant_lite_model); auto free_buffer = [&] { for (auto parameter : op_parameters) { @@ -190,17 +190,17 @@ int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config, return RET_OK; } -int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) { - CHECK_NULL_RETURN(config); - if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) { +int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr ¶m) { + CHECK_NULL_RETURN(param); + if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) { return RET_OK; } int status; bool per_layer = - config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL && !config->fullQuantParam.per_channel; + param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL && !param->fullQuantParam.per_channel; if (per_layer) { - CLEStrategy cle_strategy(old_graph, *config); + CLEStrategy cle_strategy(old_graph); status = cle_strategy.Run(); if (status != RET_OK) { MS_LOG(ERROR) << "do pre process failed!"; @@ -210,43 +210,44 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags std::shared_ptr origin; lite::LiteModel *origin_lite_model = nullptr; - if (config->commonQuantParam.is_debug) { // Bak fp32 model for debug - converter::Flags new_flag = *config; - new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE; + if (param->commonQuantParam.is_debug) { // Bak fp32 model for debug + auto quant_type = param->commonQuantParam.quant_type; + param->commonQuantParam.quant_type = schema::QuantType_QUANT_NONE; origin = std::make_shared(); CHECK_NULL_RETURN(origin); - auto ret = BuildModelByFuncGraph(origin, old_graph, new_flag); + auto ret = BuildModelByFuncGraph(origin, old_graph, param); + param->commonQuantParam.quant_type = quant_type; if (ret != kSuccess) { MS_LOG(ERROR) << "Build model failed"; return RET_ERROR; } - origin_lite_model = ParseLiteModel(old_graph, *config); + origin_lite_model = ParseLiteModel(old_graph, param); if (origin_lite_model == nullptr) { MS_LOG(ERROR) << "Parse lite model failed."; return RET_ERROR; } } - if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) { // Full Quantization - status = DoFullQuant(old_graph, config); + if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) { // Full Quantization + status = DoFullQuant(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Do full quant failed."; return status; } - } else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { // Weight Quantization - status = DoWeightQuant(old_graph, config); + } else if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { // Weight Quantization + status = DoWeightQuant(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Do weight quant failed."; return status; } - } else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { // Dynamic Quantization - status = DoDynamicQuant(old_graph, config); + } else if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { // Dynamic Quantization + status = DoDynamicQuant(old_graph, param); if (status != RET_OK) { MS_LOG(ERROR) << "Do dynamic quant failed."; return status; } } - if (config->commonQuantParam.is_debug) { - status = DoQuantDebug(old_graph, config, origin, origin_lite_model); + if (param->commonQuantParam.is_debug) { + status = DoQuantDebug(old_graph, param, origin, origin_lite_model); if (status != RET_OK) { MS_LOG(ERROR) << "Do quant debug failed."; return status; @@ -260,7 +261,7 @@ int QuantizationOptimizer::Run(const mindspore::FuncGraphPtr &func_graph) { GetFuncGraphs(func_graph, &all_func_graphs); // Support for multi-subgraph models for (auto &item : all_func_graphs) { - auto status = DoSingleGraphQuantize(item, flags_); + auto status = DoSingleGraphQuantize(item, param_); if (status != RET_OK) { MS_LOG(ERROR) << "Do Quantize failed."; return status; diff --git a/mindspore/lite/tools/converter/quantizer/quantization_optimizer.h b/mindspore/lite/tools/converter/quantizer/quantization_optimizer.h index 8f6a2e8da8c..4d7dd83bdd3 100644 --- a/mindspore/lite/tools/converter/quantizer/quantization_optimizer.h +++ b/mindspore/lite/tools/converter/quantizer/quantization_optimizer.h @@ -18,19 +18,20 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZATION_OPTIMIZER_H #include #include +#include #include #include "backend/common/optimizer/pass.h" -#include "tools/converter/converter_flags.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore::lite::quant { class QuantizationOptimizer { public: - explicit QuantizationOptimizer(converter::Flags *flags) : flags_(flags) {} + explicit QuantizationOptimizer(const std::shared_ptr ¶m) : param_(param) {} ~QuantizationOptimizer() = default; int Run(const FuncGraphPtr &func_graph); private: - converter::Flags *flags_; + const std::shared_ptr ¶m_; }; } // namespace mindspore::lite::quant #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZATION_OPTIMIZER_H diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index cdaf09be5c1..cc80afabd0e 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -255,13 +255,13 @@ std::string NodePrimitiveType(const CNodePtr &cnode) { } Status BuildModelByFuncGraph(const std::shared_ptr &model, const FuncGraphPtr &func_graph, - const converter::Flags &flags) { + const std::shared_ptr ¶m) { int size = 0; - return BuildModelByFuncGraph(model, func_graph, flags, &size); + return BuildModelByFuncGraph(model, func_graph, param, &size); } Status BuildModelByFuncGraph(const std::shared_ptr &model, const FuncGraphPtr &func_graph, - const converter::Flags &flags, int *size) { + const std::shared_ptr ¶m, int *size) { auto meta_graph = Export(func_graph, true, true); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta_graph failed"; @@ -271,7 +271,7 @@ Status BuildModelByFuncGraph(const std::shared_ptr &model, con // transform GraphDefTransform fb_transform; fb_transform.SetGraphDef(meta_graph); - auto status = fb_transform.Transform(flags); + auto status = fb_transform.Transform(param); if (status != RET_OK) { MS_LOG(ERROR) << "FBTransform model failed"; delete meta_graph; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index ca476e0dd73..a28661d2ea2 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -18,9 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_ #ifndef _MSC_VER - #include - #endif #include @@ -54,6 +52,7 @@ #include "src/common/file_utils.h" #include "src/common/quant_utils.h" #include "include/api/model.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore::lite::quant { enum WeightQuantType { @@ -188,10 +187,10 @@ int FixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPt std::string NodePrimitiveType(const CNodePtr &cnode); Status BuildModelByFuncGraph(const std::shared_ptr &model, const FuncGraphPtr &func_graph, - const converter::Flags &flags); + const std::shared_ptr ¶m); Status BuildModelByFuncGraph(const std::shared_ptr &model, const FuncGraphPtr &func_graph, - const converter::Flags &flags, int *size); + const std::shared_ptr ¶m, int *size); mindspore::lite::Tensor *MSTensorToLiteTensor(const mindspore::MSTensor &tensor); diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h index 5f516aa0319..63554aa833c 100644 --- a/mindspore/lite/tools/converter/quantizer/quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -24,20 +24,20 @@ #include "ir/func_graph.h" #include "ir/anf.h" #include "base/base.h" -#include "tools/converter/converter_flags.h" #include "tools/converter/quant_param_holder.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore::lite::quant { class Quantizer { public: - explicit Quantizer(const converter::Flags &config) : flags_(config) {} + explicit Quantizer(const std::shared_ptr ¶m) : param_(param) {} virtual ~Quantizer() = default; virtual int DoQuantize(FuncGraphPtr func_graph) = 0; protected: - converter::Flags flags_; + const std::shared_ptr param_; }; } // namespace mindspore::lite::quant #endif diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 769cfc24721..072ac82c82e 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -37,7 +37,7 @@ int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph, continue; } auto op_name = cnode->fullname_with_scope(); - if (flags_.commonQuantParam.skip_quant_node.find(op_name) != flags_.commonQuantParam.skip_quant_node.end()) { + if (param_->commonQuantParam.skip_quant_node.find(op_name) != param_->commonQuantParam.skip_quant_node.end()) { MS_LOG(INFO) << op_name << " is skip dynamic quant."; continue; } @@ -95,9 +95,9 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight"; continue; } - auto quant_strategy = std::make_unique(flags_.commonQuantParam.min_quant_weight_size, - flags_.commonQuantParam.min_quant_weight_channel, - flags_.commonQuantParam.skip_quant_node); + auto quant_strategy = std::make_unique(param_->commonQuantParam.min_quant_weight_size, + param_->commonQuantParam.min_quant_weight_channel, + param_->commonQuantParam.skip_quant_node); CHECK_NULL_RETURN(quant_strategy); int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape())); if (!quant_strategy->CanTensorQuantized(cnode, input, preferred_dim)) { @@ -114,15 +114,15 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN } auto status = RET_ERROR; if (is_mixed_bit_) { - status = MixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, + status = MixedBitQuantFilter(parameter, tensor_info, primitive, param_->commonQuantParam.quant_type, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1, preferred_dim, symmetric); } else if (type_id_ == kNumberTypeInt8) { - status = - FixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max, q_min, - bit_num_, tmp_weight_quant_type, type_id_, idx - 1, preferred_dim, symmetric); + status = FixedBitQuantFilter(parameter, tensor_info, primitive, param_->commonQuantParam.quant_type, + q_max, q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, + preferred_dim, symmetric); } else if (type_id_ == kNumberTypeInt16) { - status = FixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, + status = FixedBitQuantFilter(parameter, tensor_info, primitive, param_->commonQuantParam.quant_type, q_max, q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, preferred_dim, symmetric); } diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index b393b5e1e7b..8aebd88e373 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -41,12 +41,12 @@ namespace mindspore::lite::quant { class WeightQuantizer : public Quantizer { public: - explicit WeightQuantizer(const converter::Flags &flags) : Quantizer(flags) { - bit_num_ = flags.commonQuantParam.bit_num; + explicit WeightQuantizer(const std::shared_ptr ¶m) : Quantizer(param) { + bit_num_ = param_->commonQuantParam.bit_num; if (bit_num_ == 0) { type_id_ = kNumberTypeInt16; is_mixed_bit_ = true; - mixed_bit_init_scale_ = flags.mixedBitWeightQuantParam.init_scale; + mixed_bit_init_scale_ = param_->mixedBitWeightQuantParam.init_scale; } // parse param for fixed bit quant. if (!is_mixed_bit_) { diff --git a/mindspore/lite/tools/lite_exporter/fetch_content.h b/mindspore/lite/tools/lite_exporter/fetch_content.h index 10aeba10f1c..086746a49a8 100644 --- a/mindspore/lite/tools/lite_exporter/fetch_content.h +++ b/mindspore/lite/tools/lite_exporter/fetch_content.h @@ -23,8 +23,8 @@ #include "ir/primitive.h" #include "ir/func_graph.h" #include "src/common/utils.h" -#include "tools/converter/converter_flags.h" #include "nnacl/op_base.h" +#include "include/registry/converter_context.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/mindir_exporter/mindir_serializer.cc b/mindspore/lite/tools/mindir_exporter/mindir_serializer.cc index a906cefa204..30347940f2b 100644 --- a/mindspore/lite/tools/mindir_exporter/mindir_serializer.cc +++ b/mindspore/lite/tools/mindir_exporter/mindir_serializer.cc @@ -93,12 +93,12 @@ int MindIRSerializer::RemoveQuantParameterHolder(FuncGraphPtr func_graph) { return RET_OK; } -int MindIRSerializer::Save(const std::unique_ptr &flag, const FuncGraphPtr &func_graph) { +int MindIRSerializer::Save(const std::shared_ptr ¶m, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "func_graph is nullptr."; return RET_NULL_PTR; } - auto output_file = flag->outputFile; + auto output_file = param->output_file; auto ret = ParserPath(output_file); if (ret != RET_OK) { MS_LOG(ERROR) << "parse path failed."; @@ -403,13 +403,13 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st } #endif -int MindIRSerialize(const std::unique_ptr &flag, const FuncGraphPtr &func_graph) { +int MindIRSerialize(const std::shared_ptr ¶m, const FuncGraphPtr &func_graph) { #ifndef ENABLE_CLOUD_AND_LITE - if (!flag->export_mindir) { + if (!param->export_mindir) { return RET_OK; } mindspore::lite::MindIRSerializer serializer; - return serializer.Save(flag, func_graph); + return serializer.Save(param, func_graph); #else MS_LOG(INFO) << "No need to serialize mindir when load model online."; return RET_OK; diff --git a/mindspore/lite/tools/mindir_exporter/mindir_serializer.h b/mindspore/lite/tools/mindir_exporter/mindir_serializer.h index a067e1a41f5..70e7fd0e910 100644 --- a/mindspore/lite/tools/mindir_exporter/mindir_serializer.h +++ b/mindspore/lite/tools/mindir_exporter/mindir_serializer.h @@ -24,9 +24,9 @@ #include #include "mindspore/core/ir/func_graph.h" #include "tools/converter/converter_context.h" -#include "tools/converter/converter_flags.h" #include "proto/mind_ir.pb.h" #include "mindspore/core/utils/system/env.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore::lite { #ifndef ENABLE_CLOUD_AND_LITE @@ -40,7 +40,7 @@ class MindIRSerializer { data_fs_ = nullptr; } } - int Save(const std::unique_ptr &flag, const FuncGraphPtr &func_graph); + int Save(const std::shared_ptr ¶m, const FuncGraphPtr &func_graph); private: int ParserPath(const std::string &output_path); @@ -74,6 +74,6 @@ class MindIRSerializer { }; #endif // export func_graph -int MindIRSerialize(const std::unique_ptr &flag, const FuncGraphPtr &func_graph); +int MindIRSerialize(const std::shared_ptr ¶m, const FuncGraphPtr &func_graph); } // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h index 8c9d3d37dc9..3ab87948196 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h @@ -19,7 +19,7 @@ #include #include "backend/common/optimizer/optimizer.h" -#include "tools/converter/converter_flags.h" +#include "include/registry/converter_context.h" using mindspore::converter::FmkType; namespace mindspore::opt { diff --git a/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.cc index 70c78d0cfd1..d37eaee42d4 100644 --- a/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.cc @@ -35,8 +35,8 @@ const BaseRef MatMulActivationFusion::DefinePattern() const { const AnfNodePtr MatMulActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { // Int8 MatMul Kernel dont support matmul+activation - if (ctx_.commonQuantParam.quant_type == schema::QuantType_QUANT_ALL || - ctx_.commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { + if (param_->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL || + param_->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { return nullptr; } if (func_graph == nullptr || node == nullptr) { diff --git a/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.h b/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.h index 9d2cb08f61b..cf0fc30dd29 100644 --- a/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/matmul_activation_fusion.h @@ -17,19 +17,18 @@ #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ACTIVATION_FUSION_H_ #define MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ACTIVATION_FUSION_H_ +#include #include #include "backend/common/optimizer/optimizer.h" #include "tools/converter/converter_context.h" -#include "tools/converter/converter_flags.h" +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore { namespace opt { class MatMulActivationFusion : public PatternProcessPass { public: - explicit MatMulActivationFusion(const converter::Flags &ctx, bool multigraph = true) - : PatternProcessPass("MatMulActivationFusion", multigraph) { - ctx_ = ctx; - } + explicit MatMulActivationFusion(const std::shared_ptr ¶m, bool multigraph = true) + : PatternProcessPass("MatMulActivationFusion", multigraph), param_(param) {} ~MatMulActivationFusion() = default; private: @@ -37,7 +36,7 @@ class MatMulActivationFusion : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; private: - converter::Flags ctx_; + const std::shared_ptr param_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h index fff2fb36349..7e93f9a84be 100644 --- a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h +++ b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h @@ -17,11 +17,8 @@ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_ #include -#include "tools/converter/converter_flags.h" #include "backend/common/optimizer/pass.h" -using mindspore::converter::FmkType; -using mindspore::schema::QuantType; namespace mindspore::opt { class ClipConvertActivationPass : public Pass { public: diff --git a/mindspore/lite/tools/optimizer/graph/dump_graph.h b/mindspore/lite/tools/optimizer/graph/dump_graph.h index c9657087404..9e6b6f974fb 100644 --- a/mindspore/lite/tools/optimizer/graph/dump_graph.h +++ b/mindspore/lite/tools/optimizer/graph/dump_graph.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_ +#include #include "backend/common/optimizer/pass.h" #include "tools/converter/export_model.h" #include "include/registry/pass_base.h" @@ -25,11 +26,11 @@ namespace mindspore { namespace opt { class DumpGraph : public registry::PassBase, public Pass { public: - explicit DumpGraph(const converter::Flags *flags = nullptr) : Pass("DumpGraph"), flags_(flags) {} + explicit DumpGraph(const std::shared_ptr ¶m) : Pass("DumpGraph"), param_(param) {} ~DumpGraph() = default; bool Run(const FuncGraphPtr &graph) override { MS_CHECK_TRUE_MSG(graph != nullptr, false, "funcGraph is a nullptr."); - if (lite::ExportModel(graph, flags_) != lite::RET_OK) { + if (lite::ExportModel(graph, param_) != lite::RET_OK) { MS_LOG(ERROR) << "dump graph failed."; return false; } @@ -45,7 +46,7 @@ class DumpGraph : public registry::PassBase, public Pass { } private: - const converter::Flags *flags_{nullptr}; + const std::shared_ptr param_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h index cc3d8a49dcc..86ae477375e 100644 --- a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h +++ b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h @@ -20,10 +20,10 @@ #include #include #include -#include "tools/converter/converter_flags.h" #include "backend/common/optimizer/pass.h" #include "include/errorcode.h" #include "mindspore/core/ir/manager.h" +#include "include/registry/converter_context.h" using mindspore::converter::FmkType; namespace mindspore::opt { diff --git a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h index 10b87e4d11c..6b661da3eae 100644 --- a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h +++ b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ #include #include "backend/common/optimizer/pass.h" -#include "tools/converter/converter_flags.h" +#include "include/registry/converter_context.h" using mindspore::converter::FmkType; namespace mindspore::opt {