!40686 [bug] fix grad aux in pynative mode

Merge pull request !40686 from chenzhuo/master_aux
This commit is contained in:
i-robot 2022-08-24 01:05:01 +00:00 committed by Gitee
commit de4b5b238e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 36 additions and 10 deletions

View File

@ -538,23 +538,20 @@ class _Grad(GradOperation_):
return out
else:
grad_.pynative_ = True
fn_ = fn
if self.has_aux:
fn_ = aux_fn
# after_grad of this branch can't use @ms_function, just directly call grad_
if self.get_by_position:
def after_grad(*args, **kwargs):
if self.has_aux:
return grad_(aux_fn, weights, grad_position)(*args, **kwargs)
return grad_(fn, weights, grad_position)(*args, **kwargs)
return grad_(fn_, weights, grad_position)(*args, **kwargs)
else:
if self.get_by_list:
def after_grad(*args, **kwargs):
if self.has_aux:
return grad_(aux_fn, weights)(*args, **kwargs)
return grad_(fn, weights)(*args, **kwargs)
return grad_(fn_, weights)(*args, **kwargs)
else:
def after_grad(*args, **kwargs):
if self.has_aux:
return grad_(aux_fn)(*args, **kwargs)
return grad_(fn)(*args, **kwargs)
return grad_(fn_)(*args, **kwargs)
self.grad_fn = after_grad
self.fn = fn

View File

@ -20,7 +20,7 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore import ms_function
from mindspore.ops import composite as C
from mindspore.ops.functional import grad, value_and_grad
from mindspore.ops import grad, value_and_grad, vmap
from mindspore.common import dtype as mstype
from mindspore import Parameter, ParameterTuple
@ -186,6 +186,35 @@ def test_grad_wrap_with_msfunction_pynative():
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_vmap_pynative():
"""
Features: Function grad.
Description: Test F.grad vmap in pynative mode.
Expectation: No exception.
"""
def fn(x):
return x * x
class VmapNet(nn.Cell):
def __init__(self, net):
super(VmapNet, self).__init__()
self.grad_net = grad(net)
def construct(self, x):
res = vmap(self.grad_net, 0, 0)(x)
return res
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
ms_net = VmapNet(fn)
outputs = ms_net(x)
expect_value = np.array([[2, 4], [6, 8]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect_value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard