From 626c40fdfda8e412cd0eceacbcf3495c1252c993 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Wed, 15 May 2024 09:31:03 +0800 Subject: [PATCH] Fixed the bug of ListToTupleEliminator and TupleToListEliminator. --- .../irpass/seqence_to_sequence_op_eliminate.h | 40 +++++-------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/seqence_to_sequence_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/seqence_to_sequence_op_eliminate.h index a44511648ae..bb2d7a85915 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/seqence_to_sequence_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/seqence_to_sequence_op_eliminate.h @@ -37,13 +37,14 @@ namespace irpass { class ListToTupleEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - AnfVisitor::Match(prim::kPrimListToTuple, {IsCNode})(node); + if (!IsPrimitiveCNode(node, prim::kPrimListToTuple)) { + return nullptr; + } auto fg = node->func_graph(); if (fg != nullptr) { + auto real_node = node->cast()->input(1); + MS_EXCEPTION_IF_NULL(real_node); std::vector args_{NewValueNode(prim::kPrimMakeTuple)}; - if (real_node == nullptr) { - return nullptr; - } MS_EXCEPTION_IF_NULL(real_node->abstract()); auto input_abs = real_node->abstract()->cast(); MS_EXCEPTION_IF_NULL(input_abs); @@ -60,30 +61,20 @@ class ListToTupleEliminator : public AnfVisitor { } return nullptr; } - - void Visit(const CNodePtr &cnode) override { - real_node = cnode; - while (IsPrimitiveCNode(real_node, prim::kPrimDepend)) { - auto depend = real_node->cast(); - real_node = depend->input(1)->cast(); - } - } - - private: - CNodePtr real_node{nullptr}; }; // {prim::kPrimTupleToList, data} => {prim::kPrimMakeList, {prim::kPrimTupleGetItem, data, 0}, ...} class TupleToListEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - AnfVisitor::Match(prim::kPrimTupleToList, {IsCNode})(node); + if (!IsPrimitiveCNode(node, prim::kPrimTupleToList)) { + return nullptr; + } auto fg = node->func_graph(); if (fg != nullptr) { + auto real_node = node->cast()->input(1); + MS_EXCEPTION_IF_NULL(real_node); std::vector args_{NewValueNode(prim::kPrimMakeList)}; - if (real_node == nullptr) { - return nullptr; - } MS_EXCEPTION_IF_NULL(real_node->abstract()); auto input_abs = real_node->abstract()->cast(); MS_EXCEPTION_IF_NULL(input_abs); @@ -100,17 +91,6 @@ class TupleToListEliminator : public AnfVisitor { } return nullptr; } - - void Visit(const CNodePtr &cnode) override { - real_node = cnode; - while (IsPrimitiveCNode(real_node, prim::kPrimDepend)) { - auto depend = real_node->cast(); - real_node = depend->input(1)->cast(); - } - } - - private: - CNodePtr real_node{nullptr}; }; } // namespace irpass } // namespace opt