From fc3b3306c6ca22a5d95ba9e50d62a3489a67e69d Mon Sep 17 00:00:00 2001 From: xuanyue Date: Thu, 10 Jun 2021 13:35:59 +0800 Subject: [PATCH] add test ut and rename convert plugin lib --- cmake/package_lite.cmake | 4 +- .../train_lenet/model/prepare_model.sh | 2 +- .../transfer_learning/model/prepare_model.sh | 2 +- mindspore/lite/test/CMakeLists.txt | 5 +- mindspore/lite/test/runtest.sh | 2 + .../registry/model_parser_registry_test.cc | 53 ++++ .../converter/registry/pass_registry_test.cc | 53 ++++ mindspore/lite/tools/converter/CMakeLists.txt | 3 +- mindspore/lite/tools/converter/dump_graph.cc | 35 +++ mindspore/lite/tools/converter/dump_graph.h | 31 +++ .../lite/tools/converter/dump_graph_init.h | 28 ++ .../lite/tools/converter/export_model.cc | 249 ++++++++++++++++++ mindspore/lite/tools/converter/export_model.h | 28 ++ .../tools/converter/registry/CMakeLists.txt | 11 +- 14 files changed, 495 insertions(+), 11 deletions(-) create mode 100644 mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc create mode 100644 mindspore/lite/tools/converter/dump_graph.cc create mode 100644 mindspore/lite/tools/converter/dump_graph.h create mode 100644 mindspore/lite/tools/converter/dump_graph_init.h create mode 100644 mindspore/lite/tools/converter/export_model.cc create mode 100644 mindspore/lite/tools/converter/export_model.h diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index d1af305b959..326bff736ad 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -277,7 +277,7 @@ elseif(WIN32) install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${CONVERTER_ROOT_DIR}/include/registry COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) - install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin_reg.dll + install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/../bin/libglog.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) @@ -349,7 +349,7 @@ else() COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter COMPONENT ${RUNTIME_COMPONENT_NAME}) - install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin_reg.so + install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin.so DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) diff --git a/mindspore/lite/examples/train_lenet/model/prepare_model.sh b/mindspore/lite/examples/train_lenet/model/prepare_model.sh index 527cc8b2dea..34364aa35fb 100755 --- a/mindspore/lite/examples/train_lenet/model/prepare_model.sh +++ b/mindspore/lite/examples/train_lenet/model/prepare_model.sh @@ -14,7 +14,7 @@ CONVERTER="../../../build/tools/converter/converter_lite" if [ ! -f "$CONVERTER" ]; then if ! command -v converter_lite &> /dev/null then - tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin_reg.so + tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin.so if [ -f ./converter_lite ]; then CONVERTER=./converter_lite else diff --git a/mindspore/lite/examples/transfer_learning/model/prepare_model.sh b/mindspore/lite/examples/transfer_learning/model/prepare_model.sh index b19b600f443..5a0eb011d2b 100755 --- a/mindspore/lite/examples/transfer_learning/model/prepare_model.sh +++ b/mindspore/lite/examples/transfer_learning/model/prepare_model.sh @@ -19,7 +19,7 @@ CONVERTER="../../../build/tools/converter/converter_lite" if [ ! -f "$CONVERTER" ]; then if ! command -v converter_lite &> /dev/null then - tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin_reg.so + tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin.so if [ -f ./converter_lite ]; then CONVERTER=./converter_lite else diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 2cc7b973a53..36f60854758 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -169,6 +169,7 @@ if(MSLITE_ENABLE_CONVERTER) add_definitions(-DPRIMITIVE_WRITEABLE) add_definitions(-DUSE_GLOG) file(GLOB_RECURSE TEST_CASE_TFLITE_PARSERS_SRC + ${TEST_DIR}/ut/tools/converter/registry/*.cc ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc ) set(TEST_LITE_SRC @@ -182,6 +183,8 @@ if(MSLITE_ENABLE_CONVERTER) ${LITE_DIR}/tools/converter/graphdef_transform.cc ${LITE_DIR}/tools/converter/converter_flags.cc ${LITE_DIR}/tools/converter/converter.cc + ${LITE_DIR}/tools/converter/export_model.cc + ${LITE_DIR}/tools/converter/dump_graph.cc ${LITE_DIR}/tools/converter/parser/parser_utils.cc ${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc ${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc @@ -388,7 +391,7 @@ if(MSLITE_ENABLE_CONVERTER) add_dependencies(lite-test fbs_inner_src) target_link_libraries(lite-test anf_exporter_mid - mslite_converter_plugin_reg + mslite_converter_plugin tflite_parser_mid caffe_parser_mid onnx_parser_mid diff --git a/mindspore/lite/test/runtest.sh b/mindspore/lite/test/runtest.sh index a097718c5ee..080adde0bb4 100755 --- a/mindspore/lite/test/runtest.sh +++ b/mindspore/lite/test/runtest.sh @@ -57,6 +57,8 @@ echo 'run common ut tests' ./lite-test --gtest_filter="TestFullConnectionOpenCL*" ./lite-test --gtest_filter="TestResizeOpenCL*" ./lite-test --gtest_filter="TestSwishOpenCLCI.Fp32CI" +./lite-test --gtest_filter="ModelParserRegistryTest.TestRegistry" +./lite-test --gtest_filter="PassRegistryTest.TestRegistry" # test cases specific for train if [[ $1 == train ]]; then diff --git a/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc b/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc new file mode 100644 index 00000000000..12d7f5a6f42 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/common_test.h" +#include "include/registry/model_parser_registry.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/converter_flags.h" + +using mindspore::lite::ModelRegistrar; +using mindspore::lite::converter::Flags; +namespace mindspore { +class ModelParserRegistryTest : public mindspore::CommonTest { + public: + ModelParserRegistryTest() = default; +}; + +class ModelParserTest : public lite::ModelParser { + public: + ModelParserTest() = default; +}; + +lite::ModelParser *TestModelParserCreator() { + auto *parser = new (std::nothrow) ModelParserTest(); + if (parser == nullptr) { + MS_LOG(ERROR) << "new model parser failed"; + return nullptr; + } + return parser; +} +REG_MODEL_PARSER(TEST, TestModelParserCreator); + +TEST_F(ModelParserRegistryTest, TestRegistry) { + auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST"); + ASSERT_NE(model_parser, nullptr); + Flags flags; + auto func_graph = model_parser->Parse(flags); + ASSERT_EQ(func_graph, nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc new file mode 100644 index 00000000000..6872b3143df --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "backend/optimizer/common/pass.h" +#include "include/registry/pass_registry.h" + +namespace mindspore { +class PassRegistryTest : public mindspore::CommonTest { + public: + PassRegistryTest() = default; +}; + +namespace opt { +class TestFusion : public Pass { + public: + TestFusion() : Pass("test_fusion") {} + bool Run(const FuncGraphPtr &func_graph) override { return true; } +}; +REG_PASS(POSITION_BEGIN, TestFusion) +REG_PASS(POSITION_END, TestFusion) +} // namespace opt + +TEST_F(PassRegistryTest, TestRegistry) { + auto passes = opt::PassRegistry::GetInstance()->GetPasses(); + ASSERT_EQ(passes.size(), 2); + auto begin_pass = passes[opt::POSITION_BEGIN]; + ASSERT_NE(begin_pass, nullptr); + auto begin_pass_test = std::dynamic_pointer_cast(begin_pass); + ASSERT_NE(begin_pass_test, nullptr); + auto res = begin_pass_test->Run(nullptr); + ASSERT_EQ(res, true); + auto end_pass = passes[opt::POSITION_END]; + ASSERT_NE(end_pass, nullptr); + auto end_pass_test = std::dynamic_pointer_cast(end_pass); + ASSERT_NE(end_pass_test, nullptr); + res = end_pass_test->Run(nullptr); + ASSERT_EQ(res, true); +} +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 53d030baa48..27dc5c09f83 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -18,6 +18,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/export_model.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc @@ -176,7 +177,7 @@ target_link_libraries(converter_lite PRIVATE cpu_ops_mid nnacl_mid cpu_kernel_mid - mslite_converter_plugin_reg + mslite_converter_plugin tflite_parser_mid tf_parser_mid caffe_parser_mid diff --git a/mindspore/lite/tools/converter/dump_graph.cc b/mindspore/lite/tools/converter/dump_graph.cc new file mode 100644 index 00000000000..1f71452590b --- /dev/null +++ b/mindspore/lite/tools/converter/dump_graph.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/dump_graph.h" +#include "tools/converter/dump_graph_init.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" + +namespace mindspore { +namespace lite { +static GraphDumpFunc graph_dump_interface = nullptr; +void InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func) { graph_dump_interface = graph_dump_func; } + +int DumpGraph(const FuncGraphPtr &func_graph) { + if (graph_dump_interface == nullptr) { + MS_LOG(ERROR) << "graph_dump_interface is nullptr, which is not init."; + return RET_ERROR; + } + return graph_dump_interface(func_graph); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/dump_graph.h b/mindspore/lite/tools/converter/dump_graph.h new file mode 100644 index 00000000000..98ee8bdf494 --- /dev/null +++ b/mindspore/lite/tools/converter/dump_graph.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_ + +#include +#include "include/lite_utils.h" + +namespace mindspore { +class FuncGraph; +using FuncGraphPtr = std::shared_ptr; +namespace lite { +using GraphDumpFunc = std::function; +int MS_API DumpGraph(const FuncGraphPtr &func_graph); +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_ diff --git a/mindspore/lite/tools/converter/dump_graph_init.h b/mindspore/lite/tools/converter/dump_graph_init.h new file mode 100644 index 00000000000..84ac21719ff --- /dev/null +++ b/mindspore/lite/tools/converter/dump_graph_init.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H + +#include "tools/converter/dump_graph.h" + +namespace mindspore { +namespace lite { +void MS_API InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H diff --git a/mindspore/lite/tools/converter/export_model.cc b/mindspore/lite/tools/converter/export_model.cc new file mode 100644 index 00000000000..a5ff69a7ade --- /dev/null +++ b/mindspore/lite/tools/converter/export_model.cc @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/export_model.h" +#include +#include +#include +#include +#include +#include "include/errorcode.h" +#include "include/version.h" +#include "ir/func_graph.h" +#include "tools/anf_exporter/anf_exporter.h" +#include "tools/converter/graphdef_transform.h" +#include "tools/converter/dump_graph_init.h" +#include "tools/optimizer/graph/unify_format_pass.h" +#include "tools/optimizer/graph/while_pass.h" +#include "tools/optimizer/graph/if_pass.h" + +namespace mindspore { +namespace lite { +namespace { +using NodesMap = std::map>; +} +static converter::Flags *flags = nullptr; + +void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map, + NodesMap *mirror_map) { + MS_ASSERT(origin != nullptr && mirror != nullptr); + MS_ASSERT(origin_map != nullptr && mirror_map != nullptr); + auto origin_inputs = origin->get_inputs(); + for (auto &input : origin_inputs) { + auto mirror_input = mirror->add_parameter(); + if (input->abstract() != nullptr) { + mirror_input->set_abstract(input->abstract()->Clone()); + } + mirror_input->set_name(input->fullname_with_scope()); + (*origin_map)[input->fullname_with_scope()].push_back(input); + (*mirror_map)[input->fullname_with_scope()].push_back(mirror_input); + } +} + +AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph) { + MS_ASSERT(cnode != nullptr && mirror_graph != nullptr); + if (index >= cnode->size()) { + MS_LOG(ERROR) << "input index out of range."; + return nullptr; + } + auto node = cnode->input(index); + if (utils::isa(node)) { + MS_LOG(ERROR) << "this func cannot copy cnode."; + return nullptr; + } + if (utils::isa(node)) { + auto value_node = node->cast(); + auto value_ptr = value_node->value(); + MS_ASSERT(value_ptr != nullptr); + if (utils::isa(value_ptr)) { + std::shared_ptr mirror_monad; + if (utils::isa(value_ptr)) { + mirror_monad = std::make_shared(); + } else { + mirror_monad = std::make_shared(); + } + auto monad_abs = mirror_monad->ToAbstract(); + auto mirror_value_node = NewValueNode(mirror_monad); + mirror_value_node->set_abstract(monad_abs); + return mirror_value_node; + } + } + DataInfo data_info; + STATUS status; + if (utils::isa(node)) { + status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info); + } else if (utils::isa(node)) { + status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info); + } else { + status = RET_ERROR; + } + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "fetch data failed."; + return nullptr; + } + if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) && !data_info.data_.empty()) { + return NewValueNode(MakeValue(*reinterpret_cast(data_info.data_.data()))); + } + ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end()); + auto tensor_info = std::make_shared(static_cast(data_info.data_type_), shape_vec); + if (!data_info.data_.empty()) { + auto tensor_data = reinterpret_cast(tensor_info->data_c()); + if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return nullptr; + } + } + auto mirror_parameter = mirror_graph->add_parameter(); + if (node->abstract() != nullptr) { + mirror_parameter->set_abstract(node->abstract()->Clone()); + } + mirror_parameter->set_name(node->fullname_with_scope()); + mirror_parameter->set_default_param(tensor_info); + return mirror_parameter; +} + +PrimitivePtr ClonePrimitive(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + auto origin_prim = GetValueNode(cnode->input(0)); + MS_ASSERT(origin_prim != nullptr); + PrimitivePtr prim; + auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); + if (op_primc_fns.find(origin_prim->name()) != op_primc_fns.end()) { + prim = op_primc_fns[origin_prim->name()](); + } else { + prim = std::make_shared(origin_prim->name()); + prim->set_instance_name(origin_prim->name()); + } + prim->SetAttrs(origin_prim->attrs()); + return prim; +} + +FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) { + MS_ASSERT(graph != nullptr); + auto mirror_graph = std::make_shared(); + mirror_graph->set_attrs(graph->attrs()); + NodesMap origin_nodes; + NodesMap mirror_nodes; + CloneGraphInputs(graph, mirror_graph, &origin_nodes, &mirror_nodes); + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + auto mirrro_prim = ClonePrimitive(cnode); + std::vector node_inputs; + for (size_t i = 1; i < cnode->size(); ++i) { + auto origin_input = cnode->input(i); + AnfNodePtr mirror_input = nullptr; + auto value = origin_nodes[origin_input->fullname_with_scope()]; + auto iter = std::find(value.begin(), value.end(), origin_input); + if (iter != value.end()) { + mirror_input = mirror_nodes[origin_input->fullname_with_scope()][iter - value.begin()]; + } + if (mirror_input == nullptr) { + if (IsValueNode(origin_input)) { + auto sub_func_graph = GetValueNode(origin_input); + auto mirror_sub_graph = CloneFuncGraph(sub_func_graph); + mirror_input = NewValueNode(mirror_sub_graph); + } else { + mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph); + } + if (mirror_input == nullptr) { + MS_LOG(ERROR) << "node input cannot be found."; + return nullptr; + } + origin_nodes[origin_input->fullname_with_scope()].push_back(origin_input); + mirror_nodes[origin_input->fullname_with_scope()].push_back(mirror_input); + } + node_inputs.push_back(mirror_input); + } + auto mirror_cnode = mirror_graph->NewCNode(mirrro_prim, node_inputs); + mirror_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); + if (cnode->abstract() != nullptr) { + mirror_cnode->set_abstract(cnode->abstract()->Clone()); + } + origin_nodes[cnode->fullname_with_scope()].push_back(cnode); + mirror_nodes[cnode->fullname_with_scope()].push_back(mirror_cnode); + if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) { + mirror_graph->set_return(mirror_cnode); + } + } + return mirror_graph; +} + +STATUS ExportModel(const FuncGraphPtr &graph) { + MS_ASSERT(graph != nullptr && flags != nullptr); + auto mirror_graph = CloneFuncGraph(graph); + if (mirror_graph == nullptr) { + MS_LOG(ERROR) << "Clone funcGraph failed."; + return RET_ERROR; + } + (void)Manage(mirror_graph, true); + auto format_pass = std::make_shared(); + format_pass->Init(flags->fmk, flags->trainModel); + if (!format_pass->Run(mirror_graph)) { + MS_LOG(ERROR) << "Run format pass failed."; + return RET_ERROR; + } + auto optimizer = std::make_shared(); + auto graph_pm = std::make_shared("anf graph pass manager", true); + if (flags->fmk == lite::converter::FmkType_TFLITE || flags->fmk == lite::converter::FmkType_TF || + flags->fmk == lite::converter::FmkType_ONNX) { + graph_pm->AddPass(std::make_shared()); + graph_pm->AddPass(std::make_shared()); + } + optimizer->AddPassManager(graph_pm); + if (optimizer->Optimize(mirror_graph) == nullptr) { + MS_LOG(ERROR) << "run graph pass failed."; + return RET_ERROR; + } + auto meta_graph = Export(mirror_graph); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Export to meta graph return nullptr"; + return RET_ERROR; + } + auto metagraph_transform = std::make_unique(); + metagraph_transform->SetGraphDef(meta_graph); + auto status = metagraph_transform->Transform(*flags); + if (status != RET_OK) { + MS_LOG(ERROR) << "Transform meta graph failed " << status; + return RET_ERROR; + } + meta_graph->version = Version(); + status = Storage::Save(*meta_graph, "model"); + std::ostringstream oss; + if (status != RET_OK) { + oss << "SAVE GRAPH FAILED:" << status << " " << lite::GetErrorInfo(status); + MS_LOG(ERROR) << oss.str(); + std::cout << oss.str() << std::endl; + return status; + } + + delete meta_graph; + oss << "CONVERT RESULT SUCCESS:" << status; + MS_LOG(INFO) << oss.str(); + std::cout << oss.str() << std::endl; + return status; +} + +void ExportModelInit(converter::Flags *flag) { + MS_ASSERT(flag != nullptr); + flags = flag; + InitDumpGraphFunc(ExportModel); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/export_model.h b/mindspore/lite/tools/converter/export_model.h new file mode 100644 index 00000000000..46ab469e6b9 --- /dev/null +++ b/mindspore/lite/tools/converter/export_model.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H + +#include "tools/converter/converter_flags.h" + +namespace mindspore { +namespace lite { +void ExportModelInit(lite::converter::Flags *flag); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H diff --git a/mindspore/lite/tools/converter/registry/CMakeLists.txt b/mindspore/lite/tools/converter/registry/CMakeLists.txt index d344a8d3e51..8439386ce14 100644 --- a/mindspore/lite/tools/converter/registry/CMakeLists.txt +++ b/mindspore/lite/tools/converter/registry/CMakeLists.txt @@ -8,9 +8,10 @@ set(REG_SRC ${CONVERT_REG_SRC} ${KERNEL_REG_SRC} ${CORE_DIR}/utils/log_adapter.cc ${CORE_DIR}/gvar/log_adapter_common.cc - ${CORE_DIR}/gvar/logging_level.cc) + ${CORE_DIR}/gvar/logging_level.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../dump_graph.cc) set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) -add_library(mslite_converter_plugin_reg SHARED ${REG_SRC}) -target_link_libraries(mslite_converter_plugin_reg mindspore::glog) -add_dependencies(mslite_converter_plugin_reg fbs_src) -add_dependencies(mslite_converter_plugin_reg fbs_inner_src) +add_library(mslite_converter_plugin SHARED ${REG_SRC}) +target_link_libraries(mslite_converter_plugin mindspore::glog) +add_dependencies(mslite_converter_plugin fbs_src) +add_dependencies(mslite_converter_plugin fbs_inner_src)