!27665 [MSLITE] fix resize parse bug

Merge pull request !27665 from Liu_Xuu/bug_1214_converter
This commit is contained in:
i-robot 2021-12-17 02:42:54 +00:00 committed by Gitee
commit 0922e772a7
3 changed files with 31 additions and 6 deletions

View File

@ -22,6 +22,7 @@
#include "ops/resize.h"
#include "include/errorcode.h"
#include "nnacl/op_base.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite {
namespace {
@ -232,6 +233,7 @@ STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNod
STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(cnode != nullptr);
if (!opt::CheckInputs(cnode)) {
MS_LOG(ERROR) << "input is invalid.";
@ -270,7 +272,6 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
}
break;
}
auto inputs = cnode->inputs();
switch (cnode->inputs().size()) {
case opt::kInputSizeFour: {
std::vector<int32_t> axes;
@ -281,7 +282,7 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
if (new_param_node == nullptr) {
MS_LOG(ERROR) << "new a parameter node failed.";
}
inputs.push_back(new_param_node);
manager->AddEdge(cnode, new_param_node);
// fall through
}
case opt::kInputSizeFive: {
@ -293,19 +294,19 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
if (new_param_node == nullptr) {
MS_LOG(ERROR) << "new a parameter node failed.";
}
inputs.push_back(new_param_node);
manager->AddEdge(cnode, new_param_node);
break;
}
default:
MS_LOG(DEBUG) << "no need to adjust.";
return lite::RET_NO_CHANGE;
}
cnode->set_inputs(inputs);
return lite::RET_OK;
}
STATUS AdjustResize(const CNodePtr &cnode) {
STATUS AdjustResize(bool *need_update_manager, const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(func_graph != nullptr);
MS_CHECK_TRUE_RET(!cnode->inputs().empty(), lite::RET_ERROR);
auto node = cnode->input(0);
MS_ASSERT(node != nullptr);
@ -321,6 +322,7 @@ STATUS AdjustResize(const CNodePtr &cnode) {
auto new_input = cnode->inputs();
new_input.erase(new_input.begin() + opt::kInputIndexTwo);
cnode->set_inputs(new_input);
*need_update_manager = true;
} else if (cnode->inputs().size() > opt::kInputSizeFour) {
std::vector<AnfNodePtr> new_resize_inputs;
new_resize_inputs.push_back(cnode->inputs()[0]);
@ -360,6 +362,7 @@ STATUS AdjustResize(const CNodePtr &cnode) {
}
new_resize_inputs.push_back(cnode->inputs()[shape_index]);
cnode->set_inputs(new_resize_inputs);
*need_update_manager = true;
}
return lite::RET_OK;
}
@ -374,6 +377,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) {
}
auto node_list = TopoSort(func_graph->get_return());
int status = RET_OK;
bool need_update_manager = false;
for (auto &node : node_list) {
if (utils::isa<ParameterPtr>(node)) {
auto param_node = node->cast<ParameterPtr>();
@ -395,7 +399,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) {
} else if (opt::CheckPrimitiveType(node, prim::kPrimStridedSlice)) {
status = AdjustStridedSlice(func_graph, cnode);
} else if (opt::CheckPrimitiveType(node, prim::kPrimResize)) {
status = AdjustResize(cnode);
status = AdjustResize(&need_update_manager, cnode);
} else {
continue;
}
@ -404,6 +408,9 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) {
return false;
}
}
if (need_update_manager) {
mindspore::opt::UpdateManager(func_graph);
}
return true;
}
} // namespace mindspore::lite

View File

@ -20,6 +20,7 @@
#include <unordered_map>
#include <functional>
#include <string>
#include <set>
#include "base/float16.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/transpose.h"
@ -31,6 +32,7 @@
#include "tools/converter/quant_param_holder.h"
#include "nnacl/op_base.h"
#include "src/common/log_util.h"
#include "tools/converter/parser/parser_utils.h"
namespace mindspore {
namespace opt {
@ -1263,5 +1265,19 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim) {
}
return false;
}
void UpdateManager(const FuncGraphPtr &func_graph) {
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
} else {
manager->Clear();
manager->AddFuncGraph(func_graph, true);
}
std::set<FuncGraphPtr> all_func_graphs;
mindspore::lite::GetAllFuncGraph(func_graph, &all_func_graphs);
for (auto &one_func_graph : all_func_graphs) {
manager->AddFuncGraph(one_func_graph);
}
}
} // namespace opt
} // namespace mindspore

View File

@ -128,6 +128,8 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id);
bool IsQuantParameterNode(const PrimitiveCPtr &prim);
void UpdateManager(const FuncGraphPtr &func_graph);
template <const PrimitivePtr *prim = nullptr>
inline bool IsSpecifiedNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {