From 36cc436b070338eba813f30e8f6b1025f08e4c18 Mon Sep 17 00:00:00 2001 From: Liu_Xuu Date: Tue, 14 Dec 2021 14:22:46 +0800 Subject: [PATCH] [MSLITE] fix resize parse bug 1215_06 --- .../parser/onnx/onnx_inputs_adjust.cc | 19 +++++++++++++------ .../lite/tools/optimizer/common/gllo_utils.cc | 16 ++++++++++++++++ .../lite/tools/optimizer/common/gllo_utils.h | 2 ++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc index df355a1a3c5..43dbe41ef35 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc @@ -22,6 +22,7 @@ #include "ops/resize.h" #include "include/errorcode.h" #include "nnacl/op_base.h" +#include "tools/optimizer/common/gllo_utils.h" namespace mindspore::lite { namespace { @@ -232,6 +233,7 @@ STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNod STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_ASSERT(func_graph != nullptr); + auto manager = func_graph->manager(); MS_ASSERT(cnode != nullptr); if (!opt::CheckInputs(cnode)) { MS_LOG(ERROR) << "input is invalid."; @@ -270,7 +272,6 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) } break; } - auto inputs = cnode->inputs(); switch (cnode->inputs().size()) { case opt::kInputSizeFour: { std::vector axes; @@ -281,7 +282,7 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) if (new_param_node == nullptr) { MS_LOG(ERROR) << "new a parameter node failed."; } - inputs.push_back(new_param_node); + manager->AddEdge(cnode, new_param_node); // fall through } case opt::kInputSizeFive: { @@ -293,19 +294,19 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) if (new_param_node == nullptr) { MS_LOG(ERROR) << "new a parameter node failed."; } - inputs.push_back(new_param_node); + manager->AddEdge(cnode, new_param_node); break; } default: MS_LOG(DEBUG) << "no need to adjust."; return lite::RET_NO_CHANGE; } - cnode->set_inputs(inputs); return lite::RET_OK; } -STATUS AdjustResize(const CNodePtr &cnode) { +STATUS AdjustResize(bool *need_update_manager, const CNodePtr &cnode) { MS_ASSERT(cnode != nullptr); + MS_ASSERT(func_graph != nullptr); MS_CHECK_TRUE_RET(!cnode->inputs().empty(), lite::RET_ERROR); auto node = cnode->input(0); MS_ASSERT(node != nullptr); @@ -321,6 +322,7 @@ STATUS AdjustResize(const CNodePtr &cnode) { auto new_input = cnode->inputs(); new_input.erase(new_input.begin() + opt::kInputIndexTwo); cnode->set_inputs(new_input); + *need_update_manager = true; } else if (cnode->inputs().size() > opt::kInputSizeFour) { std::vector new_resize_inputs; new_resize_inputs.push_back(cnode->inputs()[0]); @@ -360,6 +362,7 @@ STATUS AdjustResize(const CNodePtr &cnode) { } new_resize_inputs.push_back(cnode->inputs()[shape_index]); cnode->set_inputs(new_resize_inputs); + *need_update_manager = true; } return lite::RET_OK; } @@ -374,6 +377,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) { } auto node_list = TopoSort(func_graph->get_return()); int status = RET_OK; + bool need_update_manager = false; for (auto &node : node_list) { if (utils::isa(node)) { auto param_node = node->cast(); @@ -395,7 +399,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) { } else if (opt::CheckPrimitiveType(node, prim::kPrimStridedSlice)) { status = AdjustStridedSlice(func_graph, cnode); } else if (opt::CheckPrimitiveType(node, prim::kPrimResize)) { - status = AdjustResize(cnode); + status = AdjustResize(&need_update_manager, cnode); } else { continue; } @@ -404,6 +408,9 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) { return false; } } + if (need_update_manager) { + mindspore::opt::UpdateManager(func_graph); + } return true; } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index db275cd895e..26db7644f12 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "base/float16.h" #include "ops/fusion/conv2d_fusion.h" #include "ops/transpose.h" @@ -31,6 +32,7 @@ #include "tools/converter/quant_param_holder.h" #include "nnacl/op_base.h" #include "src/common/log_util.h" +#include "tools/converter/parser/parser_utils.h" namespace mindspore { namespace opt { @@ -1263,5 +1265,19 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim) { } return false; } +void UpdateManager(const FuncGraphPtr &func_graph) { + auto manager = func_graph->manager(); + if (manager == nullptr) { + manager = Manage(func_graph, true); + } else { + manager->Clear(); + manager->AddFuncGraph(func_graph, true); + } + std::set all_func_graphs; + mindspore::lite::GetAllFuncGraph(func_graph, &all_func_graphs); + for (auto &one_func_graph : all_func_graphs) { + manager->AddFuncGraph(one_func_graph); + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 65bdd74c034..003c001d836 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -128,6 +128,8 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id); bool IsQuantParameterNode(const PrimitiveCPtr &prim); +void UpdateManager(const FuncGraphPtr &func_graph); + template inline bool IsSpecifiedNode(const BaseRef &n) { if (utils::isa(n)) {