From 50ee325b968b872cb566119a226ce57d0bb474b9 Mon Sep 17 00:00:00 2001 From: buxue Date: Fri, 12 Mar 2021 11:10:28 +0800 Subject: [PATCH] support item negative index bprop --- .../irpass/item_tuple_or_list_eliminate.h | 7 +++++-- mindspore/common/_decorator.py | 5 +++-- .../parse/test_tuple_index_by_negative.py | 17 +++++++++++++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h index a91d6348263..5a0dafbff2b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h @@ -47,8 +47,11 @@ class ConvertItemIndexToPositive : public AnfVisitor { AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); - if (is_match_) { - node->cast()->set_input(2, NewValueNode(id_)); + FuncGraphPtr fg = node->func_graph(); + if (is_match_ && fg != nullptr) { + auto inputs = node->cast()->inputs(); + inputs[2] = NewValueNode(id_); + return fg->NewCNode(inputs); } return nullptr; } diff --git a/mindspore/common/_decorator.py b/mindspore/common/_decorator.py index 7f75fd17ae2..f398b915d8b 100644 --- a/mindspore/common/_decorator.py +++ b/mindspore/common/_decorator.py @@ -14,6 +14,7 @@ # ============================================================================ """Providing decorators.""" +from mindspore import log def deprecated(version, substitute, use_substitute_name=False): """deprecated warning @@ -28,8 +29,8 @@ def deprecated(version, substitute, use_substitute_name=False): def wrapper(*args, **kwargs): cls = getattr(args[0], "__class__", None) if args else None name = cls.__name__ if cls else func.__name__ - print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " - f"use '{substitute}' instead.") + log.warning(f"'{name}' is deprecated from version {version} and " + f"will be removed in a future version, use '{substitute}' instead.") if cls and use_substitute_name: cls.substitute_name = substitute ret = func(*args, **kwargs) diff --git a/tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py b/tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py index 79afb872e07..ac4b0a61356 100644 --- a/tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py +++ b/tests/ut/python/pipeline/parse/test_tuple_index_by_negative.py @@ -20,8 +20,9 @@ from mindspore import nn from mindspore import Tensor from mindspore import context from mindspore.ops import operations as P +from mindspore.ops import composite as C -context.set_context(mode=context.GRAPH_MODE, save_graphs=True) +context.set_context(mode=context.GRAPH_MODE) def test_tuple_index_by_negative_number(): @@ -37,12 +38,24 @@ def test_tuple_index_by_negative_number(): ret[-1] = 100 return ret + class GradNet(nn.Cell): + def __init__(self, net, get_all): + super(GradNet, self).__init__() + self.forward_net = net + self.sens = Tensor(np.ones((2, 2), np.float32) * 5) + self.grad_all = C.GradOperation(get_all=get_all) + + def construct(self, x): + return self.grad_all(self.forward_net)(x) + net = Net() + grad_net = GradNet(net, True) x = Tensor(np.ones((4, 2, 3))) net(x) + grad_net(x) -def Ttest_tuple_index_by_negative_number_out_bound(): +def test_tuple_index_by_negative_number_out_bound(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__()