From 4f4a34414920e70b9f631e6c3ae8961489408e79 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Fri, 27 Aug 2021 17:41:51 +0800 Subject: [PATCH] Vjp and fix jvp --- mindspore/nn/grad/__init__.py | 4 +- mindspore/nn/grad/cell_grad.py | 47 +++++++++++++++-- mindspore/ops/functional.py | 59 --------------------- tests/st/gradient/test_jvp_graph.py | 66 ++++-------------------- tests/st/gradient/test_jvp_pynative.py | 18 +++---- tests/st/gradient/test_vjp_graph.py | 71 ++++++++++++++++++++++++++ 6 files changed, 134 insertions(+), 131 deletions(-) create mode 100644 tests/st/gradient/test_vjp_graph.py diff --git a/mindspore/nn/grad/__init__.py b/mindspore/nn/grad/__init__.py index 8b1a892f435..218e1461f02 100644 --- a/mindspore/nn/grad/__init__.py +++ b/mindspore/nn/grad/__init__.py @@ -18,7 +18,7 @@ Grad Cells of grad function. Calculate the gradient of input network or function. """ -from .cell_grad import Jvp +from .cell_grad import Jvp, Vjp -__all__ = ['Jvp'] +__all__ = ['Jvp', 'Vjp'] diff --git a/mindspore/nn/grad/cell_grad.py b/mindspore/nn/grad/cell_grad.py index da918e9bc40..ad2ba5b6a97 100644 --- a/mindspore/nn/grad/cell_grad.py +++ b/mindspore/nn/grad/cell_grad.py @@ -15,8 +15,8 @@ """cell grad""" from ..cell import Cell from ...ops import composite as C -from ...ops.primitive import Primitive from ...ops import operations as P +from ...ops.primitive import Primitive from ...common import dtype as mstype from ...common.api import ms_function @@ -73,9 +73,9 @@ class Jvp(Cell): self.tuple_len = Primitive("tuple_len") @ms_function - def construct(self, *total_input): - jvp_input = total_input[0:-1] - v = total_input[-1] + def construct(self, *args): + jvp_input = args[0:-1] + v = args[-1] output = self.fn(*jvp_input) if self.issubclass_(self.typeof(output), mstype.tuple_): @@ -92,3 +92,42 @@ class Jvp(Cell): second_gradient_net = self.second_grad_op(self.first_grad) gradient_output = second_gradient_net(u, jvp_input, v) return output, gradient_output + + +class Vjp(Cell): + """ + Computes the dot product between a vector `v` and the Jacobian of the given network at the point + given by the inputs. + + Args: + network (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor. + + Inputs: + - **inputs** (Tensors) - The inputs to `net`. Must be a tuple or a list. + - **v** (tuple of Tensors or Tensor) - The vector for which the vector Jacobian product is computed. + Must have the same size as the output of `network`. + + Outputs: + A tuple with: + net_output (Tuple(Tensor...)) - The output of `network(inputs)`. + vjp (Tuple(Tensor...)) - The result of the dot product. + """ + + def __init__(self, fn): + super(Vjp, self).__init__() + self.fn = fn + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.grad_single_value = C.GradOperation(sens_param=True) + self.issubclass_ = P.IsSubClass() + self.typeof = Primitive('typeof') + self.tuple_len = Primitive("tuple_len") + + @ms_function + def construct(self, *args): + front_input = args[0:-1] + output = self.fn(*front_input) + if self.tuple_len(front_input) == 1: + gradient_output = self.grad_single_value(self.fn)(*args) + else: + gradient_output = self.grad(self.fn)(*args) + return output, gradient_output diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 91f05006353..22173821784 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -20,8 +20,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.ops import _constants from .primitive import Primitive -from ..common import Tensor -from ..nn.grad import Jvp from . import operations as P from .operations import _grad_ops from .composite import GradOperation @@ -156,63 +154,6 @@ identity = P.identity() grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False) grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False) -def _to_tuple(inp, arg_name): - """ - Check whether input to jvp is valid and convert the input to tuple. - """ - if isinstance(inp, list): - inp = tuple(inp) - inp_is_tuple = True - if not isinstance(inp, tuple): - inp = (inp,) - inp_is_tuple = False - for index, value in enumerate(inp): - if not isinstance(value, Tensor): - if inp_is_tuple: - raise TypeError( - "The value at the index {} of {} must be a Tensor. But it's type is {}".format(index, arg_name, - type(value))) - raise TypeError("The {} must be a Tensor. But it's type is {}".format(arg_name, type(value))) - return inp - - -def jvp(fn, jvp_input, v=None): - """ - Function to compute the jacobian-vector-product of the given network. jvp is equivalent to forward mode autodiff. - - Inputs: - - **network** (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor. - - **inputs** (Tensor or tuple/list of Tensors) - The inputs to `net`. - - **v** (tuple/list of Tensors or Tensor) - The vector for which the Jacobian vector product is computed. - Must have the same size as the input of `network`. - - Outputs: - A tuple with: - net_output (Tuple(Tensor...)) - The output of `network(inputs)`. - jvp (Tuple(Tensor...)) - The result of the jacobian vector product. - """ - inputs_tuple = _to_tuple(jvp_input, "input") - if v is not None: - v_tuple = _to_tuple(v, "v") - if len(v_tuple) != len(inputs_tuple): - raise ValueError("v is not the same size as the function inputs. The length of v is {}, while the length" - "of function inputs is {}".format(len(v_tuple), len(inputs_tuple))) - for index, (v_value, inputs_value) in enumerate(zip(v_tuple, inputs_tuple)): - if v_value.shape != inputs_value.shape: - raise ValueError("The tensor shape at the index {} of v is not the same as the function inputs, it" - "should be {}, but got {}".format(index, inputs_value.shape, v_value.shape)) - else: - if len(inputs_tuple) != 1 or inputs_tuple[0].shape != (1,): - raise ValueError("The vector v can only be None if the input is a Tensor with single element") - v = Tensor([1], dtype=inputs_tuple[0].dtype) - - if len(inputs_tuple) != 1: - total_input = (*inputs_tuple, v_tuple) - else: - total_input = (jvp_input, v) - return Jvp(fn)(*total_input) - - def grad(fn, grad_first_param=False): """ A wrapper function to generate the gradient function for the input function. diff --git a/tests/st/gradient/test_jvp_graph.py b/tests/st/gradient/test_jvp_graph.py index 12bd1c478db..9ff8ab96dc0 100644 --- a/tests/st/gradient/test_jvp_graph.py +++ b/tests/st/gradient/test_jvp_graph.py @@ -19,7 +19,7 @@ import pytest import mindspore.nn as nn import mindspore.context as context from mindspore import Tensor -from mindspore.ops.functional import jvp +from mindspore.nn.grad import Jvp context.set_context(mode=context.GRAPH_MODE) @@ -53,7 +53,7 @@ def test_jvp_single_input_single_output_default_v_graph(): net = SingleInputSingleOutputNet() expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) @@ -67,7 +67,7 @@ def test_jvp_single_input_single_output_custom_v_graph(): net = SingleInputSingleOutputNet() expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) @@ -83,7 +83,7 @@ def test_jvp_single_input_multiple_outputs_default_v_graph(): expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -105,7 +105,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_graph(): expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -126,7 +126,7 @@ def test_jvp_multiple_inputs_single_output_default_v_graph(): net = MultipleInputSingleOutputNet() expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v, v)) + primal, grad = Jvp(net)(x, y, (v, v)) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) @@ -142,7 +142,7 @@ def test_jvp_multiple_inputs_single_output_custom_v_graph(): net = MultipleInputSingleOutputNet() expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v1, v2)) + primal, grad = Jvp(net)(x, y, (v1, v2)) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) @@ -159,7 +159,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_graph(): expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v, v)) + primal, grad = Jvp(net)(x, y, (v, v)) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -183,7 +183,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph(): expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v1, v2)) + primal, grad = Jvp(net)(x, y, (v1, v2)) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -192,51 +192,3 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph(): assert len(grad) == 2 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_jvp_single_input_single_tensor_default_v(): - x = Tensor(np.array([1]).astype(np.float32)) - net = SingleInputSingleOutputNet() - expect_primal = Tensor(np.array([1]).astype(np.float32)) - expect_grad = Tensor(np.array([3]).astype(np.float32)) - primal, grad = jvp(net, x) - assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) - assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_jvp_single_input_multiple_tensors_default_v(): - x = Tensor(np.array([1, 2]).astype(np.float32)) - net = SingleInputSingleOutputNet() - with pytest.raises(ValueError) as ex: - jvp(net, x) - assert "The vector v can only be None if the input is " in str(ex.value) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_jvp_single_input_multiple_tensors_wrong_shape_v(): - x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) - v = Tensor(np.array([[1], [4]]).astype(np.float32)) - net = SingleInputSingleOutputNet() - with pytest.raises(ValueError) as ex: - jvp(net, x, v) - assert "The tensor shape at the index 0 of v" in str(ex.value) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_jvp_single_input_multiple_tensors_wrong_tuple_v(): - x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) - v = (Tensor(np.array([[1], [4]]).astype(np.float32)), Tensor(np.array([[1], [4]]).astype(np.float32))) - net = SingleInputSingleOutputNet() - with pytest.raises(ValueError) as ex: - jvp(net, x, v) - assert "v is not the same size as the function inputs." in str(ex.value) diff --git a/tests/st/gradient/test_jvp_pynative.py b/tests/st/gradient/test_jvp_pynative.py index d4588b4f461..7372d4b8c04 100644 --- a/tests/st/gradient/test_jvp_pynative.py +++ b/tests/st/gradient/test_jvp_pynative.py @@ -19,7 +19,7 @@ import pytest import mindspore.nn as nn import mindspore.context as context from mindspore import Tensor -from mindspore.ops.functional import jvp +from mindspore.nn.grad import Jvp context.set_context(mode=context.PYNATIVE_MODE) @@ -52,7 +52,7 @@ def test_jvp_single_input_single_output_default_v_pynative(): net = SingleInputSingleOutputNet() expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) @@ -66,7 +66,7 @@ def test_jvp_single_input_single_output_custom_v_pynative(): net = SingleInputSingleOutputNet() expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) @@ -82,7 +82,7 @@ def test_jvp_single_input_multiple_outputs_default_v_pynative(): expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -104,7 +104,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_pynative(): expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) - primal, grad = jvp(net, x, v) + primal, grad = Jvp(net)(x, v) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -127,7 +127,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative(): expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v, v)) + primal, grad = Jvp(net)(x, y, (v, v)) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -151,7 +151,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative(): expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v1, v2)) + primal, grad = Jvp(net)(x, y, (v1, v2)) assert isinstance(primal, tuple) assert len(primal) == 2 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) @@ -172,7 +172,7 @@ def test_jvp_multiple_inputs_single_output_default_v_pynative(): net = MultipleInputSingleOutputNet() expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v, v)) + primal, grad = Jvp(net)(x, y, (v, v)) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) @@ -188,6 +188,6 @@ def test_jvp_multiple_inputs_single_output_custom_v_pynative(): net = MultipleInputSingleOutputNet() expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32)) - primal, grad = jvp(net, (x, y), (v1, v2)) + primal, grad = Jvp(net)(x, y, (v1, v2)) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) diff --git a/tests/st/gradient/test_vjp_graph.py b/tests/st/gradient/test_vjp_graph.py new file mode 100644 index 00000000000..71edbbf72a5 --- /dev/null +++ b/tests/st/gradient/test_vjp_graph.py @@ -0,0 +1,71 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test jvp in graph mode""" +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn.grad import Vjp + +context.set_context(mode=context.GRAPH_MODE) + + +class SingleInputNet(nn.Cell): + def construct(self, x): + return x**3 + + +class MultipleInputsOutputNet(nn.Cell): + def construct(self, x, y): + return 2*x, y**3 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_vjp_single_input_graph(): + x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) + v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) + net = SingleInputNet() + expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) + expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) + primal, grad = Vjp(net)(x, v) + assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) + assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_vjp_multiple_inputs_default_v_graph(): + x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) + y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) + v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) + net = MultipleInputsOutputNet() + expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) + expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) + expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) + expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) + primal, grad = Vjp(net)(x, y, (v, v)) + assert isinstance(primal, tuple) + assert len(primal) == 2 + assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) + assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) + assert isinstance(grad, tuple) + assert len(grad) == 2 + assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) + assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())