forked from mindspore-Ecosystem/mindspore
support item negative index bprop
This commit is contained in:
parent
8eb3e396e5
commit
50ee325b96
|
@ -47,8 +47,11 @@ class ConvertItemIndexToPositive : public AnfVisitor {
|
|||
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
node->cast<CNodePtr>()->set_input(2, NewValueNode(id_));
|
||||
FuncGraphPtr fg = node->func_graph();
|
||||
if (is_match_ && fg != nullptr) {
|
||||
auto inputs = node->cast<CNodePtr>()->inputs();
|
||||
inputs[2] = NewValueNode(id_);
|
||||
return fg->NewCNode(inputs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Providing decorators."""
|
||||
|
||||
from mindspore import log
|
||||
|
||||
def deprecated(version, substitute, use_substitute_name=False):
|
||||
"""deprecated warning
|
||||
|
@ -28,8 +29,8 @@ def deprecated(version, substitute, use_substitute_name=False):
|
|||
def wrapper(*args, **kwargs):
|
||||
cls = getattr(args[0], "__class__", None) if args else None
|
||||
name = cls.__name__ if cls else func.__name__
|
||||
print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, "
|
||||
f"use '{substitute}' instead.")
|
||||
log.warning(f"'{name}' is deprecated from version {version} and "
|
||||
f"will be removed in a future version, use '{substitute}' instead.")
|
||||
if cls and use_substitute_name:
|
||||
cls.substitute_name = substitute
|
||||
ret = func(*args, **kwargs)
|
||||
|
|
|
@ -20,8 +20,9 @@ from mindspore import nn
|
|||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_tuple_index_by_negative_number():
|
||||
|
@ -37,12 +38,24 @@ def test_tuple_index_by_negative_number():
|
|||
ret[-1] = 100
|
||||
return ret
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net, get_all):
|
||||
super(GradNet, self).__init__()
|
||||
self.forward_net = net
|
||||
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||
self.grad_all = C.GradOperation(get_all=get_all)
|
||||
|
||||
def construct(self, x):
|
||||
return self.grad_all(self.forward_net)(x)
|
||||
|
||||
net = Net()
|
||||
grad_net = GradNet(net, True)
|
||||
x = Tensor(np.ones((4, 2, 3)))
|
||||
net(x)
|
||||
grad_net(x)
|
||||
|
||||
|
||||
def Ttest_tuple_index_by_negative_number_out_bound():
|
||||
def test_tuple_index_by_negative_number_out_bound():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue