forked from mindspore-Ecosystem/mindspore
!25918 add Function jvp and vjp
Merge pull request !25918 from chenzhuo/jvp
This commit is contained in:
commit
e773f0621e
|
@ -31,6 +31,15 @@ class _FirstGrad(Cell):
|
|||
return self.first_grad_op(self.fn)(*first_grad_input, u)
|
||||
|
||||
|
||||
class _JvpFirstGrad(Cell):
|
||||
def __init__(self):
|
||||
super(_JvpFirstGrad, self).__init__()
|
||||
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
|
||||
|
||||
def construct(self, u, fn, first_grad_input):
|
||||
return self.first_grad_op(fn)(*first_grad_input, u)
|
||||
|
||||
|
||||
class _FirstGradSingleValue(Cell):
|
||||
def __init__(self, fn):
|
||||
super(_FirstGradSingleValue, self).__init__()
|
||||
|
@ -41,6 +50,16 @@ class _FirstGradSingleValue(Cell):
|
|||
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
|
||||
|
||||
|
||||
class _JvpFirstGradSingleValue(Cell):
|
||||
def __init__(self):
|
||||
super(_JvpFirstGradSingleValue, self).__init__()
|
||||
self.first_grad_single_value_op = C.GradOperation(sens_param=True)
|
||||
|
||||
def construct(self, u, fn, first_grad_single_value_input):
|
||||
return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u)
|
||||
|
||||
|
||||
|
||||
class Jvp(Cell):
|
||||
"""
|
||||
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff.
|
||||
|
@ -112,6 +131,46 @@ class Jvp(Cell):
|
|||
return output, gradient_output
|
||||
|
||||
|
||||
class _JvpInner(Cell):
|
||||
"""
|
||||
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff.
|
||||
This class implements the inner process of function jvp.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(_JvpInner, self).__init__()
|
||||
self.oneslike = P.OnesLike()
|
||||
self.first_grad = _JvpFirstGrad()
|
||||
self.first_grad.add_flags(enable_tuple_grad=True)
|
||||
self.first_grad_single_value = _JvpFirstGradSingleValue()
|
||||
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
|
||||
self.second_grad_op = C.GradOperation(sens_param=True)
|
||||
self.issubclass_ = P.IsSubClass()
|
||||
self.typeof = Primitive('typeof')
|
||||
self.make_tuple = Primitive('MakeTuple')
|
||||
self.tuple_len = Primitive("tuple_len")
|
||||
|
||||
def construct(self, *args):
|
||||
fn = args[0]
|
||||
v = args[1]
|
||||
jvp_input = args[2:]
|
||||
output = fn(*jvp_input)
|
||||
|
||||
if self.issubclass_(self.typeof(output), mstype.tuple_):
|
||||
u = self.make_tuple()
|
||||
for i in range(self.tuple_len(output)):
|
||||
u = u + self.make_tuple(self.oneslike(output[i]))
|
||||
else:
|
||||
u = self.oneslike(output)
|
||||
|
||||
if self.tuple_len(jvp_input) == 1:
|
||||
second_gradient_net = self.second_grad_op(self.first_grad_single_value)
|
||||
gradient_output = second_gradient_net(u, fn, jvp_input, v)
|
||||
else:
|
||||
second_gradient_net = self.second_grad_op(self.first_grad)
|
||||
gradient_output = second_gradient_net(u, fn, jvp_input, v)
|
||||
return output, gradient_output
|
||||
|
||||
|
||||
class Vjp(Cell):
|
||||
"""
|
||||
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
|
||||
|
@ -167,3 +226,27 @@ class Vjp(Cell):
|
|||
else:
|
||||
gradient_output = self.grad(self.fn)(*args)
|
||||
return output, gradient_output
|
||||
|
||||
|
||||
class _VjpInner(Cell):
|
||||
"""
|
||||
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
|
||||
given by the inputs. This class implements the inner process of function vjp.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(_VjpInner, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.grad_single_value = C.GradOperation(sens_param=True)
|
||||
self.tuple_len = Primitive("tuple_len")
|
||||
|
||||
def construct(self, *args):
|
||||
fn = args[0]
|
||||
front_input = args[1:-1]
|
||||
input_with_v = args[1:]
|
||||
output = fn(*front_input)
|
||||
if self.tuple_len(front_input) == 1:
|
||||
gradient_output = self.grad_single_value(fn)(*input_with_v)
|
||||
else:
|
||||
gradient_output = self.grad(fn)(*input_with_v)
|
||||
return output, gradient_output
|
||||
|
|
|
@ -18,7 +18,12 @@
|
|||
"""The names of functional part are summarized here."""
|
||||
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.nn.grad.cell_grad import _JvpInner
|
||||
from mindspore.nn.grad.cell_grad import _VjpInner
|
||||
from mindspore.ops import _constants
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from .primitive import Primitive
|
||||
from . import operations as P
|
||||
from .operations import _grad_ops
|
||||
|
@ -171,6 +176,114 @@ def grad(fn, grad_position=0):
|
|||
return grad_by_position(fn, None, grad_position)
|
||||
|
||||
|
||||
def jvp(fn, inputs, v):
|
||||
"""
|
||||
Compute the jacobian-vector-product of the given network.
|
||||
|
||||
Args:
|
||||
fn (Function or Cell): The function or net that takes Tensor inputs and returns a tensor or tuple of Tensors.
|
||||
inputs (Tensor or tuple or list): The inputs to `fn`.
|
||||
v (Tensor or tuple or list): The shape and type of v should be the same as inputs.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output and jvp.
|
||||
- netout(Tensors or Tuple of Tensors), the output of "fn(inputs)".
|
||||
- jvp(Tensors or Tuple of Tensors), the result of the dot product.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input is not a tensor or tuple or list of tensors.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> from mindspore import Tensor
|
||||
>>> class Net(nn.Cell):
|
||||
... def construct(self, x, y):
|
||||
... return x**3 + y
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
>>> output = F.jvp(Net(), (x, y), (v, v))
|
||||
>>> print(output[0])
|
||||
[[ 2. 10.]
|
||||
[30. 68.]]
|
||||
>>> print(output[1])
|
||||
[[ 4. 13.]
|
||||
[28. 49.]]
|
||||
"""
|
||||
jvp_inner = _JvpInner()
|
||||
@ms_function
|
||||
def _wrap_container(*arg):
|
||||
args = arg[1:]
|
||||
vectors = arg[0]
|
||||
return jvp_inner(fn, vectors, *args)
|
||||
if not isinstance(inputs, (Tensor, tuple, list)):
|
||||
_raise_type_error()
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
return _wrap_container(v, *inputs)
|
||||
return _wrap_container(v, inputs)
|
||||
|
||||
|
||||
def vjp(fn, inputs, v):
|
||||
"""
|
||||
Compute the vector-jacobian-product of the given network.
|
||||
|
||||
Args:
|
||||
fn (Function or Cell): The function or net that takes Tensor inputs and returns a tensor or tuple of Tensors.
|
||||
inputs (Tensor or tuple or list): The inputs to `fn`.
|
||||
v (Tensor or tuple or list): The shape and type of v should be the same as outputs.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output and jvp.
|
||||
- netout(Tensors or Tuple of Tensors), the output of "fn(inputs)".
|
||||
- vjp(Tensors or Tuple of Tensors), the result of the dot product.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input is not a tensor or tuple or list of tensors.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> from mindspore import Tensor
|
||||
>>> class Net(nn.Cell):
|
||||
... def construct(self, x, y):
|
||||
... return x**3 + y
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
>>> output = F.vjp(Net(), (x, y), v)
|
||||
>>> print(output[0])
|
||||
[[ 2. 10.]
|
||||
[30. 68.]]
|
||||
>>> print(output[1])
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 3.00000000e+00, 1.20000000e+01],
|
||||
[ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 1.00000000e+00, 1.00000000e+00],
|
||||
[ 1.00000000e+00, 1.00000000e+00]]))
|
||||
"""
|
||||
vjp_inner = _VjpInner()
|
||||
@ms_function
|
||||
def wrap_container(*arg):
|
||||
args = arg[:-1]
|
||||
vectors = arg[-1]
|
||||
return vjp_inner(fn, *args, vectors)
|
||||
if not isinstance(inputs, (Tensor, tuple, list)):
|
||||
_raise_type_error()
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
return wrap_container(*inputs, v)
|
||||
return wrap_container(inputs, v)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _raise_type_error():
|
||||
raise TypeError("The inputs type should be a Tensor, tuple or list of Tensor.")
|
||||
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
||||
list_getitem = Primitive('list_getitem')
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test function jvp in graph mode"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import jvp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3, 2*x
|
||||
|
||||
|
||||
class MultipleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x + 3*y
|
||||
|
||||
|
||||
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_single_output_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_multiple_outputs_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_multiple_outputs_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_single_output_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_ms_function_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with ms_function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
|
||||
@ms_function
|
||||
def jvp_with_ms_function(inputs, vectors):
|
||||
output, jvp_grad = jvp(net, inputs, vectors)
|
||||
return output, jvp_grad
|
||||
|
||||
primal, grad = jvp_with_ms_function(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_input_function_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
def test_function(inputs):
|
||||
return inputs**3
|
||||
|
||||
primal, grad = jvp(test_function, x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_construct_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with Cell construct, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Net, self).__init__()
|
||||
self.net = network
|
||||
|
||||
def construct(self, inputs, vectors):
|
||||
net_out, jvp_out = jvp(self.net, inputs, vectors)
|
||||
return net_out, jvp_out
|
||||
|
||||
test_net = Net(SingleInputSingleOutputNet())
|
||||
primal, grad = test_net(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
|
@ -0,0 +1,310 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test function jvp in pynative mode """
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import jvp
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3, 2*x
|
||||
|
||||
|
||||
class MultipleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x + 3*y
|
||||
|
||||
|
||||
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_single_output_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_multiple_outputs_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_single_input_multiple_outputs_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
primal, grad = jvp(net, x, v)
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v, v))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_multiple_inputs_single_output_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
|
||||
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_ms_function_single_input_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with ms_function, single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
|
||||
@ms_function
|
||||
def jvp_with_ms_function(inputs, vectors):
|
||||
output, jvp_grad = jvp(net, inputs, vectors)
|
||||
return output, jvp_grad
|
||||
|
||||
primal, grad = jvp_with_ms_function(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_input_function_single_input_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with function, single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
def test_function(inputs):
|
||||
return inputs**3
|
||||
|
||||
primal, grad = jvp(test_function, x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jvp_construct_single_input_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with Cell construct, single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Net, self).__init__()
|
||||
self.net = network
|
||||
|
||||
def construct(self, inputs, vectors):
|
||||
net_out, jvp_out = jvp(self.net, inputs, vectors)
|
||||
return net_out, jvp_out
|
||||
|
||||
test_net = Net(SingleInputSingleOutputNet())
|
||||
primal, grad = test_net(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test vjp in graph mode"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import vjp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_single_input_graph():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = vjp(net, x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_multiple_inputs_default_v_graph():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = vjp(net, (x, y), (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_ms_function_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with ms_function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
|
||||
@ms_function
|
||||
def vjp_with_ms_function(inputs, vectors):
|
||||
output, vjp_grad = vjp(net, inputs, vectors)
|
||||
return output, vjp_grad
|
||||
|
||||
primal, grad = vjp_with_ms_function(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_input_function_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
def test_function(inputs):
|
||||
return inputs**3
|
||||
|
||||
primal, grad = vjp(test_function, x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_construct_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Net, self).__init__()
|
||||
self.net = network
|
||||
|
||||
def construct(self, inputs, vectors):
|
||||
net_out, vjp_out = vjp(self.net, inputs, vectors)
|
||||
return net_out, vjp_out
|
||||
|
||||
test_net = Net(SingleInputNet())
|
||||
primal, grad = test_net(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test vjp in pynative mode"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import vjp
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_single_input_graph():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = vjp(net, x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_multiple_inputs_default_v_graph():
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
primal, grad = vjp(net, (x, y), (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_ms_function_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with ms_function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
|
||||
@ms_function
|
||||
def vjp_with_ms_function(inputs, vectors):
|
||||
output, vjp_grad = vjp(net, inputs, vectors)
|
||||
return output, vjp_grad
|
||||
|
||||
primal, grad = vjp_with_ms_function(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_input_function_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
def test_function(inputs):
|
||||
return inputs**3
|
||||
|
||||
primal, grad = vjp(test_function, x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_construct_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Net, self).__init__()
|
||||
self.net = network
|
||||
|
||||
def construct(self, inputs, vectors):
|
||||
net_out, vjp_out = vjp(self.net, inputs, vectors)
|
||||
return net_out, vjp_out
|
||||
|
||||
test_net = Net(SingleInputNet())
|
||||
primal, grad = test_net(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test function jvp in graph mode"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import jvp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3, 2*x
|
||||
|
||||
|
||||
class MultipleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x + 3*y
|
||||
|
||||
|
||||
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
def test_jvp_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_single_input_single_output_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_single_input_multiple_outputs_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_single_input_multiple_outputs_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
jvp(net, (x, y), (v, v))
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_single_output_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
jvp(net, (x, y), (v1, v2))
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
jvp(net, (x, y), (v, v))
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and custom v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
jvp(net, (x, y), (v1, v2))
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test function jvp in pynative mode """
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import jvp
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3, 2*x
|
||||
|
||||
|
||||
class MultipleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x + 3*y
|
||||
|
||||
|
||||
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
def test_jvp_single_input_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_single_input_single_output_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, single output and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputSingleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_single_input_multiple_outputs_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_single_input_multiple_outputs_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with single input, multiple outputs and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = SingleInputMultipleOutputNet()
|
||||
jvp(net, x, v)
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
jvp(net, (x, y), (v, v))
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, multiple outputs and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
jvp(net, (x, y), (v1, v2))
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
jvp(net, (x, y), (v, v))
|
||||
|
||||
|
||||
def test_jvp_multiple_inputs_single_output_custom_v_pynative():
|
||||
"""
|
||||
Features: Function jvp
|
||||
Description: Test jvp with multiple inputs, single output and custom v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
jvp(net, (x, y), (v1, v2))
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test vjp in graph mode"""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import vjp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
def test_vjp_single_input_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
vjp(net, x, v)
|
||||
|
||||
|
||||
def test_vjp_multiple_inputs_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with multiple input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
vjp(net, (x, y), (v, v))
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test vjp in pynative mode"""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import vjp
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class SingleInputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
|
||||
|
||||
class MultipleInputsOutputNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return 2*x, y**3
|
||||
|
||||
|
||||
def test_vjp_single_input_pynative():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
vjp(net, x, v)
|
||||
|
||||
|
||||
def test_vjp_multiple_inputs_default_v_pynative():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with multiple input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
vjp(net, (x, y), (v, v))
|
Loading…
Reference in New Issue