diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index efc3795a4cc..4d74e38c842 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -64,7 +64,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // ops eliminate item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", - {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h index acd6844ee74..6ae41eaa2a9 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h @@ -38,6 +38,7 @@ class GetitemEliminater : public AnfVisitor { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node); if (is_match_) { return tuple_->input(id_); @@ -46,14 +47,18 @@ class GetitemEliminater : public AnfVisitor { } void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { tuple_ = cnode; } } void Visit(const ValueNodePtr &vnode) override { if (tuple_ != nullptr && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value()) + 1); + int idx = GetValue(vnode->value()); + if (idx < 0) { + idx = idx + tuple_->size() - 1; + } + id_ = IntToSize(idx + 1); if (tuple_->size() > id_) { is_match_ = true; } @@ -80,6 +85,7 @@ class GetitemConstEliminater : public AnfVisitor { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsVNode, IsVNode})(node); if (is_match_) { return NewValueNode((*tuple_)[id_]); @@ -138,7 +144,7 @@ class SetitemEliminater : public AnfVisitor { } void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { auto &inputs = cnode->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(args_)); } @@ -234,6 +240,7 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsValueNode})(node); if (x_ == nullptr) { return nullptr; }