Vjp and fix jvp

This commit is contained in:
l00591931 2021-08-27 17:41:51 +08:00
parent 12ced9e89b
commit 4f4a344149
6 changed files with 134 additions and 131 deletions

View File

@ -18,7 +18,7 @@ Grad
Cells of grad function. Calculate the gradient of input network or function. Cells of grad function. Calculate the gradient of input network or function.
""" """
from .cell_grad import Jvp from .cell_grad import Jvp, Vjp
__all__ = ['Jvp'] __all__ = ['Jvp', 'Vjp']

View File

@ -15,8 +15,8 @@
"""cell grad""" """cell grad"""
from ..cell import Cell from ..cell import Cell
from ...ops import composite as C from ...ops import composite as C
from ...ops.primitive import Primitive
from ...ops import operations as P from ...ops import operations as P
from ...ops.primitive import Primitive
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.api import ms_function from ...common.api import ms_function
@ -73,9 +73,9 @@ class Jvp(Cell):
self.tuple_len = Primitive("tuple_len") self.tuple_len = Primitive("tuple_len")
@ms_function @ms_function
def construct(self, *total_input): def construct(self, *args):
jvp_input = total_input[0:-1] jvp_input = args[0:-1]
v = total_input[-1] v = args[-1]
output = self.fn(*jvp_input) output = self.fn(*jvp_input)
if self.issubclass_(self.typeof(output), mstype.tuple_): if self.issubclass_(self.typeof(output), mstype.tuple_):
@ -92,3 +92,42 @@ class Jvp(Cell):
second_gradient_net = self.second_grad_op(self.first_grad) second_gradient_net = self.second_grad_op(self.first_grad)
gradient_output = second_gradient_net(u, jvp_input, v) gradient_output = second_gradient_net(u, jvp_input, v)
return output, gradient_output return output, gradient_output
class Vjp(Cell):
"""
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
given by the inputs.
Args:
network (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
Inputs:
- **inputs** (Tensors) - The inputs to `net`. Must be a tuple or a list.
- **v** (tuple of Tensors or Tensor) - The vector for which the vector Jacobian product is computed.
Must have the same size as the output of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
vjp (Tuple(Tensor...)) - The result of the dot product.
"""
def __init__(self, fn):
super(Vjp, self).__init__()
self.fn = fn
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.grad_single_value = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.typeof = Primitive('typeof')
self.tuple_len = Primitive("tuple_len")
@ms_function
def construct(self, *args):
front_input = args[0:-1]
output = self.fn(*front_input)
if self.tuple_len(front_input) == 1:
gradient_output = self.grad_single_value(self.fn)(*args)
else:
gradient_output = self.grad(self.fn)(*args)
return output, gradient_output

View File

@ -20,8 +20,6 @@
from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.ops import _constants from mindspore.ops import _constants
from .primitive import Primitive from .primitive import Primitive
from ..common import Tensor
from ..nn.grad import Jvp
from . import operations as P from . import operations as P
from .operations import _grad_ops from .operations import _grad_ops
from .composite import GradOperation from .composite import GradOperation
@ -156,63 +154,6 @@ identity = P.identity()
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False) grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False) grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
def _to_tuple(inp, arg_name):
"""
Check whether input to jvp is valid and convert the input to tuple.
"""
if isinstance(inp, list):
inp = tuple(inp)
inp_is_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
inp_is_tuple = False
for index, value in enumerate(inp):
if not isinstance(value, Tensor):
if inp_is_tuple:
raise TypeError(
"The value at the index {} of {} must be a Tensor. But it's type is {}".format(index, arg_name,
type(value)))
raise TypeError("The {} must be a Tensor. But it's type is {}".format(arg_name, type(value)))
return inp
def jvp(fn, jvp_input, v=None):
"""
Function to compute the jacobian-vector-product of the given network. jvp is equivalent to forward mode autodiff.
Inputs:
- **network** (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
- **inputs** (Tensor or tuple/list of Tensors) - The inputs to `net`.
- **v** (tuple/list of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
Must have the same size as the input of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
"""
inputs_tuple = _to_tuple(jvp_input, "input")
if v is not None:
v_tuple = _to_tuple(v, "v")
if len(v_tuple) != len(inputs_tuple):
raise ValueError("v is not the same size as the function inputs. The length of v is {}, while the length"
"of function inputs is {}".format(len(v_tuple), len(inputs_tuple)))
for index, (v_value, inputs_value) in enumerate(zip(v_tuple, inputs_tuple)):
if v_value.shape != inputs_value.shape:
raise ValueError("The tensor shape at the index {} of v is not the same as the function inputs, it"
"should be {}, but got {}".format(index, inputs_value.shape, v_value.shape))
else:
if len(inputs_tuple) != 1 or inputs_tuple[0].shape != (1,):
raise ValueError("The vector v can only be None if the input is a Tensor with single element")
v = Tensor([1], dtype=inputs_tuple[0].dtype)
if len(inputs_tuple) != 1:
total_input = (*inputs_tuple, v_tuple)
else:
total_input = (jvp_input, v)
return Jvp(fn)(*total_input)
def grad(fn, grad_first_param=False): def grad(fn, grad_first_param=False):
""" """
A wrapper function to generate the gradient function for the input function. A wrapper function to generate the gradient function for the input function.

View File

@ -19,7 +19,7 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops.functional import jvp from mindspore.nn.grad import Jvp
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -53,7 +53,7 @@ def test_jvp_single_input_single_output_default_v_graph():
net = SingleInputSingleOutputNet() net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -67,7 +67,7 @@ def test_jvp_single_input_single_output_custom_v_graph():
net = SingleInputSingleOutputNet() net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -83,7 +83,7 @@ def test_jvp_single_input_multiple_outputs_default_v_graph():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -105,7 +105,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_graph():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -126,7 +126,7 @@ def test_jvp_multiple_inputs_single_output_default_v_graph():
net = MultipleInputSingleOutputNet() net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v)) primal, grad = Jvp(net)(x, y, (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -142,7 +142,7 @@ def test_jvp_multiple_inputs_single_output_custom_v_graph():
net = MultipleInputSingleOutputNet() net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2)) primal, grad = Jvp(net)(x, y, (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -159,7 +159,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v)) primal, grad = Jvp(net)(x, y, (v, v))
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -183,7 +183,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2)) primal, grad = Jvp(net)(x, y, (v1, v2))
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -192,51 +192,3 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
assert len(grad) == 2 assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_tensor_default_v():
x = Tensor(np.array([1]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([1]).astype(np.float32))
expect_grad = Tensor(np.array([3]).astype(np.float32))
primal, grad = jvp(net, x)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_default_v():
x = Tensor(np.array([1, 2]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x)
assert "The vector v can only be None if the input is " in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_shape_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1], [4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "The tensor shape at the index 0 of v" in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_tuple_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = (Tensor(np.array([[1], [4]]).astype(np.float32)), Tensor(np.array([[1], [4]]).astype(np.float32)))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "v is not the same size as the function inputs." in str(ex.value)

View File

@ -19,7 +19,7 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops.functional import jvp from mindspore.nn.grad import Jvp
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
@ -52,7 +52,7 @@ def test_jvp_single_input_single_output_default_v_pynative():
net = SingleInputSingleOutputNet() net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -66,7 +66,7 @@ def test_jvp_single_input_single_output_custom_v_pynative():
net = SingleInputSingleOutputNet() net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -82,7 +82,7 @@ def test_jvp_single_input_multiple_outputs_default_v_pynative():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -104,7 +104,7 @@ def test_jvp_single_input_multiple_outputs_custom_v_pynative():
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v) primal, grad = Jvp(net)(x, v)
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -127,7 +127,7 @@ def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v)) primal, grad = Jvp(net)(x, y, (v, v))
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -151,7 +151,7 @@ def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2)) primal, grad = Jvp(net)(x, y, (v1, v2))
assert isinstance(primal, tuple) assert isinstance(primal, tuple)
assert len(primal) == 2 assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
@ -172,7 +172,7 @@ def test_jvp_multiple_inputs_single_output_default_v_pynative():
net = MultipleInputSingleOutputNet() net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v)) primal, grad = Jvp(net)(x, y, (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -188,6 +188,6 @@ def test_jvp_multiple_inputs_single_output_custom_v_pynative():
net = MultipleInputSingleOutputNet() net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32)) expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2)) primal, grad = Jvp(net)(x, y, (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())

View File

@ -0,0 +1,71 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test jvp in graph mode"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn.grad import Vjp
context.set_context(mode=context.GRAPH_MODE)
class SingleInputNet(nn.Cell):
def construct(self, x):
return x**3
class MultipleInputsOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_single_input_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = Vjp(net)(x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_multiple_inputs_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = Vjp(net)(x, y, (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())