forked from mindspore-Ecosystem/mindspore
!40686 [bug] fix grad aux in pynative mode
Merge pull request !40686 from chenzhuo/master_aux
This commit is contained in:
commit
de4b5b238e
|
@ -538,23 +538,20 @@ class _Grad(GradOperation_):
|
|||
return out
|
||||
else:
|
||||
grad_.pynative_ = True
|
||||
fn_ = fn
|
||||
if self.has_aux:
|
||||
fn_ = aux_fn
|
||||
# after_grad of this branch can't use @ms_function, just directly call grad_
|
||||
if self.get_by_position:
|
||||
def after_grad(*args, **kwargs):
|
||||
if self.has_aux:
|
||||
return grad_(aux_fn, weights, grad_position)(*args, **kwargs)
|
||||
return grad_(fn, weights, grad_position)(*args, **kwargs)
|
||||
return grad_(fn_, weights, grad_position)(*args, **kwargs)
|
||||
else:
|
||||
if self.get_by_list:
|
||||
def after_grad(*args, **kwargs):
|
||||
if self.has_aux:
|
||||
return grad_(aux_fn, weights)(*args, **kwargs)
|
||||
return grad_(fn, weights)(*args, **kwargs)
|
||||
return grad_(fn_, weights)(*args, **kwargs)
|
||||
else:
|
||||
def after_grad(*args, **kwargs):
|
||||
if self.has_aux:
|
||||
return grad_(aux_fn)(*args, **kwargs)
|
||||
return grad_(fn)(*args, **kwargs)
|
||||
return grad_(fn_)(*args, **kwargs)
|
||||
|
||||
self.grad_fn = after_grad
|
||||
self.fn = fn
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.context as context
|
|||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.functional import grad, value_and_grad
|
||||
from mindspore.ops import grad, value_and_grad, vmap
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
|
||||
|
@ -186,6 +186,35 @@ def test_grad_wrap_with_msfunction_pynative():
|
|||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_vmap_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad vmap in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
def fn(x):
|
||||
return x * x
|
||||
|
||||
class VmapNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(VmapNet, self).__init__()
|
||||
self.grad_net = grad(net)
|
||||
|
||||
def construct(self, x):
|
||||
res = vmap(self.grad_net, 0, 0)(x)
|
||||
return res
|
||||
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
ms_net = VmapNet(fn)
|
||||
outputs = ms_net(x)
|
||||
expect_value = np.array([[2, 4], [6, 8]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect_value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue