forked from mindspore-Ecosystem/mindspore
!18129 [lite]add test ut and rename convert plugin lib
Merge pull request !18129 from 徐安越/master_core
This commit is contained in:
commit
599429f17b
|
@ -277,7 +277,7 @@ elseif(WIN32)
|
|||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${CONVERTER_ROOT_DIR}/include/registry
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin_reg.dll
|
||||
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin.dll
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${glog_LIBPATH}/../bin/libglog.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -349,7 +349,7 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin_reg.so
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin.so
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so.0
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
|
|
@ -14,7 +14,7 @@ CONVERTER="../../../build/tools/converter/converter_lite"
|
|||
if [ ! -f "$CONVERTER" ]; then
|
||||
if ! command -v converter_lite &> /dev/null
|
||||
then
|
||||
tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin_reg.so
|
||||
tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin.so
|
||||
if [ -f ./converter_lite ]; then
|
||||
CONVERTER=./converter_lite
|
||||
else
|
||||
|
|
|
@ -19,7 +19,7 @@ CONVERTER="../../../build/tools/converter/converter_lite"
|
|||
if [ ! -f "$CONVERTER" ]; then
|
||||
if ! command -v converter_lite &> /dev/null
|
||||
then
|
||||
tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin_reg.so
|
||||
tar -xzf ../../../../../output/mindspore-lite-*-train-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin.so
|
||||
if [ -f ./converter_lite ]; then
|
||||
CONVERTER=./converter_lite
|
||||
else
|
||||
|
|
|
@ -169,6 +169,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
add_definitions(-DPRIMITIVE_WRITEABLE)
|
||||
add_definitions(-DUSE_GLOG)
|
||||
file(GLOB_RECURSE TEST_CASE_TFLITE_PARSERS_SRC
|
||||
${TEST_DIR}/ut/tools/converter/registry/*.cc
|
||||
${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc
|
||||
)
|
||||
set(TEST_LITE_SRC
|
||||
|
@ -182,6 +183,8 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/converter/graphdef_transform.cc
|
||||
${LITE_DIR}/tools/converter/converter_flags.cc
|
||||
${LITE_DIR}/tools/converter/converter.cc
|
||||
${LITE_DIR}/tools/converter/export_model.cc
|
||||
${LITE_DIR}/tools/converter/dump_graph.cc
|
||||
${LITE_DIR}/tools/converter/parser/parser_utils.cc
|
||||
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
|
||||
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
|
||||
|
@ -388,7 +391,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
add_dependencies(lite-test fbs_inner_src)
|
||||
target_link_libraries(lite-test
|
||||
anf_exporter_mid
|
||||
mslite_converter_plugin_reg
|
||||
mslite_converter_plugin
|
||||
tflite_parser_mid
|
||||
caffe_parser_mid
|
||||
onnx_parser_mid
|
||||
|
|
|
@ -57,6 +57,8 @@ echo 'run common ut tests'
|
|||
./lite-test --gtest_filter="TestFullConnectionOpenCL*"
|
||||
./lite-test --gtest_filter="TestResizeOpenCL*"
|
||||
./lite-test --gtest_filter="TestSwishOpenCLCI.Fp32CI"
|
||||
./lite-test --gtest_filter="ModelParserRegistryTest.TestRegistry"
|
||||
./lite-test --gtest_filter="PassRegistryTest.TestRegistry"
|
||||
|
||||
# test cases specific for train
|
||||
if [[ $1 == train ]]; then
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <functional>
|
||||
#include "common/common_test.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
using mindspore::lite::converter::Flags;
|
||||
namespace mindspore {
|
||||
class ModelParserRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
ModelParserRegistryTest() = default;
|
||||
};
|
||||
|
||||
class ModelParserTest : public lite::ModelParser {
|
||||
public:
|
||||
ModelParserTest() = default;
|
||||
};
|
||||
|
||||
lite::ModelParser *TestModelParserCreator() {
|
||||
auto *parser = new (std::nothrow) ModelParserTest();
|
||||
if (parser == nullptr) {
|
||||
MS_LOG(ERROR) << "new model parser failed";
|
||||
return nullptr;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
REG_MODEL_PARSER(TEST, TestModelParserCreator);
|
||||
|
||||
TEST_F(ModelParserRegistryTest, TestRegistry) {
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
|
||||
ASSERT_NE(model_parser, nullptr);
|
||||
Flags flags;
|
||||
auto func_graph = model_parser->Parse(flags);
|
||||
ASSERT_EQ(func_graph, nullptr);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "include/registry/pass_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
class PassRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
PassRegistryTest() = default;
|
||||
};
|
||||
|
||||
namespace opt {
|
||||
class TestFusion : public Pass {
|
||||
public:
|
||||
TestFusion() : Pass("test_fusion") {}
|
||||
bool Run(const FuncGraphPtr &func_graph) override { return true; }
|
||||
};
|
||||
REG_PASS(POSITION_BEGIN, TestFusion)
|
||||
REG_PASS(POSITION_END, TestFusion)
|
||||
} // namespace opt
|
||||
|
||||
TEST_F(PassRegistryTest, TestRegistry) {
|
||||
auto passes = opt::PassRegistry::GetInstance()->GetPasses();
|
||||
ASSERT_EQ(passes.size(), 2);
|
||||
auto begin_pass = passes[opt::POSITION_BEGIN];
|
||||
ASSERT_NE(begin_pass, nullptr);
|
||||
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
|
||||
ASSERT_NE(begin_pass_test, nullptr);
|
||||
auto res = begin_pass_test->Run(nullptr);
|
||||
ASSERT_EQ(res, true);
|
||||
auto end_pass = passes[opt::POSITION_END];
|
||||
ASSERT_NE(end_pass, nullptr);
|
||||
auto end_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(end_pass);
|
||||
ASSERT_NE(end_pass_test, nullptr);
|
||||
res = end_pass_test->Run(nullptr);
|
||||
ASSERT_EQ(res, true);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -18,6 +18,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export_model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
|
||||
|
@ -176,7 +177,7 @@ target_link_libraries(converter_lite PRIVATE
|
|||
cpu_ops_mid
|
||||
nnacl_mid
|
||||
cpu_kernel_mid
|
||||
mslite_converter_plugin_reg
|
||||
mslite_converter_plugin
|
||||
tflite_parser_mid
|
||||
tf_parser_mid
|
||||
caffe_parser_mid
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/dump_graph.h"
|
||||
#include "tools/converter/dump_graph_init.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
static GraphDumpFunc graph_dump_interface = nullptr;
|
||||
void InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func) { graph_dump_interface = graph_dump_func; }
|
||||
|
||||
int DumpGraph(const FuncGraphPtr &func_graph) {
|
||||
if (graph_dump_interface == nullptr) {
|
||||
MS_LOG(ERROR) << "graph_dump_interface is nullptr, which is not init.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return graph_dump_interface(func_graph);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include "include/lite_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
class FuncGraph;
|
||||
using FuncGraphPtr = std::shared_ptr<FuncGraph>;
|
||||
namespace lite {
|
||||
using GraphDumpFunc = std::function<int(const FuncGraphPtr &)>;
|
||||
int MS_API DumpGraph(const FuncGraphPtr &func_graph);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
||||
|
||||
#include "tools/converter/dump_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void MS_API InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
|
@ -0,0 +1,249 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/export_model.h"
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/version.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/dump_graph_init.h"
|
||||
#include "tools/optimizer/graph/unify_format_pass.h"
|
||||
#include "tools/optimizer/graph/while_pass.h"
|
||||
#include "tools/optimizer/graph/if_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
|
||||
}
|
||||
static converter::Flags *flags = nullptr;
|
||||
|
||||
void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
|
||||
NodesMap *mirror_map) {
|
||||
MS_ASSERT(origin != nullptr && mirror != nullptr);
|
||||
MS_ASSERT(origin_map != nullptr && mirror_map != nullptr);
|
||||
auto origin_inputs = origin->get_inputs();
|
||||
for (auto &input : origin_inputs) {
|
||||
auto mirror_input = mirror->add_parameter();
|
||||
if (input->abstract() != nullptr) {
|
||||
mirror_input->set_abstract(input->abstract()->Clone());
|
||||
}
|
||||
mirror_input->set_name(input->fullname_with_scope());
|
||||
(*origin_map)[input->fullname_with_scope()].push_back(input);
|
||||
(*mirror_map)[input->fullname_with_scope()].push_back(mirror_input);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph) {
|
||||
MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
|
||||
if (index >= cnode->size()) {
|
||||
MS_LOG(ERROR) << "input index out of range.";
|
||||
return nullptr;
|
||||
}
|
||||
auto node = cnode->input(index);
|
||||
if (utils::isa<mindspore::CNode>(node)) {
|
||||
MS_LOG(ERROR) << "this func cannot copy cnode.";
|
||||
return nullptr;
|
||||
}
|
||||
if (utils::isa<ValueNode>(node)) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
auto value_ptr = value_node->value();
|
||||
MS_ASSERT(value_ptr != nullptr);
|
||||
if (utils::isa<Monad>(value_ptr)) {
|
||||
std::shared_ptr<Monad> mirror_monad;
|
||||
if (utils::isa<UMonad>(value_ptr)) {
|
||||
mirror_monad = std::make_shared<UMonad>();
|
||||
} else {
|
||||
mirror_monad = std::make_shared<IOMonad>();
|
||||
}
|
||||
auto monad_abs = mirror_monad->ToAbstract();
|
||||
auto mirror_value_node = NewValueNode(mirror_monad);
|
||||
mirror_value_node->set_abstract(monad_abs);
|
||||
return mirror_value_node;
|
||||
}
|
||||
}
|
||||
DataInfo data_info;
|
||||
STATUS status;
|
||||
if (utils::isa<Parameter>(node)) {
|
||||
status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
|
||||
} else if (utils::isa<ValueNode>(node)) {
|
||||
status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
|
||||
} else {
|
||||
status = RET_ERROR;
|
||||
}
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) && !data_info.data_.empty()) {
|
||||
return NewValueNode(MakeValue<int>(*reinterpret_cast<int *>(data_info.data_.data())));
|
||||
}
|
||||
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), shape_vec);
|
||||
if (!data_info.data_.empty()) {
|
||||
auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto mirror_parameter = mirror_graph->add_parameter();
|
||||
if (node->abstract() != nullptr) {
|
||||
mirror_parameter->set_abstract(node->abstract()->Clone());
|
||||
}
|
||||
mirror_parameter->set_name(node->fullname_with_scope());
|
||||
mirror_parameter->set_default_param(tensor_info);
|
||||
return mirror_parameter;
|
||||
}
|
||||
|
||||
PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto origin_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(origin_prim != nullptr);
|
||||
PrimitivePtr prim;
|
||||
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
|
||||
if (op_primc_fns.find(origin_prim->name()) != op_primc_fns.end()) {
|
||||
prim = op_primc_fns[origin_prim->name()]();
|
||||
} else {
|
||||
prim = std::make_shared<PrimitiveC>(origin_prim->name());
|
||||
prim->set_instance_name(origin_prim->name());
|
||||
}
|
||||
prim->SetAttrs(origin_prim->attrs());
|
||||
return prim;
|
||||
}
|
||||
|
||||
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto mirror_graph = std::make_shared<FuncGraph>();
|
||||
mirror_graph->set_attrs(graph->attrs());
|
||||
NodesMap origin_nodes;
|
||||
NodesMap mirror_nodes;
|
||||
CloneGraphInputs(graph, mirror_graph, &origin_nodes, &mirror_nodes);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<mindspore::CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto mirrro_prim = ClonePrimitive(cnode);
|
||||
std::vector<AnfNodePtr> node_inputs;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto origin_input = cnode->input(i);
|
||||
AnfNodePtr mirror_input = nullptr;
|
||||
auto value = origin_nodes[origin_input->fullname_with_scope()];
|
||||
auto iter = std::find(value.begin(), value.end(), origin_input);
|
||||
if (iter != value.end()) {
|
||||
mirror_input = mirror_nodes[origin_input->fullname_with_scope()][iter - value.begin()];
|
||||
}
|
||||
if (mirror_input == nullptr) {
|
||||
if (IsValueNode<FuncGraph>(origin_input)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
|
||||
auto mirror_sub_graph = CloneFuncGraph(sub_func_graph);
|
||||
mirror_input = NewValueNode(mirror_sub_graph);
|
||||
} else {
|
||||
mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph);
|
||||
}
|
||||
if (mirror_input == nullptr) {
|
||||
MS_LOG(ERROR) << "node input cannot be found.";
|
||||
return nullptr;
|
||||
}
|
||||
origin_nodes[origin_input->fullname_with_scope()].push_back(origin_input);
|
||||
mirror_nodes[origin_input->fullname_with_scope()].push_back(mirror_input);
|
||||
}
|
||||
node_inputs.push_back(mirror_input);
|
||||
}
|
||||
auto mirror_cnode = mirror_graph->NewCNode(mirrro_prim, node_inputs);
|
||||
mirror_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||
if (cnode->abstract() != nullptr) {
|
||||
mirror_cnode->set_abstract(cnode->abstract()->Clone());
|
||||
}
|
||||
origin_nodes[cnode->fullname_with_scope()].push_back(cnode);
|
||||
mirror_nodes[cnode->fullname_with_scope()].push_back(mirror_cnode);
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
||||
mirror_graph->set_return(mirror_cnode);
|
||||
}
|
||||
}
|
||||
return mirror_graph;
|
||||
}
|
||||
|
||||
STATUS ExportModel(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr && flags != nullptr);
|
||||
auto mirror_graph = CloneFuncGraph(graph);
|
||||
if (mirror_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Clone funcGraph failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(void)Manage(mirror_graph, true);
|
||||
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass->Init(flags->fmk, flags->trainModel);
|
||||
if (!format_pass->Run(mirror_graph)) {
|
||||
MS_LOG(ERROR) << "Run format pass failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
if (flags->fmk == lite::converter::FmkType_TFLITE || flags->fmk == lite::converter::FmkType_TF ||
|
||||
flags->fmk == lite::converter::FmkType_ONNX) {
|
||||
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
|
||||
graph_pm->AddPass(std::make_shared<opt::IfPass>());
|
||||
}
|
||||
optimizer->AddPassManager(graph_pm);
|
||||
if (optimizer->Optimize(mirror_graph) == nullptr) {
|
||||
MS_LOG(ERROR) << "run graph pass failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto meta_graph = Export(mirror_graph);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto metagraph_transform = std::make_unique<GraphDefTransform>();
|
||||
metagraph_transform->SetGraphDef(meta_graph);
|
||||
auto status = metagraph_transform->Transform(*flags);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Transform meta graph failed " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
meta_graph->version = Version();
|
||||
status = Storage::Save(*meta_graph, "model");
|
||||
std::ostringstream oss;
|
||||
if (status != RET_OK) {
|
||||
oss << "SAVE GRAPH FAILED:" << status << " " << lite::GetErrorInfo(status);
|
||||
MS_LOG(ERROR) << oss.str();
|
||||
std::cout << oss.str() << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
delete meta_graph;
|
||||
oss << "CONVERT RESULT SUCCESS:" << status;
|
||||
MS_LOG(INFO) << oss.str();
|
||||
std::cout << oss.str() << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
void ExportModelInit(converter::Flags *flag) {
|
||||
MS_ASSERT(flag != nullptr);
|
||||
flags = flag;
|
||||
InitDumpGraphFunc(ExportModel);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H
|
||||
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void ExportModelInit(lite::converter::Flags *flag);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H
|
|
@ -8,9 +8,10 @@ set(REG_SRC ${CONVERT_REG_SRC}
|
|||
${KERNEL_REG_SRC}
|
||||
${CORE_DIR}/utils/log_adapter.cc
|
||||
${CORE_DIR}/gvar/log_adapter_common.cc
|
||||
${CORE_DIR}/gvar/logging_level.cc)
|
||||
${CORE_DIR}/gvar/logging_level.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../dump_graph.cc)
|
||||
set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(mslite_converter_plugin_reg SHARED ${REG_SRC})
|
||||
target_link_libraries(mslite_converter_plugin_reg mindspore::glog)
|
||||
add_dependencies(mslite_converter_plugin_reg fbs_src)
|
||||
add_dependencies(mslite_converter_plugin_reg fbs_inner_src)
|
||||
add_library(mslite_converter_plugin SHARED ${REG_SRC})
|
||||
target_link_libraries(mslite_converter_plugin mindspore::glog)
|
||||
add_dependencies(mslite_converter_plugin fbs_src)
|
||||
add_dependencies(mslite_converter_plugin fbs_inner_src)
|
||||
|
|
Loading…
Reference in New Issue