forked from mindspore-Ecosystem/mindspore
!27665 [MSLITE] fix resize parse bug
Merge pull request !27665 from Liu_Xuu/bug_1214_converter
This commit is contained in:
commit
0922e772a7
|
@ -22,6 +22,7 @@
|
|||
#include "ops/resize.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
|
@ -232,6 +233,7 @@ STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNod
|
|||
|
||||
STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (!opt::CheckInputs(cnode)) {
|
||||
MS_LOG(ERROR) << "input is invalid.";
|
||||
|
@ -270,7 +272,6 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
}
|
||||
break;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
switch (cnode->inputs().size()) {
|
||||
case opt::kInputSizeFour: {
|
||||
std::vector<int32_t> axes;
|
||||
|
@ -281,7 +282,7 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
if (new_param_node == nullptr) {
|
||||
MS_LOG(ERROR) << "new a parameter node failed.";
|
||||
}
|
||||
inputs.push_back(new_param_node);
|
||||
manager->AddEdge(cnode, new_param_node);
|
||||
// fall through
|
||||
}
|
||||
case opt::kInputSizeFive: {
|
||||
|
@ -293,19 +294,19 @@ STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
if (new_param_node == nullptr) {
|
||||
MS_LOG(ERROR) << "new a parameter node failed.";
|
||||
}
|
||||
inputs.push_back(new_param_node);
|
||||
manager->AddEdge(cnode, new_param_node);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(DEBUG) << "no need to adjust.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
cnode->set_inputs(inputs);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AdjustResize(const CNodePtr &cnode) {
|
||||
STATUS AdjustResize(bool *need_update_manager, const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_CHECK_TRUE_RET(!cnode->inputs().empty(), lite::RET_ERROR);
|
||||
auto node = cnode->input(0);
|
||||
MS_ASSERT(node != nullptr);
|
||||
|
@ -321,6 +322,7 @@ STATUS AdjustResize(const CNodePtr &cnode) {
|
|||
auto new_input = cnode->inputs();
|
||||
new_input.erase(new_input.begin() + opt::kInputIndexTwo);
|
||||
cnode->set_inputs(new_input);
|
||||
*need_update_manager = true;
|
||||
} else if (cnode->inputs().size() > opt::kInputSizeFour) {
|
||||
std::vector<AnfNodePtr> new_resize_inputs;
|
||||
new_resize_inputs.push_back(cnode->inputs()[0]);
|
||||
|
@ -360,6 +362,7 @@ STATUS AdjustResize(const CNodePtr &cnode) {
|
|||
}
|
||||
new_resize_inputs.push_back(cnode->inputs()[shape_index]);
|
||||
cnode->set_inputs(new_resize_inputs);
|
||||
*need_update_manager = true;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -374,6 +377,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status = RET_OK;
|
||||
bool need_update_manager = false;
|
||||
for (auto &node : node_list) {
|
||||
if (utils::isa<ParameterPtr>(node)) {
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
|
@ -395,7 +399,7 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) {
|
|||
} else if (opt::CheckPrimitiveType(node, prim::kPrimStridedSlice)) {
|
||||
status = AdjustStridedSlice(func_graph, cnode);
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimResize)) {
|
||||
status = AdjustResize(cnode);
|
||||
status = AdjustResize(&need_update_manager, cnode);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
@ -404,6 +408,9 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if (need_update_manager) {
|
||||
mindspore::opt::UpdateManager(func_graph);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "base/float16.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "ops/transpose.h"
|
||||
|
@ -31,6 +32,7 @@
|
|||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -1263,5 +1265,19 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
void UpdateManager(const FuncGraphPtr &func_graph) {
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
} else {
|
||||
manager->Clear();
|
||||
manager->AddFuncGraph(func_graph, true);
|
||||
}
|
||||
std::set<FuncGraphPtr> all_func_graphs;
|
||||
mindspore::lite::GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
for (auto &one_func_graph : all_func_graphs) {
|
||||
manager->AddFuncGraph(one_func_graph);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -128,6 +128,8 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id);
|
|||
|
||||
bool IsQuantParameterNode(const PrimitiveCPtr &prim);
|
||||
|
||||
void UpdateManager(const FuncGraphPtr &func_graph);
|
||||
|
||||
template <const PrimitivePtr *prim = nullptr>
|
||||
inline bool IsSpecifiedNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
|
|
Loading…
Reference in New Issue