dump graph function adjustment
This commit is contained in:
parent
8572c4b110
commit
1bc9620f06
|
@ -35,6 +35,7 @@
|
|||
# Lite
|
||||
"mindspore/mindspore/core/mindrt/src/thread/" "useStlAlgorithm"
|
||||
"mindspore/mindspore/lite/test/" "syntaxError"
|
||||
"mindspore/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc" "unknownMacro"
|
||||
"mindspore/mindspore/lite/src/ops/unsqueeze.cc" "useStlAlgorithm"
|
||||
"mindspore/mindspore/lite/tools/common/flag_parser.cc" "useStlAlgorithm"
|
||||
"mindspore/mindspore/lite/tools/common/tensor_util.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -330,8 +330,6 @@ elseif(WIN32)
|
|||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||
|
@ -462,8 +460,6 @@ else()
|
|||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||
|
|
|
@ -48,8 +48,8 @@ class MS_API PassRegistry {
|
|||
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
||||
///
|
||||
/// \param[in] position Define the place where assigned passes will run.
|
||||
/// \param[in] assigned Define the names of the passes.
|
||||
PassRegistry(PassPosition position, const std::vector<std::string> &assigned);
|
||||
/// \param[in] names Define the names of the passes.
|
||||
PassRegistry(PassPosition position, const std::vector<std::string> &names);
|
||||
|
||||
/// \brief Destructor of PassRegistrar.
|
||||
~PassRegistry() = default;
|
||||
|
@ -79,9 +79,8 @@ class MS_API PassRegistry {
|
|||
/// \brief Defined assigning macro to assign Passes, which called by user directly.
|
||||
///
|
||||
/// \param[in] position Define the place where assigned passes will run.
|
||||
/// \param[in] assigned Define the names of the passes.
|
||||
#define REG_SCHEDULED_PASS(position, assigned) \
|
||||
static mindspore::registry::PassRegistry g_##position(position, assigned);
|
||||
/// \param[in] names Define the names of the passes.
|
||||
#define REG_SCHEDULED_PASS(position, names) static mindspore::registry::PassRegistry g_##position(position, names);
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -201,7 +201,6 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/converter/converter_flags.cc
|
||||
${LITE_DIR}/tools/converter/converter.cc
|
||||
${LITE_DIR}/tools/converter/export_model.cc
|
||||
${LITE_DIR}/tools/converter/dump_graph.cc
|
||||
${LITE_DIR}/tools/converter/optimizer_manager.cc
|
||||
${LITE_DIR}/tools/converter/parser/parser_utils.cc
|
||||
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
|
||||
|
|
|
@ -51,7 +51,7 @@ namespace opt {
|
|||
// fuse add and add to addn.
|
||||
class Test1Fusion : public Pass {
|
||||
public:
|
||||
Test1Fusion() : Pass("test1_fusion") {}
|
||||
Test1Fusion() : Pass("Test1Fusion") {}
|
||||
bool CanFusion(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
|
@ -94,7 +94,7 @@ class Test1Fusion : public Pass {
|
|||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
auto manager = Manage(func_graph);
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -132,9 +132,9 @@ class Test1Fusion : public Pass {
|
|||
// convert addn to custom op
|
||||
class Test2Fusion : public Pass {
|
||||
public:
|
||||
Test2Fusion() : Pass("test2_fusion") {}
|
||||
Test2Fusion() : Pass("Test2Fusion") {}
|
||||
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
if (func_graph == nullptr || cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primc = std::make_shared<ops::Custom>();
|
||||
|
@ -143,7 +143,7 @@ class Test2Fusion : public Pass {
|
|||
}
|
||||
primc->set_type("Custom_AddN");
|
||||
std::map<std::string, std::vector<uint8_t>> custom_attrs;
|
||||
std::string input_num = std::to_string(3);
|
||||
std::string input_num = std::to_string(cnode->size() - 1);
|
||||
std::vector<uint8_t> input_num_attr(input_num.begin(), input_num.end());
|
||||
custom_attrs["input_num"] = input_num_attr;
|
||||
std::string op_kind = "custom op";
|
||||
|
@ -162,7 +162,7 @@ class Test2Fusion : public Pass {
|
|||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
auto manager = Manage(func_graph);
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -185,45 +185,22 @@ class Test2Fusion : public Pass {
|
|||
}
|
||||
};
|
||||
|
||||
class TestFusion : public Pass {
|
||||
public:
|
||||
TestFusion() : Pass("test_fusion") {}
|
||||
bool Run(const FuncGraphPtr &func_graph) override {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto test1_fusion = std::make_shared<Test1Fusion>();
|
||||
if (!test1_fusion->Run(func_graph)) {
|
||||
return false;
|
||||
}
|
||||
auto test2_fusion = std::make_shared<Test2Fusion>();
|
||||
if (!test2_fusion->Run(func_graph)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
REG_PASS(TestFusion, TestFusion)
|
||||
REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"})
|
||||
REG_PASS(Test1Fusion, Test1Fusion)
|
||||
REG_PASS(Test2Fusion, Test2Fusion)
|
||||
const std::vector<std::string> schedule = {"Test1Fusion", "Test2Fusion"};
|
||||
REG_SCHEDULED_PASS(POSITION_BEGIN, schedule)
|
||||
} // namespace opt
|
||||
|
||||
TEST_F(PassRegistryTest, TestRegistry) {
|
||||
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
|
||||
ASSERT_EQ(schedule_task.size(), 1);
|
||||
ASSERT_EQ(schedule_task.size(), 2);
|
||||
auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
|
||||
ASSERT_EQ(passes.size(), 1);
|
||||
auto begin_pass = passes.front();
|
||||
ASSERT_NE(begin_pass, nullptr);
|
||||
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
|
||||
ASSERT_NE(begin_pass_test, nullptr);
|
||||
ASSERT_EQ(passes.size(), 2);
|
||||
ASSERT_NE(func_graph_, nullptr);
|
||||
auto res = begin_pass_test->Run(func_graph_);
|
||||
ASSERT_EQ(res, true);
|
||||
for (auto &pass : passes) {
|
||||
auto ret = pass->Run(func_graph_);
|
||||
ASSERT_EQ(ret, true);
|
||||
}
|
||||
auto cnode_list = func_graph_->GetOrderedCnodes();
|
||||
ASSERT_EQ(cnode_list.size(), 2);
|
||||
bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom);
|
||||
|
|
|
@ -59,6 +59,7 @@
|
|||
#include "tools/optimizer/graph/split_one_pass.h"
|
||||
#include "tools/optimizer/graph/decrease_transpose_algo.h"
|
||||
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
||||
#include "tools/optimizer/graph/dump_graph.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
@ -401,6 +402,7 @@ void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
|
|||
registry::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
|
||||
registry::PassRegistry("SpecifyGraphInputFormat",
|
||||
std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat));
|
||||
registry::PassRegistry("DumpGraph", std::make_shared<opt::DumpGraph>(config));
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include "src/train/train_populate_parameter.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "src/common/dynamic_library_loader.h"
|
||||
#include "tools/converter/export_model.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/import/mindspore_importer.h"
|
||||
namespace mindspore {
|
||||
|
@ -149,8 +148,6 @@ int RunConverter(int argc, const char **argv) {
|
|||
}
|
||||
return status;
|
||||
}
|
||||
// Init dump graph func
|
||||
ExportModelInit(flags.get());
|
||||
// Load graph
|
||||
MS_LOG(DEBUG) << "start reading model file";
|
||||
Converter cvt;
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/dump_graph.h"
|
||||
#include "tools/converter/dump_graph_init.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
static GraphDumpFunc graph_dump_interface = nullptr;
|
||||
void InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func) { graph_dump_interface = graph_dump_func; }
|
||||
|
||||
int DumpGraph(const FuncGraphPtr &func_graph) {
|
||||
if (graph_dump_interface == nullptr) {
|
||||
MS_LOG(ERROR) << "graph_dump_interface is nullptr, which is not init.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return graph_dump_interface(func_graph);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,31 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include "include/lite_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
class FuncGraph;
|
||||
using FuncGraphPtr = std::shared_ptr<FuncGraph>;
|
||||
namespace lite {
|
||||
using GraphDumpFunc = std::function<int(const FuncGraphPtr &)>;
|
||||
int MS_API DumpGraph(const FuncGraphPtr &func_graph);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_H_
|
|
@ -1,28 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
||||
|
||||
#include "tools/converter/dump_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void MS_API InitDumpGraphFunc(const GraphDumpFunc &graph_dump_func);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_DUMP_GRAPH_INIT_H
|
|
@ -26,7 +26,6 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/dump_graph_init.h"
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include "tools/optimizer/graph/control_flow_pass.h"
|
||||
|
||||
|
@ -34,9 +33,6 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
namespace {
|
||||
using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
|
||||
}
|
||||
static converter::Flags *flags = nullptr;
|
||||
|
||||
void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
|
||||
NodesMap *mirror_map) {
|
||||
MS_ASSERT(origin != nullptr && mirror != nullptr);
|
||||
|
@ -53,7 +49,8 @@ void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, No
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph) {
|
||||
AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph,
|
||||
const converter::Flags *flags) {
|
||||
MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
|
||||
if (index >= cnode->size()) {
|
||||
MS_LOG(ERROR) << "input index out of range.";
|
||||
|
@ -131,7 +128,7 @@ PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
|
|||
return prim;
|
||||
}
|
||||
|
||||
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) {
|
||||
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto mirror_graph = std::make_shared<FuncGraph>();
|
||||
mirror_graph->set_attrs(graph->attrs());
|
||||
|
@ -157,10 +154,10 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) {
|
|||
if (mirror_input == nullptr) {
|
||||
if (IsValueNode<FuncGraph>(origin_input)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
|
||||
auto mirror_sub_graph = CloneFuncGraph(sub_func_graph);
|
||||
auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, flags);
|
||||
mirror_input = NewValueNode(mirror_sub_graph);
|
||||
} else {
|
||||
mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph);
|
||||
mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, flags);
|
||||
}
|
||||
if (mirror_input == nullptr) {
|
||||
MS_LOG(ERROR) << "node input cannot be found.";
|
||||
|
@ -184,10 +181,11 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph) {
|
|||
}
|
||||
return mirror_graph;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
STATUS ExportModel(const FuncGraphPtr &graph) {
|
||||
STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
|
||||
MS_ASSERT(graph != nullptr && flags != nullptr);
|
||||
auto mirror_graph = CloneFuncGraph(graph);
|
||||
auto mirror_graph = CloneFuncGraph(graph, flags);
|
||||
if (mirror_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Clone funcGraph failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -233,11 +231,5 @@ STATUS ExportModel(const FuncGraphPtr &graph) {
|
|||
delete meta_graph;
|
||||
return status;
|
||||
}
|
||||
|
||||
void ExportModelInit(converter::Flags *flag) {
|
||||
MS_ASSERT(flag != nullptr);
|
||||
flags = flag;
|
||||
InitDumpGraphFunc(ExportModel);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,10 +18,11 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H
|
||||
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void ExportModelInit(converter::Flags *flag);
|
||||
STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -16,8 +16,7 @@ set(REG_SRC ${CONVERT_REG_SRC}
|
|||
${CORE_DIR}/utils/log_adapter.cc
|
||||
${CORE_DIR}/utils/status.cc
|
||||
${CORE_DIR}/gvar/log_adapter_common.cc
|
||||
${CORE_DIR}/gvar/logging_level.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../dump_graph.cc)
|
||||
${CORE_DIR}/gvar/logging_level.cc)
|
||||
set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(mslite_converter_plugin SHARED ${REG_SRC})
|
||||
target_link_libraries(mslite_converter_plugin mindspore::glog)
|
||||
|
|
|
@ -39,9 +39,9 @@ void RegPass(const std::string &pass_name, const opt::PassPtr &pass) {
|
|||
|
||||
PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
|
||||
|
||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) {
|
||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names) {
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
external_assigned_passes[position] = assigned;
|
||||
external_assigned_passes[position] = names;
|
||||
}
|
||||
|
||||
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/export_model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DumpGraph : public Pass {
|
||||
public:
|
||||
explicit DumpGraph(const converter::Flags *flags = nullptr) : Pass("DumpGraph"), flags_(flags) {}
|
||||
~DumpGraph() = default;
|
||||
bool Run(const FuncGraphPtr &graph) override {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (lite::ExportModel(graph, flags_) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "dump graph failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
const converter::Flags *flags_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
|
Loading…
Reference in New Issue