From dc658a002fba3f168c9e1aa204ffd8d3530aada3 Mon Sep 17 00:00:00 2001 From: liangcanli Date: Tue, 17 Jan 2023 14:50:20 +0800 Subject: [PATCH] fix get_grad docs and add testcases --- .../mindspore/ops/function/grad/grad_func.py | 14 +++--- tests/st/gradient/test_grad_graph.py | 48 ++++++++++++++++++- tests/st/gradient/test_grad_pynative.py | 42 ++++++++++++++++ 3 files changed, 96 insertions(+), 8 deletions(-) diff --git a/mindspore/python/mindspore/ops/function/grad/grad_func.py b/mindspore/python/mindspore/ops/function/grad/grad_func.py index ad7bee852d5..25d289178a3 100644 --- a/mindspore/python/mindspore/ops/function/grad/grad_func.py +++ b/mindspore/python/mindspore/ops/function/grad/grad_func.py @@ -170,7 +170,7 @@ def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False): >>> print(aux) (Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]),) >>> - >>> # For given network to be differentiated with both inputs and weights, there are 3 cases. + >>> # For given network to be differentiated with both inputs and weights, there are 4 cases. >>> net = nn.Dense(10, 1) >>> loss_fn = nn.MSELoss() >>> def forward(inputs, labels): @@ -366,9 +366,9 @@ def get_grad(gradients, identifier): As for gradient, three typical cases are included: 1. gradient with respect to inputs. In this case, use return value of ops.grad as the first input and - the position of the tensor as the second input. + the position of the tensor as the second input. 2. gradient with respect to weights. In this case, use return value of ops.grad as the first input and - the parameter as the second input. + the parameter as the second input. Args: gradients (Union[tuple[int, Tensor], tuple[tuple, tuple]]): The return value of mindspore.grad when return_ids @@ -392,17 +392,17 @@ def get_grad(gradients, identifier): >>> import mindspore >>> import mindspore.nn as nn >>> from mindspore import Tensor, ops - >>> from mindspore import grad + >>> from mindspore import grad, get_grad >>> >>> # Cell object to be differentiated >>> class Net(nn.Cell): - >>> def construct(self, x, y, z): - >>> return x * y * z + ... def construct(self, x, y, z): + ... return x * y * z >>> x = Tensor([1, 2], mindspore.float32) >>> y = Tensor([-2, 3], mindspore.float32) >>> z = Tensor([0, 3], mindspore.float32) >>> net = Net() - >>> out_grad = grad(net, grad_position=(1, 2), return_ids = True)(x, y, z) + >>> out_grad = grad(net, grad_position=(1, 2), return_ids=True)(x, y, z) >>> output = get_grad(out_grad, 1) >>> print(output) Tensor(shape=[2], dtype=Float32, value=[0.00000000e+00, 6.00000000e+00] diff --git a/tests/st/gradient/test_grad_graph.py b/tests/st/gradient/test_grad_graph.py index df236a19e2d..efd44b2f7b2 100644 --- a/tests/st/gradient/test_grad_graph.py +++ b/tests/st/gradient/test_grad_graph.py @@ -1049,7 +1049,7 @@ def test_construct_get_grad_not_found_from_empty_tuple(): @pytest.mark.env_onecard def test_get_grad_wrap_with_msfunction_graph(): """ - Features: Function grad. + Features: Function get_grad. Description: Test get_grad wrapped with @jit decorated function in graph mode. Expectation: No exception. """ @@ -1087,3 +1087,49 @@ def test_grad_primal_graph_call_others(): expected = Tensor(np.array([4, 5]).astype(np.float32)) output = net(x, y) assert np.allclose(output.asnumpy(), expected.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_get_grad_outer_list_weight(): + """ + Features: Function get_grad. + Description: Test get_grad with a list of parameter as input of the network in graph mode. + Expectation: No exception. + """ + class InnerNet(nn.Cell): + def __init__(self): + super().__init__() + self.w = Parameter([1, 2], name='w') + self.b = Parameter([1, 2], name='b') + + def construct(self, x, y): + out = self.w * x + self.b + y + return out + + class GradNet(nn.Cell): + def __init__(self, net, pos, param, get): + super().__init__() + self.net = net + self.pos = pos + self.param = param + self.get = get + + def construct(self, x, y): + grad_net = grad(self.net, self.pos, self.param, return_ids=True) + out_grad = grad_net(x, y) + out = [] + for i in self.get: + out.append(get_grad(out_grad, i)) + return out + + net = InnerNet() + grad_net = GradNet(net, (0, 1), (net.w, net.b), (0, net.w)) + x = Tensor([1, 2], mstype.float32) + y = Tensor([1, 2], mstype.float32) + out = grad_net(x, y) + expect_value0 = Tensor([1, 2], mstype.float32) + expect_value1 = Tensor([1, 2], mstype.int64) + assert np.allclose(out[0].asnumpy(), expect_value0.asnumpy()) + assert np.allclose(out[1].asnumpy(), expect_value1.asnumpy()) diff --git a/tests/st/gradient/test_grad_pynative.py b/tests/st/gradient/test_grad_pynative.py index 5ce1197395e..4b96fb70962 100644 --- a/tests/st/gradient/test_grad_pynative.py +++ b/tests/st/gradient/test_grad_pynative.py @@ -810,3 +810,45 @@ def test_get_grad_not_found_pynative(): res = grad(net, 0, weights, return_ids=True)(x) with pytest.raises(ValueError): get_grad(res, 1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_outer_list_weight(): + """ + Features: Function get_grad. + Description: Test get_grad with one weight and no position in pynative mode. + Expectation: No exception. + """ + class InnerNet(nn.Cell): + def __init__(self): + super().__init__() + self.w = Parameter([1, 2], name='w') + self.b = Parameter([1, 2], name='b') + + def construct(self, x, y): + out = self.w * x + self.b + y + return out + + class GradNet(nn.Cell): + def __init__(self, net, pos, param, get): + super().__init__() + self.net = net + self.pos = pos + self.param = param + self.get = get + + def construct(self, x, y): + grad_net = grad(self.net, self.pos, self.param, return_ids=True) + out_grad = grad_net(x, y) + out = get_grad(out_grad, self.net.w) + return out + + net = InnerNet() + grad_net = GradNet(net, (0, 1), (net.w, net.b), (0, net.w)) + x = Tensor([2, 2], mstype.float32) + y = Tensor([1, 2], mstype.float32) + out = grad_net(x, y) + expect_value = Tensor([2, 2], mstype.int64) + assert np.allclose(out[0].asnumpy(), expect_value.asnumpy())