Vjp and fix jvp

This commit is contained in:
l00591931 2021-08-27 17:41:51 +08:00
parent 12ced9e89b
commit 4f4a344149
6 changed files with 134 additions and 131 deletions

View File

@ -18,7 +18,7 @@ Grad
Cells of grad function. Calculate the gradient of input network or function.
"""
from .cell_grad import Jvp
from .cell_grad import Jvp, Vjp
__all__ = ['Jvp']
__all__ = ['Jvp', 'Vjp']

View File

@ -15,8 +15,8 @@
"""cell grad"""
from ..cell import Cell
from ...ops import composite as C
from ...ops.primitive import Primitive
from ...ops import operations as P
from ...ops.primitive import Primitive
from ...common import dtype as mstype
from ...common.api import ms_function
@ -73,9 +73,9 @@ class Jvp(Cell):
self.tuple_len = Primitive("tuple_len")
@ms_function
def construct(self, *total_input):
jvp_input = total_input[0:-1]
v = total_input[-1]
def construct(self, *args):
jvp_input = args[0:-1]
v = args[-1]
output = self.fn(*jvp_input)
if self.issubclass_(self.typeof(output), mstype.tuple_):
@ -92,3 +92,42 @@ class Jvp(Cell):
second_gradient_net = self.second_grad_op(self.first_grad)
gradient_output = second_gradient_net(u, jvp_input, v)
return output, gradient_output
class Vjp(Cell):
"""
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
given by the inputs.
Args:
network (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
Inputs:
- **inputs** (Tensors) - The inputs to `net`. Must be a tuple or a list.
- **v** (tuple of Tensors or Tensor) - The vector for which the vector Jacobian product is computed.
Must have the same size as the output of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
vjp (Tuple(Tensor...)) - The result of the dot product.
"""
def __init__(self, fn):
super(Vjp, self).__init__()
self.fn = fn
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.grad_single_value = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.typeof = Primitive('typeof')
self.tuple_len = Primitive("tuple_len")
@ms_function
def construct(self, *args):
front_input = args[0:-1]
output = self.fn(*front_input)
if self.tuple_len(front_input) == 1:
gradient_output = self.grad_single_value(self.fn)(*args)
else:
gradient_output = self.grad(self.fn)(*args)
return output, gradient_output

View File

@ -20,8 +20,6 @@
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.ops import _constants
from .primitive import Primitive
from ..common import Tensor
from ..nn.grad import Jvp
from . import operations as P
from .operations import _grad_ops
from .composite import GradOperation
@ -156,63 +154,6 @@ identity = P.identity()
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
def _to_tuple(inp, arg_name):
"""
Check whether input to jvp is valid and convert the input to tuple.
"""
if isinstance(inp, list):
inp = tuple(inp)
inp_is_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
inp_is_tuple = False
for index, value in enumerate(inp):
if not isinstance(value, Tensor):
if inp_is_tuple:
raise TypeError(
"The value at the index {} of {} must be a Tensor. But it's type is {}".format(index, arg_name,
type(value)))
raise TypeError("The {} must be a Tensor. But it's type is {}".format(arg_name, type(value)))
return inp
def jvp(fn, jvp_input, v=None):
"""
Function to compute the jacobian-vector-product of the given network. jvp is equivalent to forward mode autodiff.
Inputs:
- **network** (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
- **inputs** (Tensor or tuple/list of Tensors) - The inputs to `net`.
- **v** (tuple/list of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
Must have the same size as the input of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
"""
inputs_tuple = _to_tuple(jvp_input, "input")
if v is not None:
v_tuple = _to_tuple(v, "v")
if len(v_tuple) != len(inputs_tuple):
raise ValueError("v is not the same size as the function inputs. The length of v is {}, while the length"
"of function inputs is {}".format(len(v_tuple), len(inputs_tuple)))
for index, (v_value, inputs_value) in enumerate(zip(v_tuple, inputs_tuple)):
if v_value.shape != inputs_value.shape:
raise ValueError("The tensor shape at the index {} of v is not the same as the function inputs, it"
"should be {}, but got {}".format(index, inputs_value.shape, v_value.shape))
else:
if len(inputs_tuple) != 1 or inputs_tuple[0].shape != (1,):
raise ValueError("The vector v can only be None if the input is a Tensor with single element")
v = Tensor([1], dtype=inputs_tuple[0].dtype)
if len(inputs_tuple) != 1:
total_input = (*inputs_tuple, v_tuple)
else:
total_input = (jvp_input, v)
return Jvp(fn)(*total_input)
def grad(fn, grad_first_param=False):
"""
A wrapper function to generate the gradient function for the input function.

View File

@ -19,7 +19,7 @@ import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.functional import jvp
from mindspore.nn.grad import Jvp
context.set_context(mode=context.GRAPH_MODE)
@ -53,7 +53,7 @@ def test_jvp_single_input_single_output_default_v_graph():
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -67,7 +67,7 @@ def test_jvp_single_input_single_output_custom_v_graph():
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -83,7 +83,7 @@ def test_jvp_single_input_multiple_outputs_default_v_graph():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -105,7 +105,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_graph():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -126,7 +126,7 @@ def test_jvp_multiple_inputs_single_output_default_v_graph():
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
primal, grad = Jvp(net)(x, y, (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -142,7 +142,7 @@ def test_jvp_multiple_inputs_single_output_custom_v_graph():
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
primal, grad = Jvp(net)(x, y, (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -159,7 +159,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
primal, grad = Jvp(net)(x, y, (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -183,7 +183,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
primal, grad = Jvp(net)(x, y, (v1, v2))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -192,51 +192,3 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_tensor_default_v():
x = Tensor(np.array([1]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([1]).astype(np.float32))
expect_grad = Tensor(np.array([3]).astype(np.float32))
primal, grad = jvp(net, x)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_default_v():
x = Tensor(np.array([1, 2]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x)
assert "The vector v can only be None if the input is " in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_shape_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1], [4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "The tensor shape at the index 0 of v" in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_tuple_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = (Tensor(np.array([[1], [4]]).astype(np.float32)), Tensor(np.array([[1], [4]]).astype(np.float32)))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "v is not the same size as the function inputs." in str(ex.value)

View File

@ -19,7 +19,7 @@ import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.functional import jvp
from mindspore.nn.grad import Jvp
context.set_context(mode=context.PYNATIVE_MODE)
@ -52,7 +52,7 @@ def test_jvp_single_input_single_output_default_v_pynative():
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -66,7 +66,7 @@ def test_jvp_single_input_single_output_custom_v_pynative():
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -82,7 +82,7 @@ def test_jvp_single_input_multiple_outputs_default_v_pynative():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -104,7 +104,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_pynative():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v)
primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -127,7 +127,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
primal, grad = Jvp(net)(x, y, (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -151,7 +151,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
primal, grad = Jvp(net)(x, y, (v1, v2))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -172,7 +172,7 @@ def test_jvp_multiple_inputs_single_output_default_v_pynative():
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
primal, grad = Jvp(net)(x, y, (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -188,6 +188,6 @@ def test_jvp_multiple_inputs_single_output_custom_v_pynative():
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
primal, grad = Jvp(net)(x, y, (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())

View File

@ -0,0 +1,71 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test jvp in graph mode"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn.grad import Vjp
context.set_context(mode=context.GRAPH_MODE)
class SingleInputNet(nn.Cell):
def construct(self, x):
return x**3
class MultipleInputsOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_single_input_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = Vjp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_multiple_inputs_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = Vjp(net)(x, y, (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())