forked from mindspore-Ecosystem/mindspore
support key ward way to pass arg for outermost net in graph mode
This commit is contained in:
parent
2880720ad7
commit
1364b5c301
|
@ -282,19 +282,19 @@ class Cell(Cell_):
|
|||
return tuple(res)
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
if kwargs:
|
||||
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
|
||||
inputs = bound_args.args
|
||||
kwargs = bound_args.kwargs
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if kwargs:
|
||||
raise ValueError("For 'graph' mode, the outermost network does not support passing "
|
||||
"key-value pair parameters and variable key-value pair parameters.")
|
||||
"variable key-value pair parameters.")
|
||||
if self.enable_hook:
|
||||
raise ValueError("The graph mode does not support hook function.")
|
||||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
|
||||
if kwargs:
|
||||
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
|
||||
inputs = bound_args.args
|
||||
kwargs = bound_args.kwargs
|
||||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("cell inputs should not be numpy array.")
|
||||
|
|
|
@ -22,7 +22,6 @@ from mindspore.ops import operations as P
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
grad_all_with_sens = C.GradOperation(sens_param=True)
|
||||
|
||||
|
@ -285,3 +284,31 @@ def test_mixed_precision_const_parameter():
|
|||
y = Tensor(np.ones((1, 3, 14, 14), np.float32))
|
||||
z = Tensor(np.ones((1, 3, 28, 28), np.float32))
|
||||
_ = net(x, y, z)
|
||||
|
||||
|
||||
def test_pass_args_by_key_ward_way():
|
||||
class KeyWardNet(Cell):
|
||||
def __init__(self):
|
||||
super(KeyWardNet, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
return x + y - z
|
||||
|
||||
class GradNet(Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.net = net
|
||||
self.sens = Tensor(np.ones((3, 3, 4), np.float32))
|
||||
|
||||
def construct(self, x, y, z, sens):
|
||||
return self.grad(self.net)(x, y, z, sens)
|
||||
|
||||
x = Tensor(np.ones((1, 3, 4), np.float32))
|
||||
y = Tensor(np.ones((1, 3, 4), np.float32))
|
||||
z = Tensor(np.ones((3, 3, 4), np.float32))
|
||||
net = KeyWardNet()
|
||||
net(x, z=z, y=y)
|
||||
grad_net = GradNet(net)
|
||||
sens = Tensor(np.ones((3, 3, 4), np.float32))
|
||||
grad_net(x, y=y, z=z, sens=sens)
|
||||
|
|
Loading…
Reference in New Issue