support index by negative number

This commit is contained in:
buxue 2021-02-24 17:46:23 +08:00
parent 2184bcff36
commit 610300335f
3 changed files with 165 additions and 28 deletions

View File

@ -72,7 +72,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// ops eliminate
item_tuple_or_list_eliminate_ = MakeSubstitution(
std::make_shared<ItemTupleOrListEliminater>(), "item_tuple_or_list_eliminate",
std::make_shared<ItemTupleOrListEliminator>(), "item_tuple_or_list_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);

View File

@ -30,11 +30,70 @@
namespace mindspore {
namespace opt {
namespace irpass {
// (a, b, c, ...)[-1] => (a, b, c, ...)[length-1]
// [a, b, c, ...][-1] => [a, b, c, ...][length-1]
// {prim::kPrimTupleGetItem, T, N}
// {prim::kPrimListGetItem, L, N}
// setitem((a, b, c, ...), -1, z) => setitem((a, b, c, ...), length - 1, z)
// setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z)
// {prim::kPrimTupleSetItem, T, N, Z}
// {prim::kPrimListSetItem, L, N, Z}
class ConvertItemIndexToPositive : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
if (is_match_) {
node->cast<CNodePtr>()->set_input(2, NewValueNode(id_));
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_match_) {
return;
}
AnfVisitor::Visit(node);
}
void Visit(const CNodePtr &cnode) override { sequeue_ = cnode; }
void Visit(const ValueNodePtr &vnode) override {
if (sequeue_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) {
auto sequeue_abstract = sequeue_->abstract()->cast<abstract::AbstractSequeuePtr>();
if (sequeue_abstract == nullptr) {
return;
}
id_ = idx + sequeue_abstract->size();
is_match_ = true;
}
}
}
void Reset() {
id_ = 0;
sequeue_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
int64_t id_{0};
CNodePtr sequeue_{nullptr};
};
// (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
class GetitemEliminater : public AnfVisitor {
class GetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -82,7 +141,7 @@ class GetitemEliminater : public AnfVisitor {
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C}
// {prim::kPrimListGetItem, C1, C}
class GetitemConstEliminater : public AnfVisitor {
class GetitemConstEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -103,8 +162,12 @@ class GetitemConstEliminater : public AnfVisitor {
has_new_value_ = vnode->has_new_value();
}
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
id_ = LongToSize(GetValue<int64_t>(vnode->value()));
if (tuple_->size() > id_) {
auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) {
idx = idx + tuple_->size();
}
id_ = LongToSize(idx);
if (id_ < tuple_->size()) {
is_match_ = true;
}
}
@ -127,7 +190,7 @@ class GetitemConstEliminater : public AnfVisitor {
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
class SetitemEliminater : public AnfVisitor {
class SetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -159,8 +222,12 @@ class SetitemEliminater : public AnfVisitor {
}
void Visit(const ValueNodePtr &vnode) override {
if (args_.size() > 0 && IsValueNode<Int64Imm>(vnode)) {
id_ = LongToSize(GetValue<int64_t>(vnode->value()) + 1);
if (!args_.empty() && IsValueNode<Int64Imm>(vnode)) {
auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) {
idx = idx + args_.size() - 1;
}
id_ = LongToSize(idx + 1);
if (id_ < args_.size()) {
is_match_ = true;
}
@ -183,7 +250,7 @@ class SetitemEliminater : public AnfVisitor {
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
class GetSetitemEliminater : public AnfVisitor {
class GetSetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -217,8 +284,15 @@ class GetSetitemEliminater : public AnfVisitor {
}
void Visit(const ValueNodePtr &vnode) override {
if (IsValueNode<Int64Imm>(vnode)) {
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
auto key = GetValue<int64_t>(vnode->value());
if (key < 0) {
auto sequeue_abstract = tuple_->abstract()->cast<abstract::AbstractSequeuePtr>();
if (sequeue_abstract == nullptr) {
return;
}
key = key + sequeue_abstract->size();
}
if (is_in_set_) {
key1_ = key;
} else {
@ -282,26 +356,28 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
};
class ItemTupleOrListEliminater : public OptimizerCaller {
class ItemTupleOrListEliminator : public OptimizerCaller {
public:
ItemTupleOrListEliminater()
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()),
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) {
eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_);
eliminaters_.emplace_back(get_set_item_eliminater_);
eliminaters_.emplace_back(get_item_depend_reorder_);
ItemTupleOrListEliminator()
: get_item_eliminator_(std::make_shared<GetitemEliminator>()),
get_item_const_eliminator_(std::make_shared<GetitemConstEliminator>()),
set_item_eliminator_(std::make_shared<SetitemEliminator>()),
get_set_item_eliminator_(std::make_shared<GetSetitemEliminator>()),
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()),
convert_item_index_to_positive_(std::make_shared<ConvertItemIndexToPositive>()) {
eliminators_.emplace_back(get_item_eliminator_);
eliminators_.emplace_back(get_item_const_eliminator_);
eliminators_.emplace_back(set_item_eliminator_);
eliminators_.emplace_back(get_set_item_eliminator_);
eliminators_.emplace_back(get_item_depend_reorder_);
eliminators_.emplace_back(convert_item_index_to_positive_);
}
~ItemTupleOrListEliminater() = default;
~ItemTupleOrListEliminator() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node);
for (auto &eliminator : eliminators_) {
new_node = (*eliminator)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@ -310,10 +386,11 @@ class ItemTupleOrListEliminater : public OptimizerCaller {
}
private:
OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_,
get_item_depend_reorder_;
std::vector<OptimizerCallerPtr> eliminaters_{};
OptimizerCallerPtr get_item_eliminator_, get_item_const_eliminator_, set_item_eliminator_, get_set_item_eliminator_,
get_item_depend_reorder_, convert_item_index_to_positive_;
std::vector<OptimizerCallerPtr> eliminators_{};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test tuple index by negative number"""
import numpy as np
import pytest
from mindspore import nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def test_tuple_index_by_negative_number():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.index = -1
self.split = P.Split(axis=0, output_num=4)
def construct(self, x):
out = self.split(x)
ret = [out[-1], out[-2], out[-3], out[-4]]
ret[-1] = 100
return ret
net = Net()
x = Tensor(np.ones((4, 2, 3)))
net(x)
def Ttest_tuple_index_by_negative_number_out_bound():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.index = -1
self.split = P.Split(axis=0, output_num=2)
def construct(self, x):
out = self.split(x)
return out[-1], out[-2], out[-3]
net = Net()
x = Tensor(np.ones((2, 2, 3)))
with pytest.raises(IndexError) as err:
net(x)
assert "TupleGetItem evaluator index should be in range[-2, 2), but got -3" in str(err.value)