!47966 fix docs and add testcases

Merge pull request !47966 from 李良灿/master
This commit is contained in:
i-robot 2023-01-29 05:51:52 +00:00 committed by Gitee
commit 9839e1ba11
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 96 additions and 8 deletions

View File

@ -170,7 +170,7 @@ def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
>>> print(aux)
(Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]),)
>>>
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
>>> # For given network to be differentiated with both inputs and weights, there are 4 cases.
>>> net = nn.Dense(10, 1)
>>> loss_fn = nn.MSELoss()
>>> def forward(inputs, labels):
@ -366,9 +366,9 @@ def get_grad(gradients, identifier):
As for gradient, three typical cases are included:
1. gradient with respect to inputs. In this case, use return value of ops.grad as the first input and
the position of the tensor as the second input.
the position of the tensor as the second input.
2. gradient with respect to weights. In this case, use return value of ops.grad as the first input and
the parameter as the second input.
the parameter as the second input.
Args:
gradients (Union[tuple[int, Tensor], tuple[tuple, tuple]]): The return value of mindspore.grad when return_ids
@ -392,17 +392,17 @@ def get_grad(gradients, identifier):
>>> import mindspore
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> from mindspore import grad
>>> from mindspore import grad, get_grad
>>>
>>> # Cell object to be differentiated
>>> class Net(nn.Cell):
>>> def construct(self, x, y, z):
>>> return x * y * z
... def construct(self, x, y, z):
... return x * y * z
>>> x = Tensor([1, 2], mindspore.float32)
>>> y = Tensor([-2, 3], mindspore.float32)
>>> z = Tensor([0, 3], mindspore.float32)
>>> net = Net()
>>> out_grad = grad(net, grad_position=(1, 2), return_ids = True)(x, y, z)
>>> out_grad = grad(net, grad_position=(1, 2), return_ids=True)(x, y, z)
>>> output = get_grad(out_grad, 1)
>>> print(output)
Tensor(shape=[2], dtype=Float32, value=[0.00000000e+00, 6.00000000e+00]

View File

@ -1049,7 +1049,7 @@ def test_construct_get_grad_not_found_from_empty_tuple():
@pytest.mark.env_onecard
def test_get_grad_wrap_with_msfunction_graph():
"""
Features: Function grad.
Features: Function get_grad.
Description: Test get_grad wrapped with @jit decorated function in graph mode.
Expectation: No exception.
"""
@ -1087,3 +1087,49 @@ def test_grad_primal_graph_call_others():
expected = Tensor(np.array([4, 5]).astype(np.float32))
output = net(x, y)
assert np.allclose(output.asnumpy(), expected.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_grad_outer_list_weight():
"""
Features: Function get_grad.
Description: Test get_grad with a list of parameter as input of the network in graph mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def __init__(self):
super().__init__()
self.w = Parameter([1, 2], name='w')
self.b = Parameter([1, 2], name='b')
def construct(self, x, y):
out = self.w * x + self.b + y
return out
class GradNet(nn.Cell):
def __init__(self, net, pos, param, get):
super().__init__()
self.net = net
self.pos = pos
self.param = param
self.get = get
def construct(self, x, y):
grad_net = grad(self.net, self.pos, self.param, return_ids=True)
out_grad = grad_net(x, y)
out = []
for i in self.get:
out.append(get_grad(out_grad, i))
return out
net = InnerNet()
grad_net = GradNet(net, (0, 1), (net.w, net.b), (0, net.w))
x = Tensor([1, 2], mstype.float32)
y = Tensor([1, 2], mstype.float32)
out = grad_net(x, y)
expect_value0 = Tensor([1, 2], mstype.float32)
expect_value1 = Tensor([1, 2], mstype.int64)
assert np.allclose(out[0].asnumpy(), expect_value0.asnumpy())
assert np.allclose(out[1].asnumpy(), expect_value1.asnumpy())

View File

@ -810,3 +810,45 @@ def test_get_grad_not_found_pynative():
res = grad(net, 0, weights, return_ids=True)(x)
with pytest.raises(ValueError):
get_grad(res, 1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_outer_list_weight():
"""
Features: Function get_grad.
Description: Test get_grad with one weight and no position in pynative mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def __init__(self):
super().__init__()
self.w = Parameter([1, 2], name='w')
self.b = Parameter([1, 2], name='b')
def construct(self, x, y):
out = self.w * x + self.b + y
return out
class GradNet(nn.Cell):
def __init__(self, net, pos, param, get):
super().__init__()
self.net = net
self.pos = pos
self.param = param
self.get = get
def construct(self, x, y):
grad_net = grad(self.net, self.pos, self.param, return_ids=True)
out_grad = grad_net(x, y)
out = get_grad(out_grad, self.net.w)
return out
net = InnerNet()
grad_net = GradNet(net, (0, 1), (net.w, net.b), (0, net.w))
x = Tensor([2, 2], mstype.float32)
y = Tensor([1, 2], mstype.float32)
out = grad_net(x, y)
expect_value = Tensor([2, 2], mstype.int64)
assert np.allclose(out[0].asnumpy(), expect_value.asnumpy())