!4370 support kw and kwargs for cell in Pynative
Merge pull request !4370 from zhangbuxue/support_kw_and_kwargs_for_cell_in_pynative
This commit is contained in:
commit
b4b6e5c8ed
|
@ -291,14 +291,14 @@ class _PynativeExecutor:
|
|||
def __init__(self):
|
||||
self._executor = PynativeExecutor_.get_instance()
|
||||
|
||||
def new_graph(self, obj, *args):
|
||||
self._executor.new_graph(obj, *args)
|
||||
def new_graph(self, obj, *args, **kwargs):
|
||||
self._executor.new_graph(obj, *args, *(kwargs.values()))
|
||||
|
||||
def end_graph(self, obj, output, *args):
|
||||
self._executor.end_graph(obj, output, *args)
|
||||
def end_graph(self, obj, output, *args, **kwargs):
|
||||
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
||||
|
||||
def grad(self, grad, obj, weights, *args):
|
||||
self._executor.grad_net(grad, obj, weights, *args)
|
||||
def grad(self, grad, obj, weights, *args, **kwargs):
|
||||
self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
|
||||
|
||||
def clear(self, flag=""):
|
||||
self._executor.clear(flag)
|
||||
|
@ -306,7 +306,8 @@ class _PynativeExecutor:
|
|||
def set_grad_flag(self, flag):
|
||||
self._executor.set_grad_flag(flag)
|
||||
|
||||
def __call__(self, *args):
|
||||
def __call__(self, *args, **kwargs):
|
||||
args = args + tuple(kwargs.values())
|
||||
return self._executor(args, "")
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cell"""
|
||||
import inspect
|
||||
import time
|
||||
import gc
|
||||
from collections import OrderedDict
|
||||
|
@ -222,19 +223,27 @@ class Cell:
|
|||
else:
|
||||
object.__delattr__(self, name)
|
||||
|
||||
def __call__(self, *inputs):
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if kwargs:
|
||||
raise ValueError("For 'graph' mode, the outermost network does not support passing "
|
||||
"key-value pair parameters and variable key-value pair parameters.")
|
||||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
|
||||
if kwargs:
|
||||
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
|
||||
inputs = bound_args.args
|
||||
kwargs = bound_args.kwargs
|
||||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("cell inputs should not be numpy array.")
|
||||
orign_grad = []
|
||||
origin_grad = []
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(self, *inputs)
|
||||
_pynative_exec.new_graph(self, *inputs, **kwargs)
|
||||
for cell in self.cells():
|
||||
orign_grad.append(cell.requires_grad)
|
||||
origin_grad.append(cell.requires_grad)
|
||||
cell.set_grad(True)
|
||||
else:
|
||||
_pynative_exec.set_grad_flag(False)
|
||||
|
@ -251,15 +260,15 @@ class Cell:
|
|||
else:
|
||||
cast_inputs = inputs
|
||||
if self.enable_hook:
|
||||
output = self._hook_construct(*cast_inputs)
|
||||
output = self._hook_construct(*cast_inputs, **kwargs)
|
||||
else:
|
||||
output = self.construct(*cast_inputs)
|
||||
output = self.construct(*cast_inputs, **kwargs)
|
||||
if isinstance(output, Parameter):
|
||||
output = output.data
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.end_graph(self, output, *inputs)
|
||||
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
|
||||
for i, cell in enumerate(self.cells()):
|
||||
cell.set_grad(orign_grad[i])
|
||||
cell.set_grad(origin_grad[i])
|
||||
self._already_run = True
|
||||
return output
|
||||
|
||||
|
@ -400,7 +409,6 @@ class Cell:
|
|||
|
||||
def _get_construct_inputs_number_and_name(self):
|
||||
"""Compute self._construct_inputs_names and self._construct_inputs_num"""
|
||||
import inspect
|
||||
from mindspore._extends.parse.parser import get_parse_method_of_class
|
||||
|
||||
fn = get_parse_method_of_class(self)
|
||||
|
@ -517,7 +525,7 @@ class Cell:
|
|||
raise TypeError("Child cell type is incorrect.")
|
||||
self._cells[child_name] = child
|
||||
|
||||
def construct(self, *inputs):
|
||||
def construct(self, *inputs, **kwargs):
|
||||
"""
|
||||
Defines the computation to be performed.
|
||||
|
||||
|
@ -878,7 +886,7 @@ class Cell:
|
|||
self.add_flags(auto_parallel=True)
|
||||
self._get_construct_inputs_number_and_name()
|
||||
|
||||
def _hook_construct(self, *inputs):
|
||||
def _hook_construct(self, *inputs, **kwargs):
|
||||
"""Hook construct method to replace original construct method when hook function enabled."""
|
||||
inputs = self._backward_hook(*inputs)
|
||||
inputs = self.construct(inputs)
|
||||
|
|
|
@ -116,7 +116,7 @@ class GradOperation(GradOperation_):
|
|||
self.fn = None
|
||||
self.need_forward = False
|
||||
|
||||
def _pynative_forward_run(self, args, fn):
|
||||
def _pynative_forward_run(self, args, kwargs, fn):
|
||||
""" Pynative forward run to build grad graph. """
|
||||
if self.sens_param:
|
||||
args = args[:-1]
|
||||
|
@ -125,9 +125,9 @@ class GradOperation(GradOperation_):
|
|||
raise TypeError("grad inputs should be tensor in pynative mode")
|
||||
if isinstance(fn, FunctionType):
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(fn, *args)
|
||||
output = fn(*args)
|
||||
_pynative_exec.end_graph(fn, output, *args)
|
||||
_pynative_exec.new_graph(fn, *args, **kwargs)
|
||||
output = fn(*args, **kwargs)
|
||||
_pynative_exec.end_graph(fn, output, *args, **kwargs)
|
||||
else:
|
||||
if fn.already_run and not fn.requires_grad:
|
||||
raise ValueError("obj must set_grad.")
|
||||
|
@ -135,7 +135,7 @@ class GradOperation(GradOperation_):
|
|||
self.need_forward = True
|
||||
if self.need_forward:
|
||||
fn.set_grad()
|
||||
fn(*args)
|
||||
fn(*args, **kwargs)
|
||||
fn.already_run = False
|
||||
|
||||
def __call__(self, fn, weights=None):
|
||||
|
@ -152,10 +152,10 @@ class GradOperation(GradOperation_):
|
|||
return grad_(fn)(*args)
|
||||
else:
|
||||
@_wrap_func
|
||||
def after_grad(*args):
|
||||
self._pynative_forward_run(args, fn)
|
||||
_pynative_exec.grad(grad_, fn, weights, *args)
|
||||
out = _pynative_exec(*args)
|
||||
def after_grad(*args, **kwargs):
|
||||
self._pynative_forward_run(args, kwargs, fn)
|
||||
_pynative_exec.grad(grad_, fn, weights, *args, **kwargs)
|
||||
out = _pynative_exec(*args, **kwargs)
|
||||
_pynative_exec.clear()
|
||||
return out
|
||||
self.grad_fn = after_grad
|
||||
|
|
|
@ -30,6 +30,7 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_list_equal():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, z: list):
|
||||
|
@ -156,8 +157,10 @@ def test_class_member_not_defined():
|
|||
|
||||
z = [[1, 2], 3]
|
||||
net = Net(z)
|
||||
x = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
net(x, y)
|
||||
assert "'self.x' was not defined in the class '__init__' function." in str(ex.value)
|
||||
|
||||
|
||||
|
@ -181,7 +184,7 @@ def test_change_list_element():
|
|||
|
||||
|
||||
class ListOperate(nn.Cell):
|
||||
def __init__(self,):
|
||||
def __init__(self):
|
||||
super(ListOperate, self).__init__()
|
||||
|
||||
def construct(self, t, l):
|
||||
|
@ -201,7 +204,7 @@ class ListOperate(nn.Cell):
|
|||
|
||||
|
||||
class InListNet(nn.Cell):
|
||||
def __init__(self,):
|
||||
def __init__(self):
|
||||
super(InListNet, self).__init__()
|
||||
self.list_ = [1, 2, 3, 4, 5, "ok"]
|
||||
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test dtype and shape as attr"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops.composite import base as C
|
||||
|
||||
|
||||
def test_kw_nested():
|
||||
class NetKeyValueArg(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y, *arg, w, **kwargs):
|
||||
return x + y + arg[0] + w + kwargs['c']
|
||||
|
||||
class NetOut(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.in_net = net
|
||||
|
||||
def construct(self, x, y, z):
|
||||
ret = self.in_net(x, y, z, w=x, a=x, b=y, c=z) + x
|
||||
return ret
|
||||
|
||||
in_net = NetKeyValueArg()
|
||||
out_net = NetOut(in_net)
|
||||
x = Tensor(np.ones([3, 4, 5], np.float32))
|
||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||
z = Tensor(np.ones([3, 4, 5], np.float64))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
ret = out_net(x, y, z)
|
||||
assert ret.dtype == mstype.float64
|
||||
assert ret.shape == (3, 4, 5)
|
||||
assert (ret.asnumpy() == np.ones([3, 4, 5], np.float64) * 5).all()
|
||||
|
||||
|
||||
def test_kw_grad():
|
||||
class KwNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(KwNet, self).__init__()
|
||||
|
||||
def construct(self, x, y, *arg, **kwargs):
|
||||
return 2 * x + 3 * y + 4 * arg[0] + 5 * kwargs['v']
|
||||
|
||||
class GradKwNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradKwNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, y, *arg, **kwargs):
|
||||
return self.grad_all_wit_sense(self.net)(x, y, *arg, **kwargs)
|
||||
|
||||
kw_net = KwNet()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.float32))
|
||||
z = Tensor(np.ones([1, 2, 3], np.float64))
|
||||
u = Tensor(np.ones([1, 2, 3], np.float16))
|
||||
v = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
w = Tensor(np.ones([1, 2, 3], np.float64))
|
||||
sens = Tensor(np.ones([1, 2, 3], np.float64))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
kw_net.set_grad(True)
|
||||
ret = kw_net(x, y, z, u=u, v=v, w=w)
|
||||
assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 14).all()
|
||||
|
||||
grad_kw_net = GradKwNet(kw_net)
|
||||
ret_grad = grad_kw_net(x, y, z, u=u, v=v, w=w, sens=sens)
|
||||
assert len(ret_grad) == 6
|
||||
assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all()
|
||||
assert ret_grad[0].dtype == mstype.int32
|
||||
assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all()
|
||||
assert ret_grad[1].dtype == mstype.float32
|
||||
assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all()
|
||||
assert ret_grad[2].dtype == mstype.float64
|
||||
assert (ret_grad[3].asnumpy() == np.zeros([1, 2, 3])).all()
|
||||
assert ret_grad[3].dtype == mstype.float16
|
||||
assert (ret_grad[4].asnumpy() == np.ones([1, 2, 3]) * 5).all()
|
||||
assert ret_grad[4].dtype == mstype.int32
|
||||
assert (ret_grad[5].asnumpy() == np.zeros([1, 2, 3])).all()
|
||||
assert ret_grad[5].dtype == mstype.float64
|
||||
|
||||
|
||||
def test_grad():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
return 2 * x + 3 * y + 4 * z
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, y, z, sens):
|
||||
return self.grad_all_wit_sense(self.net)(x, y, z, sens)
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.float32))
|
||||
z = Tensor(np.ones([1, 2, 3], np.float16))
|
||||
sens = Tensor(np.ones([1, 2, 3], np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
net.set_grad(True)
|
||||
ret = net(x, y, z)
|
||||
assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 9).all()
|
||||
|
||||
grad_net = GradNet(net)
|
||||
ret_grad = grad_net(x, y, z, sens)
|
||||
assert len(ret_grad) == 3
|
||||
assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all()
|
||||
assert ret_grad[0].dtype == mstype.int32
|
||||
assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all()
|
||||
assert ret_grad[1].dtype == mstype.float32
|
||||
assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all()
|
||||
assert ret_grad[2].dtype == mstype.float16
|
Loading…
Reference in New Issue