From 1bc9620f06399c4ad78762827aea7a23461e1689 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 17 Aug 2021 15:22:21 +0800 Subject: [PATCH] dump graph function adjustment --- .jenkins/check/config/filter_cppcheck.txt | 1 + cmake/package_lite.cmake | 4 -- .../lite/include/registry/pass_registry.h | 9 ++- mindspore/lite/test/CMakeLists.txt | 1 - .../converter/registry/pass_registry_test.cc | 55 ++++++------------- .../lite/tools/converter/anf_transform.cc | 2 + mindspore/lite/tools/converter/converter.cc | 3 - mindspore/lite/tools/converter/dump_graph.cc | 35 ------------ mindspore/lite/tools/converter/dump_graph.h | 31 ----------- .../lite/tools/converter/dump_graph_init.h | 28 ---------- .../lite/tools/converter/export_model.cc | 24 +++----- mindspore/lite/tools/converter/export_model.h | 3 +- .../tools/converter/registry/CMakeLists.txt | 3 +- .../tools/converter/registry/pass_registry.cc | 4 +- .../lite/tools/optimizer/graph/dump_graph.h | 43 +++++++++++++++ 15 files changed, 79 insertions(+), 167 deletions(-) delete mode 100644 mindspore/lite/tools/converter/dump_graph.cc delete mode 100644 mindspore/lite/tools/converter/dump_graph.h delete mode 100644 mindspore/lite/tools/converter/dump_graph_init.h create mode 100644 mindspore/lite/tools/optimizer/graph/dump_graph.h diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 7c3bb48bbf9..f88ee6c4b53 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -35,6 +35,7 @@ # Lite "mindspore/mindspore/core/mindrt/src/thread/" "useStlAlgorithm" "mindspore/mindspore/lite/test/" "syntaxError" +"mindspore/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc" "unknownMacro" "mindspore/mindspore/lite/src/ops/unsqueeze.cc" "useStlAlgorithm" "mindspore/mindspore/lite/tools/common/flag_parser.cc" "useStlAlgorithm" "mindspore/mindspore/lite/tools/common/tensor_util.cc" "useStlAlgorithm" diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 4b6d97cafd4..fff35b85b26 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -330,8 +330,6 @@ elseif(WIN32) DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) - install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h - DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema @@ -462,8 +460,6 @@ else() DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) - install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h - DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema diff --git a/mindspore/lite/include/registry/pass_registry.h b/mindspore/lite/include/registry/pass_registry.h index f6b0ed1f45d..3ed83e95e02 100644 --- a/mindspore/lite/include/registry/pass_registry.h +++ b/mindspore/lite/include/registry/pass_registry.h @@ -48,8 +48,8 @@ class MS_API PassRegistry { /// \brief Constructor of PassRegistry to assign which passes are required for external extension. /// /// \param[in] position Define the place where assigned passes will run. - /// \param[in] assigned Define the names of the passes. - PassRegistry(PassPosition position, const std::vector &assigned); + /// \param[in] names Define the names of the passes. + PassRegistry(PassPosition position, const std::vector &names); /// \brief Destructor of PassRegistrar. ~PassRegistry() = default; @@ -79,9 +79,8 @@ class MS_API PassRegistry { /// \brief Defined assigning macro to assign Passes, which called by user directly. /// /// \param[in] position Define the place where assigned passes will run. -/// \param[in] assigned Define the names of the passes. -#define REG_SCHEDULED_PASS(position, assigned) \ - static mindspore::registry::PassRegistry g_##position(position, assigned); +/// \param[in] names Define the names of the passes. +#define REG_SCHEDULED_PASS(position, names) static mindspore::registry::PassRegistry g_##position(position, names); } // namespace registry } // namespace mindspore diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index d86004e7557..e5a88a12f5b 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -201,7 +201,6 @@ if(MSLITE_ENABLE_CONVERTER) ${LITE_DIR}/tools/converter/converter_flags.cc ${LITE_DIR}/tools/converter/converter.cc ${LITE_DIR}/tools/converter/export_model.cc - ${LITE_DIR}/tools/converter/dump_graph.cc ${LITE_DIR}/tools/converter/optimizer_manager.cc ${LITE_DIR}/tools/converter/parser/parser_utils.cc ${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc diff --git a/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc index f0211314221..082218ac709 100644 --- a/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc +++ b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc @@ -51,7 +51,7 @@ namespace opt { // fuse add and add to addn. class Test1Fusion : public Pass { public: - Test1Fusion() : Pass("test1_fusion") {} + Test1Fusion() : Pass("Test1Fusion") {} bool CanFusion(const CNodePtr &cnode) { if (cnode == nullptr) { return false; @@ -94,7 +94,7 @@ class Test1Fusion : public Pass { if (func_graph == nullptr) { return false; } - auto manager = func_graph->manager(); + auto manager = Manage(func_graph); if (manager == nullptr) { return false; } @@ -132,9 +132,9 @@ class Test1Fusion : public Pass { // convert addn to custom op class Test2Fusion : public Pass { public: - Test2Fusion() : Pass("test2_fusion") {} + Test2Fusion() : Pass("Test2Fusion") {} AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) { - if (cnode == nullptr) { + if (func_graph == nullptr || cnode == nullptr) { return nullptr; } auto primc = std::make_shared(); @@ -143,7 +143,7 @@ class Test2Fusion : public Pass { } primc->set_type("Custom_AddN"); std::map> custom_attrs; - std::string input_num = std::to_string(3); + std::string input_num = std::to_string(cnode->size() - 1); std::vector input_num_attr(input_num.begin(), input_num.end()); custom_attrs["input_num"] = input_num_attr; std::string op_kind = "custom op"; @@ -162,7 +162,7 @@ class Test2Fusion : public Pass { if (func_graph == nullptr) { return false; } - auto manager = func_graph->manager(); + auto manager = Manage(func_graph); if (manager == nullptr) { return false; } @@ -185,45 +185,22 @@ class Test2Fusion : public Pass { } }; -class TestFusion : public Pass { - public: - TestFusion() : Pass("test_fusion") {} - bool Run(const FuncGraphPtr &func_graph) override { - if (func_graph == nullptr) { - return false; - } - auto manager = Manage(func_graph, true); - if (manager == nullptr) { - return false; - } - auto test1_fusion = std::make_shared(); - if (!test1_fusion->Run(func_graph)) { - return false; - } - auto test2_fusion = std::make_shared(); - if (!test2_fusion->Run(func_graph)) { - return false; - } - return true; - } -}; - -REG_PASS(TestFusion, TestFusion) -REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"}) +REG_PASS(Test1Fusion, Test1Fusion) +REG_PASS(Test2Fusion, Test2Fusion) +const std::vector schedule = {"Test1Fusion", "Test2Fusion"}; +REG_SCHEDULED_PASS(POSITION_BEGIN, schedule) } // namespace opt TEST_F(PassRegistryTest, TestRegistry) { auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN); - ASSERT_EQ(schedule_task.size(), 1); + ASSERT_EQ(schedule_task.size(), 2); auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task); - ASSERT_EQ(passes.size(), 1); - auto begin_pass = passes.front(); - ASSERT_NE(begin_pass, nullptr); - auto begin_pass_test = std::dynamic_pointer_cast(begin_pass); - ASSERT_NE(begin_pass_test, nullptr); + ASSERT_EQ(passes.size(), 2); ASSERT_NE(func_graph_, nullptr); - auto res = begin_pass_test->Run(func_graph_); - ASSERT_EQ(res, true); + for (auto &pass : passes) { + auto ret = pass->Run(func_graph_); + ASSERT_EQ(ret, true); + } auto cnode_list = func_graph_->GetOrderedCnodes(); ASSERT_EQ(cnode_list.size(), 2); bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom); diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index e196743a0ef..8f88ff60f0c 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -59,6 +59,7 @@ #include "tools/optimizer/graph/split_one_pass.h" #include "tools/optimizer/graph/decrease_transpose_algo.h" #include "tools/optimizer/graph/specify_graph_input_format.h" +#include "tools/optimizer/graph/dump_graph.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -401,6 +402,7 @@ void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) { registry::PassRegistry("ToNHWCFormat", std::make_shared(fmk, is_train)); registry::PassRegistry("SpecifyGraphInputFormat", std::make_shared(config->graphInputFormat)); + registry::PassRegistry("DumpGraph", std::make_shared(config)); } FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index f977f367270..b804fa61a50 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -26,7 +26,6 @@ #include "src/train/train_populate_parameter.h" #include "include/registry/model_parser_registry.h" #include "src/common/dynamic_library_loader.h" -#include "tools/converter/export_model.h" #include "tools/converter/parser/parser_utils.h" #include "tools/converter/import/mindspore_importer.h" namespace mindspore { @@ -149,8 +148,6 @@ int RunConverter(int argc, const char **argv) { } return status; } - // Init dump graph func - ExportModelInit(flags.get()); // Load graph MS_LOG(DEBUG) << "start reading model file"; Converter cvt; diff --git a/mindspore/lite/tools/converter/dump_graph.cc b/mindspore/lite/tools/converter/dump_graph.cc deleted file mode 100644 index 1f71452590b..00000000000 --- a/mindspore/lite/tools/converter/dump_graph.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/dump_graph.h" -#include "tools/converter/dump_graph_init.h" -#include "include/errorcode.h" -#include "src/common/log_adapter.h" - -namespace mindspore { -namespace lite { -static GraphDumpFunc graph_dump_interface = nullptr; -void InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func) { graph_dump_interface = graph_dump_func; } - -int DumpGraph(const FuncGraphPtr &func_graph) { - if (graph_dump_interface == nullptr) { - MS_LOG(ERROR) << "graph_dump_interface is nullptr, which is not init."; - return RET_ERROR; - } - return graph_dump_interface(func_graph); -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/dump_graph.h b/mindspore/lite/tools/converter/dump_graph.h deleted file mode 100644 index 98ee8bdf494..00000000000 --- a/mindspore/lite/tools/converter/dump_graph.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_ -#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_ - -#include -#include "include/lite_utils.h" - -namespace mindspore { -class FuncGraph; -using FuncGraphPtr = std::shared_ptr; -namespace lite { -using GraphDumpFunc = std::function; -int MS_API DumpGraph(const FuncGraphPtr &func_graph); -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_ diff --git a/mindspore/lite/tools/converter/dump_graph_init.h b/mindspore/lite/tools/converter/dump_graph_init.h deleted file mode 100644 index 84ac21719ff..00000000000 --- a/mindspore/lite/tools/converter/dump_graph_init.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H - -#include "tools/converter/dump_graph.h" - -namespace mindspore { -namespace lite { -void MS_API InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func); -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H diff --git a/mindspore/lite/tools/converter/export_model.cc b/mindspore/lite/tools/converter/export_model.cc index 298287fd3c7..27e481d8588 100644 --- a/mindspore/lite/tools/converter/export_model.cc +++ b/mindspore/lite/tools/converter/export_model.cc @@ -26,7 +26,6 @@ #include "ir/func_graph.h" #include "tools/anf_exporter/anf_exporter.h" #include "tools/converter/graphdef_transform.h" -#include "tools/converter/dump_graph_init.h" #include "tools/converter/optimizer_manager.h" #include "tools/optimizer/graph/control_flow_pass.h" @@ -34,9 +33,6 @@ namespace mindspore { namespace lite { namespace { using NodesMap = std::map>; -} -static converter::Flags *flags = nullptr; - void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map, NodesMap *mirror_map) { MS_ASSERT(origin != nullptr && mirror != nullptr); @@ -53,7 +49,8 @@ void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, No } } -AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph) { +AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph, + const converter::Flags *flags) { MS_ASSERT(cnode != nullptr && mirror_graph != nullptr); if (index >= cnode->size()) { MS_LOG(ERROR) << "input index out of range."; @@ -131,7 +128,7 @@ PrimitivePtr ClonePrimitive(const CNodePtr &cnode) { return prim; } -FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) { +FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags) { MS_ASSERT(graph != nullptr); auto mirror_graph = std::make_shared(); mirror_graph->set_attrs(graph->attrs()); @@ -157,10 +154,10 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) { if (mirror_input == nullptr) { if (IsValueNode(origin_input)) { auto sub_func_graph = GetValueNode(origin_input); - auto mirror_sub_graph = CloneFuncGraph(sub_func_graph); + auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, flags); mirror_input = NewValueNode(mirror_sub_graph); } else { - mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph); + mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, flags); } if (mirror_input == nullptr) { MS_LOG(ERROR) << "node input cannot be found."; @@ -184,10 +181,11 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) { } return mirror_graph; } +} // namespace -STATUS ExportModel(const FuncGraphPtr &graph) { +STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) { MS_ASSERT(graph != nullptr && flags != nullptr); - auto mirror_graph = CloneFuncGraph(graph); + auto mirror_graph = CloneFuncGraph(graph, flags); if (mirror_graph == nullptr) { MS_LOG(ERROR) << "Clone funcGraph failed."; return RET_ERROR; @@ -233,11 +231,5 @@ STATUS ExportModel(const FuncGraphPtr &graph) { delete meta_graph; return status; } - -void ExportModelInit(converter::Flags *flag) { - MS_ASSERT(flag != nullptr); - flags = flag; - InitDumpGraphFunc(ExportModel); -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/export_model.h b/mindspore/lite/tools/converter/export_model.h index 7268ebe5997..03ab259522b 100644 --- a/mindspore/lite/tools/converter/export_model.h +++ b/mindspore/lite/tools/converter/export_model.h @@ -18,10 +18,11 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H #include "tools/converter/converter_flags.h" +#include "ir/func_graph.h" namespace mindspore { namespace lite { -void ExportModelInit(converter::Flags *flag); +STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/registry/CMakeLists.txt b/mindspore/lite/tools/converter/registry/CMakeLists.txt index ca6c0ddb445..2aa1b1e583c 100644 --- a/mindspore/lite/tools/converter/registry/CMakeLists.txt +++ b/mindspore/lite/tools/converter/registry/CMakeLists.txt @@ -16,8 +16,7 @@ set(REG_SRC ${CONVERT_REG_SRC} ${CORE_DIR}/utils/log_adapter.cc ${CORE_DIR}/utils/status.cc ${CORE_DIR}/gvar/log_adapter_common.cc - ${CORE_DIR}/gvar/logging_level.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../dump_graph.cc) + ${CORE_DIR}/gvar/logging_level.cc) set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(mslite_converter_plugin SHARED ${REG_SRC}) target_link_libraries(mslite_converter_plugin mindspore::glog) diff --git a/mindspore/lite/tools/converter/registry/pass_registry.cc b/mindspore/lite/tools/converter/registry/pass_registry.cc index f86c30c384a..6e2c0dc6ad6 100644 --- a/mindspore/lite/tools/converter/registry/pass_registry.cc +++ b/mindspore/lite/tools/converter/registry/pass_registry.cc @@ -39,9 +39,9 @@ void RegPass(const std::string &pass_name, const opt::PassPtr &pass) { PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); } -PassRegistry::PassRegistry(PassPosition position, const std::vector &assigned) { +PassRegistry::PassRegistry(PassPosition position, const std::vector &names) { std::unique_lock lock(pass_mutex); - external_assigned_passes[position] = assigned; + external_assigned_passes[position] = names; } std::vector PassRegistry::GetOuterScheduleTask(PassPosition position) { diff --git a/mindspore/lite/tools/optimizer/graph/dump_graph.h b/mindspore/lite/tools/optimizer/graph/dump_graph.h new file mode 100644 index 00000000000..ad1d50828dd --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/dump_graph.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_ + +#include "backend/optimizer/common/pass.h" +#include "tools/converter/export_model.h" + +namespace mindspore { +namespace opt { +class DumpGraph : public Pass { + public: + explicit DumpGraph(const converter::Flags *flags = nullptr) : Pass("DumpGraph"), flags_(flags) {} + ~DumpGraph() = default; + bool Run(const FuncGraphPtr &graph) override { + MS_ASSERT(graph != nullptr); + if (lite::ExportModel(graph, flags_) != lite::RET_OK) { + MS_LOG(ERROR) << "dump graph failed."; + return false; + } + return true; + } + + private: + const converter::Flags *flags_{nullptr}; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_