From 610300335f313967cd123686db256e10936389d1 Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 24 Feb 2021 17:46:23 +0800 Subject: [PATCH] support index by negative number --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 2 +- .../irpass/item_tuple_or_list_eliminate.h | 131 ++++++++++++++---- .../parse/test_tuple_index_by_negative.py | 60 ++++++++ 3 files changed, 165 insertions(+), 28 deletions(-) create mode 100644 tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 8b718f73a31..077dfe38f88 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -72,7 +72,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // ops eliminate item_tuple_or_list_eliminate_ = MakeSubstitution( - std::make_shared(), "item_tuple_or_list_eliminate", + std::make_shared(), "item_tuple_or_list_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem}); tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h index d09cf1cf002..a91d6348263 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h @@ -30,11 +30,70 @@ namespace mindspore { namespace opt { namespace irpass { +// (a, b, c, ...)[-1] => (a, b, c, ...)[length-1] +// [a, b, c, ...][-1] => [a, b, c, ...][length-1] +// {prim::kPrimTupleGetItem, T, N} +// {prim::kPrimListGetItem, L, N} +// setitem((a, b, c, ...), -1, z) => setitem((a, b, c, ...), length - 1, z) +// setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z) +// {prim::kPrimTupleSetItem, T, N, Z} +// {prim::kPrimListSetItem, L, N, Z} +class ConvertItemIndexToPositive : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); + + if (is_match_) { + node->cast()->set_input(2, NewValueNode(id_)); + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (is_match_) { + return; + } + + AnfVisitor::Visit(node); + } + + void Visit(const CNodePtr &cnode) override { sequeue_ = cnode; } + + void Visit(const ValueNodePtr &vnode) override { + if (sequeue_ != nullptr && IsValueNode(vnode)) { + auto idx = GetValue(vnode->value()); + if (idx < 0) { + auto sequeue_abstract = sequeue_->abstract()->cast(); + if (sequeue_abstract == nullptr) { + return; + } + id_ = idx + sequeue_abstract->size(); + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + sequeue_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + int64_t id_{0}; + CNodePtr sequeue_{nullptr}; +}; + // (a, b, c, ...)[0] => a // (a, b, c, ...)[1] => b // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} // {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C} -class GetitemEliminater : public AnfVisitor { +class GetitemEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -82,7 +141,7 @@ class GetitemEliminater : public AnfVisitor { // (a, b, c, ...)[1] => b // {prim::kPrimTupleGetItem, C1, C} // {prim::kPrimListGetItem, C1, C} -class GetitemConstEliminater : public AnfVisitor { +class GetitemConstEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -103,8 +162,12 @@ class GetitemConstEliminater : public AnfVisitor { has_new_value_ = vnode->has_new_value(); } if (tuple_ != nullptr && IsValueNode(vnode)) { - id_ = LongToSize(GetValue(vnode->value())); - if (tuple_->size() > id_) { + auto idx = GetValue(vnode->value()); + if (idx < 0) { + idx = idx + tuple_->size(); + } + id_ = LongToSize(idx); + if (id_ < tuple_->size()) { is_match_ = true; } } @@ -127,7 +190,7 @@ class GetitemConstEliminater : public AnfVisitor { // setitem((a, b, c, ...), 1, z) => (a, z, c, ...) // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} // {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z} -class SetitemEliminater : public AnfVisitor { +class SetitemEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -159,8 +222,12 @@ class SetitemEliminater : public AnfVisitor { } void Visit(const ValueNodePtr &vnode) override { - if (args_.size() > 0 && IsValueNode(vnode)) { - id_ = LongToSize(GetValue(vnode->value()) + 1); + if (!args_.empty() && IsValueNode(vnode)) { + auto idx = GetValue(vnode->value()); + if (idx < 0) { + idx = idx + args_.size() - 1; + } + id_ = LongToSize(idx + 1); if (id_ < args_.size()) { is_match_ = true; } @@ -183,7 +250,7 @@ class SetitemEliminater : public AnfVisitor { // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} // {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2} -class GetSetitemEliminater : public AnfVisitor { +class GetSetitemEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -217,8 +284,15 @@ class GetSetitemEliminater : public AnfVisitor { } void Visit(const ValueNodePtr &vnode) override { - if (IsValueNode(vnode)) { + if (tuple_ != nullptr && IsValueNode(vnode)) { auto key = GetValue(vnode->value()); + if (key < 0) { + auto sequeue_abstract = tuple_->abstract()->cast(); + if (sequeue_abstract == nullptr) { + return; + } + key = key + sequeue_abstract->size(); + } if (is_in_set_) { key1_ = key; } else { @@ -282,26 +356,28 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; }; -class ItemTupleOrListEliminater : public OptimizerCaller { +class ItemTupleOrListEliminator : public OptimizerCaller { public: - ItemTupleOrListEliminater() - : get_item_eliminater_(std::make_shared()), - get_item_const_eliminater_(std::make_shared()), - set_item_eliminater_(std::make_shared()), - get_set_item_eliminater_(std::make_shared()), - get_item_depend_reorder_(std::make_shared()) { - eliminaters_.emplace_back(get_item_eliminater_); - eliminaters_.emplace_back(get_item_const_eliminater_); - eliminaters_.emplace_back(set_item_eliminater_); - eliminaters_.emplace_back(get_set_item_eliminater_); - eliminaters_.emplace_back(get_item_depend_reorder_); + ItemTupleOrListEliminator() + : get_item_eliminator_(std::make_shared()), + get_item_const_eliminator_(std::make_shared()), + set_item_eliminator_(std::make_shared()), + get_set_item_eliminator_(std::make_shared()), + get_item_depend_reorder_(std::make_shared()), + convert_item_index_to_positive_(std::make_shared()) { + eliminators_.emplace_back(get_item_eliminator_); + eliminators_.emplace_back(get_item_const_eliminator_); + eliminators_.emplace_back(set_item_eliminator_); + eliminators_.emplace_back(get_set_item_eliminator_); + eliminators_.emplace_back(get_item_depend_reorder_); + eliminators_.emplace_back(convert_item_index_to_positive_); } - ~ItemTupleOrListEliminater() = default; + ~ItemTupleOrListEliminator() = default; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); + for (auto &eliminator : eliminators_) { + new_node = (*eliminator)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -310,10 +386,11 @@ class ItemTupleOrListEliminater : public OptimizerCaller { } private: - OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, - get_item_depend_reorder_; - std::vector eliminaters_{}; + OptimizerCallerPtr get_item_eliminator_, get_item_const_eliminator_, set_item_eliminator_, get_set_item_eliminator_, + get_item_depend_reorder_, convert_item_index_to_positive_; + std::vector eliminators_{}; }; + } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py b/tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py new file mode 100644 index 00000000000..79afb872e07 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test tuple index by negative number""" +import numpy as np +import pytest + +from mindspore import nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + + +def test_tuple_index_by_negative_number(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.index = -1 + self.split = P.Split(axis=0, output_num=4) + + def construct(self, x): + out = self.split(x) + ret = [out[-1], out[-2], out[-3], out[-4]] + ret[-1] = 100 + return ret + + net = Net() + x = Tensor(np.ones((4, 2, 3))) + net(x) + + +def Ttest_tuple_index_by_negative_number_out_bound(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.index = -1 + self.split = P.Split(axis=0, output_num=2) + + def construct(self, x): + out = self.split(x) + return out[-1], out[-2], out[-3] + + net = Net() + x = Tensor(np.ones((2, 2, 3))) + with pytest.raises(IndexError) as err: + net(x) + assert "TupleGetItem evaluator index should be in range[-2, 2), but got -3" in str(err.value)