forked from mindspore-Ecosystem/mindspore
dump graph function adjustment
This commit is contained in:
parent
8572c4b110
commit
1bc9620f06
|
@ -35,6 +35,7 @@
|
||||||
# Lite
|
# Lite
|
||||||
"mindspore/mindspore/core/mindrt/src/thread/" "useStlAlgorithm"
|
"mindspore/mindspore/core/mindrt/src/thread/" "useStlAlgorithm"
|
||||||
"mindspore/mindspore/lite/test/" "syntaxError"
|
"mindspore/mindspore/lite/test/" "syntaxError"
|
||||||
|
"mindspore/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc" "unknownMacro"
|
||||||
"mindspore/mindspore/lite/src/ops/unsqueeze.cc" "useStlAlgorithm"
|
"mindspore/mindspore/lite/src/ops/unsqueeze.cc" "useStlAlgorithm"
|
||||||
"mindspore/mindspore/lite/tools/common/flag_parser.cc" "useStlAlgorithm"
|
"mindspore/mindspore/lite/tools/common/flag_parser.cc" "useStlAlgorithm"
|
||||||
"mindspore/mindspore/lite/tools/common/tensor_util.cc" "useStlAlgorithm"
|
"mindspore/mindspore/lite/tools/common/tensor_util.cc" "useStlAlgorithm"
|
||||||
|
|
|
@ -330,8 +330,6 @@ elseif(WIN32)
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
|
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||||
|
@ -462,8 +460,6 @@ else()
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
|
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||||
|
|
|
@ -48,8 +48,8 @@ class MS_API PassRegistry {
|
||||||
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
||||||
///
|
///
|
||||||
/// \param[in] position Define the place where assigned passes will run.
|
/// \param[in] position Define the place where assigned passes will run.
|
||||||
/// \param[in] assigned Define the names of the passes.
|
/// \param[in] names Define the names of the passes.
|
||||||
PassRegistry(PassPosition position, const std::vector<std::string> &assigned);
|
PassRegistry(PassPosition position, const std::vector<std::string> &names);
|
||||||
|
|
||||||
/// \brief Destructor of PassRegistrar.
|
/// \brief Destructor of PassRegistrar.
|
||||||
~PassRegistry() = default;
|
~PassRegistry() = default;
|
||||||
|
@ -79,9 +79,8 @@ class MS_API PassRegistry {
|
||||||
/// \brief Defined assigning macro to assign Passes, which called by user directly.
|
/// \brief Defined assigning macro to assign Passes, which called by user directly.
|
||||||
///
|
///
|
||||||
/// \param[in] position Define the place where assigned passes will run.
|
/// \param[in] position Define the place where assigned passes will run.
|
||||||
/// \param[in] assigned Define the names of the passes.
|
/// \param[in] names Define the names of the passes.
|
||||||
#define REG_SCHEDULED_PASS(position, assigned) \
|
#define REG_SCHEDULED_PASS(position, names) static mindspore::registry::PassRegistry g_##position(position, names);
|
||||||
static mindspore::registry::PassRegistry g_##position(position, assigned);
|
|
||||||
} // namespace registry
|
} // namespace registry
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -201,7 +201,6 @@ if(MSLITE_ENABLE_CONVERTER)
|
||||||
${LITE_DIR}/tools/converter/converter_flags.cc
|
${LITE_DIR}/tools/converter/converter_flags.cc
|
||||||
${LITE_DIR}/tools/converter/converter.cc
|
${LITE_DIR}/tools/converter/converter.cc
|
||||||
${LITE_DIR}/tools/converter/export_model.cc
|
${LITE_DIR}/tools/converter/export_model.cc
|
||||||
${LITE_DIR}/tools/converter/dump_graph.cc
|
|
||||||
${LITE_DIR}/tools/converter/optimizer_manager.cc
|
${LITE_DIR}/tools/converter/optimizer_manager.cc
|
||||||
${LITE_DIR}/tools/converter/parser/parser_utils.cc
|
${LITE_DIR}/tools/converter/parser/parser_utils.cc
|
||||||
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
|
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
|
||||||
|
|
|
@ -51,7 +51,7 @@ namespace opt {
|
||||||
// fuse add and add to addn.
|
// fuse add and add to addn.
|
||||||
class Test1Fusion : public Pass {
|
class Test1Fusion : public Pass {
|
||||||
public:
|
public:
|
||||||
Test1Fusion() : Pass("test1_fusion") {}
|
Test1Fusion() : Pass("Test1Fusion") {}
|
||||||
bool CanFusion(const CNodePtr &cnode) {
|
bool CanFusion(const CNodePtr &cnode) {
|
||||||
if (cnode == nullptr) {
|
if (cnode == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -94,7 +94,7 @@ class Test1Fusion : public Pass {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto manager = func_graph->manager();
|
auto manager = Manage(func_graph);
|
||||||
if (manager == nullptr) {
|
if (manager == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -132,9 +132,9 @@ class Test1Fusion : public Pass {
|
||||||
// convert addn to custom op
|
// convert addn to custom op
|
||||||
class Test2Fusion : public Pass {
|
class Test2Fusion : public Pass {
|
||||||
public:
|
public:
|
||||||
Test2Fusion() : Pass("test2_fusion") {}
|
Test2Fusion() : Pass("Test2Fusion") {}
|
||||||
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||||
if (cnode == nullptr) {
|
if (func_graph == nullptr || cnode == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto primc = std::make_shared<ops::Custom>();
|
auto primc = std::make_shared<ops::Custom>();
|
||||||
|
@ -143,7 +143,7 @@ class Test2Fusion : public Pass {
|
||||||
}
|
}
|
||||||
primc->set_type("Custom_AddN");
|
primc->set_type("Custom_AddN");
|
||||||
std::map<std::string, std::vector<uint8_t>> custom_attrs;
|
std::map<std::string, std::vector<uint8_t>> custom_attrs;
|
||||||
std::string input_num = std::to_string(3);
|
std::string input_num = std::to_string(cnode->size() - 1);
|
||||||
std::vector<uint8_t> input_num_attr(input_num.begin(), input_num.end());
|
std::vector<uint8_t> input_num_attr(input_num.begin(), input_num.end());
|
||||||
custom_attrs["input_num"] = input_num_attr;
|
custom_attrs["input_num"] = input_num_attr;
|
||||||
std::string op_kind = "custom op";
|
std::string op_kind = "custom op";
|
||||||
|
@ -162,7 +162,7 @@ class Test2Fusion : public Pass {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto manager = func_graph->manager();
|
auto manager = Manage(func_graph);
|
||||||
if (manager == nullptr) {
|
if (manager == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -185,45 +185,22 @@ class Test2Fusion : public Pass {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class TestFusion : public Pass {
|
REG_PASS(Test1Fusion, Test1Fusion)
|
||||||
public:
|
REG_PASS(Test2Fusion, Test2Fusion)
|
||||||
TestFusion() : Pass("test_fusion") {}
|
const std::vector<std::string> schedule = {"Test1Fusion", "Test2Fusion"};
|
||||||
bool Run(const FuncGraphPtr &func_graph) override {
|
REG_SCHEDULED_PASS(POSITION_BEGIN, schedule)
|
||||||
if (func_graph == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto manager = Manage(func_graph, true);
|
|
||||||
if (manager == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto test1_fusion = std::make_shared<Test1Fusion>();
|
|
||||||
if (!test1_fusion->Run(func_graph)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto test2_fusion = std::make_shared<Test2Fusion>();
|
|
||||||
if (!test2_fusion->Run(func_graph)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
REG_PASS(TestFusion, TestFusion)
|
|
||||||
REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"})
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
||||||
TEST_F(PassRegistryTest, TestRegistry) {
|
TEST_F(PassRegistryTest, TestRegistry) {
|
||||||
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
|
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
|
||||||
ASSERT_EQ(schedule_task.size(), 1);
|
ASSERT_EQ(schedule_task.size(), 2);
|
||||||
auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
|
auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
|
||||||
ASSERT_EQ(passes.size(), 1);
|
ASSERT_EQ(passes.size(), 2);
|
||||||
auto begin_pass = passes.front();
|
|
||||||
ASSERT_NE(begin_pass, nullptr);
|
|
||||||
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
|
|
||||||
ASSERT_NE(begin_pass_test, nullptr);
|
|
||||||
ASSERT_NE(func_graph_, nullptr);
|
ASSERT_NE(func_graph_, nullptr);
|
||||||
auto res = begin_pass_test->Run(func_graph_);
|
for (auto &pass : passes) {
|
||||||
ASSERT_EQ(res, true);
|
auto ret = pass->Run(func_graph_);
|
||||||
|
ASSERT_EQ(ret, true);
|
||||||
|
}
|
||||||
auto cnode_list = func_graph_->GetOrderedCnodes();
|
auto cnode_list = func_graph_->GetOrderedCnodes();
|
||||||
ASSERT_EQ(cnode_list.size(), 2);
|
ASSERT_EQ(cnode_list.size(), 2);
|
||||||
bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom);
|
bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom);
|
||||||
|
|
|
@ -59,6 +59,7 @@
|
||||||
#include "tools/optimizer/graph/split_one_pass.h"
|
#include "tools/optimizer/graph/split_one_pass.h"
|
||||||
#include "tools/optimizer/graph/decrease_transpose_algo.h"
|
#include "tools/optimizer/graph/decrease_transpose_algo.h"
|
||||||
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
||||||
|
#include "tools/optimizer/graph/dump_graph.h"
|
||||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||||
#include "tools/converter/quantizer/quant_cast.h"
|
#include "tools/converter/quantizer/quant_cast.h"
|
||||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||||
|
@ -401,6 +402,7 @@ void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
|
||||||
registry::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
|
registry::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
|
||||||
registry::PassRegistry("SpecifyGraphInputFormat",
|
registry::PassRegistry("SpecifyGraphInputFormat",
|
||||||
std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat));
|
std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat));
|
||||||
|
registry::PassRegistry("DumpGraph", std::make_shared<opt::DumpGraph>(config));
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
||||||
|
|
|
@ -26,7 +26,6 @@
|
||||||
#include "src/train/train_populate_parameter.h"
|
#include "src/train/train_populate_parameter.h"
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "src/common/dynamic_library_loader.h"
|
#include "src/common/dynamic_library_loader.h"
|
||||||
#include "tools/converter/export_model.h"
|
|
||||||
#include "tools/converter/parser/parser_utils.h"
|
#include "tools/converter/parser/parser_utils.h"
|
||||||
#include "tools/converter/import/mindspore_importer.h"
|
#include "tools/converter/import/mindspore_importer.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -149,8 +148,6 @@ int RunConverter(int argc, const char **argv) {
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
// Init dump graph func
|
|
||||||
ExportModelInit(flags.get());
|
|
||||||
// Load graph
|
// Load graph
|
||||||
MS_LOG(DEBUG) << "start reading model file";
|
MS_LOG(DEBUG) << "start reading model file";
|
||||||
Converter cvt;
|
Converter cvt;
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "tools/converter/dump_graph.h"
|
|
||||||
#include "tools/converter/dump_graph_init.h"
|
|
||||||
#include "include/errorcode.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
static GraphDumpFunc graph_dump_interface = nullptr;
|
|
||||||
void InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func) { graph_dump_interface = graph_dump_func; }
|
|
||||||
|
|
||||||
int DumpGraph(const FuncGraphPtr &func_graph) {
|
|
||||||
if (graph_dump_interface == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "graph_dump_interface is nullptr, which is not init.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return graph_dump_interface(func_graph);
|
|
||||||
}
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,31 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include "include/lite_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
class FuncGraph;
|
|
||||||
using FuncGraphPtr = std::shared_ptr<FuncGraph>;
|
|
||||||
namespace lite {
|
|
||||||
using GraphDumpFunc = std::function<int(const FuncGraphPtr &)>;
|
|
||||||
int MS_API DumpGraph(const FuncGraphPtr &func_graph);
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
|
|
@ -1,28 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
|
||||||
|
|
||||||
#include "tools/converter/dump_graph.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
void MS_API InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func);
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
|
|
@ -26,7 +26,6 @@
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "tools/anf_exporter/anf_exporter.h"
|
#include "tools/anf_exporter/anf_exporter.h"
|
||||||
#include "tools/converter/graphdef_transform.h"
|
#include "tools/converter/graphdef_transform.h"
|
||||||
#include "tools/converter/dump_graph_init.h"
|
|
||||||
#include "tools/converter/optimizer_manager.h"
|
#include "tools/converter/optimizer_manager.h"
|
||||||
#include "tools/optimizer/graph/control_flow_pass.h"
|
#include "tools/optimizer/graph/control_flow_pass.h"
|
||||||
|
|
||||||
|
@ -34,9 +33,6 @@ namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
namespace {
|
namespace {
|
||||||
using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
|
using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
|
||||||
}
|
|
||||||
static converter::Flags *flags = nullptr;
|
|
||||||
|
|
||||||
void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
|
void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
|
||||||
NodesMap *mirror_map) {
|
NodesMap *mirror_map) {
|
||||||
MS_ASSERT(origin != nullptr && mirror != nullptr);
|
MS_ASSERT(origin != nullptr && mirror != nullptr);
|
||||||
|
@ -53,7 +49,8 @@ void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, No
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph) {
|
AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph,
|
||||||
|
const converter::Flags *flags) {
|
||||||
MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
|
MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
|
||||||
if (index >= cnode->size()) {
|
if (index >= cnode->size()) {
|
||||||
MS_LOG(ERROR) << "input index out of range.";
|
MS_LOG(ERROR) << "input index out of range.";
|
||||||
|
@ -131,7 +128,7 @@ PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
|
||||||
return prim;
|
return prim;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) {
|
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags) {
|
||||||
MS_ASSERT(graph != nullptr);
|
MS_ASSERT(graph != nullptr);
|
||||||
auto mirror_graph = std::make_shared<FuncGraph>();
|
auto mirror_graph = std::make_shared<FuncGraph>();
|
||||||
mirror_graph->set_attrs(graph->attrs());
|
mirror_graph->set_attrs(graph->attrs());
|
||||||
|
@ -157,10 +154,10 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) {
|
||||||
if (mirror_input == nullptr) {
|
if (mirror_input == nullptr) {
|
||||||
if (IsValueNode<FuncGraph>(origin_input)) {
|
if (IsValueNode<FuncGraph>(origin_input)) {
|
||||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
|
||||||
auto mirror_sub_graph = CloneFuncGraph(sub_func_graph);
|
auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, flags);
|
||||||
mirror_input = NewValueNode(mirror_sub_graph);
|
mirror_input = NewValueNode(mirror_sub_graph);
|
||||||
} else {
|
} else {
|
||||||
mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph);
|
mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, flags);
|
||||||
}
|
}
|
||||||
if (mirror_input == nullptr) {
|
if (mirror_input == nullptr) {
|
||||||
MS_LOG(ERROR) << "node input cannot be found.";
|
MS_LOG(ERROR) << "node input cannot be found.";
|
||||||
|
@ -184,10 +181,11 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) {
|
||||||
}
|
}
|
||||||
return mirror_graph;
|
return mirror_graph;
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
STATUS ExportModel(const FuncGraphPtr &graph) {
|
STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
|
||||||
MS_ASSERT(graph != nullptr && flags != nullptr);
|
MS_ASSERT(graph != nullptr && flags != nullptr);
|
||||||
auto mirror_graph = CloneFuncGraph(graph);
|
auto mirror_graph = CloneFuncGraph(graph, flags);
|
||||||
if (mirror_graph == nullptr) {
|
if (mirror_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Clone funcGraph failed.";
|
MS_LOG(ERROR) << "Clone funcGraph failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -233,11 +231,5 @@ STATUS ExportModel(const FuncGraphPtr &graph) {
|
||||||
delete meta_graph;
|
delete meta_graph;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExportModelInit(converter::Flags *flag) {
|
|
||||||
MS_ASSERT(flag != nullptr);
|
|
||||||
flags = flag;
|
|
||||||
InitDumpGraphFunc(ExportModel);
|
|
||||||
}
|
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,10 +18,11 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H
|
||||||
|
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
void ExportModelInit(converter::Flags *flag);
|
STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags);
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,7 @@ set(REG_SRC ${CONVERT_REG_SRC}
|
||||||
${CORE_DIR}/utils/log_adapter.cc
|
${CORE_DIR}/utils/log_adapter.cc
|
||||||
${CORE_DIR}/utils/status.cc
|
${CORE_DIR}/utils/status.cc
|
||||||
${CORE_DIR}/gvar/log_adapter_common.cc
|
${CORE_DIR}/gvar/log_adapter_common.cc
|
||||||
${CORE_DIR}/gvar/logging_level.cc
|
${CORE_DIR}/gvar/logging_level.cc)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../dump_graph.cc)
|
|
||||||
set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||||
add_library(mslite_converter_plugin SHARED ${REG_SRC})
|
add_library(mslite_converter_plugin SHARED ${REG_SRC})
|
||||||
target_link_libraries(mslite_converter_plugin mindspore::glog)
|
target_link_libraries(mslite_converter_plugin mindspore::glog)
|
||||||
|
|
|
@ -39,9 +39,9 @@ void RegPass(const std::string &pass_name, const opt::PassPtr &pass) {
|
||||||
|
|
||||||
PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
|
PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
|
||||||
|
|
||||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
|
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names) {
|
||||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||||
external_assigned_passes[position] = assigned;
|
external_assigned_passes[position] = names;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
|
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
|
||||||
|
|
||||||
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
#include "tools/converter/export_model.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class DumpGraph : public Pass {
|
||||||
|
public:
|
||||||
|
explicit DumpGraph(const converter::Flags *flags = nullptr) : Pass("DumpGraph"), flags_(flags) {}
|
||||||
|
~DumpGraph() = default;
|
||||||
|
bool Run(const FuncGraphPtr &graph) override {
|
||||||
|
MS_ASSERT(graph != nullptr);
|
||||||
|
if (lite::ExportModel(graph, flags_) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "dump graph failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const converter::Flags *flags_{nullptr};
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
|
Loading…
Reference in New Issue